Theory Sorting_Par_Partition

theory Sorting_Par_Partition
imports Sorting_Setup Sorting_Guarded_Partition IICF_DS_Interval_List IICF_Shared_Lists IICF_DS_Array_Idxs Sorting_Sample_Partition
begin

(* TODO: Move near set_drop_conv *)
lemma set_take_conv_nth: "set (take m xs) = {xs!i| i. i<m  i<length xs}"
  by (auto 0 3 simp: in_set_conv_nth) 

lemma set_drop_conv_nth: "set (drop m xs) = {xs!i| i. im  i<length xs}" by (rule set_drop_conv)
      
(* TODO: Move *)  
lemma slice_eq_mset_whole_iff: 
  "slice_eq_mset 0 (length xs) xs' xs  mset xs' = mset xs"  
  "slice_eq_mset 0 (length xs') xs' xs  mset xs' = mset xs"  
  unfolding slice_eq_mset_def Misc.slice_def
  apply clarsimp
  apply (metis dual_order.refl mset_eq_length take_all_iff)
  apply clarsimp
  apply (metis dual_order.refl mset_eq_length take_all_iff)
  done


(* TODO: Move *)

definition nat_div_round_up :: "nat  nat  nat nres" where "nat_div_round_up p q  doN {
  ASSERT (q0);
  let r = p div q;
  if p mod q = 0 then 
    RETURN r 
  else doN {
    ASSERT (r+1p);
    RETURN (r+1)
  }
}"  

lemma nat_div_round_up_correct[refine_vcg]: "q0  nat_div_round_up p q   SPEC (λr. r*q  {p..<p+q})"
  unfolding nat_div_round_up_def
  apply refine_vcg
  apply (auto simp: algebra_simps modulo_nat_def)
  subgoal 
    by (metis less_add_same_cancel1 nle_le times_div_less_eq_dividend)
  subgoal
    by (metis div_le_dividend gr0_implies_Suc leI le_antisym less_Suc_eq_le mult_Suc not_add_less1)
  subgoal
    using dividend_less_times_div less_imp_le_nat by presburger
  done
  

sepref_register nat_div_round_up

sepref_def nat_div_round_up_impl is "uncurry nat_div_round_up" :: "(snat_assn' TYPE('l::len2))k *a (snat_assn' TYPE('l))k a snat_assn' TYPE('l)"
  unfolding nat_div_round_up_def
  apply (annot_snat_const "TYPE('l)")
  by sepref



(* TODO: Move *)

(* Copy-nth implementation for woarray_slice_assn *)

definition [llvm_inline]: "array_cp_nth (cp :: 'a::llvm_rep  'a llM) p i  doM {
  rarray_nth p i;
  rcp r;
  Mreturn r
}"

lemma sao_assn_cp_rl[vcg_rules]:
  fixes A :: "('a,'c::llvm_rep) dr_assn"
  fixes cp
  assumes copy_elem_rl[vcg_rules]: "a c. llvm_htriple (A a c) (cp c) (λr. A a c ** A a r)"
  shows "llvm_htriple 
    ((sao_assn A) xs p ∧* snat.assn i ii ∧* d(i < length xs  xs!iNone)) 
    (array_cp_nth cp p ii)
    (λri. A (the (xs!i)) ri ∧* (sao_assn A) (xs) p)"
  unfolding sao_assn_def array_cp_nth_def
  supply [simp] = lo_extract_elem
  by vcg

sepref_decl_op list_cp_get: nth :: "[λ(l,i). i<length l]f Alist_rel ×r nat_rel  A" .
  
lemma woarray_slice_cp_nth_hnr:
  assumes "GEN_ALGO cp (is_copy A)"
  shows "(uncurry (array_cp_nth cp), uncurry mop_list_cp_get)  (woarray_slice_assn A)k *a snat_assnk a A"
proof -

  from assms have copy_elem_rl[vcg_rules]: "a c. llvm_htriple (A a c) (cp c) (λr. A a c ** A a r)" 
    unfolding GEN_ALGO_def is_copy_def
    apply -
    apply (drule hfrefD; simp)
    apply (drule hn_refineD; simp)
    apply (erule htriple_ent_post[rotated])
    apply (simp add: sep_algebra_simps)
    done
  

  show ?thesis
    apply sepref_to_hoare
    unfolding woarray_slice_assn_def eoarray_slice_assn_def hr_comp_def in_snat_rel_conv_assn
    supply [dest] = list_rel_imp_same_length
    apply (clarsimp simp: refine_pw_simps some_list_rel_conv)
    by vcg
    
qed  

context
  notes [fcomp_norm_simps] = list_rel_id_simp
begin
  sepref_decl_impl (ismop) woarray_slice_cp_nth_hnr uses mop_list_cp_get.fref[of Id] .
end

section Abstract Algorithm

  

