Theory Sorting_Sample_Partition

theory Sorting_Sample_Partition
imports Sorting_Quicksort_Scheme Sorting_Quicksort_Partition Sorting_Ex_Array_Idxs
begin

context weak_ordering begin
  
  definition num_samples :: "nat  nat nres" 
    where "num_samples n  RETURN (min n 64)"

  lemma num_samples_correct: "n4  num_samples n  SPEC (λi. i3  i  n)"
    unfolding num_samples_def by auto

  definition "idxs_spec xs n ns  doN {
    ASSERT (n=length xs  nsn);
    SPEC (λidxs. 
        ns = length idxs 
       distinct idxs 
       set idxs  {0..<n} 
       sorted_wrt (λi j. xs!i  xs!j) idxs)
  }"  
    
 
  definition "equidist n ns  doN {
    ASSERT (2ns  nsn);
    let idxs = replicate ns 0;
    
    let incr = n div ns;
    let extra = n mod ns;

    ASSERT (incr>0);
          
    (_,_,_,idxs)  WHILEIT 
      (λ(i,j,extra,idxs). 
          ns=length idxs
         ins  extra0  extra < ns
         set (take i idxs)  {0..<j}
         distinct (take i idxs)
         j + (ns-i)*incr + extra = n
      ) 
      (λ(i,j,extra,idxs). i<ns) (λ(i,j,extra,idxs). doN {
        idxs  mop_list_set idxs i j;
        ASSERT(i+1  ns);
        let i=i+1;
        ASSERT(j+incr  n);
        let j=j+incr;
        (j,extra)  if extra>0 then doN {
          ASSERT (j+1n);
          RETURN (j+1,extra-1)
        } else 
          RETURN (j,extra);
          
        RETURN (i,j,extra,idxs) 
      }) 
      (0,0,extra,idxs);
    
    RETURN idxs
  }"     

  lemma set_subset_conv_nth: "set xs  S  (i<length xs. xs!i  S)"
    by (auto 0 3 simp: in_set_conv_nth in_mono)
    
  
  
  lemma equidist_correct[refine_vcg]: "2ns; nsn  
    equidist n ns  SPEC (λidxs. length idxs = ns  distinct idxs  set idxs  {0..<n})"
    unfolding equidist_def
    apply (refine_vcg WHILEIT_rule[where R="measure (λ(i,j,extra,idxs). ns-i)"])
    apply (simp_all add: div_positive)
    apply auto []
    
    apply (auto simp: in_set_conv_nth nth_list_update' set_subset_conv_nth distinct_conv_nth)
    subgoal by (metis div_le_dividend div_mult_self1_is_m less_eq_Suc_le nat_add_left_cancel_le trans_le_add1 zero_less_diff)
    subgoal by (smt (z3) Euclidean_Division.div_less add_less_cancel_left diff_is_0_eq div_mult_self_is_m div_positive le_neq_implies_less less_add_same_cancel1 less_trans not_less_eq_eq)
    subgoal by (meson less_SucI less_antisym trans_less_add1)
    subgoal by (metis less_SucE nat_less_le)
    subgoal by (metis less_SucE nat_less_le)
    subgoal by (metis Suc_diff_le diff_Suc_Suc group_cancel.add1 mult_Suc) 
    subgoal by (meson div_positive less_le_trans pos2) 
    subgoal by (meson less_antisym trans_less_add1)
    subgoal by (metis less_SucE nat_less_le)
    subgoal by (metis less_SucE nat_less_le)
    subgoal by (metis Suc_diff_le diff_Suc_Suc group_cancel.add1 mult_Suc) 
    done      

    
          
  definition "sorted_samples_spec xs n ns  doN {ASSERT (2ns  nsn  n = length xs); SPEC (λidxs. 
      ns = length idxs 
     distinct idxs 
     set idxs  {0..<n} 
     sorted_wrt (λi j. xs!i  xs!j) idxs)}"

  definition "sorted_samples xs n ns  doN {
    ASSERT (n = length xs);
    idxs  equidist n ns;
    idxs  mop_array_to_woarray idxs;
    idxs  pslice_sort_spec (λxs. {0..<length xs}) (λxs i j. xs!i < xs!j) xs idxs 0 ns;
    idxs  mop_woarray_to_array idxs;
    RETURN idxs
  }"  

  lemma mset_eq_imp_distinct_eq: "mset xs = mset ys  distinct xs  distinct ys"
    by (metis count_mset_0_iff distinct_count_atmost_1)
  
  lemma sorted_samples_correct: "sorted_samples xs n ns  sorted_samples_spec xs n ns"
    unfolding sorted_samples_def sorted_samples_spec_def pslice_sort_spec_def slice_sort_spec_def
    apply simp
    apply refine_vcg
    apply (auto simp: slice_complete' sort_spec_def)
    subgoal by (metis mset_eq_imp_distinct_eq)
    subgoal by (metis atLeastLessThan_iff mset_eq_setD subsetD)
    subgoal by (simp add: le_by_lt_def sorted_wrt_iff_nth_less wo_leI)
    done

            
  definition "sample_pivot xs n  doN {
    ASSERT (n = length xs);
    ASSERT (n4);
    
    ns  num_samples n;
    
    idxs  sorted_samples_spec xs n ns;
    
    let mi = ns div 2;
    ASSERT (1mi  mi < length idxs-1);
    
    ASSERT (idxs!(mi-1)<length xs  idxs!(mi)<length xs  idxs!(mi+1)<length xs);
    
    ASSERT (xs!(idxs!(mi-1))  xs!(idxs!mi) 
           xs!(idxs!mi)  xs!(idxs!(mi+1)));
    
    RETURN (idxs!mi)
  }"

  (* For presentation in paper *)  
  lemma "doN {
    let ns = min (length xs) 64;
    idxs  equidist (length xs) ns;
    idxs  slice_sort_spec (λi j. xs!i < xs!j) idxs 0 ns;
    RETURN (idxs!(ns div 2))
  }  sample_pivot xs n"
    using sorted_samples_correct
    unfolding sample_pivot_def num_samples_def sorted_samples_def pslice_sort_spec_def
    apply (simp only: pw_le_iff refine_pw_simps mop_array_to_woarray_def mop_woarray_to_array_def Let_def)
    apply safe
    apply simp
    apply blast
    apply blast
    apply blast
    apply blast
    by (metis (no_types, lifting))
    
    
  
  
  lemma sample_pivot_correct[refine_vcg]: "
    n=length xs; length xs  4  sample_pivot xs n  SPEC (λi. 
      i{0..<length xs} 
     (j{0..<length xs}. ij  xs!ixs!j)
     (j{0..<length xs}. ij  xs!ixs!j))"
      
    unfolding sample_pivot_def sorted_samples_spec_def
    apply (refine_vcg num_samples_correct)
    apply (clarsimp_all simp: sort_spec_def)
    apply simp_all
    subgoal for idxs
      by (meson atLeastLessThan_iff diff_le_self less_imp_diff_less less_le_trans nth_mem subset_code(1))
    subgoal for idxs
      by (meson atLeastLessThan_iff diff_le_self less_imp_diff_less less_le_trans nth_mem subset_code(1))
    subgoal for idxs
      by (simp add: subset_code(1))
    subgoal 
      by  (auto simp: sorted_wrt_iff_nth_less) 
    subgoal by (auto simp: sorted_wrt_iff_nth_less)
    subgoal for idxs
      apply (rule bexI[where x="idxs!(length idxs div 2 + 1)"])
      apply (simp_all add: distinct_conv_nth)
      done
    subgoal for idxs
      apply (rule bexI[where x="idxs!(length idxs div 2 - 1)"])
      apply (auto simp: distinct_conv_nth) 
      done
    done
    
  definition "move_pivot_to_first_sample xs n  doN {
    i  sample_pivot xs n;
    if i0 then
      mop_list_swap xs 0 i
    else
      RETURN xs
  }"
  
  lemma move_pivot_to_first_sample_correct[refine_vcg]: 
    "n=length xs; length xs  4  
    move_pivot_to_first_sample xs n  SPEC (λxs'. 
      mset xs' = mset xs
     (j{1..<length xs'}. xs'!0xs'!j)
     (j{1..<length xs'}. xs'!0xs'!j)        
    )"      
    unfolding move_pivot_to_first_sample_def
    apply refine_vcg
    apply (auto simp: swap_nth)
    subgoal by (metis One_nat_def atLeastLessThan_iff le_neq_implies_less less_one nat_le_linear)
    subgoal by (metis One_nat_def atLeastLessThan_iff less_one not_le)
    done
  
  
  definition "partition_pivot_sample xs n  doN {
    ASSERT (n=length xs);
    xs  move_pivot_to_first_sample xs n;
    (xs,m)  qs_partition 1 n 0 xs;
    RETURN (xs,m)
  }"
  
  
lemma slice_eq_mset_all[simp]: 
  "slice_eq_mset 0 (length xs') xs xs'  mset xs = mset xs'"
  unfolding slice_eq_mset_def 
  apply (auto simp: Misc.slice_def dest: mset_eq_length)  
  using mset_eq_length by fastforce

lemma slice_eq_mset_mono: "ll'  l'h'  h'h 
   slice_eq_mset l' h' xs xs'  slice_eq_mset l h xs xs'"
  unfolding slice_eq_mset_def
  apply auto  
  subgoal by (metis min_def take_take)
  subgoal by (meson slice_eq_mset_def slice_eq_mset_subslice)
  subgoal using drop_eq_mono by blast
  done  
  
  
find_theorems slice_eq_mset
thm slice_eq_mset_alt   
  
lemma partition_pivot_sample_correct: "(xs,xs')Id; (n,n')Id; n'=length xs' 
   partition_pivot_sample xs n  (Idlist_rel ×r nat_rel) (partition3_spec xs' 0 n')"
  unfolding partition_pivot_sample_def partition3_spec_def
  apply simp
  apply (refine_vcg qs_partition_correct)
  apply auto
  apply (metis eq_imp_le mset_eq_length)
  subgoal by (metis atLeastLessThan_iff size_mset)
  subgoal by (metis atLeastLessThan_iff size_mset)
  subgoal by (smt (z3) le_SucI le_trans less_or_eq_imp_le mset_eq_length slice_eq_mset_all slice_eq_mset_subslice)
  subgoal for xs1 m xs2 j ja
    apply (rule slice_LT_I_aux[where p="xs1!0"])
    apply (auto dest: mset_eq_length slice_eq_mset_eq_length)[2]
    subgoal by (metis le_refl size_mset slice_eq_mset_eq_length)
    apply (metis One_nat_def atLeastLessThan_iff bot_nat_0.extremum less_one not_le not_less_eq_eq slice_eq_mset_eq_length slice_eq_mset_nth_outside wo_refl)
    apply (auto dest: mset_eq_length slice_eq_mset_eq_length)[1]
    done
  done  
  


end

context sort_impl_context begin

  text Introsort for sorting samples
  sublocale SAMPLE_SORT: idxs_comp "()" "(<)" lt_impl elem_assn
    by unfold_locales
    
  
  find_in_thms SAMPLE_SORT.introsort_param_impl in sepref_fr_rules

  find_theorems SAMPLE_SORT.introsort_param_impl
  
  
  sepref_register equidist
  
  sepref_def equidist_impl [llvm_inline] is 
    "uncurry equidist" :: "size_assnk *a size_assnk a array_assn size_assn"
    unfolding equidist_def
    apply (rewrite array.fold_replicate_init)
    apply (annot_snat_const "TYPE(size_t)")
    by sepref

  lemma fold_sample_sort_spec: "pslice_sort_spec (λxs. {0..<length xs}) (λxs i j. xs ! i < xs ! j) 
    = PR_CONST (pslice_sort_spec SAMPLE_SORT.idx_cdom SAMPLE_SORT.idx_less)"
    unfolding PR_CONST_def SAMPLE_SORT.idx_cdom_def SAMPLE_SORT.idx_less_def ..    
      
  sepref_register "PR_CONST (pslice_sort_spec SAMPLE_SORT.idx_cdom SAMPLE_SORT.idx_less)"  
    
  sepref_register sorted_samples
  sepref_def sorted_samples_impl is "uncurry2 (PR_CONST sorted_samples)" 
    :: "arr_assnk *a size_assnk *a size_assnk a array_assn size_assn"
    unfolding PR_CONST_def
    unfolding sorted_samples_def fold_sample_sort_spec
    apply (annot_snat_const "TYPE(size_t)")
    supply [sepref_fr_rules] = SAMPLE_SORT.introsort_param_impl_correct
    apply sepref
    done
    
    
  sepref_register sorted_samples_spec  
    
  lemma sorted_samples_refine: "(PR_CONST sorted_samples, PR_CONST sorted_samples_spec) Id  Id  Id  Idnres_rel"
    using sorted_samples_correct by (auto simp: nres_relI)
  
  lemmas sorted_samples_impl_correct = sorted_samples_impl.refine[FCOMP sorted_samples_refine]
    
  sepref_register sample_pivot  
  sepref_def sample_pivot_impl is "uncurry (PR_CONST sample_pivot)" :: "arr_assnk *a size_assnk a size_assn" 
    unfolding sample_pivot_def num_samples_def min_def PR_CONST_def (* TODO: Include rule for min! *)
    apply (annot_snat_const "TYPE(size_t)")
    supply [sepref_fr_rules] = sorted_samples_impl_correct
    by sepref
    
  sepref_register partition_pivot_sample  
  sepref_def partition_pivot_sample_impl is "uncurry (PR_CONST partition_pivot_sample)" 
      :: "[λ_. True]c arr_assnd *a size_assnk  arr_assn×asize_assn [λ(ai,_) (r,_). r=ai]c" 
    unfolding partition_pivot_sample_def move_pivot_to_first_sample_def PR_CONST_def (* TODO: Include rule for min! *)
    apply (annot_snat_const "TYPE(size_t)")
    by sepref
    
    
end

end