Theory Sorting_Parsort

section Parallel Quicksort
theory Sorting_Parsort
imports Sorting_Introsort Sorting_PDQ Sorting_Sample_Partition
begin



context weak_ordering begin

    subsection Abstract Algorithm
    text We use a few straightforward refinement steps to develop the abstract parallel 
      quicksort algorithm
            
    definition bad_partition :: "nat  nat  bool nres" where 
    "bad_partition m n  do {
      ASSERT (mn);
      RETURN (m < n div 8  n-m < n div 8)
    }"
  
    lemma bad_partition_triv: "mn  bad_partition m n  SPEC (λ_. True)"
      unfolding bad_partition_def
      apply refine_vcg
      by simp
    
    abbreviation "par_sort_seq_threshold::nat  100000"  
      
    definition "par_sort_aux xs d  RECT (λpar_sort_aux (xs,d::nat). doN {
      let n = length xs;
      if d=0  n<par_sort_seq_threshold then
        slice_sort_spec (<) xs 0 n
      else doN {
        (xs',m)  partition3_spec xs 0 n;
        bad  bad_partition m n;
        
        ASSERT (length xs' = length xs);
        (_,xs'')  WITH_SPLIT m xs' (λxs1 xs2. doN {
          ASSERT (length xs' = length xs1 + length xs2);
          (xs1',xs2')  if bad then doN {
            xs1'  par_sort_aux (xs1,d-1);
            ASSERT (length xs1' = length xs1);
            xs2'  par_sort_aux (xs2,d-1);
            ASSERT (length xs2' = length xs2);
            RETURN (xs1',xs2')
          } else doN {
            xs1'  par_sort_aux (xs1,d-1);
            ASSERT (length xs1' = length xs1);
            xs2'  par_sort_aux (xs2,d-1);
            ASSERT (length xs2' = length xs2);
            RETURN (xs1',xs2')
          };
          RETURN ((),xs1',xs2')
        });
        RETURN xs''
      }
    }) (xs,d)"
    
    
    lemma par_sort_aux_correct: "par_sort_aux xs d  slice_sort_spec (<) xs 0 (length xs)"
      unfolding par_sort_aux_def 
      apply (subst if_cancel)
      apply (refine_vcg RECT_rule_arb[where V="measure (λ(_,d). d)" and pre="λxss (xs,d). xss=xs"])
      apply simp_all [2]
      unfolding slice_sort_spec_def partition3_spec_def
      apply (refine_vcg bad_partition_triv)
      apply simp_all 
      apply (meson slice_eq_mset_all slice_eq_mset_def)
      apply (rule order_trans) apply (rprems)
      apply simp
      apply simp
      apply refine_vcg
      apply (clarsimp)
      apply (rule order_trans) apply (rprems)
      apply simp
      apply simp
      apply refine_vcg
      apply (clarsimp_all)
      subgoal for xs' xs2' xs1 xs2 xc x1b
        
        unfolding sort_spec_def slice_eq_mset_def slice_LT_def
        apply (auto simp: slice_complete' sorted_wrt_append le_by_lt slice_append1' slice_append2')
        by (metis set_mset_mset)
      done
      
    text Introducing explicit parameter for list length  
    definition "par_sort_aux2 xs n d  RECT (λpar_sort_aux (xs,n,d::nat). doN {
      ASSERT (n = length xs);
      if d=0  n<par_sort_seq_threshold then
        slice_sort_spec (<) xs 0 n
      else doN {
        (xs,m)  partition3_spec xs 0 n;
        bad  bad_partition m n;
        (_,xs)  WITH_SPLIT m xs (λxs1 xs2. doN {
          ASSERT (length xs2 = length xs - m);
          ASSERT (nm);
          (xs1,xs2)  if bad then doN {
            xs1  par_sort_aux (xs1,m,d-1);
            xs2  par_sort_aux (xs2,n-m,d-1);
            RETURN (xs1,xs2)
          } else doN {
            nres_par par_sort_aux par_sort_aux (xs1,m,d-1) (xs2,n-m,d-1)
          };
          RETURN ((),xs1,xs2)
        });
        RETURN xs
      }
    }) (xs,n,d)"

    
    lemma par_sort_aux2_refine: "n=length xs  par_sort_aux2 xs n d  (Idlist_rel) (par_sort_aux xs d)"
      unfolding par_sort_aux2_def par_sort_aux_def nres_par_def
      apply (refine_rcg)
      supply [refine_dref_RELATES] = RELATESI[where R="{((xs,n,d),(xs',d')). xs'=xs  d'=d  length xs'=n}"]
      apply refine_dref_type
      apply (simp_all (no_asm_use)) (* TODO: This is a hack against a yet unidentified simplifier loop *)
      apply auto
      done



  text Fixing concrete algorithms to be used
  definition "par_sort_aux3 xs n d  RECT (λpar_sort_aux (xs,n,d::nat). doN {
    ASSERT (n = length xs);
    if d=0  n<par_sort_seq_threshold then
      pdqsort xs 0 n
    else doN {
      (xs,m)  partition_pivot_sample xs n;
      bad  bad_partition m n;
      (_,xs)  WITH_SPLIT m xs (λxs1 xs2. doN {
        ASSERT (length xs2 = length xs - m);
        ASSERT (d>0);
        ASSERT (nm);
        (xs1,xs2)  if bad then doN {
          xs1  par_sort_aux (xs1,m,d-1);
          xs2  par_sort_aux (xs2,n-m,d-1);
          RETURN (xs1,xs2)
        } else doM {
          nres_par par_sort_aux par_sort_aux (xs1,m,d-1) (xs2,n-m,d-1)
        };
          
        RETURN ((),xs1,xs2)
      });
      RETURN xs
    }
  }) (xs,n,d)"


  (* TODO: Move *)
  lemma introsort4_refines_spec: "(xs',xs)Idlist_rel; (l',l)nat_rel; (h',h)nat_rel  introsort4 xs' l' h'   Id (slice_sort_spec (<) xs l h)"
    using introsort4_correct by auto
  
  lemma pdqsort_refines_spec: "(xs',xs)Idlist_rel; (l',l)nat_rel; (h',h)nat_rel  pdqsort xs' l' h'   Id (slice_sort_spec (<) xs l h)"
    using pdqsort_correct by auto
  
  lemma par_sort_aux3_refine: "par_sort_aux3 xs n d Id (par_sort_aux2 xs n d)"
    unfolding par_sort_aux3_def par_sort_aux2_def
    thm partition_pivot_correct partition_pivot_sample_correct
    apply (refine_rcg partition_pivot_sample_correct introsort4_refines_spec pdqsort_refines_spec)
    apply refine_dref_type
    by auto    
    
  (* TODO: Move *)  
  lemma slice_sort_spec_complete: "slice_sort_spec lt xs 0 (length xs) = SPEC (sort_spec lt xs)"  
    unfolding slice_sort_spec_def sort_spec_def
    apply clarsimp
    by (metis le_refl mset_eq_length slice_complete)
    
  lemma par_sort_aux3_correct:
    assumes "n=length xs"  
    shows "par_sort_aux3 xs n d  SPEC (sort_spec (<) xs)"
  proof -
    note par_sort_aux3_refine[where d=d]
    also note par_sort_aux2_refine[OF assms]
    also note par_sort_aux_correct
    also note slice_sort_spec_complete
    finally show ?thesis by simp 
  qed
    
  
  text Initializing depth bound
  definition "par_sort xs n  doN {
    if n>1 then doN {
      par_sort_aux3 xs n (Discrete.log n * 2)
    } else RETURN xs
  }"
  
  thm sort_spec_def
  
  lemma par_sort_correct: "n=length xs  par_sort xs n  SPEC (sort_spec (<) xs)"
    unfolding par_sort_def 
    apply (refine_vcg par_sort_aux3_correct)
    by (simp add: sort_spec_def sorted_wrt01)
        