context weak_ordering begin  

  (* Partitioning the whole list *)

  abbreviation "gpartition_all_spec p xs xs' m  gpartition_spec 0 (length xs) p xs xs' m"

  lemma gpartition_all_spec_alt1: "gpartition_all_spec p xs xs' m  
      mset xs' = mset xs 
     mlength xs
     (i<m. xs'!i  p)  
     (i{m..<length xs}. xs'!i  p)
  "
    unfolding gpartition_spec_def
    by (auto simp: slice_eq_mset_whole_iff)
  
  lemma gpartition_all_spec_alt2: "gpartition_all_spec p xs xs' m 
    mset xs' = mset xs 
   mlength xs
   (xset (take m xs'). xp)
   (xset (drop m xs'). px)
  "  
    unfolding gpartition_all_spec_alt1
    by (fastforce simp: set_take_conv_nth set_drop_conv_nth dest: mset_eq_length)  
    
  lemma gpartition_spec_permute: "mset xs = mset xs1  gpartition_all_spec p xs xs' m = gpartition_all_spec p xs1 xs' m"  
    unfolding gpartition_all_spec_alt1
    by (auto dest: mset_eq_length)
    
    
end

subsection Step 1: Partition the Array

(*
  Step 1: partition the array, and keep track of the set of small/big indices
    in practice, the array will be partitioned into multiple slices, and the sets will be intervals
*)

(* An array together with sets of small/big elements *)
locale is_ppart = weak_ordering + 
  fixes p xs ss bs
  assumes complete: "ss  bs = {0..<length xs}"
  assumes disjoint: "ss  bs = {}"
  assumes ss_in_P1: "iss  xs!i  p"
  assumes bs_in_P2: "ibs  p  xs!i"
  


context weak_ordering begin
  
  definition "ppart_spec p xs xs' ss bs  mset xs' = mset xs  is_ppart () (<) p xs' ss bs"  
    
  definition "ppart_SPEC p n xs  doN { ASSERT(n=length xs); SPEC (λ(xs',ss,bs). ppart_spec p xs xs' ss bs) }"
  
  (* For paper *)
  lemma "ppart_SPEC p (length xs) xs = SPEC (λ(xs',ss,bs). 
    mset xs' = mset xs
   ss  bs = {0..<length xs'}  ss  bs = {}
   (iss. xs'!i  p)  (ibs. p  xs'!i)
  )"
    unfolding ppart_SPEC_def ppart_spec_def is_ppart_def is_ppart_axioms_def
    by (auto simp: weak_ordering_axioms)
    
  
  
  
  subsection Step 2: compute mid-index, filter misplaced indexes

  (*
    Step 2: compute position of bound (first position of big element)
  *)  
  
  definition ppart_mpos :: "nat set  nat" where "ppart_mpos ss  card ss"
  definition ppart_filter :: "nat  nat set  nat set  nat set × nat set" where "ppart_filter m ss bs  ({iss. mi},{ibs. i<m})"
end  
  
(*
  Number of misplaced elements small and big elements is the same.
  Lemma is generalized for sets of indexes, no matter how they are formed.
*)
locale misplaced_elements =
  fixes ss bs n
  assumes SSU: "ss  bs = {0..<n}" "ss  bs = {}" 
  fixes m ss1 ss2 bs1 bs2
  assumes m_def: "m = card ss"
  assumes ss1_def: "ss1 = {iss. i<m}"
  assumes ss2_def: "ss2 = {iss. mi}"
  assumes bs1_def: "bs1 = {ibs. mi}"
  assumes bs2_def: "bs2 = {ibs. i<m}"
begin

  lemma finiteIs1[simp, intro!]:
    "finite ss" "finite bs"
    using SSU
    by (metis finite_Un finite_atLeastLessThan)+
    
  lemma finiteIs2[simp, intro!]:
    "finite ss1" "finite ss2" "finite bs1" "finite bs2"
    unfolding ss1_def ss2_def bs1_def bs2_def by auto

  lemma m_le_n: "mn" 
  proof - 
    have "card ss + card bs = n"
      using SSU
      by (simp add: card_Un_Int)
      
    thus "mn" unfolding m_def by auto 
  qed              

  lemma ss_split: "ss = ss1  ss2"
    and ss_dj: "ss1  ss2 = {}"
    unfolding ss1_def ss2_def 
    by auto
    
  lemma bs_split: "bs = bs1  bs2"
    and bs_dj: "bs1  bs2 = {}"
    unfolding bs1_def bs2_def 
    by auto

  lemma low_range_split: "{0..<m} = ss1  bs2"
    and low_range_dj: "ss1  bs2 = {}" 
    unfolding ss1_def bs2_def using SSU m_le_n
    by auto
    
  lemma high_range_split: "{m..<n} = bs1  ss2"
    and high_range_dj: "bs1  ss2 = {}" 
    unfolding bs1_def ss2_def using SSU
    by auto
    
  lemma same_djs: 
    "ss1  bs1 = {}"  
    "ss2  bs2 = {}"  
    unfolding bs1_def bs2_def ss1_def ss2_def
    by auto
    
  lemma in_range: 
    "ss1  {0..<n}"          
    "ss2  {0..<n}"          
    "bs1  {0..<n}"          
    "bs2  {0..<n}"
    using SSU
    unfolding bs1_def bs2_def ss1_def ss2_def
    by auto
    
    
  lemma misplaced_same_card:
    shows "card ss2 = card bs2"
  proof -
    from ss_split ss_dj have 1: "card ss1 + card ss2 = m" by (simp add: card_Un_disjoint m_def)
    also from low_range_split[symmetric] have "m = card (ss1  bs2)" by simp
    also from low_range_dj have " = card ss1 + card bs2"   
      by (simp add: card_Un_disjoint)
    finally show ?thesis by simp  
  qed
end  
  
  
locale ppar_step2 = is_ppart
begin  
  abbreviation "m  ppart_mpos ss"
  
  
  definition "ss1 = {iss. i<m}"
  definition "ss2 = fst (ppart_filter m ss bs)"
  
  definition "bs1 = {ibs. mi}"
  definition "bs2 = snd (ppart_filter m ss bs)"


  sublocale misplaced: misplaced_elements ss bs "length xs" m ss1 ss2 bs1 bs2 
    apply unfold_locales
    unfolding ppart_mpos_def ppart_filter_def ss1_def ss2_def bs1_def bs2_def
    using complete disjoint
    by auto

  (*
    Same number of misplaced small and big indexes
  *)  
  thm misplaced.misplaced_same_card
  
  (*
    All other indexes are well-placed
  *)  
  lemma low_nbs2_well_placed: assumes "i<m" "ibs2" shows "xs!i  p"
  proof -
    from assms misplaced.low_range_split have "iss1" by fastforce
    with misplaced.ss_split have "iss" by auto
    with ss_in_P1 show ?thesis by auto
  qed  
    
  lemma high_nss2_well_placed: assumes "mi" "i<length xs" "iss2" shows "p  xs!i"
  proof -
    from assms misplaced.high_range_split have "ibs1" by fastforce
    with misplaced.bs_split have "ibs" by auto
    with bs_in_P2 show ?thesis by auto
  qed  

end  
  
subsection Step 3: compute swaps
  
  
locale swap_spec_pre = 
  fixes src dst :: "nat set" and xs :: "'a list"
  assumes src_dst_dj: "src  dst = {}" 
  assumes src_ss: "src  {0..<length xs}"
  assumes dst_ss: "dst  {0..<length xs}"
  assumes card_eq: "card src = card dst"
begin
  lemma finite_src[simp,intro!]: "finite src" using src_ss by (blast intro: finite_subset)
  lemma finite_dst[simp,intro!]: "finite dst" using dst_ss by (blast intro: finite_subset)
end  

locale swap_spec = swap_spec_pre + fixes xs' :: "'a list"
  assumes elems_outside: "isrc  dst  i<length xs  xs'!i=xs!i"
  assumes elems_src: "isrc  jdst. xs'!i=xs!j"
  assumes elems_dst: "idst  jsrc. xs'!i=xs!j"
  assumes permut: "mset xs' = mset xs"
begin
  lemma length_xs'_eq[simp]: "length xs' = length xs"
    using mset_eq_length[OF permut] by blast

end  
  
(* For presentation in paper *)
lemma "swap_spec_pre src dst xs  
  src  dst = {}  src  dst  {0..<length xs}  card src = card dst
  "
  unfolding swap_spec_pre_def
  by blast

lemma "swap_spec src dst xs xs'  
    swap_spec_pre src dst xs
   mset xs' = mset xs  (i. isrc  dst  i<length xs  xs'!i=xs!i)
   (isrc. jdst. xs'!i=xs!j)  (jdst. isrc. xs'!j=xs!i)
  "
  unfolding swap_spec_def swap_spec_axioms_def
  by blast




context swap_spec_pre begin

  lemma swap_spec_refl: 
    assumes [simp]: "src={}"
    shows "swap_spec src dst xs xs"
    apply unfold_locales
    using card_eq
    by auto


end

definition "swap_SPEC ss bs xs  do { ASSERT (swap_spec_pre ss bs xs); SPEC (swap_spec ss bs xs) }"


(* Sanity check lemma *)
lemma swap_spec_exists:
  assumes "swap_spec_pre src dst xs"
  shows "xs'. swap_spec src dst xs xs'"
  using assms
proof (induction "card src" arbitrary: src dst)
  case 0
  then interpret swap_spec_pre src dst xs by simp
  
  from 0 card_eq have [simp]: "src={}" "dst={}" by auto
  
  show ?case 
    apply (rule exI[where x=xs])
    apply unfold_locales
    by auto
  
next
  case (Suc n)
  then interpret swap_spec_pre src dst xs by simp

  from Suc n = card src[symmetric] card_eq obtain i j src' dst' where 
    [simp]: "src=insert i src'" "dst=insert j dst'" 
    and NI: "isrc'" "jdst'" 
    and CARD: "card src' = n" "card dst' = n"
    by (auto simp: card_Suc_eq_finite)
    
  have [simp]: "i<length xs" "j<length xs" using src_ss dst_ss by auto
    
  have "swap_spec_pre src' dst' xs"
    apply unfold_locales
    using src_dst_dj src_ss dst_ss card_eq CARD
    by auto
  with Suc.hyps(1)[of src' dst'] obtain xs' where "swap_spec src' dst' xs xs'" 
    using CARD by blast
  then interpret IH: swap_spec src' dst' xs xs' .
    
  have "swap_spec src dst xs (swap xs' i j)"
    apply unfold_locales
    subgoal for k by (auto simp: IH.elems_outside)
    subgoal for k
      apply (cases "k=i"; simp)
      subgoal
        using IH.elems_outside IH.elems_src NI(2) j < length xs by blast
      subgoal
        by (metis IH.elems_src dst = insert j dst' src = insert i src' disjoint_iff insertCI src_dst_dj swap_indep)
      done
    subgoal for k
      apply (cases "k=j"; simp)
      subgoal
        using IH.elems_dst IH.elems_outside NI(1) i < length xs by blast
      subgoal
        by (metis IH.elems_dst dst = insert j dst' src = insert i src' disjoint_iff insertCI src_dst_dj swap_indep)
      done
    subgoal
      by (simp add: IH.permut)
    done  
  thus ?case ..    
qed      
  
  
context ppar_step2 begin    
  (*
    ss2 and bs2 satisfy precondition for swapping
  *)
  lemma swap_spec_pre: "swap_spec_pre ss2 bs2 xs"
    apply unfold_locales
    using misplaced.same_djs misplaced.in_range
    by (auto simp: misplaced.misplaced_same_card)
    
end    
  
locale ppar_step3 = ppar_step2 + swap_spec ss2 bs2 xs
begin  

  lemma "mset xs' = mset xs" by (rule permut)

  
  lemma elems_ss1: "iss1  xs'!i = xs!i"
    using elems_outside[of i] misplaced.in_range misplaced.low_range_dj misplaced.ss_dj
    by fastforce

  lemma elems_bs1: "ibs1  xs'!i = xs!i"
    using elems_outside[of i] misplaced.in_range misplaced.high_range_dj misplaced.bs_dj
    by fastforce
        
  lemma partitioned1: assumes "i<m" shows "xs'!i  p" 
  proof -
    have "iss1  bs2"
      using assms misplaced.low_range_split by fastforce
    then show ?thesis 
      apply rule
      subgoal 
        apply (simp add: elems_ss1)
        using misplaced.ss_split ss_in_P1 by auto
      subgoal 
        by (metis UnCI elems_dst misplaced.ss_split ss_in_P1)
      done
      
  qed  
      
  lemma partitioned2: assumes "mi" "i<length xs" shows "p  xs'!i"
  proof -
    have "iss2  bs1"
      using assms misplaced.high_range_split by fastforce
    then show ?thesis 
      apply rule
      subgoal
        by (metis bs_in_P2 dual_order.refl elems_src in_mono le_sup_iff misplaced.bs_split)
      subgoal 
        apply (simp add: elems_bs1)
        using misplaced.bs_split bs_in_P2 by auto
      done
      
  qed
    

  lemma is_valid_partition: "gpartition_all_spec p xs xs' m"
    unfolding gpartition_all_spec_alt1
    by (auto simp: permut misplaced.m_le_n partitioned1 partitioned2)

end


  


subsection The Algorithm

context weak_ordering begin
  definition "ppart1 p n xs  do {
    (xs,ss,bs)  ppart_SPEC p n xs;
    let m = ppart_mpos ss;

    let (ss2,bs2) = ppart_filter m ss bs;
    
    xs  swap_SPEC ss2 bs2 xs;
  
    RETURN (m,xs)
  }"

  (* For presentation in paper *)
  lemma "ppart1 p (length xs) xs = doN {
    (xs,ss,bs)  ppart_SPEC p (length xs) xs;
    let m = card ss;
    let (ss,bs) = ({iss. mi},{ibs. i<m});
    xs  swap_SPEC ss bs xs;
    RETURN (m,xs)
  }"
    unfolding ppart1_def ppart_mpos_def ppart_filter_def
    by simp
  
  
  
  lemma ppart1_valid_partitition: "n=length xs  ppart1 p n xs  SPEC (λ(m,xs'). gpartition_all_spec p xs xs' m)"
    unfolding ppart1_def ppart_spec_def swap_SPEC_def ppart_SPEC_def
    apply refine_vcg
    apply clarsimp_all
  proof -
    fix xs1 ss bs ss2X bs2X
    assume 
      pp_flt: "ppart_filter (ppart_mpos ss) ss bs = (ss2X, bs2X)" and
      [simp]: "mset xs1 = mset xs" and
      "is_ppart () (<) p xs1 ss bs"
  
      
    interpret is_ppart "()" "(<)" p xs1 ss bs by fact  
    interpret ppar_step2 "()" "(<)" p xs1 ss bs by unfold_locales
    
    have [simp]: "ss2X = ss2" "bs2X = bs2" unfolding ss2_def bs2_def using pp_flt
      by auto

    show "swap_spec_pre ss2X bs2X xs1" using swap_spec_pre by simp 
      
    fix xs'  
    assume sspec: "swap_spec ss2X bs2X xs1 xs'"
  
          
    interpret swap_spec ss2 bs2 xs1 xs' 
      using sspec by simp
      
    interpret ppar_step3 "()" "(<)" p xs1 ss bs xs'
      by unfold_locales 
      
    show "gpartition_all_spec p xs xs' m"  
      using mset xs1 = mset xs is_valid_partition gpartition_spec_permute by blast
  qed    
  
  (* For presentation in paper *)
  lemma "ppart1 p (length xs) xs  SPEC (λ(m, xs'). 
      mset xs' = mset xs  m  length xs 
     (i<m. xs' ! i  p)  (i{m..<length xs}. p  xs' ! i))"
    using ppart1_valid_partitition[OF refl, of p xs] 
    unfolding gpartition_all_spec_alt1 by simp
    
  
end
  
section Refinement to Parallel Partitioning  

context weak_ordering begin
  
  (*
    Parallel partitioning with interval, abstract level
  *)
  
  lemma ppart_spec_merge:
    assumes "ppart_spec p xs1 xs1' ss1 bs1"
    assumes "ppart_spec p xs2 xs2' ss2 bs2"
    defines "ss2'  (+)(length xs1)`ss2"
    defines "bs2'  (+)(length xs1)`bs2"
    shows "ppart_spec p (xs1@xs2) (xs1'@xs2') (ss1  ss2') (bs1  bs2')"
    using assms(1,2)
    unfolding ppart_spec_def 
  proof clarsimp
    assume "mset xs1' = mset xs1" "mset xs2' = mset xs2"
    hence [simp]: "length xs1' = length xs1" "length xs2' = length xs2"
      by (auto dest: mset_eq_length)
  
  
    assume "is_ppart () (<) p xs1' ss1 bs1" "is_ppart () (<) p xs2' ss2 bs2"

    then interpret p1: is_ppart "()" "(<)" p xs1' ss1 bs1 + p2: is_ppart "()" "(<)" p xs2' ss2 bs2 .
    
    from p2.complete have sb2_Un: "ss2'  bs2' = {length xs1' ..< length (xs1'@xs2')}"
      unfolding ss2'_def bs2'_def
      by (auto simp flip: image_Un)
      
    from p2.disjoint have sb2_dj: "ss2'  bs2' = {}"
      unfolding ss2'_def bs2'_def
      by auto
      
    have sb'_dj: "(ss1bs1)  (ss2'bs2') = {}"  
      apply (simp add: p1.complete)
      unfolding ss2'_def bs2'_def
      by auto
      
    
    from p1.complete have ss1_in_range: "i. iss1  i<length xs1" by auto 
    from p1.complete have bs1_in_range: "i. ibs1  i<length xs1" by auto 

    from sb2_Un have ss2'_in_range: "i. iss2'  length xs1i" by auto 
    from sb2_Un have bs2'_in_range: "i. ibs2'  length xs1i" by auto 
    
          
    
    show "is_ppart () (<) p (xs1' @ xs2') (ss1  ss2') (bs1  bs2')"
      apply unfold_locales
      subgoal
        apply (rule HOL.trans[where s="(ss1bs1)(ss2'  bs2')"])
        apply blast
        apply (simp add: sb2_Un p1.complete) by auto
      subgoal
        using p1.disjoint p2.disjoint sb2_dj sb'_dj by blast     
      subgoal for i
        apply (auto simp: nth_append p1.ss_in_P1 ss1_in_range dest: ss2'_in_range)
        apply (auto simp: ss2'_def p2.ss_in_P1)
        done
      subgoal for i
        apply (auto simp: nth_append p1.bs_in_P2 bs1_in_range dest: bs2'_in_range)
        apply (auto simp: bs2'_def p2.bs_in_P2)
        done
      done    
  qed      
        
  
  definition "gpartition_slices d p len xs = RECT (λgpartition_slices (len,xs). do {
    ASSERT (d>0);
    ASSERT (len = length xs);
    if (len  d) then do {
      (xs,ss,bs)  ppart_SPEC p len xs;
      RETURN (xs,ss,bs)
    } else do {
      let si = len - d;
      (((ss1,bs1),(ss2,bs2)),xs)  WITH_SPLIT si xs (λxs1 xs2. do {
        ((xs1,ivs1),(xs2,ivs2))  nres_par (gpartition_slices) (ppart_SPEC p d) (si,xs1) xs2;
        RETURN (((ivs1,ivs2),xs1,xs2))
      });
      
      ASSERT(iv_incr_elems_abs_bound ss2 si len);
      ss2mop_set_incr_elems si ss2;
      
      ASSERT(iv_incr_elems_abs_bound bs2 si len);
      bs2mop_set_incr_elems si bs2;
      
      ssmop_set_union_disj ss1 ss2;
      bsmop_set_union_disj bs1 bs2;
      
      RETURN (xs,ss,bs)
    }
  }) (len,xs)"
  
  lemma ppart_spec_imp_len_eq: "ppart_spec p xs xs' ss bs  length xs' = length xs"
    unfolding ppart_spec_def
    by (auto dest: mset_eq_length)
  
  lemma ppart_spec_len_bound: 
    assumes "ppart_spec p xs xs' ss bs" 
    shows "ss{0..<length xs}" "bs{0..<length xs}"  
    using assms unfolding ppart_spec_def is_ppart_def is_ppart_axioms_def 
    by (auto dest!: mset_eq_length)
    
  lemma ppart_spec_lb_imp_disj:
    assumes "ppart_spec p xs xs' ss bs"
    shows "ss  (+)(length xs)`ss' = {}" "bs  (+)(length xs)`bs' = {}"
    using ppart_spec_len_bound[OF assms]
    by auto
    
  
  lemma gpartition_slices_refine_aux: "d>0  gpartition_slices d p len xs  ppart_SPEC p len xs"
    unfolding gpartition_slices_def ppart_SPEC_def
    
    thm RECT_rule
    
    apply refine_vcg
    
    apply (refine_vcg RECT_rule[
      where 
            V="measure (λ(_,xs). length xs)" 
        and pre="λ(len,xs). len=length xs" 
        and M="λ(d,xs). Refine_Basic.SPEC (λ(xs', ss, bs). ppart_spec p xs xs' ss bs)", 
      THEN order_trans])
    apply (all (thin_tac "RECT _ = _")?)
    
    subgoal by simp  
    subgoal by auto
    subgoal by simp  
    subgoal by clarsimp
    subgoal by simp  
    
    apply (drule sym[of "length _" "_ - _"]) (* Turn around problematic premise for simplifier *)
    
    apply (rule order_trans)
    apply rprems 
    
    subgoal by force
    subgoal by auto
    
    apply refine_vcg
    subgoal by (auto dest: ppart_spec_imp_len_eq)
    subgoal by (auto dest: ppart_spec_imp_len_eq)
    subgoal by (auto dest: ppart_spec_imp_len_eq)
    subgoal
      apply clarsimp
      apply (rule iv_incr_elems_abs_bound_card_boundI) 
      apply (erule ppart_spec_len_bound)
      by simp
    subgoal 
      apply clarsimp
      apply (rule iv_incr_elems_abs_bound_card_boundI) 
      apply (erule ppart_spec_len_bound)
      by simp
    subgoal by (simp add: ppart_spec_lb_imp_disj)
    subgoal by (simp add: ppart_spec_lb_imp_disj)
    subgoal by (auto intro: ppart_spec_merge)
    subgoal by auto
    done    
    
    
  lemma gpartition_slices_refine: " (xs,xs')Idlist_rel; d>0  gpartition_slices d p n xs  Id (ppart_SPEC p n xs')"
    by (auto simp: gpartition_slices_refine_aux)
  
    
end
    

section Refinement to Parallel Swap

    
locale swap_opt_spec_pre = 
  fixes src dst :: "nat set" and xs :: "'a option list"
  assumes src_dst_dj: "src  dst = {}" 
  assumes src_ss: "src  sl_indexes' xs"
  assumes dst_ss: "dst  sl_indexes' xs"
  assumes card_eq: "card src = card dst"
begin  

  lemma finite_src[simp,intro!]: "finite src" apply (rule finite_subset[OF src_ss]) by auto
  lemma finite_dst[simp,intro!]: "finite dst" apply (rule finite_subset[OF dst_ss]) by auto

  lemma idxs_in_bounds: "srcdst  sl_indexes' xs" using src_ss dst_ss by auto

end
  
locale swap_opt_spec = swap_opt_spec_pre +  
  fixes xs'
  assumes struct_eq[simp]: "sl_struct xs' = sl_struct xs"
  assumes elems_outside: "isrc  dst  isl_indexes' xs  sl_get xs' i=sl_get xs i"
  assumes elems_src: "isrc  jdst. sl_get xs' i = sl_get xs j"
  assumes elems_dst: "idst  jsrc. sl_get xs' i = sl_get xs j"
  assumes permut: "mset xs' = mset xs"
begin

  lemma length_eq[simp]: "length xs' = length xs" using mset_eq_length[OF permut] .




end  
  
  

locale swap_opt_spec_pre_split = swap_opt_spec_pre ss bs xs for ss bs xs + 
  fixes ss1 ss2 bs1 bs2
  assumes split_complete: "ss = ss1  ss2" "bs = bs1  bs2"
  assumes split_dj: "ss1  ss2 = {}" "bs1  bs2 = {}"
  assumes split_card_eq1: "card bs1 = card ss1"
  assumes ss1_ne: "ss1{}"
begin  
  lemma finites[simp,intro!]: 
    "finite ss1"
    "finite ss2"
    "finite bs1"
    "finite bs2"
    using finite_src finite_dst
    by (auto simp: split_complete)

  lemma card_ss_eq: "card ss = card ss1 + card ss2" using split_complete split_dj by (auto simp: card_Un_disjoint)
  lemma card_bs_eq: "card bs = card bs1 + card bs2" using split_complete split_dj by (auto simp: card_Un_disjoint)

  lemma split_card_eq2: "card bs2 = card ss2"
    by (metis card_bs_eq card_ss_eq diff_add_inverse card_eq split_card_eq1)

  lemmas split_card_eq = split_card_eq1 split_card_eq2
  
  
  sublocale p1: swap_opt_spec_pre ss1 bs1 "(sl_split (ss1bs1) xs)"
    apply unfold_locales
    using split_dj src_dst_dj idxs_in_bounds
    apply (auto simp: split_card_eq)
    subgoal using split_complete(1) split_complete(2) by blast
    subgoal by (auto simp: sl_indexes_split split_complete)
    subgoal by (auto simp: sl_indexes_split split_complete)
    done

  sublocale p2: swap_opt_spec_pre ss2 bs2 "(sl_split (-ss1-bs1) xs)"
    apply unfold_locales
    using split_dj src_dst_dj idxs_in_bounds
    apply (auto simp: split_card_eq)
    subgoal using split_complete(1) split_complete(2) by blast
    subgoal by (auto simp: sl_indexes_split split_complete)
    subgoal by (auto simp: sl_indexes_split split_complete)
    done
  
  lemma extreme:
    assumes "ss2={}"  
    shows "bs2={}" "ss1=ss" "bs1=bs"
    using assms split_complete split_dj split_card_eq
    by auto
    
  lemma idxs1_in_bounds: "ss1  bs1  sl_indexes' xs"  
    using dst_ss split_complete(1) split_complete(2) src_ss by blast

  lemma decreasing: "card ss2 < card ss"  
    using card_ss_eq ss1_ne by fastforce
    
    
  lemma join:
    assumes "swap_opt_spec ss1 bs1 (sl_split (ss1bs1) xs) xs1'"
    assumes "swap_opt_spec ss2 bs2 (sl_split (-ss1-bs1) xs) xs2'"
    shows "swap_opt_spec ss bs xs (sl_join xs1' xs2')"
  proof -
    interpret p1: swap_opt_spec ss1 bs1 "(sl_split (ss1bs1) xs)" xs1' by fact
    interpret p2: swap_opt_spec ss2 bs2 "(sl_split (-ss1-bs1) xs)" xs2' by fact
  
    
    have COMPAT[simp]: "sl_compat (sl_struct_split (ss1  bs1) (sl_struct xs)) (sl_struct_split (- ss1 - bs1) (sl_struct xs))"
      by (auto intro: sl_compat_splitI)
    
      
    have "mset xs = mset (sl_join (sl_split (ss1bs1) xs) (sl_split (-ss1-bs1) xs))"  
      using sl_join_split_eq[of "ss1bs1" xs]
      by simp
    also have " = mset (sl_split (ss1  bs1) xs) + mset (sl_split (- ss1 - bs1) xs) - replicate_mset (length xs) None"  
      by (simp add: mset_join_idxs_eq)
    finally have mset_xs_conv: "mset xs = mset (sl_split (ss1  bs1) xs) + mset (sl_split (- ss1 - bs1) xs) - replicate_mset (length xs) None" .
      
      
    show ?thesis
      apply unfold_locales
      subgoal
        using sl_struct_join_split[of "ss1bs1"]
        by auto
      subgoal
        apply simp
        apply (subst sl_get_join)
        by (auto dest: sl_indexes_lengthD simp: split_complete sl_indexes_split p2.elems_outside sl_get_split)
      subgoal for i
        apply (simp add: split_complete; safe)  
        subgoal
          by (metis Un_iff COMPAT p1.elems_src p1.src_ss p1.struct_eq p2.struct_eq sl_get_join1 sl_get_split sl_struct_split subset_iff)
        subgoal 
          apply (frule p2.elems_src; clarsimp)
          by (metis COMPAT IntD1 UnI2 p1.struct_eq p2.dst_ss p2.src_ss p2.struct_eq sl_get_join2 sl_get_split sl_indexes_split sl_struct_split subsetD)
        done
      subgoal
        apply (simp add: split_complete; safe)  
        subgoal by (metis COMPAT in_mono p1.dst_ss p1.elems_dst p1.struct_eq p2.struct_eq sl_get_join1 sl_get_split sl_struct_split sup_ge1)
        subgoal 
          apply (frule p2.elems_dst; clarsimp)
          by (metis COMPAT IntD1 Un_Int_eq(2) in_mono p1.struct_eq p2.dst_ss p2.src_ss p2.struct_eq sl_get_join2 sl_get_split sl_indexes_split sl_struct_split)
        done
      subgoal
        by (simp add: mset_join_idxs_eq mset_xs_conv p1.permut p2.permut)
      done  
  qed        

end
  

definition "swap_opt_SPEC ss bs xs  do {ASSERT (swap_opt_spec_pre ss bs xs); SPEC (swap_opt_spec ss bs xs)}"

definition "split_sets_eq_SPEC ss bs = do {
  ASSERT (ss{}  bs{}  finite ss  finite bs);
  SPEC (λ((ss1,ss2),(bs1,bs2)). 
    ss = ss1  ss2  bs = bs1  bs2
   ss1  ss2 = {}  bs1  bs2 = {}
   ss1  {}
   card ss1 = card bs1
  )
}"


lemma (in swap_opt_spec_pre) split_sets_eq_SPEC_swap_rl:
  shows "src{}  split_sets_eq_SPEC src dst  SPEC (λ((ss1,ss2),(bs1,bs2)). swap_opt_spec_pre_split src dst xs ss1 ss2 bs1 bs2)"
  unfolding split_sets_eq_SPEC_def
  apply refine_vcg
  subgoal using card_eq by force
  subgoal by simp
  subgoal by simp
  apply unfold_locales
  apply clarsimp_all
  done


definition "par_swap_aux ss bs xs  RECT (λpar_swap (ss,bs,xs). do {
  ASSERT (swap_opt_spec_pre ss bs xs);
    
  ((ss1,ss2),(bs1,bs2))  split_sets_eq_SPEC ss bs;
  
  if (ss2={}) then do {
    ASSERT (bs2={}  ss1=ss  bs1=bs);
    swap_opt_SPEC ss1 bs1 xs
  } else do {
    (_,xs)  WITH_IDXS (ss1bs1) xs (λxs1 xs2. do {
      (xs1,xs2)  nres_par (λ(ss1,bs1,xs1). (swap_opt_SPEC ss1 bs1 xs1)) (par_swap) (ss1,bs1,xs1) (ss2,bs2,xs2);
      RETURN ((),xs1,xs2)
    });
    RETURN xs
  }
}) (ss,bs,xs)"  



lemma par_swap_aux_correct:
  shows "ss{}  par_swap_aux ss bs xs  swap_opt_SPEC ss bs xs"
  unfolding par_swap_aux_def swap_opt_SPEC_def
  supply R = RECT_rule[where V="measure (card o fst)" and pre="λ(ss,bs,xs). ss{}  swap_opt_spec_pre ss bs xs"]
  apply refine_vcg
  apply (refine_vcg R[where M="λ(ss,bs,xs). SPEC (swap_opt_spec ss bs xs)", THEN order_trans])
  apply (all (thin_tac "RECT _ = _")?)
  subgoal by simp
  subgoal by simp
  subgoal by simp
  subgoal by simp
  
  apply (clarsimp)
  subgoal for par_swap ss bs xs proof goal_cases
    case 1
    
    note IH = _;_  par_swap _   _
    
    note [simp] = ss{}
    
    note SOS_PRE[simp] = swap_opt_spec_pre ss bs xs
    
    interpret swap_opt_spec_pre ss bs xs by fact
    
    show ?case 
      apply (rule split_sets_eq_SPEC_swap_rl[THEN order_trans])
      apply simp
      apply (rule refine_vcg)
      apply clarsimp
      subgoal for ss1 ss2 bs1 bs2 proof goal_cases
        case 1
        then interpret swap_opt_spec_pre_split ss bs xs ss1 ss2 bs1 bs2 .
        
        show ?thesis
          apply (refine_vcg)
          subgoal by (simp add: extreme)
          subgoal by (simp add: extreme)
          subgoal by (simp add: extreme)
          subgoal by (simp add: extreme)
          subgoal by (simp add: extreme)
          subgoal by (rule idxs1_in_bounds)
          subgoal
            using p1.swap_opt_spec_pre_axioms
            by clarsimp 
          apply clarsimp  
          subgoal for xs1'
            apply (rule IH[THEN order_trans])
            subgoal
              using p2.swap_opt_spec_pre_axioms
              by clarsimp 
            subgoal by (rule decreasing)
            apply clarsimp
            subgoal for xs2' proof goal_cases
              case 1
              then interpret 
                p1: swap_opt_spec ss1 bs1 "sl_split (ss1  bs1) xs" xs1' +
                p2: swap_opt_spec ss2 bs2 "sl_split (- ss1 - bs1) xs" xs2' by simp_all
              from 1 join have "swap_opt_spec ss bs xs (sl_join xs1' xs2')" by simp
              then show ?thesis
                by simp
            qed  
            done
          done
      qed
      done
  qed
  subgoal by simp
done  
            

definition "par_swap ss bs xs  do {
  if (ss={}) then RETURN xs
  else do {
    xs  mop_sl_of_list xs;
    xs  par_swap_aux ss bs xs;
    mop_list_of_sl xs
  }
}"



