Theory Sepref_HOL_Bindings

theory Sepref_HOL_Bindings
imports Sepref_Tool
section ‹HOL Setup›
theory Sepref_HOL_Bindings
imports Sepref_Tool
begin

subsection ‹Assertion Annotation›
text ‹Annotate an assertion to a term. The term must then be refined with this assertion.›
(* TODO: Version for monadic expressions.*)
definition ASSN_ANNOT :: "('a ⇒ 'ai ⇒ assn) ⇒ 'a ⇒ 'a" where [simp]: "ASSN_ANNOT A x ≡ x"
context fixes A :: "'a ⇒ 'ai ⇒ assn" begin
  sepref_register "PR_CONST (ASSN_ANNOT A)"
  lemma [def_pat_rules]: "ASSN_ANNOT$A ≡ UNPROTECT (ASSN_ANNOT A)" by simp
  lemma [sepref_fr_rules]: "(ureturn o (λx. x), RETURNT o PR_CONST (ASSN_ANNOT A)) ∈ AdaA"
    apply rule
    unfolding hn_refine_def
    apply (auto simp: execute_ureturn' invalid_assn_def relH_def zero_enat_def)     
    by (metis (full_types) entails_def entt_refl' mult.right_neutral pure_true star_aci(3)) 
end  

lemma annotate_assn: "x ≡ ASSN_ANNOT A x" by simp

subsection ‹Shortcuts›
abbreviation "nat_assn ≡ (id_assn::nat ⇒ _)"
abbreviation "int_assn ≡ (id_assn::int ⇒ _)"
abbreviation "bool_assn ≡ (id_assn::bool ⇒ _)"

subsection ‹Identity Relations›
definition "IS_ID R ≡ R=Id"
definition "IS_BELOW_ID R ≡ R⊆Id"

lemma [safe_constraint_rules]: 
  "IS_ID Id"
  "IS_ID R1 ⟹ IS_ID R2 ⟹ IS_ID (R1 → R2)"
  "IS_ID R ⟹ IS_ID (⟨R⟩option_rel)"
  "IS_ID R ⟹ IS_ID (⟨R⟩list_rel)"
  "IS_ID R1 ⟹ IS_ID R2 ⟹ IS_ID (R1 ×r R2)"
  "IS_ID R1 ⟹ IS_ID R2 ⟹ IS_ID (⟨R1,R2⟩sum_rel)"
  by (auto simp: IS_ID_def)

lemma [safe_constraint_rules]: 
  "IS_BELOW_ID Id"
  "IS_BELOW_ID R ⟹ IS_BELOW_ID (⟨R⟩option_rel)"
  "IS_BELOW_ID R1 ⟹ IS_BELOW_ID R2 ⟹ IS_BELOW_ID (R1 ×r R2)"
  "IS_BELOW_ID R1 ⟹ IS_BELOW_ID R2 ⟹ IS_BELOW_ID (⟨R1,R2⟩sum_rel)"
  by (auto simp: IS_ID_def IS_BELOW_ID_def option_rel_def sum_rel_def list_rel_def)

lemma IS_BELOW_ID_fun_rel_aux: "R1⊇Id ⟹ IS_BELOW_ID R2 ⟹ IS_BELOW_ID (R1 → R2)"
  by (auto simp: IS_BELOW_ID_def dest: fun_relD)

corollary IS_BELOW_ID_fun_rel[safe_constraint_rules]: 
  "IS_ID R1 ⟹ IS_BELOW_ID R2 ⟹ IS_BELOW_ID (R1 → R2)"
  using IS_BELOW_ID_fun_rel_aux[of Id R2]
  by (auto simp: IS_ID_def)


lemma IS_BELOW_ID_list_rel[safe_constraint_rules]: 
  "IS_BELOW_ID R ⟹ IS_BELOW_ID (⟨R⟩list_rel)"
  unfolding IS_BELOW_ID_def
proof safe
  fix l l'
  assume A: "R⊆Id" 
  assume "(l,l')∈⟨R⟩list_rel"
  thus "l=l'"
    apply induction
    using A by auto
qed

lemma IS_ID_imp_BELOW_ID[constraint_rules]: 
  "IS_ID R ⟹ IS_BELOW_ID R"
  by (auto simp: IS_ID_def IS_BELOW_ID_def )



subsection ‹Inverse Relation›

lemma inv_fun_rel_eq[simp]: "(A→B)¯ = A¯→B¯"
  by (auto dest: fun_relD)

lemma inv_option_rel_eq[simp]: "(⟨K⟩option_rel)¯ = ⟨K¯⟩option_rel"
  by (auto simp: option_rel_def)

lemma inv_prod_rel_eq[simp]: "(P ×r Q)¯ = P¯ ×r Q¯"
  by (auto)

lemma inv_sum_rel_eq[simp]: "(⟨P,Q⟩sum_rel)¯ = ⟨P¯,Q¯⟩sum_rel"
  by (auto simp: sum_rel_def)

lemma inv_list_rel_eq[simp]: "(⟨R⟩list_rel)¯ = ⟨R¯⟩list_rel"
  unfolding list_rel_def
  apply safe
  apply (subst list.rel_flip[symmetric])
  apply (simp add: conversep_iff[abs_def])
  apply (subst list.rel_flip[symmetric])
  apply (simp add: conversep_iff[abs_def])
  done

lemmas [constraint_simps] =
  Relation.converse_Id
  inv_fun_rel_eq
  inv_option_rel_eq
  inv_prod_rel_eq
  inv_sum_rel_eq
  inv_list_rel_eq


subsection ‹Single Valued and Total Relations›

(* TODO: Link to other such theories: Transfer, Autoref *)
definition "IS_LEFT_UNIQUE R ≡ single_valued (R¯)"
definition "IS_LEFT_TOTAL R ≡ Domain R = UNIV"
definition "IS_RIGHT_TOTAL R ≡ Range R = UNIV"
abbreviation (input) "IS_RIGHT_UNIQUE ≡ single_valued"

lemmas IS_RIGHT_UNIQUED = single_valuedD
lemma IS_LEFT_UNIQUED: "⟦IS_LEFT_UNIQUE r; (y, x) ∈ r; (z, x) ∈ r⟧ ⟹ y = z"
  by (auto simp: IS_LEFT_UNIQUE_def dest: single_valuedD)

lemma prop2p:
  "IS_LEFT_UNIQUE R = left_unique (rel2p R)"
  "IS_RIGHT_UNIQUE R = right_unique (rel2p R)"
  "right_unique (rel2p (R¯)) = left_unique (rel2p R)"
  "IS_LEFT_TOTAL R = left_total (rel2p R)"
  "IS_RIGHT_TOTAL R = right_total (rel2p R)"
  by (auto 
    simp: IS_LEFT_UNIQUE_def left_unique_def single_valued_def
    simp: right_unique_def
    simp: IS_LEFT_TOTAL_def left_total_def
    simp: IS_RIGHT_TOTAL_def right_total_def
    simp: rel2p_def
    )

lemma p2prop:
  "left_unique P = IS_LEFT_UNIQUE (p2rel P)"
  "right_unique P = IS_RIGHT_UNIQUE (p2rel P)"
  "left_total P = IS_LEFT_TOTAL (p2rel P)"
  "right_total P = IS_RIGHT_TOTAL (p2rel P)"
  "bi_unique P ⟷ left_unique P ∧ right_unique P"
  by (auto 
    simp: IS_LEFT_UNIQUE_def left_unique_def single_valued_def
    simp: right_unique_def bi_unique_alt_def
    simp: IS_LEFT_TOTAL_def left_total_def
    simp: IS_RIGHT_TOTAL_def right_total_def
    simp: p2rel_def
    )

lemmas [safe_constraint_rules] = 
  single_valued_Id  
  prod_rel_sv 
  list_rel_sv 
  option_rel_sv 
  sum_rel_sv

lemma [safe_constraint_rules]:
  "IS_LEFT_UNIQUE Id"
  "IS_LEFT_UNIQUE R1 ⟹ IS_LEFT_UNIQUE R2 ⟹ IS_LEFT_UNIQUE (R1×rR2)"
  "IS_LEFT_UNIQUE R1 ⟹ IS_LEFT_UNIQUE R2 ⟹ IS_LEFT_UNIQUE (⟨R1,R2⟩sum_rel)"
  "IS_LEFT_UNIQUE R ⟹ IS_LEFT_UNIQUE (⟨R⟩option_rel)"
  "IS_LEFT_UNIQUE R ⟹ IS_LEFT_UNIQUE (⟨R⟩list_rel)"
  by (auto simp: IS_LEFT_UNIQUE_def prod_rel_sv sum_rel_sv option_rel_sv list_rel_sv)

lemma IS_LEFT_TOTAL_alt: "IS_LEFT_TOTAL R ⟷ (∀x. ∃y. (x,y)∈R)"
  by (auto simp: IS_LEFT_TOTAL_def)

lemma IS_RIGHT_TOTAL_alt: "IS_RIGHT_TOTAL R ⟷ (∀x. ∃y. (y,x)∈R)"
  by (auto simp: IS_RIGHT_TOTAL_def)

lemma [safe_constraint_rules]:
  "IS_LEFT_TOTAL Id"
  "IS_LEFT_TOTAL R1 ⟹ IS_LEFT_TOTAL R2 ⟹ IS_LEFT_TOTAL (R1×rR2)"
  "IS_LEFT_TOTAL R1 ⟹ IS_LEFT_TOTAL R2 ⟹ IS_LEFT_TOTAL (⟨R1,R2⟩sum_rel)"
  "IS_LEFT_TOTAL R ⟹ IS_LEFT_TOTAL (⟨R⟩option_rel)"
  apply (auto simp: IS_LEFT_TOTAL_alt sum_rel_def option_rel_def list_rel_def)
  apply (rename_tac x; case_tac x; auto)
  apply (rename_tac x; case_tac x; auto)
  done

lemma [safe_constraint_rules]: "IS_LEFT_TOTAL R ⟹ IS_LEFT_TOTAL (⟨R⟩list_rel)"
  unfolding IS_LEFT_TOTAL_alt
proof safe
  assume A: "∀x.∃y. (x,y)∈R"
  fix l
  show "∃l'. (l,l')∈⟨R⟩list_rel"
    apply (induction l)
    using A
    by (auto simp: list_rel_split_right_iff)
qed

lemma [safe_constraint_rules]:
  "IS_RIGHT_TOTAL Id"
  "IS_RIGHT_TOTAL R1 ⟹ IS_RIGHT_TOTAL R2 ⟹ IS_RIGHT_TOTAL (R1×rR2)"
  "IS_RIGHT_TOTAL R1 ⟹ IS_RIGHT_TOTAL R2 ⟹ IS_RIGHT_TOTAL (⟨R1,R2⟩sum_rel)"
  "IS_RIGHT_TOTAL R ⟹ IS_RIGHT_TOTAL (⟨R⟩option_rel)"
  apply (auto simp: IS_RIGHT_TOTAL_alt sum_rel_def option_rel_def) []
  apply (auto simp: IS_RIGHT_TOTAL_alt sum_rel_def option_rel_def) []
  apply (auto simp: IS_RIGHT_TOTAL_alt sum_rel_def option_rel_def) []
  apply (rename_tac x; case_tac x; auto)
  apply (clarsimp simp: IS_RIGHT_TOTAL_alt option_rel_def)
  apply (rename_tac x; case_tac x; auto)
  done

lemma [safe_constraint_rules]: "IS_RIGHT_TOTAL R ⟹ IS_RIGHT_TOTAL (⟨R⟩list_rel)"
  unfolding IS_RIGHT_TOTAL_alt
proof safe
  assume A: "∀x.∃y. (y,x)∈R"
  fix l
  show "∃l'. (l',l)∈⟨R⟩list_rel"
    apply (induction l)
    using A
    by (auto simp: list_rel_split_left_iff)
qed
  
lemma [constraint_simps]:
  "IS_LEFT_TOTAL (R¯) ⟷ IS_RIGHT_TOTAL R "
  "IS_RIGHT_TOTAL (R¯) ⟷ IS_LEFT_TOTAL R  "
  "IS_LEFT_UNIQUE (R¯) ⟷ IS_RIGHT_UNIQUE R"
  "IS_RIGHT_UNIQUE (R¯) ⟷ IS_LEFT_UNIQUE R "
  by (auto simp: IS_RIGHT_TOTAL_alt IS_LEFT_TOTAL_alt IS_LEFT_UNIQUE_def)

lemma [safe_constraint_rules]:
  "IS_RIGHT_UNIQUE A ⟹ IS_RIGHT_TOTAL B ⟹ IS_RIGHT_TOTAL (A→B)"
  "IS_RIGHT_TOTAL A ⟹ IS_RIGHT_UNIQUE B ⟹ IS_RIGHT_UNIQUE (A→B)"
  "IS_LEFT_UNIQUE A ⟹ IS_LEFT_TOTAL B ⟹ IS_LEFT_TOTAL (A→B)"
  "IS_LEFT_TOTAL A ⟹ IS_LEFT_UNIQUE B ⟹ IS_LEFT_UNIQUE (A→B)"
  apply (simp_all add: prop2p rel2p)
  (*apply transfer_step TODO: Isabelle 2016 *)
  apply (blast intro!: transfer_raw)+
  done

lemma [constraint_rules]: 
  "IS_BELOW_ID R ⟹ IS_RIGHT_UNIQUE R"
  "IS_BELOW_ID R ⟹ IS_LEFT_UNIQUE R"
  "IS_ID R ⟹ IS_RIGHT_TOTAL R"
  "IS_ID R ⟹ IS_LEFT_TOTAL R"
  by (auto simp: IS_BELOW_ID_def IS_ID_def IS_LEFT_UNIQUE_def IS_RIGHT_TOTAL_def IS_LEFT_TOTAL_def
    intro: single_valuedI)

thm constraint_rules

subsubsection ‹Additional Parametricity Lemmas›
(* TODO: Move. Problem: Depend on IS_LEFT_UNIQUE, which has to be moved to!*)

lemma param_distinct[param]: "⟦IS_LEFT_UNIQUE A; IS_RIGHT_UNIQUE A⟧ ⟹ (distinct, distinct) ∈ ⟨A⟩list_rel → bool_rel"  
  apply (fold rel2p_def)
  apply (simp add: rel2p)
  apply (rule distinct_transfer)
  apply (simp add: p2prop)
  done

lemma param_Image[param]: 
  assumes "IS_LEFT_UNIQUE A" "IS_RIGHT_UNIQUE A"
  shows "((``), (``)) ∈ ⟨A×rB⟩set_rel → ⟨A⟩set_rel → ⟨B⟩set_rel"
  apply (clarsimp simp: set_rel_def; intro conjI)  
  apply (fastforce dest: IS_RIGHT_UNIQUED[OF assms(2)])
  apply (fastforce dest: IS_LEFT_UNIQUED[OF assms(1)])
  done

lemma pres_eq_iff_svb: "((=),(=))∈K→K→bool_rel ⟷ (single_valued K ∧ single_valued (K¯))"
  apply (safe intro!: single_valuedI)
  apply (metis (full_types) IdD fun_relD1)
  apply (metis (full_types) IdD fun_relD1)
  by (auto dest: single_valuedD)

definition "IS_PRES_EQ R ≡ ((=), (=))∈R→R→bool_rel"
lemma [constraint_rules]: "⟦single_valued R; single_valued (R¯)⟧ ⟹ IS_PRES_EQ R"
  by (simp add: pres_eq_iff_svb IS_PRES_EQ_def)


subsection ‹Bounded Assertions›
definition "b_rel R P ≡ R ∩ UNIV×Collect P"
definition "b_assn A P ≡ λx y. A x y * ↑(P x)"

lemma b_assn_pure_conv[constraint_simps]: "b_assn (pure R) P = pure (b_rel R P)"
  by (auto intro!: ext simp: b_rel_def b_assn_def pure_def)
lemmas [sepref_import_rewrite, sepref_frame_normrel_eqs, fcomp_norm_unfold] 
  = b_assn_pure_conv[symmetric]

lemma b_rel_nesting[simp]: 
  "b_rel (b_rel R P1) P2 = b_rel R (λx. P1 x ∧ P2 x)"
  by (auto simp: b_rel_def)
lemma b_rel_triv[simp]: 
  "b_rel R (λ_. True) = R"
  by (auto simp: b_rel_def)
lemma b_assn_nesting[simp]: 
  "b_assn (b_assn A P1) P2 = b_assn A (λx. P1 x ∧ P2 x)"
  by (auto simp: b_assn_def pure_def mult.assoc intro!: ext)
lemma b_assn_triv[simp]: 
  "b_assn A (λ_. True) = A"
  by (auto simp: b_assn_def pure_def intro!: ext)

lemmas [simp,constraint_simps,sepref_import_rewrite, sepref_frame_normrel_eqs, fcomp_norm_unfold]
  = b_rel_nesting b_assn_nesting

lemma b_rel_simp[simp]: "(x,y)∈b_rel R P ⟷ (x,y)∈R ∧ P y"
  by (auto simp: b_rel_def)

lemma b_assn_simp[simp]: "b_assn A P x y = A x y * ↑(P x)"
  by (auto simp: b_assn_def)

lemma b_rel_Range[simp]: "Range (b_rel R P) = Range R ∩ Collect P" by auto
lemma b_assn_rdom[simp]: "rdomp (b_assn R P) x ⟷ rdomp R x ∧ P x"
  by (auto simp: rdomp_def)


lemma b_rel_below_id[constraint_rules]: 
  "IS_BELOW_ID R ⟹ IS_BELOW_ID (b_rel R P)"
  by (auto simp: IS_BELOW_ID_def)

lemma b_rel_left_unique[constraint_rules]: 
  "IS_LEFT_UNIQUE R ⟹ IS_LEFT_UNIQUE (b_rel R P)"
  by (auto simp: IS_LEFT_UNIQUE_def single_valued_def)
  
lemma b_rel_right_unique[constraint_rules]: 
  "IS_RIGHT_UNIQUE R ⟹ IS_RIGHT_UNIQUE (b_rel R P)"
  by (auto simp: single_valued_def)

― ‹Registered as safe rule, although may loose information in the 
    odd case that purity depends condition.›
lemma b_assn_is_pure[safe_constraint_rules]:
  "is_pure A ⟹ is_pure (b_assn A P)"
  by (auto simp: is_pure_conv b_assn_pure_conv)

― ‹Most general form›
lemma b_assn_subtyping_match[sepref_frame_match_rules]:
  assumes "hn_ctxt (b_assn A P) x y ⟹t hn_ctxt A' x y"
  assumes "⟦vassn_tag (hn_ctxt A x y); vassn_tag (hn_ctxt A' x y); P x⟧ ⟹ P' x"
  shows "hn_ctxt (b_assn A P) x y ⟹t hn_ctxt (b_assn A' P') x y"
  using assms
  unfolding hn_ctxt_def b_assn_def entailst_def entails_def
  by (auto simp: vassn_tag_def move_back_pure' dest: mod_starD )
    
― ‹Simplified forms:›
lemma b_assn_subtyping_match_eqA[sepref_frame_match_rules]:
  assumes "⟦vassn_tag (hn_ctxt A x y); P x⟧ ⟹ P' x"
  shows "hn_ctxt (b_assn A P) x y ⟹t hn_ctxt (b_assn A P') x y"
  apply (rule b_assn_subtyping_match)
  subgoal 
    unfolding hn_ctxt_def b_assn_def entailst_def entails_def
    by (auto simp: vassn_tag_def intro:mod_star_trueI)
  subgoal
    using assms .
  done  

lemma b_assn_subtyping_match_tR[sepref_frame_match_rules]:
  assumes "⟦P x⟧ ⟹ hn_ctxt A x y ⟹t hn_ctxt A' x y"
  shows "hn_ctxt (b_assn A P) x y ⟹t hn_ctxt A' x y"
  using assms
  unfolding hn_ctxt_def b_assn_def entailst_def entails_def
  by (auto simp: vassn_tag_def  )

lemma b_assn_subtyping_match_tL[sepref_frame_match_rules]:
  assumes "hn_ctxt A x y ⟹t hn_ctxt A' x y"
  assumes "⟦vassn_tag (hn_ctxt A x y)⟧ ⟹ P' x"
  shows "hn_ctxt A x y ⟹t hn_ctxt (b_assn A' P') x y"
  using assms
  unfolding hn_ctxt_def b_assn_def entailst_def entails_def
  by (fastforce simp: vassn_tag_def  )


lemma b_assn_subtyping_match_eqA_tR[sepref_frame_match_rules]: 
  "hn_ctxt (b_assn A P) x y ⟹t hn_ctxt A x y"
  unfolding hn_ctxt_def b_assn_def
  by (auto intro!: enttI  ) 

lemma b_assn_subtyping_match_eqA_tL[sepref_frame_match_rules]:
  assumes "⟦vassn_tag (hn_ctxt A x y)⟧ ⟹ P' x"
  shows "hn_ctxt A x y ⟹t hn_ctxt (b_assn A P') x y"
  using assms
  unfolding hn_ctxt_def b_assn_def entailst_def entails_def
  by (auto simp: vassn_tag_def move_back_pure' intro:mod_star_trueI )

― ‹General form›
lemma b_rel_subtyping_merge[sepref_frame_merge_rules]:
  assumes "hn_ctxt A x y ∨A hn_ctxt A' x y ⟹t hn_ctxt Am x y"
  shows "hn_ctxt (b_assn A P) x y ∨A hn_ctxt (b_assn A' P') x y ⟹t hn_ctxt (b_assn Am (λx. P x ∨ P' x)) x y"
  using assms
  unfolding hn_ctxt_def b_assn_def entailst_def entails_def
  by (fastforce simp: vassn_tag_def)
  
― ‹Simplified forms›
lemma b_rel_subtyping_merge_eqA[sepref_frame_merge_rules]:
  shows "hn_ctxt (b_assn A P) x y ∨A hn_ctxt (b_assn A P') x y ⟹t hn_ctxt (b_assn A (λx. P x ∨ P' x)) x y"
  apply (rule b_rel_subtyping_merge)
  apply (auto simp add: entailst_def intro!: ent_true_drop(2))
  apply(rule ent_disjE) by auto

lemma b_rel_subtyping_merge_tL[sepref_frame_merge_rules]:
  assumes "hn_ctxt A x y ∨A hn_ctxt A' x y ⟹t hn_ctxt Am x y"
  shows "hn_ctxt A x y ∨A hn_ctxt (b_assn A' P') x y ⟹t hn_ctxt Am x y"
  using b_rel_subtyping_merge[of A x y A' Am "λ_. True" P', simplified] assms .

lemma b_rel_subtyping_merge_tR[sepref_frame_merge_rules]:
  assumes "hn_ctxt A x y ∨A hn_ctxt A' x y ⟹t hn_ctxt Am x y"
  shows "hn_ctxt (b_assn A P) x y ∨A hn_ctxt A' x y ⟹t hn_ctxt Am x y"
  using b_rel_subtyping_merge[of A x y A' Am P "λ_. True", simplified] assms .

lemma b_rel_subtyping_merge_eqA_tL[sepref_frame_merge_rules]:
  shows "hn_ctxt A x y ∨A hn_ctxt (b_assn A P') x y ⟹t hn_ctxt A x y"
  using b_rel_subtyping_merge_eqA[of A "λ_. True" x y P', simplified] .

lemma b_rel_subtyping_merge_eqA_tR[sepref_frame_merge_rules]:
  shows "hn_ctxt (b_assn A P) x y ∨A hn_ctxt A x y ⟹t hn_ctxt A x y"
  using b_rel_subtyping_merge_eqA[of A P x y "λ_. True", simplified] .

(* TODO: Combinatorial explosion :( *)
lemma b_assn_invalid_merge1: "hn_invalid (b_assn A P) x y ∨A hn_invalid (b_assn A P') x y
  ⟹t hn_invalid (b_assn A (λx. P x ∨ P' x)) x y"
  by (auto simp: hn_ctxt_def invalid_assn_def entailst_def) 

lemma b_assn_invalid_merge2: "hn_invalid (b_assn A P) x y ∨A hn_invalid A x y
  ⟹t hn_invalid A x y"
  by (auto simp: hn_ctxt_def invalid_assn_def entailst_def)

lemma b_assn_invalid_merge3: "hn_invalid A x y ∨A hn_invalid (b_assn A P) x y
  ⟹t hn_invalid A x y"
  by (auto simp: hn_ctxt_def invalid_assn_def entailst_def)

lemma b_assn_invalid_merge4: "hn_invalid (b_assn A P) x y ∨A hn_ctxt (b_assn A P') x y
  ⟹t hn_invalid (b_assn A (λx. P x ∨ P' x)) x y"
  by (auto simp: hn_ctxt_def invalid_assn_def entailst_def)
lemma b_assn_invalid_merge5: "hn_ctxt (b_assn A P') x y ∨A hn_invalid (b_assn A P) x y
  ⟹t hn_invalid (b_assn A (λx. P x ∨ P' x)) x y"
  by (auto simp: hn_ctxt_def invalid_assn_def entailst_def)

lemma b_assn_invalid_merge6: "hn_invalid (b_assn A P) x y ∨A hn_ctxt A x y
  ⟹t hn_invalid A x y"
  by (auto simp: hn_ctxt_def invalid_assn_def entailst_def)
lemma b_assn_invalid_merge7: "hn_ctxt A x y ∨A hn_invalid (b_assn A P) x y
  ⟹t hn_invalid A x y"
  by (auto simp: hn_ctxt_def invalid_assn_def entailst_def)

lemma b_assn_invalid_merge8: "hn_ctxt (b_assn A P) x y ∨A hn_invalid A x y
  ⟹t hn_invalid A x y"
  by (auto simp: hn_ctxt_def invalid_assn_def entailst_def)
lemma b_assn_invalid_merge9: "hn_invalid A x y ∨A hn_ctxt (b_assn A P) x y
  ⟹t hn_invalid A x y"
  by (auto simp: hn_ctxt_def invalid_assn_def entailst_def)

lemmas b_assn_invalid_merge[sepref_frame_merge_rules] = 
  b_assn_invalid_merge1
  b_assn_invalid_merge2
  b_assn_invalid_merge3
  b_assn_invalid_merge4
  b_assn_invalid_merge5
  b_assn_invalid_merge6
  b_assn_invalid_merge7
  b_assn_invalid_merge8
  b_assn_invalid_merge9




(*
lemma list_rel_b_id: "∀x∈set l. B x ⟹ (l,l)∈⟨b_rel B⟩list_rel"
  by (induction l) auto
*)


abbreviation nbn_rel :: "nat ⇒ (nat × nat) set" 
  ― ‹Natural numbers with upper bound.›
  where "nbn_rel n ≡ b_rel nat_rel (λx::nat. x<n)"  

abbreviation nbn_assn :: "nat ⇒ nat ⇒ nat ⇒ assn" 
  ― ‹Natural numbers with upper bound.›
  where "nbn_assn n ≡ b_assn nat_assn (λx::nat. x<n)"  

(*
subsection ‹Bounded Identity Relations›
definition "b_rel B ≡ {(x,x) | x. B x}"

lemma b_rel_simp[simp]: "(x,y)∈b_rel B ⟷ x=y ∧ B y"
  by (auto simp: b_rel_def)

lemma b_rel_Range[simp]: "Range (b_rel B) = Collect B" by auto

lemma b_rel_below_id[safe_constraint_rules]: "IS_BELOW_ID (b_rel B)"
  by (auto simp: IS_BELOW_ID_def)

lemma list_rel_b_id: "∀x∈set l. B x ⟹ (l,l)∈⟨b_rel B⟩list_rel"
  by (induction l) auto

lemma b_rel_subtyping_match[sepref_frame_match_rules]:
  "P x ⟹ hn_val Id x y ⟹t hn_val (b_rel P) x y"
  "⟦P1 x ⟹ P2 x⟧ ⟹ hn_val (b_rel P1) x y ⟹t hn_val (b_rel P2) x y"
  "hn_val (b_rel P) x y ⟹t hn_val Id x y"
  by (auto simp: hn_ctxt_def pure_def intro: enttI)

lemma b_rel_subtyping_merge[sepref_frame_merge_rules]:
  "hn_val Id x y ∨A hn_val (b_rel P) x y ⟹t hn_val Id x y"
  "hn_val (b_rel P) x y ∨A hn_val Id x y ⟹t hn_val Id x y"
  "hn_val (b_rel P1) x y ∨A hn_val (b_rel P2) x y ⟹t hn_val (b_rel (λx. P1 x ∨ P2 x)) x y"
  by (auto simp: hn_ctxt_def pure_def intro: enttI)


abbreviation nbn_rel :: "nat ⇒ (nat × nat) set" 
  -- ‹Natural numbers with upper bound.›
  where "nbn_rel n ≡ b_rel (λx::nat. x<n)"  


*)


subsection ‹Tool Setup›
lemmas [sepref_relprops] = 
  sepref_relpropI[of IS_LEFT_UNIQUE]
  sepref_relpropI[of IS_RIGHT_UNIQUE]
  sepref_relpropI[of IS_LEFT_TOTAL]
  sepref_relpropI[of IS_RIGHT_TOTAL]
  sepref_relpropI[of is_pure]
  sepref_relpropI[of "IS_PURE Φ" for Φ]
  sepref_relpropI[of IS_ID]
  sepref_relpropI[of IS_BELOW_ID]
 


lemma [sepref_relprops_simps]:
  "CONSTRAINT (IS_PURE IS_ID) A ⟹ CONSTRAINT (IS_PURE IS_BELOW_ID) A"
  "CONSTRAINT (IS_PURE IS_ID) A ⟹ CONSTRAINT (IS_PURE IS_LEFT_TOTAL) A"
  "CONSTRAINT (IS_PURE IS_ID) A ⟹ CONSTRAINT (IS_PURE IS_RIGHT_TOTAL) A"
  "CONSTRAINT (IS_PURE IS_BELOW_ID) A ⟹ CONSTRAINT (IS_PURE IS_LEFT_UNIQUE) A"
  "CONSTRAINT (IS_PURE IS_BELOW_ID) A ⟹ CONSTRAINT (IS_PURE IS_RIGHT_UNIQUE) A"
  by (auto 
    simp: IS_ID_def IS_BELOW_ID_def IS_PURE_def IS_LEFT_UNIQUE_def
    simp: IS_LEFT_TOTAL_def IS_RIGHT_TOTAL_def
    simp: single_valued_below_Id)

declare True_implies_equals[sepref_relprops_simps]

lemma [sepref_relprops_transform]: "single_valued (R¯) = IS_LEFT_UNIQUE R"
  by (auto simp: IS_LEFT_UNIQUE_def)


subsection ‹HOL Combinators›
lemma hn_if[sepref_comb_rules]:
  assumes P: "Γ ⟹t Γ1 * hn_val bool_rel a a'"
  assumes RT: "a ⟹ hn_refine (Γ1 * hn_val bool_rel a a') b' Γ2b R b"
  assumes RE: "¬a ⟹ hn_refine (Γ1 * hn_val bool_rel a a') c' Γ2c R c"
  assumes IMP: "TERM If ⟹ Γ2b ∨A Γ2c ⟹t Γ'"
  shows "hn_refine Γ (if a' then b' else c') Γ' R (If$a$b$c)"
  using P RT RE IMP[OF TERMI]
  unfolding APP_def PROTECT2_def 
  by (rule hnr_If)

lemmas [sepref_opt_simps] = if_True if_False

lemma hn_let[sepref_comb_rules]:
  assumes P: "Γ ⟹t Γ1 * hn_ctxt R v v'"
  assumes R: "⋀x x'. x=v ⟹ hn_refine (Γ1 * hn_ctxt R x x') (f' x') 
    (Γ' x x') R2 (f x)"
  assumes F: "⋀x x'. Γ' x x' ⟹t Γ2 * hn_ctxt R' x x'"
  shows 
    "hn_refine Γ (Let v' f') (Γ2 * hn_ctxt R' v v') R2 (Let$v$(λ2x. f x))"
  apply (rule hn_refine_cons[OF P _ F entt_refl])
  apply (simp)
  apply (rule R)
  by simp

subsection ‹Basic HOL types›

lemma hnr_default[sepref_import_param]: "(default,default)∈Id" by simp

lemma unit_hnr[sepref_import_param]: "((),())∈unit_rel" by auto
    
lemmas [sepref_import_param] = 
  param_bool
  param_nat1
  param_int

lemmas [id_rules] = 
  itypeI[Pure.of 0 "TYPE (nat)"]
  itypeI[Pure.of 0 "TYPE (int)"]
  itypeI[Pure.of 1 "TYPE (nat)"]
  itypeI[Pure.of 1 "TYPE (int)"]
  itypeI[Pure.of numeral "TYPE (num ⇒ nat)"]
  itypeI[Pure.of numeral "TYPE (num ⇒ int)"]
  itype_self[of num.One]
  itype_self[of num.Bit0]
  itype_self[of num.Bit1]

lemma param_min_nat[param,sepref_import_param]: "(min,min)∈nat_rel → nat_rel → nat_rel" by auto
lemma param_max_nat[param,sepref_import_param]: "(max,max)∈nat_rel → nat_rel → nat_rel" by auto

lemma param_min_int[param,sepref_import_param]: "(min,min)∈int_rel → int_rel → int_rel" by auto
lemma param_max_int[param,sepref_import_param]: "(max,max)∈int_rel → int_rel → int_rel" by auto

lemma uminus_hnr[sepref_import_param]: "(uminus,uminus)∈int_rel → int_rel" by auto
    
lemma nat_param[param,sepref_import_param]: "(nat,nat) ∈ int_rel → nat_rel" by auto
lemma int_param[param,sepref_import_param]: "(int,int) ∈ nat_rel → int_rel" by auto
      
      
      
subsection "Product"


lemmas [sepref_import_rewrite, sepref_frame_normrel_eqs, fcomp_norm_unfold] = prod_assn_pure_conv[symmetric]

lemma prod_assn_precise[constraint_rules]: 
  "precise P1 ⟹ precise P2 ⟹ precise (prod_assn P1 P2)"
  apply rule
  apply (clarsimp simp: prod_assn_def mult.assoc)
  apply safe
  subgoal apply (erule (1) prec_frame) by(rule match_first, rule entails_triv)+
  subgoal apply (erule (1) prec_frame)
      apply rotatel apply (rule match_first) apply (rule entails_triv)
    apply rotatel apply (rule match_first) apply (rule entails_triv)
    done
  done
(*
lemma  
  "precise P1 ⟹ precise P2 ⟹ precise (prod_assn P1 P2)" ― ‹Original proof›
  apply rule
  apply (clarsimp simp: prod_assn_def)
proof (rule conjI)
  fix F F' h as a b a' b' ap bp
  assume P1: "precise P1" and P2: "precise P2"
  assume F: "(h, as) ⊨ P1 a ap * P2 b bp * F ∧A P1 a' ap * P2 b' bp * F'"

  from F have "(h, as) ⊨ P1 a ap * (P2 b bp * F) ∧A P1 a' ap * (P2 b' bp * F')"
    by (simp only: mult.assoc)
  with preciseD[OF P1] show "a=a'" .
  from F have "(h, as) ⊨ P2 b bp * (P1 a ap * F) ∧A P2 b' bp * (P1 a' ap * F')"
    by (simp only: mult.assoc[where 'a=assn] mult.commute[where 'a=assn] mult.left_commute[where 'a=assn])
  with preciseD[OF P2] show "b=b'" .
qed *)

(* TODO Add corresponding rules for other types and add to datatype snippet *)
lemma intf_of_prod_assn[intf_of_assn]:
  assumes "intf_of_assn A TYPE('a)" "intf_of_assn B TYPE('b)"
  shows "intf_of_assn (prod_assn A B) TYPE('a * 'b)"
by simp

lemma pure_prod[constraint_rules]: 
  assumes P1: "is_pure P1" and P2: "is_pure P2"
  shows "is_pure (prod_assn P1 P2)"
proof -
  from P1 obtain P1' where P1': "⋀x x'. P1 x x' = ↑(P1' x x')"
    using is_pureE by blast
  from P2 obtain P2' where P2': "⋀x x'. P2 x x' = ↑(P2' x x')"
    using is_pureE by blast

  show ?thesis proof
    fix x x'
    show "prod_assn P1 P2 x x' =
         ↑ (case (x, x') of ((a1, a2), c1, c2) ⇒ P1' a1 c1 ∧ P2' a2 c2)"
      unfolding prod_assn_def
      apply (simp add: P1' P2' split: prod.split)
      done
  qed
qed

lemma prod_frame_match[sepref_frame_match_rules]:
  assumes "hn_ctxt A (fst x) (fst y) ⟹t hn_ctxt A' (fst x) (fst y)"
  assumes "hn_ctxt B (snd x) (snd y) ⟹t hn_ctxt B' (snd x) (snd y)"
  shows "hn_ctxt (prod_assn A B) x y ⟹t hn_ctxt (prod_assn A' B') x y"
  apply (cases x; cases y; simp)
  apply (simp add: hn_ctxt_def)
  apply (rule entt_star_mono)
  using assms apply (auto simp: hn_ctxt_def)
  done

lemma prod_frame_merge[sepref_frame_merge_rules]:   
  assumes "hn_ctxt A (fst x) (fst y) ∨A hn_ctxt A' (fst x) (fst y) ⟹t hn_ctxt Am (fst x) (fst y)"
  assumes "hn_ctxt B (snd x) (snd y) ∨A hn_ctxt B' (snd x) (snd y) ⟹t hn_ctxt Bm (snd x) (snd y)"
  shows "hn_ctxt (prod_assn A B) x y ∨A hn_ctxt (prod_assn A' B') x y ⟹t hn_ctxt (prod_assn Am Bm) x y"
  by (blast intro: entt_disjE prod_frame_match 
    entt_disjD1[OF assms(1)] entt_disjD2[OF assms(1)]
    entt_disjD1[OF assms(2)] entt_disjD2[OF assms(2)])
  
lemma entt_invalid_prod: "hn_invalid (prod_assn A B) p p' ⟹t hn_ctxt (prod_assn (invalid_assn A) (invalid_assn B)) p p'"
    apply (simp add: hn_ctxt_def invalid_assn_def[abs_def])
    apply (rule enttI)
    apply clarsimp
    apply (cases p; cases p'; auto simp:  pure_def dest: mod_starD) 
    done

lemmas invalid_prod_merge[sepref_frame_merge_rules] = gen_merge_cons[OF entt_invalid_prod]

lemma prod_assn_ctxt: "prod_assn A1 A2 x y = z ⟹ hn_ctxt (prod_assn A1 A2) x y = z"
  by (simp add: hn_ctxt_def)

lemma hn_case_prod'[sepref_prep_comb_rule,sepref_comb_rules]:
  assumes FR: "Γ⟹thn_ctxt (prod_assn P1 P2) p' p * Γ1"
  assumes Pair: "⋀a1 a2 a1' a2'. ⟦p'=(a1',a2')⟧ 
    ⟹ hn_refine (hn_ctxt P1 a1' a1 * hn_ctxt P2 a2' a2 * Γ1 * hn_invalid (prod_assn P1 P2) p' p) (f a1 a2) 
          (hn_ctxt P1' a1' a1 * hn_ctxt P2' a2' a2 * hn_ctxt XX1 p' p * Γ1') R (f' a1' a2')"
  shows "hn_refine Γ (case_prod f p) (hn_ctxt (prod_assn P1' P2') p' p * Γ1')
    R (case_prod$(λ2a b. f' a b)$p')" (is "?G Γ")
    apply1 (rule hn_refine_cons_pre[OF FR])
    apply1 extract_hnr_invalids
    apply1 (cases p; cases p'; simp add: prod_assn_pair_conv[THEN prod_assn_ctxt])
    apply (rule hn_refine_cons[OF _ Pair _ entt_refl])
   
    applyS (simp add: hn_ctxt_def  ) 
    applyS simp
      subgoal  
        apply  (simp only: hn_ctxt_def entailst_def mult.assoc)
    apply(rule match_first)
    apply(rule match_first) apply(rotatel)
    apply(rule match_first)  by simp
      done
(*
lemma hn_case_prod_old:
  assumes P: "Γ⟹tΓ1 * hn_ctxt (prod_assn P1 P2) p' p"
  assumes R: "⋀a1 a2 a1' a2'. ⟦p'=(a1',a2')⟧ 
    ⟹ hn_refine (Γ1 * hn_ctxt P1 a1' a1 * hn_ctxt P2 a2' a2 * hn_invalid (prod_assn P1 P2) p' p) (f a1 a2) 
          (Γh a1 a1' a2 a2') R (f' a1' a2')"
  assumes M: "⋀a1 a1' a2 a2'. Γh a1 a1' a2 a2' 
    ⟹t Γ' * hn_ctxt P1' a1' a1 * hn_ctxt P2' a2' a2 * hn_ctxt Pxx p' p"
  shows "hn_refine Γ (case_prod f p) (Γ' * hn_ctxt (prod_assn P1' P2') p' p)
    R (case_prod$(λ2a b. f' a b)$p')"
  apply1 (cases p; cases p'; simp)  
  apply1 (rule hn_refine_cons_pre[OF P])
  apply (rule hn_refine_preI)
  apply (simp add: hn_ctxt_def assn_aci)
  apply (rule hn_refine_cons[OF _ R])
  apply1 (rule enttI)
  applyS (sep_auto simp add: hn_ctxt_def invalid_assn_def mod_star_conv)

  applyS simp
  apply1 (rule entt_trans[OF M])
  applyS (sep_auto intro!: enttI simp: hn_ctxt_def)

  applyS simp
  done *)

lemma hn_Pair[sepref_fr_rules]: "hn_refine 
  (hn_ctxt P1 x1 x1' * hn_ctxt P2 x2 x2')
  (ureturn (x1',x2'))
  (hn_invalid P1 x1 x1' * hn_invalid P2 x2 x2')
  (prod_assn P1 P2)
  (RETURNT$(Pair$x1$x2))"
  unfolding hn_refine_def apply (auto simp: execute_ureturn pure_def hn_ctxt_def)
   apply(rule exI[where x=0]) apply (auto simp: zero_enat_def relH_def )      
  apply(rule entailsD)  prefer 2 by (auto simp: entt_refl' invalid_assn_def dest: mod_starD ) 

lemma fst_hnr[sepref_fr_rules]: "(ureturn o fst,RETURNT o fst) ∈ (prod_assn A B)da A"
  apply rule apply(auto simp: hn_refine_def relH_def execute_ureturn invalid_assn_def zero_enat_def  )
  apply(rule entailsD) apply auto apply (rule match_first) by auto 

lemma snd_hnr[sepref_fr_rules]: "(ureturn o snd,RETURNT o snd) ∈ (prod_assn A B)da B"
  apply rule apply(auto simp: hn_refine_def relH_def execute_ureturn invalid_assn_def zero_enat_def  )
  apply(rule entailsD) apply auto  apply rotatel apply (rule match_first) by auto 


lemmas [constraint_simps] = prod_assn_pure_conv
lemmas [sepref_import_param] = param_prod_swap

lemma rdomp_prodD[dest!]: "rdomp (prod_assn A B) (a,b) ⟹ rdomp A a ∧ rdomp B b"
  unfolding rdomp_def prod_assn_def
  by (auto dest!: mod_starD   )


subsection "Option"
fun option_assn :: "('a ⇒ 'c ⇒ assn) ⇒ 'a option ⇒ 'c option ⇒ assn" where
  "option_assn P None None = emp"
| "option_assn P (Some a) (Some c) = P a c"
| "option_assn _ _ _ = false"

lemma option_assn_simps[simp]:
  "option_assn P None v' = ↑(v'=None)"
  "option_assn P v None = ↑(v=None)"
  apply (cases v', simp_all)
  apply (cases v, simp_all)
  done

lemma option_assn_alt_def: "option_assn R a b = 
  (case (a,b) of (Some x, Some y) ⇒ R x y
  | (None,None) ⇒ emp
  | _ ⇒ false)"
  by (auto split: option.split)


lemma option_assn_pure_conv[constraint_simps]: "option_assn (pure R) = pure (⟨R⟩option_rel)"
  apply (intro ext)      
  apply (rename_tac a c)
  apply (case_tac "(pure R,a,c)" rule: option_assn.cases)  
  by (auto simp: pure_def)
                                                
lemmas [sepref_import_rewrite, sepref_frame_normrel_eqs, fcomp_norm_unfold] = option_assn_pure_conv[symmetric]

lemma hr_comp_option_conv[simp, fcomp_norm_unfold]: "
  hr_comp (option_assn R) (⟨R'⟩option_rel) 
  = option_assn (hr_comp R R')"
  unfolding hr_comp_def[abs_def]
  apply (intro ext ent_iffI)
  subgoal for a c
  apply (auto  intro!: ent_ex_preI  )
  apply (case_tac "(R,b,c)" rule: option_assn.cases)
       apply clarsimp_all apply (cases a) apply auto
    apply (rule ent_ex_postI) apply auto
    done
  
  apply (auto simp: option_assn_alt_def split: option.splits)
    apply (rule ent_ex_postI) apply auto 
  apply (intro ent_ex_preI) 
  apply (rule ent_ex_postI)
  apply (auto split: option.splits)
  done
      
 

lemma option_assn_precise[safe_constraint_rules]: 
  assumes "precise P"  
  shows "precise (option_assn P)"
proof
  fix a a' p h F F'
  assume A: "h ⊨ option_assn P a p * F ∧A option_assn P a' p * F'"
  thus "a=a'" proof (cases "(P,a,p)" rule: option_assn.cases)
    case (2 _ av pv) hence [simp]: "a=Some av" "p=Some pv" by simp_all

    from A obtain av' where [simp]: "a'=Some av'" by (cases a', simp_all)

    from A have "h ⊨ P av pv * F ∧A P av' pv * F'" by simp
    with `precise P` have "av=av'" by (rule preciseD)
    thus ?thesis by simp
  qed simp_all
qed

lemma pure_option[safe_constraint_rules]: 
  assumes P: "is_pure P"
  shows "is_pure (option_assn P)"
proof -
  from P obtain P' where P': "⋀x x'. P x x' = ↑(P' x x')"
    using is_pureE by blast

  show ?thesis proof
    fix x x'
    show "option_assn P x x' =
         ↑ (case (x, x') of 
             (None,None) ⇒ True | (Some v, Some v') ⇒ P' v v' | _ ⇒ False
           )"
      apply (simp add: P' split: prod.split option.split)
      done
  qed
qed

lemma hn_ctxt_option: "option_assn A x y = z ⟹ hn_ctxt (option_assn A) x y = z"
  by (simp add: hn_ctxt_def)

lemma hn_case_option[sepref_prep_comb_rule, sepref_comb_rules]:
  fixes p p' P
  defines [simp]: "INVE ≡ hn_invalid (option_assn P) p p'"
  assumes FR: "Γ ⟹t hn_ctxt (option_assn P) p p' * F"
  assumes Rn: "p=None ⟹ hn_refine (hn_ctxt (option_assn P) p p' * F) f1' (hn_ctxt XX1 p p' * Γ1') R f1"
  assumes Rs: "⋀x x'. ⟦ p=Some x; p'=Some x' ⟧ ⟹ 
    hn_refine (hn_ctxt P x x' * INVE * F) (f2' x') (hn_ctxt P' x x' * hn_ctxt XX2 p p' * Γ2') R (f2 x)"
  assumes MERGE1: "Γ1' ∨A Γ2' ⟹t Γ'"  
  shows "hn_refine Γ (case_option f1' f2' p') (hn_ctxt (option_assn P') p p' * Γ') R (case_option$f1$(λ2x. f2 x)$p)"
    apply (rule hn_refine_cons_pre[OF FR])
    apply1 extract_hnr_invalids
    apply (cases p; cases p'; simp add: option_assn.simps[THEN hn_ctxt_option])
    subgoal 
      apply (rule hn_refine_cons[OF _ Rn _ entt_refl]; assumption?)
      applyS (simp add: hn_ctxt_def)

      apply (subst mult.commute, rule entt_fr_drop)
      apply (rule entt_trans[OF _ MERGE1])
      apply (simp add: ent_disjI1' ent_disjI2')
    done  

    subgoal
      apply (rule hn_refine_cons[OF _ Rs _ entt_refl]; assumption?)
      applyS (auto simp add: hn_ctxt_def  )
      apply (rule entt_star_mono)
      apply1 (rule entt_fr_drop)
      applyS (simp add: hn_ctxt_def)
      apply1 (rule entt_trans[OF _ MERGE1])
      applyS (simp add: hn_ctxt_def)
    done
    done

lemma hn_None[sepref_fr_rules]:
  "hn_refine emp (ureturn None) emp (option_assn P) (RETURNT$None)"
  apply (auto simp: hn_refine_def execute_ureturn zero_enat_def relH_def  )
  using mod_star_trueI by force 


lemma hn_Some[sepref_fr_rules]: "hn_refine 
  (hn_ctxt P v v')
  (ureturn (Some v'))
  (hn_invalid P v v')
  (option_assn P)
  (RETURNT$(Some$v))"
  apply (auto simp: hn_refine_def relH_def invalid_assn_def zero_enat_def execute_ureturn hn_ctxt_def invalidate_clone')
  apply(rule entailsD) by auto  

definition "imp_option_eq eq a b ≡ case (a,b) of 
  (None,None) ⇒ return True
| (Some a, Some b) ⇒ eq a b
| _ ⇒ return False"

(*
(* TODO: This is some kind of generic algorithm! Use GEN_ALGO here, and 
  let GEN_ALGO re-use the registered operator rules *)
lemma option_assn_eq[sepref_comb_rules]:
  fixes a b :: "'a option"
  assumes F1: "Γ ⟹t hn_ctxt (option_assn P) a a' * hn_ctxt (option_assn P) b b' * Γ1"
  assumes EQ: "⋀va va' vb vb'. hn_refine 
    (hn_ctxt P va va' * hn_ctxt P vb vb' * Γ1)
    (eq' va' vb') 
    (Γ' va va' vb vb') 
    bool_assn
    (RETURNT$((=) $va$vb))"
  assumes F2: 
    "⋀va va' vb vb'. 
      Γ' va va' vb vb' ⟹t hn_ctxt P va va' * hn_ctxt P vb vb' * Γ1"
  shows "hn_refine 
    Γ 
    (imp_option_eq eq' a' b') 
    (hn_ctxt (option_assn P) a a' * hn_ctxt (option_assn P) b b' * Γ1)
    bool_assn 
    (RETURNT$((=) $a$b))"
  apply (rule hn_refine_cons_pre[OF F1])
  unfolding imp_option_eq_def
  apply rule
  apply (simp split: option.split add: hn_ctxt_def, intro impI conjI)

  apply (sep_auto split: option.split simp: hn_ctxt_def pure_def)
  apply (cases a, (sep_auto split: option.split simp: hn_ctxt_def pure_def)+)[]
  apply (cases a, (sep_auto split: option.split simp: hn_ctxt_def pure_def)+)[]
  apply (cases b, (sep_auto split: option.split simp: hn_ctxt_def pure_def)+)[]
  apply (rule cons_post_rule)
  apply (rule hn_refineD[OF EQ[unfolded hn_ctxt_def]])
  apply simp
  apply (rule ent_frame_fwd[OF F2[THEN enttD,unfolded hn_ctxt_def]])
  apply (fr_rot 2)
  apply (fr_rot_rhs 1)
  apply (rule fr_refl)
  apply (rule ent_refl)
  apply (sep_auto simp: pure_def)
  done
*)
lemma [pat_rules]: 
  "(=) $a$None ≡ is_None$a"
  "(=) $None$a ≡ is_None$a"
  apply (rule eq_reflection, simp split: option.split)+
  done

lemma hn_is_None[sepref_fr_rules]: "hn_refine 
  (hn_ctxt (option_assn P) a a')
  (ureturn (is_None a'))
  (hn_ctxt (option_assn P) a a')
  bool_assn
  (RETURNT$(is_None$a))" 
  apply (auto simp: top_assn_rule hn_refine_def execute_ureturn zero_enat_def relH_def split: option.split simp: hn_ctxt_def pure_def)
    using mod_star_trueI by blast 
   

lemma (in -) sepref_the_complete[sepref_fr_rules]:
  assumes "x≠None"
  shows "hn_refine 
    (hn_ctxt (option_assn R) x xi) 
    (ureturn (the xi)) 
    (hn_invalid (option_assn R) x xi)
    (R)
    (RETURNT$(the$x))"
    using assms
    apply (cases x)
    apply simp
    apply (cases xi)
    apply (simp add: hn_ctxt_def) 
    apply (auto simp: hn_refine_def execute_ureturn zero_enat_def relH_def hn_ctxt_def invalidate_clone' vassn_tagI invalid_assn_const)
    by (metis assn_times_comm mod_star_trueI)

(* As the sepref_the_complete rule does not work for us 
  --- the assertion ensuring the side-condition gets decoupled from its variable by a copy-operation ---
  we use the following rule that only works for the identity relation *)
lemma (in -) sepref_the_id:
  assumes "CONSTRAINT (IS_PURE IS_ID) R"
  shows "hn_refine 
    (hn_ctxt (option_assn R) x xi) 
    (ureturn (the xi)) 
    (hn_ctxt (option_assn R) x xi)
    (R)
    (RETURNT$(the$x))"
    using assms 
    apply (clarsimp simp: IS_PURE_def IS_ID_def hn_ctxt_def is_pure_conv)
    apply (cases x; cases xi)
    apply (auto simp add: hn_ctxt_def invalid_assn_def)
      apply (auto simp: hn_refine_def  execute_ureturn zero_enat_def relH_def  pure_def) 
    subgoal using mod_star_trueI by force   
    subgoal by (simp add: top_assn_rule)  
    done


subsection "Lists"

fun list_assn :: "('a ⇒ 'c ⇒ assn) ⇒ 'a list ⇒ 'c list ⇒ assn" where
  "list_assn P [] [] = emp"
| "list_assn P (a#as) (c#cs) = P a c * list_assn P as cs"
| "list_assn _ _ _ = false"

lemma list_assn_aux_simps[simp]:
  "list_assn P [] l' = (↑(l'=[]))"
  "list_assn P l [] = (↑(l=[]))"
  unfolding hn_ctxt_def
  apply (cases l')
  apply simp
  apply simp
  apply (cases l)
  apply simp
  apply simp
  done

lemma list_assn_aux_append[simp]:
  "length l1=length l1' ⟹ 
    list_assn P (l1@l2) (l1'@l2') 
    = list_assn P l1 l1' * list_assn P l2 l2'"
  apply (induct rule: list_induct2)
  apply simp
  apply (simp add: mult.assoc)
  done

lemma list_assn_aux_ineq_len: "length l ≠ length li ⟹ list_assn A l li = false"
proof (induction l arbitrary: li)
  case (Cons x l li) thus ?case by (cases li; auto)
qed simp

lemma list_assn_aux_append2[simp]:
  assumes "length l2=length l2'"  
  shows "list_assn P (l1@l2) (l1'@l2') 
    = list_assn P l1 l1' * list_assn P l2 l2'"
  apply (cases "length l1 = length l1'")
  apply (erule list_assn_aux_append)
  apply (simp add: list_assn_aux_ineq_len assms)
  done

lemma list_assn_pure_conv[constraint_simps]: "list_assn (pure R) = pure (⟨R⟩list_rel)"
proof (intro ext)
  fix l li
  show "list_assn (pure R) l li = pure (⟨R⟩list_rel) l li"
    apply (induction "pure R" l li rule: list_assn.induct)
    by (auto simp: pure_def)
qed

lemmas [sepref_import_rewrite, sepref_frame_normrel_eqs, fcomp_norm_unfold] = list_assn_pure_conv[symmetric]


lemma list_assn_simps[simp]:
  "hn_ctxt (list_assn P) [] l' = (↑(l'=[]))"
  "hn_ctxt (list_assn P) l [] = (↑(l=[]))"
  "hn_ctxt (list_assn P) [] [] = emp"
  "hn_ctxt (list_assn P) (a#as) (c#cs) = hn_ctxt P a c * hn_ctxt (list_assn P) as cs"
  "hn_ctxt (list_assn P) (a#as) [] = false"
  "hn_ctxt (list_assn P) [] (c#cs) = false"
  unfolding hn_ctxt_def
  apply (cases l')
  apply simp
  apply simp
  apply (cases l)
  apply simp
  apply simp
  apply simp_all
  done

lemma list_assn_precise[constraint_rules]: "precise P ⟹ precise (list_assn P)"
proof
  fix l1 l2 l h F1 F2
  assume P: "precise P"
  assume "h⊨list_assn P l1 l * F1 ∧A list_assn P l2 l * F2"
  thus "l1=l2"
  proof (induct l arbitrary: l1 l2 F1 F2)
    case Nil thus ?case by simp
  next
    case (Cons a ls)
    from Cons obtain a1 ls1 where [simp]: "l1=a1#ls1"
      by (cases l1, simp)
    from Cons obtain a2 ls2 where [simp]: "l2=a2#ls2"
      by (cases l2, simp)
    
    from Cons.prems have M:
      "h ⊨ P a1 a * list_assn P ls1 ls * F1 
        ∧A P a2 a * list_assn P ls2 ls * F2" by simp
    have "a1=a2"
      apply (rule preciseD[OF P, where a=a1 and a'=a2 and p=a
        and F= "list_assn P ls1 ls * F1" 
        and F'="list_assn P ls2 ls * F2"
        ])
      using M
      by (simp add: mult.assoc)
    
    moreover have "ls1=ls2"
      apply (rule Cons.hyps[where ?F1.0="P a1 a * F1" and ?F2.0="P a2 a * F2"])
      using M
      by (simp only: star_aci)
    ultimately show ?case by simp
  qed
qed
lemma list_assn_pure[constraint_rules]: 
  assumes P: "is_pure P" 
  shows "is_pure (list_assn P)"
proof -
  from P obtain P' where P_eq: "⋀x x'. P x x' = ↑(P' x x')" 
    by (rule is_pureE) blast

  {
    fix l l'
    have "list_assn P l l' = ↑(list_all2 P' l l')"
      by (induct PP l l' rule: list_assn.induct)
         (simp_all add: P_eq)
  } thus ?thesis by rule
qed

lemma list_assn_mono: 
  "⟦⋀x x'. P x x'⟹AP' x x'⟧ ⟹ list_assn P l l' ⟹A list_assn P' l l'"
  unfolding hn_ctxt_def
  apply (induct P l l' rule: list_assn.induct)
  by (auto intro: ent_star_mono)

lemma list_assn_monot: 
  "⟦⋀x x'. P x x'⟹tP' x x'⟧ ⟹ list_assn P l l' ⟹t list_assn P' l l'"
  unfolding hn_ctxt_def
  apply (induct P l l' rule: list_assn.induct)
  by (auto intro: entt_star_mono)

lemma list_match_cong[sepref_frame_match_rules]: 
  "⟦⋀x x'. ⟦x∈set l; x'∈set l'⟧ ⟹ hn_ctxt A x x' ⟹t hn_ctxt A' x x' ⟧ ⟹ hn_ctxt (list_assn A) l l' ⟹t hn_ctxt (list_assn A') l l'"
  unfolding hn_ctxt_def
  by (induct A l l' rule: list_assn.induct) (simp_all add: entt_star_mono)

lemma list_merge_cong[sepref_frame_merge_rules]:
  assumes "⋀x x'. ⟦x∈set l; x'∈set l'⟧ ⟹ hn_ctxt A x x' ∨A hn_ctxt A' x x' ⟹t hn_ctxt Am x x'"
  shows "hn_ctxt (list_assn A) l l' ∨A hn_ctxt (list_assn A') l l' ⟹t hn_ctxt (list_assn Am) l l'"
  apply (blast intro: entt_disjE list_match_cong entt_disjD1[OF assms] entt_disjD2[OF assms])
  done
  
lemma invalid_list_split: 
  "invalid_assn (list_assn A) (x#xs) (y#ys) ⟹t invalid_assn A x y * invalid_assn (list_assn A) xs ys"
  by (fastforce simp: invalid_assn_def intro!: enttI simp: mod_star_conv)

lemma entt_invalid_list: "hn_invalid (list_assn A) l l' ⟹t hn_ctxt (list_assn (invalid_assn A)) l l'"
  apply (induct A l l' rule: list_assn.induct)
  applyS simp

  subgoal
    apply1 (simp add: hn_ctxt_def cong del: invalid_assn_cong)
    apply1 (rule entt_trans[OF invalid_list_split])
    apply (rule entt_star_mono)
      applyS simp

      apply (rule entt_trans)
        applyS assumption
        applyS simp
    done
    
  applyS (simp add: hn_ctxt_def invalid_assn_def) 
  applyS (simp add: hn_ctxt_def invalid_assn_def) 
  done

lemmas invalid_list_merge[sepref_frame_merge_rules] = gen_merge_cons[OF entt_invalid_list]


lemma list_assn_comp[fcomp_norm_unfold]: "hr_comp (list_assn A) (⟨B⟩list_rel) = list_assn (hr_comp A B)"
proof (intro ext)  
  { fix x l y m
    have "hr_comp (list_assn A) (⟨B⟩list_rel) (x # l) (y # m) = 
      hr_comp A B x y * hr_comp (list_assn A) (⟨B⟩list_rel) l m"
      apply (auto 
        simp: hr_comp_def list_rel_split_left_iff
        intro!:  ent_ex_preI ent_iffI)
      apply(rule ent_ex_postI)
       apply(rule ent_ex_postI) apply(simp only: mult.assoc) apply(rule match_first) 
       apply(rule match_rest) apply simp
      subgoal for b ba
        apply(rule ent_ex_postI[where x="ba#b"]) by simp  
      done
       
 (* TODO: ent_ex_preI should be applied by default, before ent_ex_postI!*)
  } note aux = this

  fix l li
  show "hr_comp (list_assn A) (⟨B⟩list_rel) l li = list_assn (hr_comp A B) l li"
    apply (induction l arbitrary: li; case_tac li; intro ent_iffI)
    apply ((auto simp add: hr_comp_def intro!: ent_ex_preI ent_ex_postI ; fail)[1]) + 
    by (simp_all add: aux)
qed  

lemma hn_ctxt_eq: "A x y = z ⟹ hn_ctxt A x y = z" by (simp add: hn_ctxt_def)

lemmas hn_ctxt_list = hn_ctxt_eq[of "list_assn A" for A]

lemma hn_case_list[sepref_prep_comb_rule, sepref_comb_rules]:
  fixes p p' P
  defines [simp]: "INVE ≡ hn_invalid (list_assn P) p p'"
  assumes FR: "Γ ⟹t hn_ctxt (list_assn P) p p' * F"
  assumes Rn: "p=[] ⟹ hn_refine (hn_ctxt (list_assn P) p p' * F) f1' (hn_ctxt XX1 p p' * Γ1') R f1"
  assumes Rs: "⋀x l x' l'. ⟦ p=x#l; p'=x'#l' ⟧ ⟹ 
    hn_refine (hn_ctxt P x x' * hn_ctxt (list_assn P) l l' * INVE * F) (f2' x' l') (hn_ctxt P1' x x' * hn_ctxt (list_assn P2') l l' * hn_ctxt XX2 p p' * Γ2') R (f2 x l)"
  assumes MERGE1[unfolded hn_ctxt_def]: "⋀x x'. hn_ctxt P1' x x' ∨A hn_ctxt P2' x x' ⟹t hn_ctxt P' x x'"  
  assumes MERGE2: "Γ1' ∨A Γ2' ⟹t Γ'"  
  shows "hn_refine Γ (case_list f1' f2' p') (hn_ctxt (list_assn P') p p' * Γ') R (case_list$f1$(λ2x l. f2 x l)$p)"
    apply (rule hn_refine_cons_pre[OF FR])
    apply1 extract_hnr_invalids
    apply (cases p; cases p'; simp add: list_assn.simps[THEN hn_ctxt_list])
    subgoal 
      apply (rule hn_refine_cons[OF _ Rn _ entt_refl]; assumption?)
      applyS (simp add: hn_ctxt_def)

      apply (subst mult.commute, rule entt_fr_drop)
      apply (rule entt_trans[OF _ MERGE2])
      apply (simp add: ent_disjI1' ent_disjI2')
    done  

    subgoal
      apply (rule hn_refine_cons[OF _ Rs _ entt_refl]; assumption?)
      applyS (simp add: hn_ctxt_def)       
      apply (rule entt_star_mono)
      apply1 (rule entt_fr_drop)
      apply (rule entt_star_mono)

      apply1 (simp add: hn_ctxt_def)
      apply1 (rule entt_trans[OF _ MERGE1])
      applyS (simp)

      apply1 (simp add: hn_ctxt_def)
      apply (rule list_assn_monot)
      apply1 (rule entt_trans[OF _ MERGE1])
      applyS (simp)

      apply1 (rule entt_trans[OF _ MERGE2])
      applyS (simp)
    done
    done

lemma hn_Nil[sepref_fr_rules]: 
  "hn_refine emp (ureturn []) emp (list_assn P) (RETURNT$[])"
  unfolding hn_refine_def
  apply(auto simp: hn_refine_def execute_ureturn zero_enat_def relH_def )
  using mod_star_trueI by force 

lemma hn_Cons[sepref_fr_rules]: "hn_refine (hn_ctxt P x x' * hn_ctxt (list_assn P) xs xs') 
  (ureturn (x'#xs')) (hn_invalid P x x' * hn_invalid (list_assn P) xs xs') (list_assn P)
  (RETURNT$((#) $x$xs))"
  unfolding hn_refine_def apply (auto simp:  execute_ureturn pure_def hn_ctxt_def)
   apply(rule exI[where x=0]) apply (auto simp: zero_enat_def invalid_assn_def relH_def intro: mod_star_trueI)
  using mod_star_convE apply blast
  using mod_star_convE apply blast done
 

lemma list_assn_aux_len: 
  "list_assn P l l' = list_assn P l l' * ↑(length l = length l')"
  apply (induct PP l l' rule: list_assn.induct)
  apply simp_all
  subgoal for a as c cs
    by (erule_tac t="list_assn P as cs" in subst[OF sym]) simp
  done

lemma list_assn_aux_eqlen_simp: 
  "vassn_tag (list_assn P l l') ⟹ length l' = length l"
  "h ⊨ (list_assn P l l') ⟹ length l' = length l"
  apply (subst (asm) list_assn_aux_len; auto simp: vassn_tag_def)+
  done


lemma hn_append[sepref_fr_rules]: "hn_refine (hn_ctxt (list_assn P) l1 l1' * hn_ctxt (list_assn P) l2 l2')
  (ureturn (l1'@l2')) (hn_invalid (list_assn P) l1 l1' * hn_invalid (list_assn P) l2 l2') (list_assn P)
  (RETURNT$((@) $l1$l2))"
  unfolding hn_refine_def apply (auto simp:  execute_ureturn pure_def hn_ctxt_def)
   apply(rule exI[where x=0]) apply (auto simp: zero_enat_def invalid_assn_def relH_def intro!: mod_star_trueI)
  subgoal
    apply (subst (asm) list_assn_aux_len) 
    apply(rule entailsD) prefer 2 by (auto) 
  using mod_star_convE apply blast
  using mod_star_convE apply blast done 


lemma list_assn_aux_cons_conv1:
  "list_assn R (a#l) m = (∃Ab m'. R a b * list_assn R l m' * ↑(m=b#m'))"
  apply (cases m)
  by (auto intro!: assn_ext ) 

lemma list_assn_aux_cons_conv2:
  "list_assn R l (b#m) = (∃Aa l'. R a b * list_assn R l' m * ↑(l=a#l'))"
  apply (cases l)
  by (auto intro!: assn_ext)  

lemmas list_assn_aux_cons_conv = list_assn_aux_cons_conv1 list_assn_aux_cons_conv2

lemma list_assn_aux_append_conv1:
  "list_assn R (l1@l2) m = (∃Am1 m2. list_assn R l1 m1 * list_assn R l2 m2 * ↑(m=m1@m2))"
  apply (induction l1 arbitrary: m)
  apply (auto intro!: ent_iffI ent_ex_postI ent_ex_preI)[1]
  apply (auto intro!: ent_iffI ent_ex_preI simp: list_assn_aux_cons_conv)
  subgoal for a l1 b m1 m2
    apply(rule ent_ex_postI[where x="b#m1"])
    apply(rule ent_ex_postI[where x="m2"])
    apply(rule ent_ex_postI[where x="b"])
    apply(rule ent_ex_postI[where x="m1"]) apply simp
    apply(simp only: mult.assoc) by (rule entails_triv)
  subgoal for a l1 m2 b m'
    apply(rule ent_ex_postI[where x="b"])
    apply(rule ent_ex_postI[where x="m' @ m2"])
    apply(rule ent_ex_postI[where x="m'"])
    apply(rule ent_ex_postI[where x="m2"]) apply simp
    apply(simp only: mult.assoc) by (rule entails_triv)
  done

lemma list_assn_aux_append_conv2:
  "list_assn R l (m1@m2) = (∃Al1 l2. list_assn R l1 m1 * list_assn R l2 m2 * ↑(l=l1@l2))"
  apply (induction m1 arbitrary: l)
  apply (auto intro!: ent_iffI ent_ex_postI ent_ex_preI)[1]
  apply (auto intro!: ent_iffI ent_ex_preI simp: list_assn_aux_cons_conv)
  subgoal for a l1 b m1 m2
    apply(rule ent_ex_postI[where x="b#m1"])
    apply(rule ent_ex_postI[where x="m2"])
    apply(rule ent_ex_postI[where x="b"])
    apply(rule ent_ex_postI[where x="m1"]) apply simp
    apply(simp only: mult.assoc) by (rule entails_triv)
  subgoal for a l1 m2 b m'
    apply(rule ent_ex_postI[where x="b"])
    apply(rule ent_ex_postI[where x="m' @ m2"])
    apply(rule ent_ex_postI[where x="m'"])
    apply(rule ent_ex_postI[where x="m2"]) apply simp
    apply(simp only: mult.assoc) by (rule entails_triv)
  done

lemmas list_assn_aux_append_conv = list_assn_aux_append_conv1 list_assn_aux_append_conv2  

declare param_upt[sepref_import_param]
  
  
subsection ‹Sum-Type›    

fun sum_assn :: "('ai ⇒ 'a ⇒ assn) ⇒ ('bi ⇒ 'b ⇒ assn) ⇒ ('ai+'bi) ⇒ ('a+'b) ⇒ assn" where
  "sum_assn A B (Inl ai) (Inl a) = A ai a"
| "sum_assn A B (Inr bi) (Inr b) = B bi b"
| "sum_assn A B _ _ = false"  

notation sum_assn (infixr "+a" 67)
  
lemma sum_assn_pure[safe_constraint_rules]: "⟦is_pure A; is_pure B⟧ ⟹ is_pure (sum_assn A B)"
  apply (auto simp: is_pure_iff_pure_assn)
  apply (rename_tac x x')
  apply (case_tac x; case_tac x'; simp add: pure_def)
  done
  
lemma sum_assn_id[simp]: "sum_assn id_assn id_assn = id_assn"
  apply (intro ext)
  subgoal for x y by (cases x; cases y; simp add: pure_def)
  done

lemma sum_assn_pure_conv[simp]: "sum_assn (pure A) (pure B) = pure (⟨A,B⟩sum_rel)"
  apply (intro ext)
  subgoal for a b by (cases a; cases b; auto simp: pure_def)
  done
    
    
lemma sum_match_cong[sepref_frame_match_rules]: 
  "⟦
    ⋀x y. ⟦e = Inl x; e'=Inl y⟧ ⟹ hn_ctxt A x y ⟹t hn_ctxt A' x y;
    ⋀x y. ⟦e = Inr x; e'=Inr y⟧ ⟹ hn_ctxt B x y ⟹t hn_ctxt B' x y
  ⟧ ⟹ hn_ctxt (sum_assn A B) e e' ⟹t hn_ctxt (sum_assn A' B') e e'"
  by (cases e; cases e'; simp add: hn_ctxt_def entt_star_mono)

lemma enum_merge_cong[sepref_frame_merge_rules]:
  assumes "⋀x y. ⟦e=Inl x; e'=Inl y⟧ ⟹ hn_ctxt A x y ∨A hn_ctxt A' x y ⟹t hn_ctxt Am x y"
  assumes "⋀x y. ⟦e=Inr x; e'=Inr y⟧ ⟹ hn_ctxt B x y ∨A hn_ctxt B' x y ⟹t hn_ctxt Bm x y"
  shows "hn_ctxt (sum_assn A B) e e' ∨A hn_ctxt (sum_assn A' B') e e' ⟹t hn_ctxt (sum_assn Am Bm) e e'"
  apply (rule entt_disjE)
  apply (rule sum_match_cong)
  apply (rule entt_disjD1[OF assms(1)]; simp)
  apply (rule entt_disjD1[OF assms(2)]; simp)

  apply (rule sum_match_cong)
  apply (rule entt_disjD2[OF assms(1)]; simp)
  apply (rule entt_disjD2[OF assms(2)]; simp)
  done

lemma entt_invalid_sum: "hn_invalid (sum_assn A B) e e' ⟹t hn_ctxt (sum_assn (invalid_assn A) (invalid_assn B)) e e'"
  apply (simp add: hn_ctxt_def invalid_assn_def[abs_def])
  apply (rule enttI)
  apply clarsimp
  apply (cases e; cases e'; auto simp: mod_star_conv pure_def) 
  done

lemmas invalid_sum_merge[sepref_frame_merge_rules] = gen_merge_cons[OF entt_invalid_sum]

sepref_register Inr Inl  

lemma [sepref_fr_rules]: "(ureturn o Inl,RETURNT o Inl) ∈ Ada sum_assn A B"
  apply rule
  by(auto simp: hn_refine_def execute_ureturn zero_enat_def relH_def invalid_assn_def intro: mod_star_trueI)
lemma [sepref_fr_rules]: "(ureturn o Inr,RETURNT o Inr) ∈ Bda sum_assn A B"
  apply rule
  by(auto simp: hn_refine_def execute_ureturn zero_enat_def relH_def invalid_assn_def intro: mod_star_trueI)

sepref_register case_sum

text ‹In the monadify phase, this eta-expands to make visible all required arguments›
lemma [sepref_monadify_arity]: "case_sum ≡ λ2f1 f2 x. SP case_sum$(λ2x. f1$x)$(λ2x. f2$x)$x"
  by simp

text ‹This determines an evaluation order for the first-order operands›  
lemma [sepref_monadify_comb]: "case_sum$f1$f2$x ≡ (⤜) $(EVAL$x)$(λ2x. SP case_sum$f1$f2$x)" by simp

text ‹This enables translation of the case-distinction in a non-monadic context.›  
lemma [sepref_monadify_comb]: "EVAL$(case_sum$(λ2x. f1 x)$(λ2x. f2 x)$x) 
  ≡ (⤜) $(EVAL$x)$(λ2x. SP case_sum$(λ2x. EVAL $ f1 x)$(λ2x. EVAL $ f2 x)$x)"
  apply (rule eq_reflection)
  by (simp split: sum.splits)

text ‹Auxiliary lemma, to lift simp-rule over ‹hn_ctxt››  
lemma sum_assn_ctxt: "sum_assn A B x y = z ⟹ hn_ctxt (sum_assn A B) x y = z"
  by (simp add: hn_ctxt_def)

text ‹The cases lemma first extracts the refinement for the datatype from the precondition.
  Next, it generate proof obligations to refine the functions for every case. 
  Finally the postconditions of the refinement are merged. 

  Note that we handle the
  destructed values separately, to allow reconstruction of the original datatype after the case-expression.

  Moreover, we provide (invalidated) versions of the original compound value to the cases,
  which allows access to pure compound values from inside the case.
  ›  
lemma sum_cases_hnr:
  fixes A B e e'
  defines [simp]: "INVe ≡ hn_invalid (sum_assn A B) e e'"
  assumes FR: "Γ ⟹t hn_ctxt (sum_assn A B) e e' * F"
  assumes E1: "⋀x1 x1a. ⟦e = Inl x1; e' = Inl x1a⟧ ⟹ hn_refine (hn_ctxt A x1 x1a * INVe * F) (f1' x1a) (hn_ctxt A' x1 x1a * hn_ctxt XX1 e e' * Γ1') R (f1 x1)"
  assumes E2: "⋀x2 x2a. ⟦e = Inr x2; e' = Inr x2a⟧ ⟹ hn_refine (hn_ctxt B x2 x2a * INVe * F) (f2' x2a) (hn_ctxt B' x2 x2a * hn_ctxt XX2 e e' * Γ2') R (f2 x2)"
  assumes MERGE[unfolded hn_ctxt_def]: "Γ1' ∨A Γ2' ⟹t Γ'"
  shows "hn_refine Γ (case_sum f1' f2' e') (hn_ctxt (sum_assn A' B') e e' * Γ') R (case_sum$(λ2x. f1 x)$(λ2x. f2 x)$e)"
  apply (rule hn_refine_cons_pre[OF FR])
  apply1 extract_hnr_invalids
  apply (cases e; cases e'; simp add: sum_assn.simps[THEN sum_assn_ctxt])
  subgoal
    apply (rule hn_refine_cons[OF _ E1 _ entt_refl]; assumption?)
    applyS (simp add: hn_ctxt_def) ― ‹Match precondition for case, get ‹enum_assn› from assumption generated by ‹extract_hnr_invalids››
    apply (rule entt_star_mono) ― ‹Split postcondition into pairs for compounds and frame, drop ‹hn_ctxt XX››
    apply1 (rule entt_fr_drop)
    applyS (simp add: hn_ctxt_def entt_disjI1' entt_disjI2')
    apply1 (rule entt_trans[OF _ MERGE])
    applyS (simp add: entt_disjI1' entt_disjI2')
  done
  subgoal 
    apply (rule hn_refine_cons[OF _ E2 _ entt_refl]; assumption?)
    applyS (simp add: hn_ctxt_def)
    apply (rule entt_star_mono)
    apply1 (rule entt_fr_drop)
    applyS (simp add: hn_ctxt_def entt_disjI1' entt_disjI2')
    apply1 (rule entt_trans[OF _ MERGE])
    applyS (simp add: entt_disjI1' entt_disjI2')
  done    
  done 

text ‹After some more preprocessing (adding extra frame-rules for non-atomic postconditions, 
  and splitting the merge-terms into binary merges), this rule can be registered›
lemmas [sepref_comb_rules] = sum_cases_hnr[sepref_prep_comb_rule]

sepref_register isl projl projr
lemma isl_hnr[sepref_fr_rules]: "(ureturn o isl,RETURNT o isl) ∈ (sum_assn A B)ka bool_assn"
  apply sepref_to_hoare
  subgoal for a b by (cases a; cases b; sep_auto)
  done

lemma projl_hnr[sepref_fr_rules]: "(ureturn o projl,RETURNT o projl) ∈ [isl]a (sum_assn A B)d → A"
  apply sepref_to_hoare
  subgoal for a b by (cases a; cases b; sep_auto)
  done

lemma projr_hnr[sepref_fr_rules]: "(ureturn o projr,RETURNT o projr) ∈ [Not o isl]a (sum_assn A B)d → B"
  apply sepref_to_hoare
  subgoal for a b by (cases a; cases b; sep_auto)
  done
  
subsection ‹String Literals›  

sepref_register "PR_CONST String.empty_literal"

lemma empty_literal_hnr [sepref_import_param]:
  "(String.empty_literal, PR_CONST String.empty_literal) ∈ Id"
  by simp

lemma empty_literal_pat [def_pat_rules]:
  "String.empty_literal ≡ UNPROTECT String.empty_literal"
  by simp

context
  fixes b0 b1 b2 b3 b4 b5 b6 :: bool
  and s :: String.literal
begin

sepref_register "PR_CONST (String.Literal b0 b1 b2 b3 b4 b5 b6 s)"

lemma Literal_hnr [sepref_import_param]:
  "(String.Literal b0 b1 b2 b3 b4 b5 b6 s,
    PR_CONST (String.Literal b0 b1 b2 b3 b4 b5 b6 s)) ∈ Id"
  by simp

end

lemma Literal_pat [def_pat_rules]:
  "String.Literal $ b0 $ b1 $ b2 $ b3 $ b4 $ b5 $ b6 $ s ≡
    UNPROTECT (String.Literal $ b0 $ b1 $ b2 $ b3 $ b4 $ b5 $ b6 $ s)"
  by simp

lemma [sepref_import_param]: 
  "((=),(=))∈Id→Id→Id" 
  "((<),(<))∈Id→Id→Id" 
  by simp_all

end