end    
    
subsection Refining to LLVM

(* TODO: Move *)
definition "map_res f m  doM { xm; Mreturn (f x) }"

lemma map_res_return[sepref_opt_simps2]: "map_res φ (Mreturn x) = Mreturn (φ x)"
  unfolding map_res_def by auto

lemma map_res_bind[sepref_opt_simps2]: "map_res φ (doM {xm; f x}) = doM {xm; map_res φ (f x)}"  
  unfolding map_res_def by auto

lemma map_res_prod_case[sepref_opt_simps2]: "map_res φ (case p of (a,b)  f a b) = (case p of (a,b)  map_res φ (f a b))" 
  by (rule prod.case_distrib)

lemmas [sepref_opt_simps2] = prod.sel  
  
  
definition [llvm_inline]: "ars_with_split_nores i a m  doM {
  (a1,a2)  ars_split i a;
  (_,_)  m a1 a2;
  ars_join a1 a2;
  Mreturn a
}"


lemma ars_with_split_bind_unit[sepref_opt_simps2]: "doM {
  (uu::unit,xs)  ars_with_split i a m;
  mm uu xs
} = doM {
  xsars_with_split_nores i a (λxs1 xs2. map_res snd (m xs1 xs2));
  mm () xs
}"
  unfolding ars_with_split_def ars_with_split_nores_def map_res_def 
  apply pw
  done
    
  