context swap_spec_pre begin

  lemma to_opt: "swap_opt_spec_pre src dst (sl_of_list xs)"
    apply unfold_locales
    using src_dst_dj src_ss dst_ss card_eq
    by auto
  


end



context swap_opt_spec begin

  lemma to_plain_complete:
    assumes [simp]: "xs = sl_of_list xs0"
    shows "sl_complete (sl_struct xs')"
    by simp


  lemma to_plain:
    assumes [simp]: "xs = sl_of_list xs0"
    shows "swap_spec src dst xs0 (list_of_sl xs')"
    apply unfold_locales
    subgoal by (simp add: src_dst_dj)
    subgoal using src_ss by simp
    subgoal using dst_ss by simp
    subgoal by (simp add: card_eq)
    subgoal for i using elems_outside[of i] by simp
    subgoal for i using elems_src[of i] by simp
    subgoal for i using elems_dst[of i] by simp
    subgoal using permut by (simp add: mset_of_list_permut)
    done

end
  
lemma par_swap_refine_aux: "par_swap ss bs xs  swap_SPEC ss bs xs"
  unfolding par_swap_def swap_SPEC_def
  apply refine_vcg
  apply (simp add: swap_spec_pre.swap_spec_refl)
  apply (rule order_trans[OF par_swap_aux_correct])
  apply simp
  apply (simp add: swap_opt_SPEC_def)
  apply refine_vcg
  apply (simp add: swap_spec_pre.to_opt)
  apply (simp add: swap_opt_spec.to_plain_complete)
  apply (simp add: swap_opt_spec.to_plain)
  done

