Theory Sepref_Foreach

theory Sepref_Foreach
imports Sepref_IICF_Bindings Pf_Add
header ‹Setup for Foreach Combinator›
theory Sepref_Foreach
imports Sepref_HOL_Bindings Sepref_IICF_Bindings Pf_Add
begin

subsection "Foreach Loops"

subsubsection "Monadic Version of Foreach"

text {*
  In a first step, we define a version of foreach where the continuation condition
  is also monadic, and show that it is equal to the standard version for
  continuation conditions of the form @{text "λx. RETURN (c x)"}
*}

definition "FOREACH_inv xs Φ s ≡ 
  case s of (it, σ) => ∃xs'. xs = xs' @ it ∧ Φ (set it) σ"

definition "monadic_FOREACH R Φ S c f σ0 ≡ do {
  ASSERT (finite S);
  xs0 \<leftarrow> it_to_sorted_list R S;
  (_,σ) \<leftarrow> RECT (λW (xs,σ). do {
    ASSERT (FOREACH_inv xs0 Φ (xs,σ));
    if xs≠[] then do {
      b \<leftarrow> c σ;
      if b then
        FOREACH_body f (xs,σ) »= W
      else
        RETURN (xs,σ)
    } else RETURN (xs,σ)
  }) (xs0,σ0);
  RETURN σ
}"

lemma FOREACH_oci_to_monadic:
  "FOREACHoci R Φ S c f σ0 = monadic_FOREACH R Φ S (λσ. RETURN (c σ)) f σ0"
  unfolding FOREACHoci_def monadic_FOREACH_def WHILEIT_def WHILEI_body_def
  unfolding it_to_sorted_list_def FOREACH_cond_def FOREACH_inv_def
  apply simp
  apply (fo_rule arg_cong[THEN cong] | rule refl ext)+
  apply (simp split: prod.split)
  apply (rule refl)+
  done


text {* Next, we define a characterization w.r.t. @{text "nfoldli"} *}
definition "monadic_nfoldli l c f s ≡ RECT (λD (l,s). case l of 
    [] => RETURN s
  | x#ls => do {
      b \<leftarrow> c s;
      if b then do { s'\<leftarrow>f x s; D (ls,s')} else RETURN s
    }
  ) (l,s)"

lemma monadic_nfoldli_eq:
  "monadic_nfoldli l c f s = (
    case l of 
      [] => RETURN s 
    | x#ls => do {
        b\<leftarrow>c s; 
        if b then f x s »= monadic_nfoldli ls c f else RETURN s
      }
  )"
  apply (subst monadic_nfoldli_def)
  apply (subst RECT_unfold)
  apply (tagged_solver)
  apply (subst monadic_nfoldli_def[symmetric])
  apply simp
  done
  
lemma monadic_nfoldli_simp[simp]:
  "monadic_nfoldli [] c f s = RETURN s"
  "monadic_nfoldli (x#ls) c f s = do {
    b\<leftarrow>c s;
    if b then f x s »= monadic_nfoldli ls c f else RETURN s
  }"
  apply (subst monadic_nfoldli_eq, simp)
  apply (subst monadic_nfoldli_eq, simp)
  done

lemma nfoldli_to_monadic:
  "nfoldli l c f = monadic_nfoldli l (λx. RETURN (c x)) f"
  apply (induct l)
  apply auto
  done

definition "nfoldli_alt l c f s ≡ RECT (λD (l,s). case l of 
    [] => RETURN s
  | x#ls => do {
      let b = c s;
      if b then do { s'\<leftarrow>f x s; D (ls,s')} else RETURN s
    }
  ) (l,s)"

lemma nfoldli_alt_eq:
  "nfoldli_alt l c f s = (
    case l of 
      [] => RETURN s 
    | x#ls => do {let b=c s; if b then f x s »= nfoldli_alt ls c f else RETURN s}
  )"
  apply (subst nfoldli_alt_def)
  apply (subst RECT_unfold)
  apply (tagged_solver)
  apply (subst nfoldli_alt_def[symmetric])
  apply simp
  done
  