lemma sepref_adhoc_opt_case_add_const[sepref_opt_simps]:
  "(case case x of (a1c, a2c)  (c, a1c, a2c) of (uu, a, b)  m uu a b) = (case x of (a,b)  m c a b)" by simp

context sort_impl_context begin
  
  sepref_register bad_partition
  sepref_def bad_partition_impl [llvm_inline] is "uncurry bad_partition" :: "size_assnk *a size_assnk a bool1_assn"
    unfolding bad_partition_def
    apply (annot_snat_const "TYPE(size_t)")
    by sepref

  sepref_register par_sort_aux3
  
  sepref_def par_sort_aux_impl is "uncurry2 (PR_CONST par_sort_aux3)" 
    :: "[λ_. True]c (arr_assn)d *a size_assnk *a size_assnk  
     arr_assn [λ((ai,_),_) r. r=ai]c"
  unfolding par_sort_aux3_def PR_CONST_def
  supply [[goals_limit = 1]]
  apply (annot_snat_const "TYPE(size_t)")
  apply (rewrite RECT_cp_annot[where CP="λ(ai,_,_) r. r=ai"])
  
  supply [sepref_comb_rules] = hn_RECT_cp_annot_noframe
  apply sepref
(* 
  (* debugging boilerplate: *)
  apply sepref_dbg_preproc
  apply sepref_dbg_cons_init
  apply sepref_dbg_id
  apply sepref_dbg_monadify
  apply sepref_dbg_opt_init
  apply sepref_dbg_trans
  apply sepref_dbg_opt
  apply sepref_dbg_cons_solve
  apply sepref_dbg_cons_solve
  apply sepref_dbg_cons_solve_cp
  apply sepref_dbg_constraints
*)  
  done
  

  sepref_def par_sort_impl is "uncurry (PR_CONST par_sort)" 
    :: "[λ_. True]c (arr_assn)d *a size_assnk  
     arr_assn [λ(ai,_) r. r=ai]c"
    unfolding par_sort_def PR_CONST_def
    apply (annot_snat_const "TYPE(size_t)")
    supply [intro!] = introsort_depth_limit_in_bounds_aux 
    by sepref
  
  thm par_sort_impl.refine[to_hnr, unfolded hn_ctxt_def, of xs xsi n ni]
    
  subsection Final Correctness Theorem as Hoare-Triple
  
  lemma par_sort_refine_aux: "(uncurry par_sort, uncurry (λxs n. doN {ASSERT (n=length xs); SPEC (sort_spec (<) xs) }))  Id ×r Id  Idnres_rel"  
    using par_sort_correct[OF refl]
    by (auto simp: pw_nres_rel_iff pw_le_iff refine_pw_simps)
        
  text We unfold the definition of hnr›, to extract a correctness statement as Hoare-Triple  
  theorem par_sort_impl_correct: "llvm_htriple (arr_assn xs xsi ** snat_assn n ni ** (n = length xs)) 
    (par_sort_impl xsi ni) 
    (λr. (r=xsi) ** (EXS xs'. arr_assn xs' xsi ** (sort_spec (<) xs xs')))"
    apply (cases "n=length xs"; simp)
    apply (rule cons_rule)
    supply R = par_sort_impl.refine[unfolded PR_CONST_def,FCOMP par_sort_refine_aux, to_hnr, simplified]
    supply R = R[of xs xsi n ni]
    apply (rule R[THEN hn_refineD])
    apply (simp)
    apply (simp add: hn_ctxt_def sep_algebra_simps)
    apply (auto simp add: hn_ctxt_def sep_algebra_simps pure_def invalid_assn_def)
    done
  
  text With the sorting specification unfolded, too. 
    Note that constmset and constsorted_wrt are standard concepts from Isabelle's library.
    
  theorem par_sort_impl_correct': "llvm_htriple (arr_assn xs xsi ** snat_assn n ni ** (n = length xs)) 
    (par_sort_impl xsi ni) 
    (λr. (r=xsi) ** (EXS xs'. arr_assn xs' xsi ** (mset xs' = mset xs  sorted_wrt () xs')))"
    apply (rule cons_rule[OF par_sort_impl_correct])
    apply simp
    apply (clarsimp simp add: sep_algebra_simps)
    apply (auto simp: sort_spec_def le_by_lt)
    done
    
    

end

end