lemma par_swap_refine: "(ss,ss')Idset_rel; (bs,bs')Idset_rel; (xs,xs')Idlist_rel 
   par_swap ss bs xs (Idlist_rel) (swap_SPEC ss' bs' xs')"
  by (auto simp: par_swap_refine_aux)
  

  
  
definition "swap_o_intv_aux  λ(l1,h1) (l2,h2) xs0. doN {
  ASSERT (l1h1  l2h2  h2-l2 = h1-l1);
  (xs,_,_)WHILEIT 
    (λ(xs,i1,i2). i1{l1..h1}  i2{l2..h2}  i1-l1 = i2-l2  swap_opt_spec {l1..<i1} {l2..<i2} xs0 xs)
    (λ(xs,i1,i2). i1<h1) 
    (λ(xs,i1,i2). doN {
      ASSERT(i1<h1  i2<h2);
      xsmop_slist_swap xs i1 i2;
      RETURN (xs,i1+1,i2+1)
    }) (xs0,l1,l2);
  RETURN xs
}"  

lemma swap_opt_spec_empty[simp]: "swap_opt_spec {} {} xs0 xs0"
  apply unfold_locales
  by auto

  
  
lemma swap_o_intv_aux_correct:
  assumes "l1h1" "l2h2"
  shows "swap_o_intv_aux (l1,h1) (l2,h2) xs0   swap_opt_SPEC {l1..<h1} {l2..<h2} xs0"
  unfolding swap_opt_SPEC_def
  apply refine_vcg
