Theory Sepref_Parallel

section Automatic Refinement to Parallel Execution
theory Sepref_Parallel
imports Sepref_Tool
begin

  text Abstractly, we define an annotation, that maps to sequnetial execution
  definition "nres_par f g x y  do { r1f x; r2g y; RETURN (r1,r2) }"

  subsection Setup Boilerplate
  text Boilerplate required by the IRF to set up higher-order combinator
  lemma nres_par_arity[sepref_monadify_arity]:
    "nres_par  λ2f g x y. SP nres_par $ (λ2x. f$x) $ (λ2y. g$y) $ x $ y" 
    by (simp_all)
  
  lemma nres_par_comb[sepref_monadify_comb]:
    "nres_par$f$g$x$y  
      Refine_Basic.bind$(EVAL$x)$(λ2x. Refine_Basic.bind$(EVAL$y)$(λ2y. 
        (SP nres_par$f$g$x$y)))"
    by (simp_all)
  
  lemma nres_par_id[id_rules]: 
    "nres_par ::i TYPE(('a  'b nres)  ('c  'd nres)  'a  'c  ('b × 'd) nres)"
    by simp
  
    
  lemma nres_par_pw[refine_pw_simps]:
    "nofail (nres_par f g x y) = (nofail (f x)  ((xa. inres (f x) xa)  nofail (g y)))"
    "inres (nres_par f g x y) r  (nofail (f x) 
     (ya. inres (f x) ya  (nofail (g y)  (yb. inres (g y) yb  (ya, yb) = r))))"
    unfolding nres_par_def
    by (simp_all add: refine_pw_simps)
    
  lemma nres_par_vcg[refine_vcg]:
    assumes "f x  SPEC (λr1. g y  SPEC (λr2. Φ (r1,r2)))"
    shows "nres_par f g x y  SPEC Φ"
    using assms by (auto simp: refine_pw_simps pw_le_iff nres_par_def)
    
  lemma nres_par_refine[refine]:
    assumes "f x  Rx (f' x')"
    assumes "g y  Ry (g' y')"
    assumes "rx rx' ry ry'. (rx,rx')Rx  (ry,ry')Ry  ((rx,ry), (rx',ry'))R"
    shows "nres_par f g x y R (nres_par f' g' x' y')"
    using assms 
    by (simp add: refine_pw_simps pw_le_iff nres_par_def; blast)


  lemma nres_par_mono[refine_mono]:
    assumes "x. f x  f' x" 
    assumes "y. g y  g' y"  
    shows "nres_par f g x y  nres_par f' g' x y"
    using assms 
    by (simp add: refine_pw_simps pw_le_iff nres_par_def; blast)
    
  lemma nres_par_flat_mono[refine_mono]:
    assumes "x. flat_ge (f x) (f' x)" 
    assumes "y. flat_ge (g y) (g' y)"  
    shows "flat_ge (nres_par f g x y) (nres_par f' g' x y)"
    using assms 
    by (simp add: refine_pw_simps pw_flat_ge_iff nres_par_def; blast)


    
  subsection Refinement Rule        
  
  text Refinement rule from annotation to actual parallel execution
  
  (* TODO: Move *)
  lemma ht_llc_par:
    assumes "llvm_htriple P1 (m1 x1) Q1"  
    assumes "llvm_htriple P2 (m2 x2) Q2"  
    shows "llvm_htriple (P1**P2) (llc_par m1 m2 x1 x2) (λ(r1,r2). Q1 r1 ** Q2 r2)"
    unfolding llc_par_def
    supply [vcg_rules] = ht_par[OF assms]
    by (vcg)
    
    
  lemma hnr_nres_par_aux:
    assumes A: "hn_refine (Ax) (fi xi) (Ax') Rx CP1 (f x)"
    assumes B: "hn_refine (Ay) (gi yi) (Ay') Ry CP2 (g y)"
    shows "hn_refine (Ax ** Ay) (llc_par fi gi xi yi) (Ax' ** Ay') (Rx×aRy) (λ(rxi,ryi). CP1 rxi  CP2 ryi) (nres_par f g x y)"
  proof -
    note [vcg_rules] = ht_llc_par[where m1=fi and m2=gi, OF A[THEN hn_refineD] B[THEN hn_refineD]]

    from A[THEN hn_refineD] have 
      NSUCCA: "r. inres (f x) r" if "realizable Ax"
      using that
      apply (cases "f x  SUCCEED")
      subgoal by (auto simp: refine_pw_simps pw_eq_iff)
      apply (clarsimp simp: htriple_false)
      done
    
    show ?thesis
      apply sepref_to_hoare
      apply (rule htriple_realizable_preI; drule realizable_conjD; clarify)
      supply [simp] = refine_pw_simps NSUCCA inres_def[symmetric]
      apply (vcg)
      done
  qed
      
  lemma hnr_nres_par[sepref_comb_rules]:
    assumes FR: "P  hn_ctxt Ax x xi ** hn_ctxt Ay y yi ** F"
    assumes A: "hn_refine (hn_ctxt Ax x xi) (fi xi) Ax' Rx CP1 (f x)"
    assumes B: "hn_refine (hn_ctxt Ay y yi) (gi yi) Ay' Ry CP2 (g y)"
    shows "hn_refine 
      P 
      (llc_par fi gi xi yi) 
      (Ax' ** Ay' ** F) (Rx×aRy) (λ(rxi,ryi). CP1 rxi  CP2 ryi)
      (nres_par$(λ2x. f x)$(λ2y. g y)$x$y)"
    apply (rule hn_refine_cons_post)
    unfolding autoref_tag_defs PROTECT2_def
    apply (rule hn_refine_frame[OF hnr_nres_par_aux, where F=F])
    apply fact
    apply fact
    subgoal using FR unfolding hn_ctxt_def by (auto simp: algebra_simps)
    subgoal unfolding hn_ctxt_def by (auto simp: algebra_simps)
    done
  

end