lemma nfoldli_alt_simp[simp]:
  "nfoldli_alt [] c f s = RETURN s"
  "nfoldli_alt (x#ls) c f s = do {
    let b = c s;
    if b then f x s »= nfoldli_alt ls c f else RETURN s
  }"
  apply (subst nfoldli_alt_eq, simp)
  apply (subst nfoldli_alt_eq, simp)
  done


lemma nfoldli_alt:
  "(nfoldli::'a list => ('b => bool) => ('a => 'b => 'b nres) => 'b => 'b nres)
  = nfoldli_alt"
proof (intro ext)
  fix l::"'a list" and c::"'b => bool" and f::"'a => 'b => 'b nres" and s :: 'b
  have "nfoldli l c f = nfoldli_alt l c f"
    by (induct l) auto
  thus "nfoldli l c f s = nfoldli_alt l c f s" by simp
qed

lemma monadic_nfoldli_rec:
  "monadic_nfoldli x' c f σ
          ≤\<Down>Id (RECT
             (λW (xs, σ).
                 ASSERT (FOREACH_inv xs0 I (xs, σ)) »=
                 (λ_. if xs = [] then RETURN (xs, σ)
                      else c σ »=
                           (λb. if b then FOREACH_body f (xs, σ) »= W
                                else RETURN (xs, σ))))
             (x', σ) »=
            (λ(_, y). RETURN y))"
  apply (induct x' arbitrary: σ)

  apply (subst RECT_unfold, refine_mono)
  apply (simp)
  apply (rule le_ASSERTI)
  apply simp

  apply (subst RECT_unfold, refine_mono)
  apply (subst monadic_nfoldli_simp)
  apply (simp del: conc_Id)
  apply refine_rcg
  apply (clarsimp simp add: FOREACH_body_def)
  apply (rule bind_mono(1)[OF order_refl])
  apply assumption
  done


lemma monadic_FOREACH_itsl:
  fixes R I tsl
  shows 
    "do { l \<leftarrow> it_to_sorted_list R s; monadic_nfoldli l c f σ } 
     ≤ monadic_FOREACH R I s c f σ"
    apply (rule refine_IdD)
    unfolding monadic_FOREACH_def it_to_sorted_list_def
    apply (refine_rcg)
    apply simp
    apply (rule monadic_nfoldli_rec[simplified])
    done

lemma FOREACHoci_itsl:
  fixes R I tsl
  shows 
    "do { l \<leftarrow> it_to_sorted_list R s; nfoldli l c f σ } 
     ≤ FOREACHoci R I s c f σ"
    apply (rule refine_IdD)
    unfolding FOREACHoci_def it_to_sorted_list_def
    apply refine_rcg
    apply simp
    apply (rule nfoldli_while)
    done

lemma [def_pat_rules]:
  "FOREACHc ≡ PR_CONST (FOREACHoci (λ_ _. True) (λ_ _. True))"
  "FOREACHci$I ≡ PR_CONST (FOREACHoci (λ_ _. True) I)"
  "FOREACHi$I ≡ λ2s. PR_CONST (FOREACHoci (λ_ _. True) I)$s$(λ2x. True)"
  "FOREACH ≡ FOREACHi$(λ2_ _. True)"
  by (simp_all add: 
    FOREACHci_def FOREACHi_def[abs_def] FOREACHc_def FOREACH_def[abs_def])
  
term "FOREACHoci R I"
lemma id_FOREACHoci[id_rules]: "PR_CONST (FOREACHoci R I) ::i 
  TYPE('c set => ('d => bool) => ('c => 'd => 'd nres) => 'd => 'd nres)"
  by simp

text {* We set up the monadify-phase such that all FOREACH-loops get
  rewritten to the monadic version of FOREACH *}
lemma FOREACH_arities[sepref_monadify_arity]:
  (*"FOREACHc ≡ FOREACHoci$(λ2_ _. True)$(λ2_ _. True)"
  "FOREACHci ≡ FOREACHoci$(λ2_ _. True)"
  "FOREACHi ≡ λ2I s. FOREACHci$I$s$(λ2x. True)"
  "FOREACH ≡ FOREACHi$(λ2_ _. True)"*)
  "PR_CONST (FOREACHoci R I) ≡ λ2s c f σ. SP (PR_CONST (FOREACHoci R I))$s$(λ2x. c$x)$(λ2x σ. f$x$σ)$σ"
  by (simp_all)

lemma FOREACHoci_comb[sepref_monadify_comb]:
  "!!s c f σ. (PR_CONST (FOREACHoci R I))$s$(λ2x. c x)$f$σ ≡ 
    bind$(EVAL$s)$(λ2s. bind$(EVAL$σ)$(λ2σ. 
      SP (PR_CONST (monadic_FOREACH R I))$s$(λ2x. (EVAL$(c x)))$f$σ
    ))"
  by (simp_all add: FOREACH_oci_to_monadic)

text {* Setup for linearity analysis. *}
lemma monadic_FOREACH_skel[sepref_la_skel]:
  "!!s c f σ. SKEL ((PR_CONST (monadic_FOREACH R I))$s$c$f$σ) = 
    la_seq 
      (la_op (s,σ)) 
      (la_rec (λD. la_seq (SKEL c) (la_seq (SKEL f) (la_rcall D)))
      )" by simp


subsubsection "Imperative Version of nfoldli"
text {* We define an imperative version of @{text "nfoldli"}. It is the
  equivalent to the monadic version in the nres-monad *}

definition "imp_nfoldli l c f s ≡ heap.fixp_fun (λD (l,s). case l of 
    [] => return s
  | x#ls => do {
      b\<leftarrow>c s;
      if b then do { s'\<leftarrow>f x s; D (ls,s')} else return s
    }
  ) (l,s)"