proof -  
  assume "swap_opt_spec_pre {l1..<h1} {l2..<h2} xs0"

  then interpret swap_opt_spec_pre "{l1..<h1}" "{l2..<h2}" xs0 .
  
  {
    fix xs i1 i2
    assume B1: "l1  i1" "i1 < h1"
    assume B2: "l2  i2" "i2 < h2" 
    
    assume "swap_opt_spec {l1..<i1} {l2..<i2} xs0 xs"
    then interpret this: swap_opt_spec "{l1..<i1}" "{l2..<i2}" xs0 xs .
    
    have [simp]: "i1  sl_indexes' xs0" "i2  sl_indexes' xs0" using B1 B2 src_ss dst_ss by auto
    
    
    have "swap_opt_spec {l1..<Suc i1} {l2..<Suc i2} xs0 (swap xs i1 i2)"
      apply unfold_locales
      subgoal using B1 B2 src_dst_dj by fastforce
      subgoal using B1 B2 src_ss by fastforce
      subgoal using B1 B2 dst_ss by fastforce
      subgoal using B1 B2 this.card_eq by (metis Suc_diff_le card_atLeastLessThan)
      subgoal using B1 B2 src_ss dst_ss by fastforce
      subgoal using B1 B2 this.elems_outside by (auto simp: sl_get_swap_other) 
      subgoal for j
        apply (subgoal_tac "ji2  jsl_indexes' xs0")
        apply clarsimp
        apply (cases "j<i1")
        subgoal using this.elems_src[of j] by (auto simp: sl_get_swap_iff)
        subgoal using B1 B2 src_dst_dj by (auto simp: this.elems_outside sl_get_swap_iff)
        using B1 B2 src_dst_dj src_ss by auto  
      subgoal for j
        apply (subgoal_tac "ji1  jsl_indexes' xs0")
        apply clarsimp
        apply (cases "j<i2")
        subgoal using this.elems_dst[of j] by (auto simp: sl_get_swap_iff)
        subgoal using B1 B2 src_dst_dj by (auto simp: this.elems_outside sl_get_swap_iff)
        using B1 B2 src_dst_dj dst_ss by auto  
      by (metis i1  sl_indexes' xs0 i2  sl_indexes' xs0 sl_indexes_lengthD sl_struct_length swap_multiset this.length_eq this.permut)  
  } note aux=this
  
  show "swap_o_intv_aux (l1, h1) (l2, h2) xs0  SPEC (swap_opt_spec {l1..<h1} {l2..<h2} xs0)"
    unfolding swap_o_intv_aux_def
    apply (refine_vcg WHILEIT_rule[where R="measure (λ(_,i,_). h1-i)"])
    apply (clarsimp_all simp: assms swap_opt_spec.struct_eq aux)
    subgoal using card_eq by simp
    subgoal by auto
    subgoal using src_ss by fastforce
    subgoal using dst_ss by fastforce
    subgoal using Suc_diff_le by presburger
    subgoal using diff_less_mono2 lessI by presburger
    subgoal for xs i2 using assms(2) eq_diff_iff by blast
    done
    
qed

(* Transform this to intervals *)

(* As we, technically, allow intervals to be empty here, we guard against empty intervals. *)
definition "swap_o_intv  λ(l1,h1) (l2,h2) xs. 
  if l1h1  l2h2 then swap_o_intv_aux (l1,h1) (l2,h2) xs else RETURN xs"

lemma swap_opt_spec_pre_empty1_conv: "swap_opt_spec_pre {} bs xs0  bs={}"  
  unfolding swap_opt_spec_pre_def 
  by (auto dest: finite_subset[OF _ sl_indexes_finite])
  
lemma swap_opt_spec_pre_empty2_conv: "swap_opt_spec_pre ss {} xs0  ss={}"  
  unfolding swap_opt_spec_pre_def 
  by (auto dest: finite_subset[OF _ sl_indexes_finite])
  
  
lemma swap_o_intv_correct:
  shows "swap_o_intv (l1,h1) (l2,h2) xs0   swap_opt_SPEC {l1..<h1} {l2..<h2} xs0"
  unfolding swap_o_intv_def
  apply (cases "l1  h1"; cases "l2  h2"; simp)
  apply (simp add: swap_o_intv_aux_correct)
  unfolding swap_opt_SPEC_def
  apply (all refine_vcg)
  apply (simp_all add: swap_opt_spec_pre_empty1_conv swap_opt_spec_pre_empty2_conv)
  done
  
  
  
lemma swap_o_intv_refine: "(swap_o_intv, swap_opt_SPEC)  iv_rel  iv_rel  Id  Idnres_rel"
  by (clarsimp simp: iv_rel_def in_br_conv iv_α_def intro!: nres_relI swap_o_intv_correct)
  
context
  notes [fcomp_norm_simps] = iv_assn_def[symmetric]
begin  
  
  private abbreviation (input) "ivA  iv_assn_raw :: _  size_t word × _  _"

  sepref_definition swap_o_intv_impl [llvm_code] is "uncurry2 swap_o_intv" 
    :: "[λ_. True]c ivAk *a ivAk *a (oarray_idxs_assn A)d  oarray_idxs_assn A [λ((_,_),xsi) ri. ri=xsi]c"
    unfolding swap_o_intv_def swap_o_intv_aux_def iv_assn_raw_def
    apply (annot_snat_const "TYPE(size_t)")
    by sepref
  
  lemmas swap_o_intv_oidxs_hnr[sepref_fr_rules] = swap_o_intv_impl.refine[FCOMP swap_o_intv_refine]  
  
end  
  
    
section Wrap-Up: Parallel Algorithm on Abstract Level
  
context weak_ordering begin

definition "ppart2 d p n xs  do {
  (xs,ss,bs)  gpartition_slices d p n xs;
  m  mop_set_card ss;

  let (ss2,bs2) = ppart_filter m ss bs;
  
  xs  par_swap ss2 bs2 xs;

  RETURN (m,xs)
}"
  

lemma ppart2_refine: " 0<d; (xs,xs')Idlist_rel; n=length xs'   ppart2 d p n xs  Id (ppart1 p n xs')"
  unfolding ppart2_def ppart1_def ppart_mpos_def
  apply (refine_rcg gpartition_slices_refine par_swap_refine)
  by auto

  
lemma ppart2_refine_p_all_spec: 
  assumes "0<d" "n=length xs"  
  shows "ppart2 d p n xs  (SPEC(λ(m,xs'). gpartition_all_spec p xs xs' m))"
proof -  
  note ppart2_refine
  also note ppart1_valid_partitition
  finally show ?thesis using assms by simp
qed

definition "ppart_partition_pivot_par d xs n  doN {
  ASSERT (n=length xs  4length xs);
  
  xs  move_pivot_to_first_sample xs n;

  p  mop_list_cp_get xs 0;
    
  (m,xs)  WITH_SPLIT 1 xs (λxs1 xs2. doN {
    ASSERT (n>0);
    (m,xs2)  ppart2 d p (n-1) xs2;
    RETURN (m,xs1,xs2)
  });
  
  ASSERT (m+1  n);
  
  let m'=m+1;
  
  if m'=n then doN {
    xsmop_list_swap xs 0 m;
    RETURN (xs,m)
  } else RETURN (xs,m')
}"