declare imp_nfoldli_def[code del]

lemma imp_nfoldli_simps[simp,code]:
  "imp_nfoldli [] c f s = return s"
  "imp_nfoldli (x#ls) c f s = (do {
    b \<leftarrow> c s;
    if b then do { 
      s'\<leftarrow>f x s; 
      imp_nfoldli ls c f s'
    } else return s
  })"
  apply -
  unfolding imp_nfoldli_def
  apply (subst heap.mono_body_fixp)
  apply (tactic {* Pf_Mono_Prover.mono_tac @{context} 1 *})
  apply simp
  apply (subst heap.mono_body_fixp)
  apply (tactic {* Pf_Mono_Prover.mono_tac @{context} 1 *})
  apply simp
  done

lemma monadic_nfoldli_refine_aux:
  assumes c_ref: "!!s s'. hn_refine 
    (Γ * hn_ctxt Rs s' s) 
    (c s) 
    (Γ * hn_ctxt Rs s' s) 
    (pure bool_rel)
    (c' s')"
  assumes f_ref: "!!x x' s s'. hn_refine 
    (Γ * hn_ctxt Rl x' x * hn_ctxt Rs s' s)
    (f x s)
    (Γ * hn_invalid x' x * hn_invalid s' s) Rs
    (f' x' s')"

  shows "hn_refine 
    (Γ * hn_list Rl l' l * hn_ctxt Rs s' s) 
    (imp_nfoldli l c f s) 
    (Γ * hn_invalid l' l * hn_invalid s' s) Rs
    (monadic_nfoldli l' c' f' s')"

  apply (induct pRl l' l 
    arbitrary: s s'
    rule: hn_list_aux.induct)

  apply simp
  apply (rule hn_refine_cons_post)
  apply (rule hn_refine_frame[OF hnr_RETURN_pass])
  apply (tactic {* Sepref_Frame.frame_tac @{context} 1 *})
  apply (simp add: hn_ctxt_def ent_true_drop)

  apply (simp only: imp_nfoldli_simps monadic_nfoldli_simp)
  apply (rule hnr_bind)
  apply (rule hn_refine_frame[OF c_ref])
  apply (tactic {* Sepref_Frame.frame_tac @{context} 1 *})

  apply (rule hnr_If)
  apply (tactic {* Sepref_Frame.frame_tac @{context} 1 *})
  apply (rule hnr_bind)
  apply (rule hn_refine_frame[OF f_ref])
  apply (simp add: assn_aci)
  apply (fr_rot_rhs 1)
  apply (fr_rot 2)
  apply (rule fr_refl)
  apply (rule fr_refl)
  apply (rule fr_refl)
  apply (rule ent_refl)

  apply (rule hn_refine_frame)
  apply rprems

  apply (simp add: assn_aci)
  apply (fr_rot_rhs 1)
  apply (fr_rot 2)
  apply (rule fr_refl)
  apply (fr_rot 1)
  apply (rule fr_refl)
  apply (rule fr_refl)
  apply (rule ent_refl)
  
  apply (tactic {* Sepref_Frame.frame_tac @{context} 1 *})

  apply (rule hn_refine_frame[OF hnr_RETURN_pass])
  apply (tactic {* Sepref_Frame.frame_tac @{context} 1 *})

  apply (simp add: assn_assoc)
  apply (tactic {* Sepref_Frame.merge_tac @{context} 1 *})
  apply (simp only: sup.idem, rule ent_refl)
  apply (simp add: hn_ctxt_def, (rule fr_refl ent_refl)+) []

  apply (rule, sep_auto)
  apply (rule, sep_auto)
  done

lemma hn_monadic_nfoldli:
  assumes FR: "P ==>A Γ * hn_list Rl l' l * hn_ctxt Rs s' s"
  assumes c_ref: "!!s s'. hn_refine 
    (Γ * hn_ctxt Rs s' s) 
    (c s) 
    (Γ * hn_ctxt Rs s' s)
    (pure bool_rel) 
    (c'$s')"
  assumes f_ref: "!!x x' s s'. hn_refine 
    (Γ * hn_ctxt Rl x' x * hn_ctxt Rs s' s)
    (f x s)
    (Γ * hn_invalid x' x * hn_invalid s' s) Rs
    (f'$x'$s')"
  shows "hn_refine 
    P 
    (imp_nfoldli l c f s) 
    (Γ * hn_invalid l' l * hn_invalid s' s)
    Rs
    (monadic_nfoldli$l'$c'$f'$s')
    "
  apply (rule hn_refine_cons_pre[OF FR])
  unfolding APP_def
  apply (rule monadic_nfoldli_refine_aux)
  apply (rule c_ref[unfolded APP_def])
  apply (rule f_ref[unfolded APP_def])
  done  

lemma hn_itsl:
  assumes ITSL: "is_set_to_sorted_list ordR Rk Rs tsl"
  shows "hn_refine 
    (hn_val (⟨Rk⟩Rs) s' s) 
    (return (tsl s)) 
    (hn_val (⟨Rk⟩Rs) s' s) 
    (pure (⟨Rk⟩list_rel))
    (it_to_sorted_list ordR s')"
  apply rule
  unfolding hn_ctxt_def pure_def
  apply vcg
  apply clarsimp
  apply (erule is_set_to_sorted_listE[OF ITSL])
  apply sep_auto
  done

lemma hn_monadic_FOREACH[sepref_comb_rules]:
  assumes "INDEP Rk" "INDEP Rs" "INDEP Rσ"
  assumes FR: "P ==>A Γ * hn_val (⟨Rk⟩Rs) s' s * hn_ctxt Rσ σ' σ"
  assumes STL: "GEN_ALGO_tag (is_set_to_sorted_list ordR Rk Rs tsl)"
  assumes c_ref: "!!σ σ'. hn_refine 
    (Γ * hn_val (⟨Rk⟩Rs) s' s * hn_ctxt Rσ σ' σ) 
    (c σ) 
    (Γc σ' σ) 
    (pure bool_rel) 
    (c' σ')"
  assumes C_FR: 
    "!!σ' σ. TERM monadic_FOREACH ==> 
      Γc σ' σ ==>A Γ * hn_val (⟨Rk⟩Rs) s' s * hn_ctxt Rσ σ' σ"

  assumes f_ref: "!!x' x σ' σ. hn_refine 
    (Γ * hn_val (⟨Rk⟩Rs) s' s * hn_val Rk x' x * hn_ctxt Rσ σ' σ)
    (f x σ)
    (Γf x' x σ' σ) Rσ
    (f' x' σ')"
  assumes F_FR: "!!x' x σ' σ. TERM monadic_FOREACH ==> Γf x' x σ' σ ==>A 
    Γ * hn_val (⟨Rk⟩Rs) s' s * hn_ctxt Pfx x' x * hn_ctxt Pfσ σ' σ"

  shows "hn_refine 
    P 
    (imp_nfoldli (tsl$s) c f σ) (* Important: Using tagged application to avoid ho-unifier problems *)
    (Γ * hn_val (⟨Rk⟩Rs) s' s * hn_invalid σ' σ)
    Rσ
    ((PR_CONST (monadic_FOREACH ordR I))
      $(LIN_ANNOT s' a)$(λ2σ'. c' σ')$(λ2x' σ'. f' x' σ')$(σ'L)
    )"
  unfolding APP_def PROTECT2_def LIN_ANNOT_def PR_CONST_def
  apply (rule hn_refine_ref[OF monadic_FOREACH_itsl])
  apply (rule hn_refine_guessI)
  apply (rule hnr_bind)
  apply (rule hn_refine_cons_pre[OF FR])
  apply (rule hn_refine_frame)
  apply (rule hn_itsl[OF STL[unfolded GEN_ALGO_tag_def]])
  apply (tactic {* Sepref_Frame.frame_tac @{context} 1*})
  apply (rule hn_monadic_nfoldli[unfolded APP_def])
  apply (simp add: pure_hn_list_eq_list_rel)
  apply (tactic {* Sepref_Frame.frame_tac @{context} 1*})
  apply (rule hn_refine_cons_post)
  apply (rule c_ref[unfolded APP_def])
  apply (rule C_FR)
  apply (rule TERMI)
  apply (rule hn_refine_cons_post)
  apply (rule f_ref[unfolded APP_def])
  apply (rule ent_trans[OF F_FR])
  apply (rule TERMI)
  apply (tactic {* Sepref_Frame.frame_tac @{context} 1*})
  apply (tactic {* Sepref_Frame.frame_tac @{context} 1*})
  apply simp
  done

(* TODO: We should be able to prove that directly using heap_fixp_mono! 
lemma imp_nfoldli_mono[partial_function_mono]:
  assumes "!!x σ. mono_Heap (λfa. f fa x σ)"
  shows "mono_Heap (λx. imp_nfoldli l c (f x) σ)"
  apply rule
  unfolding imp_nfoldli_def
  apply (rule ccpo.fixp_mono[OF heap.ccpo, THEN fun_ordD])
  apply (rule mono_fun_fun_cnv)
  apply (erule thin_rl)
  apply (tactic {* Pf_Mono_Prover.mono_tac @{context} 1 *})
  apply (rule mono_fun_fun_cnv)
  apply (erule thin_rl)
  apply (tactic {* Pf_Mono_Prover.mono_tac @{context} 1 *})

  apply (rule fun_ordI)
  apply (erule monotoneD[of "fun_ord Heap_ord" Heap_ord, rotated])
  apply (tactic {* Pf_Mono_Prover.mono_tac @{context} 1 *})
  by fact*)


lemma heap_fixp_mono[partial_function_mono]:
  assumes [partial_function_mono]: 
    "!!x d. mono_Heap (λxa. B x xa d)"
    "!!Z xa. mono_Heap (λa. B a Z xa)" 
  shows "mono_Heap (λx. heap.fixp_fun (λD σ. B x D σ) σ)"
  apply rule
  apply (rule ccpo.fixp_mono[OF heap.ccpo, THEN fun_ordD])
  apply (rule mono_fun_fun_cnv, 
    erule thin_rl, tactic {* Pf_Mono_Prover.mono_tac @{context} 1 *})+
  apply (rule fun_ordI)
  apply (erule monotoneD[of "fun_ord Heap_ord" Heap_ord, rotated])
  apply (tactic {* Pf_Mono_Prover.mono_tac @{context} 1 *})
  done

lemma imp_nfoldli_mono[partial_function_mono]:
  assumes [partial_function_mono]: "!!x σ. mono_Heap (λfa. f fa x σ)"
  shows "mono_Heap (λx. imp_nfoldli l c (f x) σ)"
  unfolding imp_nfoldli_def
  by (tactic {* Pf_Mono_Prover.mono_tac @{context} 1 *})

(* Inline nfoldli as fixed-points *)
declare imp_nfoldli_def[sepref_opt_simps]

end