lemma move_pivot_to_first_sample_weak_correct: 
  "n=length xs; length xs  4  move_pivot_to_first_sample xs n  SPEC (λxs'. mset xs' = mset xs)"
  apply refine_vcg by simp


lemma mset_eq_sum_mset_lengthD: "mset xs = mset xs1 + mset xs2  length xs = length xs1+length xs2"
  by (metis size_mset size_union)


lemma slice_LT_by_nthI:
  assumes "lm" "mh" "hlength xs"
  assumes "i{l..<m}. xs!ip" "i{m..<h}. pxs!i"
  shows "slice_LT () (Misc.slice l m xs) (Misc.slice m h xs)"  
  using assms
  unfolding slice_LT_def Ball_def in_set_conv_nth
  by (force simp: slice_nth intro: trans[of _ p]) 
  
  
lemma ppart_partition_pivot_par_refines_part3: "0<d; n=length xs  ppart_partition_pivot_par d xs n  partition3_spec xs 0 n"
  unfolding partition3_spec_def ppart_partition_pivot_par_def
  apply (refine_vcg ppart2_refine_p_all_spec move_pivot_to_first_sample_weak_correct)
  apply (clarsimp_all dest!: sym[of "_+_" "mset _"] simp: gpartition_spec_def slice_eq_mset_whole_iff)
  subgoal by (auto dest: mset_eq_length)
  subgoal by (auto dest: mset_eq_sum_mset_lengthD)
  subgoal by (auto dest: mset_eq_length mset_eq_sum_mset_lengthD)
  subgoal by (auto dest: mset_eq_length mset_eq_sum_mset_lengthD)
  subgoal by (auto dest: mset_eq_length mset_eq_sum_mset_lengthD)
  
  (* Swapping: multiset equal *)
  subgoal by (subst swap_multiset; auto dest: mset_eq_length)
    
  (* Swapping: partitioning *)
  subgoal for xs1 xs2 xs2' m
    apply (rule slice_LT_by_nthI[where p="xs1 ! 0"])
    apply clarsimp_all
    subgoal by (drule mset_eq_length)+ auto
    subgoal by (drule mset_eq_length)+ (auto simp: length_Suc_conv swap_nth)
    subgoal by (drule mset_eq_length)+ (auto simp: length_Suc_conv swap_nth)
    done
    
  (* Shifting: partitioning *)
  subgoal for xs2 m xs2' xs1
    apply (rule slice_LT_by_nthI[where p="xs1 ! 0"])
    apply clarsimp_all
    subgoal by (drule mset_eq_length mset_eq_sum_mset_lengthD)+ (auto)
    subgoal by (drule mset_eq_length mset_eq_sum_mset_lengthD)+ (auto simp: length_Suc_conv nth_Cons')
    subgoal by (drule mset_eq_length mset_eq_sum_mset_lengthD)+ (force simp: length_Suc_conv nth_Cons')
    done
  done    
  
  
  



  
definition align_chunk_size :: "nat  nat  nat nres" where "align_chunk_size d n = doN {
  ASSERT (d>0);
  if n=0  d=1 then RETURN d
  else doN {
    p  nat_div_round_up n d;
    nat_div_round_up n p
  }
}"  

lemma align_chunk_size_correct[refine_vcg]: "d>0  align_chunk_size d n  SPEC (λr. r>0)"
  unfolding align_chunk_size_def
  apply refine_vcg
  by (auto intro: ccontr[of "0<_"])

  
definition "ppart_partition_pivot d xs n  
  if nd then partition_pivot_sample xs n
  else doN {
    d  align_chunk_size d n;
    ppart_partition_pivot_par d xs n
  }"  
  
  
lemma ppart_partition_pivot_refines_part3: "0<d; n=length xs  ppart_partition_pivot d xs n  partition3_spec xs 0 n"
  unfolding ppart_partition_pivot_def
  apply (split if_split, intro conjI impI)
  subgoal
    using partition_pivot_sample_correct[of xs xs n n]
    by (auto)
  subgoal
    apply (rule specify_left)  
    apply (rule align_chunk_size_correct, simp)
    apply (rule ppart_partition_pivot_par_refines_part3)
    by auto
  done    
  
    
end  

  



(* Taking the smallest of two intervals, chopping the other *)

definition "split_sets_eq1 s1 s2 = do {

  (c1,r1)  mop_set_rm_subset s1;
  (c2,r2)  mop_set_rm_subset s2;
    
  let card1 = op_set_card c1;
  let card2 = op_set_card c2;

  let mc = min card1 card2;
  
  (c11,c12)  mop_set_split_card mc c1;
  (c21,c22)  mop_set_split_card mc c2;
  
  r1  mop_set_union_disj r1 c12;
  r2  mop_set_union_disj r2 c22;

  RETURN ((c11,r1),(c21,r2))
}"



lemma ivls_split_refine: "(split_sets_eq1, split_sets_eq_SPEC)  Id  Id  Idnres_rel"
  unfolding split_sets_eq1_def split_sets_eq_SPEC_def
  apply clarsimp
  apply (refine_vcg SPEC_refine)
  apply simp_all
  subgoal by blast
  subgoal by blast
  apply (intro conjI)
  subgoal by blast
  subgoal by blast
  subgoal by blast
  subgoal by (metis Int_Un_distrib inf_commute)
  subgoal by (metis card_gt_0_iff min_less_iff_conj sup_bot_left)
  done

abbreviation "iv_assn_sz  iv_assn' TYPE(size_t)"  
abbreviation "iv_lb_assn_sz  iv_lb_assn' TYPE(size_t)"  
abbreviation "ivl_assn_sz  ivl_assn' TYPE(size_t)"  
  

sepref_definition split_sets_eq_impl [llvm_inline] is "uncurry split_sets_eq1" :: "ivl_assn_szd *a ivl_assn_szd a (iv_assn_sz ×a ivl_assn_sz) ×a (iv_assn_sz ×a ivl_assn_sz)"
  unfolding split_sets_eq1_def min_def
  by sepref
  
sepref_register split_sets_eq_SPEC
lemmas split_sets_eq_impl_spec_refine[sepref_fr_rules] = split_sets_eq_impl.refine[FCOMP ivls_split_refine]


  
context weak_ordering begin
  term partition_spec 
  term partition_SPEC

  definition "ppart_seq1 p h xs  do {
    ASSERT (h = length xs);
    (xs,m)  gpartition_SPEC 0 h p xs;
    RETURN (xs,{0..<m},{m..<h})
  }"
  
  
  lemma gpartition_all_imp_ppart_spec: 
    "gpartition_all_spec p xs xs' m  ppart_spec p xs xs' {0..<m} {m..<length xs}"
    unfolding gpartition_all_spec_alt1 ppart_spec_def
    apply (clarsimp) 
    apply (frule mset_eq_length)
    apply unfold_locales
    apply auto
    done

  sepref_register ppart_SPEC   
       
  lemma ppart_seq1_refine: "(ppart_seq1, PR_CONST ppart_SPEC)  Id  Id  Id  Idnres_rel"
    unfolding ppart_seq1_def gpartition_SPEC_def ppart_SPEC_def PR_CONST_def
    apply refine_vcg
    by (clarsimp_all simp: gpartition_all_imp_ppart_spec)

  (*
    Refinement to match technical requirements of Sepref:
    
    Recursion must not rely on frame (all used parameters must be arguments to recursive function)
    Pivot element must be explicitly copied
    
    If doing sequential partitioning, the result must be converted from iv to ivl. 
    We currently do that by op_set_union_disj _ {}, and a custom_fold to include the length parameter.
    TODO: a iv_to_ivl conversion operator would not need the custom fold, 
    and save the (trivial) disjoint side-condition that, nevertheless, has to be solved by sepref
  *)  
    
  definition "gpartition_slices2 d p len xs = RECT (λgpartition_slices (d,p,len,xs). do {
    ASSERT (d>0);
    ASSERT (len = length xs);
    if (len  d) then do {
      (xs,ss,bs)  ppart_SPEC p len xs;
      RETURN (xs,op_set_union_disj ss {},op_set_union_disj bs {})
    } else do {
      let si = len - d;
      (((ss1,bs1),(ss2,bs2)),xs)  WITH_SPLIT si xs (λxs1 xs2. do {
        let p_copy = COPY p;
        ((xs1,ivs1),(xs2,ivs2))  nres_par (gpartition_slices) (λ(p,d,xs). ppart_SPEC p d xs) (d,p,si,xs1) (p_copy,d,xs2);
        RETURN (((ivs1,ivs2),xs1,xs2))
      });
      
      ASSERT(iv_incr_elems_abs_bound ss2 si len);
      ss2mop_set_incr_elems si ss2;
      
      ASSERT(iv_incr_elems_abs_bound bs2 si len);
      bs2mop_set_incr_elems si bs2;
      
      ssmop_set_union_disj ss1 ss2;
      bsmop_set_union_disj bs1 bs2;
      
      RETURN (xs,ss,bs)
    }
  }) (d,COPY p,len,xs)"
    
  lemma gpartition_slices2_refine: "(gpartition_slices2, PR_CONST gpartition_slices) 
     nat_rel  Id  nat_rel  Idlist_rel  Idlist_rel ×r Id ×r Idnres_rel"
  proof (intro fun_relI, clarsimp, goal_cases)
    case (1 d0 p0 len0 xs0)
    
    define R1 where "R1={((d0,p0,len,xs),(len,xs)) | (len::nat) (xs::'a list). True }"
    note [refine_dref_RELATES] = RELATESI[of R1]
    
    
    show ?case
      unfolding gpartition_slices2_def gpartition_slices_def
      apply (rewrite at "let _ = COPY _ in  _" Let_def)
      apply refine_rcg
      apply refine_dref_type
      unfolding R1_def
      by simp_all
    
  qed
    
end






   
context sort_impl_copy_context begin




  definition [llvm_code]: "swap_o_intv_impl_uncurried  (λ(a, x, y). swap_o_intv_impl a x y)"


  sepref_def par_swap_aux_impl is "uncurry2 par_swap_aux" 
    :: "[λ_. True]c ivl_assn_szd *a ivl_assn_szd *a (oarray_idxs_assn elem_assn)d  oarray_idxs_assn elem_assn [λ((_,_),xsi) xsi'. xsi'=xsi]c"
    unfolding par_swap_aux_def
    supply [[goals_limit = 1]]
    apply (rewrite RECT_cp_annot[where CP="λ(_,_,xsi) xsi'. xsi'=xsi"])
    supply [sepref_comb_rules] = hn_RECT_cp_annot_noframe
    supply [sepref_opt_simps] = swap_o_intv_impl_uncurried_def[symmetric]
    by sepref    
    
  sepref_def par_swap_impl is "uncurry2 par_swap" 
    :: "[λ_. True]c ivl_assn_szd *a ivl_assn_szd *a (arr_assn)d  arr_assn [λ((_,_),xsi) xsi'. xsi'=xsi]c"
    unfolding par_swap_def
    by sepref
    
    
  sepref_definition ppart_seq_impl [llvm_inline] is "uncurry2 ppart_seq1" :: "[λ_. True]c elem_assnk *a size_assnk *a arr_assnd  arr_assn ×a iv_assn_sz ×a iv_assn_sz [λ(_,xsi) (xsi',_,_). xsi'=xsi]c" 
    unfolding ppart_seq1_def
    apply (annot_snat_const "TYPE(size_t)")
    by sepref
  
  lemmas ppart_seq_impl_hnr[sepref_fr_rules] = ppart_seq_impl.refine[FCOMP ppart_seq1_refine]
    

  (* We fold that, because the parallel operator works only with functions wight now.
    TODO: use automatic extraction (like for REC) to extract functions from parallel operator upon code generation!
  *)
  definition [llvm_code]: "ppart_seq1_impl_uncurried  λ(p,d,xs). ppart_seq_impl p d xs"
    
  
  sepref_register gpartition_slices
  
  sepref_definition gpartition_slices_impl [llvm_code] is "uncurry3 gpartition_slices2" :: 
    "[λ_. True]c size_assnk *a elem_assnk *a size_assnk *a arr_assnd  arr_assn ×a ivl_assn_sz ×a ivl_assn_sz [λ(_,xsi) (xsi',_,_). xsi'=xsi]c"
    unfolding gpartition_slices2_def
    supply [[goals_limit = 1]]
    
    apply (rewrite RECT_cp_annot[where CP="λ(_,_,_,xsi) (xsi',_,_). xsi'=xsi"])
    apply (rewrite ivl_fold_custom_empty[where 'l=size_t])
    apply (rewrite ivl_fold_custom_empty[where 'l=size_t])
  
    supply [sepref_comb_rules] = hn_RECT_cp_annot_noframe
    
    supply [sepref_opt_simps] = ppart_seq1_impl_uncurried_def[symmetric]
    
    
    by sepref (* Takes looooong ~30 sec *)

  context
    notes [fcomp_norm_simps] = list_rel_id_simp
  begin    

    lemmas gpartition_slices_hnr[sepref_fr_rules] = gpartition_slices_impl.refine[FCOMP gpartition_slices2_refine]

  end  
    

  thm ppart_filter_def
  lemma ppart_filter_alt: "ppart_filter m ss bs = ( ss{m..}, bs  {0..<m} )"
    unfolding ppart_filter_def by auto
  
  sepref_def ppart_filter_impl is "uncurry2 (RETURN ooo ppart_filter)" :: "size_assnk *a ivl_assn_szd *a ivl_assn_szd a ivl_assn_sz ×a ivl_assn_sz"
    unfolding ppart_filter_alt
    apply (annot_snat_const "TYPE(size_t)")
    by sepref
    
        
  sepref_register ppart2  
  sepref_def ppart_impl is "uncurry3 (PR_CONST ppart2)" 
    :: "[λ_. True]c size_assnk *a elem_assnk *a size_assnk *a arr_assnd  size_assn ×a arr_assn [λ(((_,_),_),xsi) (_,ri). ri=xsi]c"
    unfolding ppart2_def PR_CONST_def
    by sepref

  find_theorems ppart1
    
  thm ppart_impl.refine ppart2_refine  
  (* TODO: Find out what is needed for sorting, and assemble! *)      
  
  sepref_register ppart_partition_pivot_par
  sepref_def ppart_partition_pivot_par_impl is "uncurry2 (PR_CONST ppart_partition_pivot_par)" 
    :: "[λ_. True]c size_assnk *a arr_assnd *a size_assnk  arr_assn ×a size_assn [λ((_,xsi),_) (xsi',_). xsi'=xsi]c"
    unfolding ppart_partition_pivot_par_def move_pivot_to_first_sample_def PR_CONST_def
    apply (annot_snat_const "TYPE(size_t)")
    by sepref


  sepref_register ppart_partition_pivot
  sepref_def ppart_partition_pivot_impl is "uncurry2 (PR_CONST ppart_partition_pivot)" 
    :: "[λ_. True]c size_assnk *a arr_assnd *a size_assnk  arr_assn ×a size_assn [λ((_,xsi),_) (xsi',_). xsi'=xsi]c"
    unfolding ppart_partition_pivot_def align_chunk_size_def PR_CONST_def
    apply (annot_snat_const "TYPE(size_t)")
    by sepref
    
    
  
end      
    

(*
global_interpretation test: sort_impl_copy_context "(≤)" "(<)" ll_icmp_ult "unat_assn' TYPE(size_t)" Mreturn free_pure
  defines test_ppart_impl = test.ppart_impl
      and test_ppart_partition_pivot_par_impl = test.ppart_partition_pivot_par_impl
      and test_par_swap_impl = test.par_swap_impl
      and test_gpartition_slices_impl = test.gpartition_slices_impl
      and test_ppart_seq1_impl_uncurried = test.ppart_seq1_impl_uncurried (* TODO: Workaround, as templating seems to choke on llc_par! *)
  apply unfold_locales
  apply (auto simp: mk_free_pure free_pure_def[abs_def])
  apply rule apply sepref
  apply (rule is_copy_pure_gen_algo) apply solve_constraint
  done
  
declare [[llc_compile_par_call=true]]
export_llvm (timing) test_ppart_partition_pivot_par_impl
*)


end