Theory Sepref_Rules

theory Sepref_Rules
imports Sepref_Constraints Sepref_Additional
section ‹Refinement Rule Management›
theory Sepref_Rules
imports Sepref_Basic Sepref_Constraints Sepref_Additional
begin
  text ‹This theory contains tools for managing the refinement rules used by Sepref›

  text ‹The theories are based on uncurried functions, i.e.,
    every function has type @{typ "'a⇒'b"}, where @{typ 'a} is the 
    tuple of parameters, or unit if there are none.
    ›


  subsection ‹Assertion Interface Binding›
  text ‹Binding of interface types to refinement assertions›
  definition intf_of_assn :: "('a ⇒ _ ⇒ assn) ⇒ 'b itself ⇒ bool" where
    [simp]: "intf_of_assn a b = True"

  lemma intf_of_assnI: "intf_of_assn R TYPE('a)" by simp
  
  named_theorems_rev intf_of_assn ‹Links between refinement assertions and interface types›  

  lemma intf_of_assn_fallback: "intf_of_assn (R :: 'a ⇒ _ ⇒ assn) TYPE('a)" by simp

  subsection ‹Function Refinement with Precondition›
  definition fref :: "('c ⇒ bool) ⇒ ('a × 'c) set ⇒ ('b × 'd) set
           ⇒ (('a ⇒ 'b) × ('c ⇒ 'd)) set"
    ("[_]f _ → _" [0,60,60] 60)         
  where "[P]f R → S ≡ {(f,g). ∀x y. P y ∧ (x,y)∈R ⟶ (f x, g y)∈S}"
  
  abbreviation freft ("_ →f _" [60,60] 60) where "R →f S ≡ ([λ_. True]f R → S)"
  
  lemma rel2p_fref[rel2p]: "rel2p (fref P R S) 
    = (λf g. (∀x y. P y ⟶ rel2p R x y ⟶ rel2p S (f x) (g y)))"  
    by (auto simp: fref_def rel2p_def[abs_def])

  lemma fref_cons:  
    assumes "(f,g) ∈ [P]f R → S"
    assumes "⋀c a. (c,a)∈R' ⟹ Q a ⟹ P a"
    assumes "R' ⊆ R"
    assumes "S ⊆ S'"
    shows "(f,g) ∈ [Q]f R' → S'"
    using assms
    unfolding fref_def
    by fastforce

  lemmas fref_cons' = fref_cons[OF _ _ order_refl order_refl]  

  lemma frefI[intro?]: 
    assumes "⋀x y. ⟦P y; (x,y)∈R⟧ ⟹ (f x, g y)∈S"
    shows "(f,g)∈fref P R S"
    using assms
    unfolding fref_def
    by auto

  lemma fref_ncI: "(f,g)∈R→S ⟹ (f,g)∈R→fS"  
    apply (rule frefI)
    apply parametricity
    done

  lemma frefD: 
    assumes "(f,g)∈fref P R S"
    shows "⟦P y; (x,y)∈R⟧ ⟹ (f x, g y)∈S"
    using assms
    unfolding fref_def
    by auto

  lemma fref_ncD: "(f,g)∈R→fS ⟹ (f,g)∈R→S"  
    apply (rule fun_relI)
    apply (drule frefD)
    apply simp
    apply assumption+
    done


  lemma fref_compI: 
    "fref P R1 R2 O fref Q S1 S2 ⊆
      fref (λx. Q x ∧ (∀y. (y,x)∈S1 ⟶ P y)) (R1 O S1) (R2 O S2)"
    unfolding fref_def
    apply (auto)
    apply blast
    done

  lemma fref_compI':
    "⟦ (f,g)∈fref P R1 R2; (g,h)∈fref Q S1 S2 ⟧ 
      ⟹ (f,h) ∈ fref (λx. Q x ∧ (∀y. (y,x)∈S1 ⟶ P y)) (R1 O S1) (R2 O S2)"
    using fref_compI[of P R1 R2 Q S1 S2]   
    by auto

  lemma fref_unit_conv:
    "(λ_. c, λ_. a) ∈ fref P unit_rel S ⟷ (P () ⟶ (c,a)∈S)"   
    by (auto simp: fref_def)

  lemma fref_uncurry_conv:
    "(uncurry c, uncurry a) ∈ fref P (R1×rR2) S 
    ⟷ (∀x1 y1 x2 y2. P (y1,y2) ⟶ (x1,y1)∈R1 ⟶ (x2,y2)∈R2 ⟶ (c x1 x2, a y1 y2) ∈ S)"
    by (auto simp: fref_def)

  lemma fref_mono: "⟦ ⋀x. P' x ⟹ P x; R' ⊆ R; S ⊆ S' ⟧ 
    ⟹ fref P R S ⊆ fref P' R' S'"  
    unfolding fref_def
    by auto blast

  lemma fref_composeI:
    assumes FR1: "(f,g)∈fref P R1 R2"
    assumes FR2: "(g,h)∈fref Q S1 S2"
    assumes C1: "⋀x. P' x ⟹ Q x"
    assumes C2: "⋀x y. ⟦P' x; (y,x)∈S1⟧ ⟹ P y"
    assumes R1: "R' ⊆ R1 O S1"
    assumes R2: "R2 O S2 ⊆ S'"
    assumes FH: "f'=f" "h'=h"
    shows "(f',h') ∈ fref P' R' S'"
    unfolding FH
    apply (rule set_mp[OF fref_mono fref_compI'[OF FR1 FR2]])
    using C1 C2 apply blast
    using R1 apply blast
    using R2 apply blast
    done

  lemma fref_triv: "A⊆Id ⟹ (f,f)∈[P]f A → Id"
    by (auto simp: fref_def)


  subsection ‹Heap-Function Refinement›
  text ‹
    The following relates a heap-function with a pure function.
    It contains a precondition, a refinement assertion for the arguments
    before and after execution, and a refinement relation for the result.
    ›
  (* TODO: We only use this with keep/destroy information, so we could model
    the parameter relations as such (('a⇒'ai ⇒ assn) × bool) *)
  definition hfref 
    :: "
      ('a ⇒ bool) 
   ⇒ (('a ⇒ 'ai ⇒ assn) × ('a ⇒ 'ai ⇒ assn)) 
   ⇒ ('b ⇒ 'bi ⇒ assn) 
   ⇒ (('ai ⇒ 'bi Heap) × ('a⇒'b nrest)) set"
   ("[_]a _ → _" [0,60,60] 60)
   where
    "[P]a RS → T ≡ { (f,g) . ∀c a.  P a ⟶ hn_refine (fst RS a c) (f c) (snd RS a c) T (g a)}"

  abbreviation hfreft ("_ →a _" [60,60] 60) where "RS →a T ≡ ([λ_. True]a RS → T)"

  lemma hfrefI[intro?]: 
    assumes "⋀c a. P a ⟹ hn_refine (fst RS a c) (f c) (snd RS a c) T (g a)"
    shows "(f,g)∈hfref P RS T"
    using assms unfolding hfref_def by blast

  lemma hfrefD: 
    assumes "(f,g)∈hfref P RS T"
    shows "⋀c a. P a ⟹ hn_refine (fst RS a c) (f c) (snd RS a c) T (g a)"
    using assms unfolding hfref_def by blast
 
  lemma hfref_to_ASSERT_conv: 
    "NO_MATCH (λ_. True) P ⟹ (a,b)∈[P]a R → S ⟷ (a,λx. ASSERT (P x) ⪢ b x) ∈ R →a S"  
    unfolding hfref_def
    apply (clarsimp; safe; clarsimp?)
    apply (rule hn_refine_nofailI)
    apply (simp add: refine_pw_simps)
    subgoal for xc xa
      apply (drule spec[of _ xc])
      apply (drule spec[of _ xa]) 
      by (simp add: hnr_ASSERT) 
    done

  text ‹
    A pair of argument refinement assertions can be created by the 
    input assertion and the information whether the parameter is kept or destroyed
    by the function.
    ›  
  primrec hf_pres 
    :: "('a ⇒ 'b ⇒ assn) ⇒ bool ⇒ ('a ⇒ 'b ⇒ assn)×('a ⇒ 'b ⇒ assn)"
    where 
      "hf_pres R True = (R,R)" | "hf_pres R False = (R,invalid_assn R)"

  abbreviation hfkeep 
    :: "('a ⇒ 'b ⇒ assn) ⇒ ('a ⇒ 'b ⇒ assn)×('a ⇒ 'b ⇒ assn)" 
    ("(_k)" [1000] 999)
    where "Rk ≡ hf_pres R True"
  abbreviation hfdrop 
    :: "('a ⇒ 'b ⇒ assn) ⇒ ('a ⇒ 'b ⇒ assn)×('a ⇒ 'b ⇒ assn)" 
    ("(_d)" [1000] 999)
    where "Rd ≡ hf_pres R False"

  abbreviation "hn_kede R kd ≡ hn_ctxt (snd (hf_pres R kd))"
  abbreviation "hn_keep R ≡ hn_kede R True"
  abbreviation "hn_dest R ≡ hn_kede R False"

  lemma keep_drop_sels[simp]:  
    "fst (Rk) = R"
    "snd (Rk) = R"
    "fst (Rd) = R"
    "snd (Rd) = invalid_assn R"
    by auto

  lemma hf_pres_fst[simp]: "fst (hf_pres R k) = R" by (cases k) auto

  text ‹
    The following operator combines multiple argument assertion-pairs to
    argument assertion-pairs for the product. It is required to state
    argument assertion-pairs for uncurried functions.
    ›  
  definition hfprod :: "
    (('a ⇒ 'b ⇒ assn)×('a ⇒ 'b ⇒ assn)) 
    ⇒ (('c ⇒ 'd ⇒ assn)×('c ⇒ 'd ⇒ assn))
    ⇒ ((('a×'c) ⇒ ('b × 'd) ⇒ assn) × (('a×'c) ⇒ ('b × 'd) ⇒ assn))"
    (infixl "*a" 65)
    where "RR *a SS ≡ (prod_assn (fst RR) (fst SS), prod_assn (snd RR) (snd SS))"

  lemma hfprod_fst_snd[simp]:
    "fst (A *a B) = prod_assn (fst A) (fst B)" 
    "snd (A *a B) = prod_assn (snd A) (snd B)" 
    unfolding hfprod_def by auto



  subsubsection ‹Conversion from fref to hfref›  
   (* TODO: Variant of import-param! Automate this! *)
(* never used
lemma fref_to_pure_hfref':
    assumes "(f,g) ∈ [P]f R→⟨S⟩nrest_rel"
    assumes "⋀x. x∈Domain R ∩ R¯``Collect P ⟹ f x = RETURNT (f' x)"
    shows "(ureturn o f', g) ∈ [P]a (pure R)k→pure S"
    apply (rule hfrefI) 
    unfolding hn_refine_def apply (auto simp add: execute_ureturn' pure_assn_rule)
    subgoal         
      using assms
      apply ((auto simp: fref_def pure_def pw_le_iff pw_nrest_rel_iff
            refine_pw_simps  ))
        sorry
    subgoal by (simp add: relH_def)
    done 
*)

  subsubsection ‹Conversion from hfref to hnr›  
  text ‹This section contains the lemmas. The ML code is further down. ›
  lemma hf2hnr:
    assumes "(f,g) ∈ [P]a R → S"
    shows "∀x xi. P x ⟶ hn_refine (emp * hn_ctxt (fst R) x xi) (f$xi) (emp * hn_ctxt (snd R) x xi) S (g$x)"
    using assms
    unfolding hfref_def 
    by (auto simp: hn_ctxt_def)

  (*lemma hf2hnr_new:
    assumes "(f,g) ∈ [P]a R → S"
    shows "∀x xi. (∀h. h⊨fst R x xi ⟶ P x) ⟶ hn_refine (emp * hn_ctxt (fst R) x xi) (f xi) (emp * hn_ctxt (snd R) x xi) S (g$x)"
    using assms
    unfolding hfref_def 
    by (auto simp: hn_ctxt_def intro: hn_refine_preI)
  *)


  (* Products that stem from currying are tagged by a special refinement relation *)  
  definition [simp]: "to_hnr_prod ≡ prod_assn"

  lemma to_hnr_prod_fst_snd:
    "fst (A *a B) = to_hnr_prod (fst A) (fst B)" 
    "snd (A *a B) = to_hnr_prod (snd A) (snd B)" 
    unfolding hfprod_def by auto

  (* Warning: This lemma is carefully set up to be applicable as an unfold rule,
    for more than one level of uncurrying*)
  lemma hnr_uncurry_unfold: "
    (∀x xi. P x ⟶ 
      hn_refine 
        (Γ * hn_ctxt (to_hnr_prod A B) x xi) 
        (fi xi) 
        (Γ' * hn_ctxt (to_hnr_prod A' B') x xi) 
        R 
        (f x))
⟷ (∀b bi a ai. P (a,b) ⟶
      hn_refine 
        (Γ * hn_ctxt B b bi * hn_ctxt A a ai) 
        (fi (ai,bi)) 
        (Γ' * hn_ctxt B' b bi * hn_ctxt A' a ai)
        R
        (f (a,b))
    )"
    by (auto simp: hn_ctxt_def prod_assn_def star_aci)
    
  lemma hnr_intro_dummy:
    "∀x xi. P x ⟶ hn_refine (Γ x xi) (c xi) (Γ' x xi) R (a x) ⟹ ∀x xi. P x ⟶ hn_refine (emp*Γ x xi) (c xi) (emp*Γ' x xi) R (a x)" 
    by simp

  lemma hn_ctxt_ctxt_fix_conv: "hn_ctxt (hn_ctxt R) = hn_ctxt R"
    by (simp add: hn_ctxt_def[abs_def])

  lemma uncurry_APP: "uncurry f$(a,b) = f$a$b" by auto

  (* TODO: Replace by more general rule. *)  
  lemma norm_RETURNT_o: 
    "⋀f. (RETURNT o f)$x = (RETURNT$(f$x))"
    "⋀f. (RETURNT oo f)$x$y = (RETURNT$(f$x$y))"
    "⋀f. (RETURNT ooo f)$x$y$z = (RETURNT$(f$x$y$z))"
    "⋀f. (λx. RETURNT ooo f x)$x$y$z$a = (RETURNT$(f$x$y$z$a))"
    "⋀f. (λx y. RETURNT ooo f x y)$x$y$z$a$b = (RETURNT$(f$x$y$z$a$b))"
    by auto

  lemma norm_return_o: 
    "⋀f. (ureturn o f)$x = (ureturn$(f$x))"
    "⋀f. (ureturn oo f)$x$y = (ureturn$(f$x$y))"
    "⋀f. (ureturn ooo f)$x$y$z = (ureturn$(f$x$y$z))"
    "⋀f. (λx. ureturn ooo f x)$x$y$z$a = (ureturn$(f$x$y$z$a))"
    "⋀f. (λx y. ureturn ooo f x y)$x$y$z$a$b = (ureturn$(f$x$y$z$a$b))"
    by auto

  
  lemma hn_val_unit_conv_emp[simp]: "hn_val unit_rel x y = emp"
    by (auto simp: hn_ctxt_def pure_def)

  subsubsection ‹Conversion from hnr to hfref›  
  text ‹This section contains the lemmas. The ML code is further down. ›

  abbreviation "id_assn ≡ pure Id"
  abbreviation "unit_assn ≡ id_assn :: unit ⇒ _"

  lemma pure_unit_rel_eq_empty: "unit_assn x y = emp"  
    by (auto simp: pure_def)

  lemma uc_hfprod_sel:
    "fst (A *a B) a c = (case (a,c) of ((a1,a2),(c1,c2)) ⇒ fst A a1 c1 * fst B a2 c2)" 
    "snd (A *a B) a c = (case (a,c) of ((a1,a2),(c1,c2)) ⇒ snd A a1 c1 * snd B a2 c2)" 
    unfolding hfprod_def prod_assn_def[abs_def] by auto


  subsubsection ‹Conversion from relation to fref›  
  text ‹This section contains the lemmas. The ML code is further down. ›

  definition "CURRY R ≡ { (f,g). (uncurry f, uncurry g) ∈ R }"

  lemma fref_param1: "R→S = fref (λ_. True) R S"  
    by (auto simp: fref_def fun_relD)

  lemma fref_nest: "fref P1 R1 (fref P2 R2 S) 
    ≡ CURRY (fref (λ(a,b). P1 a ∧ P2 b) (R1×rR2) S)"
    apply (rule eq_reflection)
    by (auto simp: fref_def CURRY_def)

  lemma in_CURRY_conv: "(f,g) ∈ CURRY R ⟷ (uncurry f, uncurry g) ∈ R"  
    unfolding CURRY_def by auto

  lemma uncurry0_APP[simp]: "uncurry0 c $ x = c" by auto

  lemma fref_param0I: "(c,a)∈R ⟹ (uncurry0 c, uncurry0 a) ∈ fref (λ_. True) unit_rel R"
    by (auto simp: fref_def)

  subsubsection ‹Composition›
  definition hr_comp :: "('b ⇒ 'c ⇒ assn) ⇒ ('b × 'a) set ⇒ 'a ⇒ 'c ⇒ assn"
    ― ‹Compose refinement assertion with refinement relation›
    where "hr_comp R1 R2 a c ≡ ∃Ab. R1 b c * ↑((b,a)∈R2)"

  definition hrp_comp 
    :: "('d ⇒ 'b ⇒ assn) × ('d ⇒ 'c ⇒ assn)
        ⇒ ('d × 'a) set ⇒ ('a ⇒ 'b ⇒ assn) × ('a ⇒ 'c ⇒ assn)"
    ― ‹Compose argument assertion-pair with refinement relation›    
    where "hrp_comp RR' S ≡ (hr_comp (fst RR') S, hr_comp (snd RR') S) "

  lemma hr_compI: "(b,a)∈R2 ⟹ R1 b c ⟹A hr_comp R1 R2 a c"  
    unfolding hr_comp_def
    using entails_def entails_ex_post entails_pure_post by blast

  lemma hr_comp_Id1[simp]: "hr_comp (pure Id) R = pure R"  
    unfolding hr_comp_def[abs_def] pure_def
    apply (intro ext ent_iffI)    
     apply (smt BNF_Greatest_Fixpoint.IdD SepLog_Misc.mod_pure_star_dist entails_def entails_ex move_back_pure)
    by (smt IdI entails_equiv_backward entails_ex pure_conj)
    

  lemma hr_comp_Id2[simp]: "hr_comp R Id = R"  
    unfolding hr_comp_def[abs_def]
    apply (intro ext ent_iffI)
    apply (smt BNF_Greatest_Fixpoint.IdD SepLog_Misc.mod_pure_star_dist entails_def entails_ex)
    by (simp add: entailsI mod_ex_dist)
 
  (*lemma hr_comp_invalid[simp]: "hr_comp (λa c. true) R a c = true * ↑(∃b. (b,a)∈R)"
    unfolding hr_comp_def[abs_def]
    apply (intro ext ent_iffI)
    apply sep_auto+
    done*)
    

lemma isolate_first: "⋀A B C. Γ ⟹A Γ' ⟹ A ⟹A B ⟹ Γ * A ⟹A Γ' * B"  
  by (simp add: ent_star_mono)  


  lemma hr_comp_emp[simp]: "hr_comp (λa c. emp) R a c = ↑(∃b. (b,a)∈R)"
    unfolding hr_comp_def[abs_def]
    apply (intro ext ent_iffI)
    by (auto intro: ent_ex_postI ent_ex_preI) 


  lemma hr_comp_prod_conv[simp]:
    "hr_comp (prod_assn Ra Rb) (Ra' ×r Rb') 
    = prod_assn (hr_comp Ra Ra') (hr_comp Rb Rb')"  
    unfolding hr_comp_def[abs_def] unfolding prod_assn_def[abs_def]
    apply (intro ext ent_iffI)
     apply (auto intro!: ent_ex_preI)
    subgoal apply(intro ent_ex_postI)
      apply (simp only: mult.assoc)
        apply (rule match_first)
        apply (rule match_rest) 
      by simp
    subgoal for a b aa ba bb bc
      apply(intro ent_ex_postI[where x="(bc,bb)"])
      by (auto simp: mult.assoc)  
    done
 
lemma pure_entails: "(P⟹Q) ⟹ ↑ P ⟹A ↑ Q"  
  using entails_pure' entails_triv by blast  

lemma ex_pure: "(∃Ab. ↑ (B b)) = ↑ (∃b. B b)"
  apply(rule assn_ext) by(simp add: mod_ex_dist pure_assn_rule)  

  lemma hr_comp_pure: "hr_comp (pure R) S = pure (R O S)"  
    apply (intro ext)
    apply (rule ent_iffI)
    unfolding hr_comp_def[abs_def] 
     apply (auto  intro!: ent_ex_preI simp: ex_pure pure_def)
    done                        

  lemma hr_comp_is_pure[safe_constraint_rules]: "is_pure A ⟹ is_pure (hr_comp A B)"
    by (auto simp: hr_comp_pure is_pure_conv)

  lemma hr_comp_the_pure: "is_pure A ⟹ the_pure (hr_comp A B) = the_pure A O B"
    unfolding is_pure_conv
    by (clarsimp simp: hr_comp_pure)

  lemma rdomp_hrcomp_conv: "rdomp (hr_comp A R) x ⟷ (∃y. rdomp A y ∧ (y,x)∈R)"
    by (auto simp: rdomp_def hr_comp_def mod_ex_dist)
      

  lemma ret_le_down_conv: 
    "nofailT m ⟹ RETURNT c ≤ ⇓R m ⟷ (∃a. (c,a)∈R ∧ RETURNT a ≤ m)"
    by (auto simp: pw_le_iff refine_pw_simps) 

lemma entails_pure'': "(B ⟹ A ⟹A C) ⟹ A * ↑ B ⟹A C" 
  using entails_pure by blast 

lemma entails_pure''': "(emp ⟹A ↑ B) = B"  
  by (metis ent_iffI pure_assn_eq_conv pure_entails pure_true)   

  lemma hn_rel_compI: 
    "⟦nofailT a; (b,a)∈⟨R2⟩nrest_rel⟧ ⟹ hn_rel R1 b c ⟹A hn_rel (hr_comp R1 R2) a c"
    unfolding hr_comp_def hn_rel_def nrest_rel_def
    apply (clarsimp intro!: ent_ex_preI entails_pure'' simp:  pure_conj[symmetric] del: pure_conj) 
    apply (drule (1)  order_trans) 
    apply (simp add: ret_le_down_conv) apply auto
    apply (rule ent_ex_postI)
    apply (rule ent_ex_postI)
    apply (rule match_rest) by auto 

  lemma hr_comp_precise[constraint_rules]:
    assumes [safe_constraint_rules]: "precise R"
    assumes SV: "single_valued S"
    shows "precise (hr_comp R S)"
    apply (rule preciseI)
    unfolding hr_comp_def
    apply clarsimp 
    by (smt SV mod_pure_star_dist and_assn_conv assms(1) assn_times_assoc mod_ex_dist mod_starD preciseD' single_valuedD)  

  lemma hr_comp_assoc: "hr_comp (hr_comp R S) T = hr_comp R (S O T)"
    apply (intro ext)
    unfolding hr_comp_def
    apply (rule ent_iffI; clarsimp) (*
    apply sep_auto
    apply (rule ent_ex_preI; clarsimp) (* TODO: 
      sep_auto/solve_entails is too eager splitting the subgoal here! *)
    apply sep_auto
    done *) 
    subgoal by (smt SepLog_Misc.mod_pure_star_dist entailsI mod_ex_dist relcomp.relcompI) 
    subgoal 
      apply (clarsimp intro!: ent_ex_preI entails_pure'' simp:  pure_conj[symmetric] del: pure_conj)
      apply (rule ent_ex_postI)
      apply (rule ent_ex_postI)
      apply (rule match_rest) by auto 
    done


  lemma hnr_comp:
    assumes R: "⋀b1 c1. P b1 ⟹ hn_refine (R1 b1 c1 * Γ) (c c1) (R1p b1 c1 * Γ') R (b b1)"
    assumes S: "⋀a1 b1. ⟦Q a1; (b1,a1)∈R1'⟧ ⟹ (b b1,a a1)∈⟨R'⟩nrest_rel"
    assumes PQ: "⋀a1 b1. ⟦Q a1; (b1,a1)∈R1'⟧ ⟹ P b1"
    assumes Q: "Q a1"
    shows "hn_refine 
      (hr_comp R1 R1' a1 c1 * Γ) 
      (c c1)
      (hr_comp R1p R1' a1 c1 * Γ') 
      (hr_comp R R') 
      (a a1)" 
    unfolding hn_refine_def
  proof clarsimp
    fix h as n M
    assume anofail: "a a1 = SPECT M"
    then have anofail': "nofailT (a a1)" by auto
    assume "pHeap h as n ⊨ hr_comp R1 R1' a1 c1 * Γ"
    then obtain b1 where pb: "pHeap h as n ⊨ R1 b1 c1  * Γ * ↑ ((b1, a1) ∈ R1')" 
      unfolding hr_comp_def ex_distrib_star[symmetric] move_back_pure' mod_ex_dist by blast
    then have h: "pHeap h as n ⊨ R1 b1 c1 * Γ" and b1: "(b1, a1) ∈ R1'"   
      by auto

    from b1 PQ Q have P: "P b1" by auto
    from S Q b1 have R': "(b b1,a a1)∈⟨R'⟩nrest_rel" by auto
    with anofail have nfbb: "nofailT (b b1)" apply(auto dest!: nrest_relD) 
      apply(cases "b b1") by auto
    then obtain M' where SPbb: "b b1 = SPECT M'" by force

    from R' anofail SPbb have over: "SPECT M' ≤ ⇓ R' (SPECT M)" unfolding nrest_rel_def by simp

    from R[OF P] h nfbb SPbb have "
                   (∃h' t r.
                       execute (c c1) h = Some (r, h', t) ∧
                       (∃ra Ca.
                           Some (enat Ca) ≤ M' ra ∧
                           t ≤ n + Ca ∧ pHeap h' (new_addrs h as h') (n + Ca - t) ⊨ R1p b1 c1 * Γ' * R ra r * true) ∧
                       relH {a. a < heap.lim h ∧ a ∉ as} h h' ∧ heap.lim h ≤ heap.lim h')"
      unfolding hn_refine_def by auto
    then obtain h' t r ra' Ca'
      where "execute (c c1) h = Some (r, h', t)"
            and t: "Some (enat Ca') ≤ M' ra'" "t ≤ n + Ca'" 
           and h': "pHeap h' (new_addrs h as h') (n + Ca' - t) ⊨ R1p b1 c1 * Γ' * R ra' r * true"
       and     "relH {a. a < heap.lim h ∧ a ∉ as} h h'" "heap.lim h ≤ heap.lim h'" by blast

    from over have f: "⋀ra. M' ra ≤ Sup {M a |a. (ra, a) ∈ R'}"  unfolding conc_fun_def by (auto simp: le_fun_def)

    from t(1) have "M' ra' > None"  
      by (metis le_some_optE less_option_None_Some_code)   
    with f[of ra'] have "None < Sup {M a |a. (ra', a) ∈ R'}" by auto
    then have " {M a |a. (ra', a) ∈ R'} ≠ {}" by force

    have "∃ra. (ra', ra) ∈ R' ∧ M ra ≥ Some (enat Ca')"
    proof (rule ccontr)
      assume a: "∄ra. (ra', ra) ∈ R' ∧ Some (enat Ca') ≤ M ra"
      hence "⋀ra. (ra', ra) ∈ R' ⟹ Some (enat Ca') > M ra" by auto
      then have "Sup {M a |a. (ra', a) ∈ R'} <  Some (enat Ca')"         
        by (smt Sup_finite_enat Sup_least dual_order.antisym linear mem_Collect_eq not_less)
      with t(1) f[of ra'] show "False" by auto
    qed

    then obtain ra where R'': "(ra', ra)∈ R'" and t': "M ra ≥ Some (enat Ca')" by blast


    have "pHeap h' (new_addrs h as h') (n + Ca' - t) ⊨ (∃Ab. R1p b c1 * ↑ ((b, a1) ∈ R1')) * Γ' * (∃Ab. R b r * ↑ ((b, ra) ∈ R')) * true"
      apply(rule entailsD[OF _ h'])
      apply(simp add: ex_distrib_star[symmetric])
      apply(rule ent_ex_postI)
      apply(rule ent_ex_postI)
      apply(simp only: mult.assoc ) apply(rule match_first) 
      apply(rule match_first) apply(rule match_first)  apply(rule match_rest) 
      using b1 R'' by auto

    thm hn_rel_compI[OF anofail' R']
    thm hr_compI[OF b1]

    show "∃h' t r.
          execute (c c1) h = Some (r, h', t) ∧
          (∃ra Ca.
              Some (enat Ca) ≤ M ra ∧
              t ≤ n + Ca ∧
              pHeap h' (new_addrs h as h') (n + Ca - t) ⊨ hr_comp R1p R1' a1 c1 * Γ' * hr_comp R R' ra r * true) ∧
          relH {a. a < heap.lim h ∧ a ∉ as} h h' ∧ heap.lim h ≤ heap.lim h'"
      apply(rule exI[where x=h'])
      apply(rule exI[where x=t])
      apply(rule exI[where x=r])
      apply safe apply fact      
      subgoal unfolding hr_comp_def  
       apply(rule exI[where x=ra]) 
        apply(rule exI[where x=Ca'])
        apply safe 
        by fact+
      by fact+
  qed
 
 (*
    unfolding hn_refine_alt
  proof clarsimp
    assume NF: "nofail (a a1)"
    show "
      <hr_comp R1 R1' a1 c1 * Γ> 
        c c1 
      <λr. hn_rel (hr_comp R R') (a a1) r * (hr_comp R1p R1' a1 c1 * Γ')>t"
      apply (subst hr_comp_def)
      apply (clarsimp intro!: norm_pre_ex_rule)
    proof -
      fix b1
      assume R1: "(b1, a1) ∈ R1'"

      from S R1 Q have R': "(b b1, a a1) ∈ ⟨R'⟩nres_rel" by blast
      with NF have NFB: "nofail (b b1)" 
        by (simp add: nres_rel_def pw_le_iff refine_pw_simps)
      
      from PQ R1 Q have P: "P b1" by blast
      with NFB R have "<R1 b1 c1 * Γ> c c1 <λr. hn_rel R (b b1) r * (R1p b1 c1 * Γ')>t"
        unfolding hn_refine_alt by auto
      thus "<R1 b1 c1 * Γ> 
        c c1 
        <λr. hn_rel (hr_comp R R') (a a1) r * (hr_comp R1p R1' a1 c1 * Γ')>t"
        apply (rule cons_post_rule)
        apply (solve_entails)
        by (intro ent_star_mono hn_rel_compI[OF NF R'] hr_compI[OF R1] ent_refl)
    qed
  qed    *)  

  lemma hnr_comp1_aux:
    assumes R: "⋀b1 c1. P b1 ⟹ hn_refine (hn_ctxt R1 b1 c1) (c c1) (hn_ctxt R1p b1 c1) R (b$b1)"
    assumes S: "⋀a1 b1. ⟦Q a1; (b1,a1)∈R1'⟧ ⟹ (b$b1,a$a1)∈⟨R'⟩nrest_rel"
    assumes PQ: "⋀a1 b1. ⟦Q a1; (b1,a1)∈R1'⟧ ⟹ P b1"
    assumes Q: "Q a1"
    shows "hn_refine 
      (hr_comp R1 R1' a1 c1) 
      (c c1)
      (hr_comp R1p R1' a1 c1) 
      (hr_comp R R') 
      (a a1)"
    using assms hnr_comp[where Γ=emp and Γ'=emp and a=a and b=b and c=c and P=P and Q=Q]  
    unfolding hn_ctxt_def
    by auto

  lemma hfcomp:
    assumes A: "(f,g) ∈ [P]a RR' → S"
    assumes B: "(g,h) ∈ [Q]f T → ⟨U⟩nrest_rel"
    shows "(f,h) ∈ [λa. Q a ∧ (∀a'. (a',a)∈T ⟶ P a')]a 
      hrp_comp RR' T → hr_comp S U"
    using assms  
    unfolding fref_def hfref_def hrp_comp_def
    apply clarsimp
    apply (rule hnr_comp1_aux[of 
        P "fst RR'" f "snd RR'" S g "λa. Q a ∧ (∀a'. (a',a)∈T ⟶ P a')" T h U])
    apply (auto simp: hn_ctxt_def)
    done

  lemma hfref_weaken_pre_nofail: 
    assumes "(f,g) ∈ [P]a R → S"  
    shows "(f,g) ∈ [λx. nofailT (g x) ⟶ P x]a R → S"
    using assms
    unfolding hfref_def hn_refine_def
    by auto

  lemma hfref_cons:
    assumes "(f,g) ∈ [P]a R → S"
    assumes "⋀x. P' x ⟹ P x"
    assumes "⋀x y. fst R' x y ⟹t fst R x y"
    assumes "⋀x y. snd R x y ⟹t snd R' x y"
    assumes "⋀x y. S x y ⟹t S' x y"
    shows "(f,g) ∈ [P']a R' → S'"
    unfolding hfref_def
    apply clarsimp
    apply (rule hn_refine_cons)
    apply (rule assms(3))
    defer
    apply (rule entt_trans[OF assms(4)]; auto)
    apply (rule assms(5))
    apply (frule assms(2))
    using assms(1)
    unfolding hfref_def
    apply auto
    done

  subsubsection ‹Composition Automation›  
  text ‹This section contains the lemmas. The ML code is further down. ›

  lemma prod_hrp_comp: 
    "hrp_comp (A *a B) (C ×r D) = hrp_comp A C *a hrp_comp B D"
    unfolding hrp_comp_def hfprod_def by simp
  
  lemma hrp_comp_keep: "hrp_comp (Ak) B = (hr_comp A B)k"
    by (auto simp: hrp_comp_def)

  lemma hr_comp_invalid: "hr_comp (invalid_assn R1) R2 = invalid_assn (hr_comp R1 R2)"
    apply (intro ent_iffI entailsI ext)
    unfolding invalid_assn_def hr_comp_def by(auto simp add: mod_ex_dist) 

  lemma hrp_comp_dest: "hrp_comp (Ad) B = (hr_comp A B)d"
    by (auto simp: hrp_comp_def hr_comp_invalid)



  definition "hrp_imp RR RR' ≡ 
    ∀a b. (fst RR' a b ⟹t fst RR a b) ∧ (snd RR a b ⟹t snd RR' a b)"

  lemma hfref_imp: "hrp_imp RR RR' ⟹ [P]a RR → S ⊆ [P]a RR' → S"  
    apply clarsimp
    apply (erule hfref_cons)
    apply (simp_all add: hrp_imp_def)
    done
    
  lemma hrp_imp_refl: "hrp_imp RR RR"
    unfolding hrp_imp_def by auto

  lemma hrp_imp_reflI: "RR = RR' ⟹ hrp_imp RR RR'"
    unfolding hrp_imp_def by auto

  lemma fe: "B ⟹ A ⟹A C ⟹  A ⟹A C * ↑B" 
    by simp

  lemma hrp_comp_cong: "hrp_imp A A' ⟹ B=B' ⟹ hrp_imp (hrp_comp A B) (hrp_comp A' B')"
    by (auto intro!: ent_ex_postI ent_ex_preI entails_pure'' 
          simp:  hrp_imp_def hrp_comp_def hr_comp_def entailst_def) 

  lemma hrp_prod_cong: "hrp_imp A A' ⟹ hrp_imp B B' ⟹ hrp_imp (A*aB) (A'*aB')"
    by (auto simp: hrp_imp_def prod_assn_def intro: entt_star_mono)
    
  lemma hrp_imp_trans: "hrp_imp A B ⟹ hrp_imp B C ⟹ hrp_imp A C"  
    unfolding hrp_imp_def
    by (fastforce intro: entt_trans)

  lemma fcomp_norm_dflt_init: "x∈[P]a R → T ⟹ hrp_imp R S ⟹ x∈[P]a S → T"
    apply (erule set_rev_mp)
    by (rule hfref_imp)

  definition "comp_PRE R P Q S ≡ λx. S x ⟶ (P x ∧ (∀y. (y,x)∈R ⟶ Q x y))"

  lemma comp_PRE_cong[cong]: 
    assumes "R≡R'"
    assumes "⋀x. P x ≡ P' x"
    assumes "⋀x. S x ≡ S' x"
    assumes "⋀x y. ⟦P x; (y,x)∈R; y∈Domain R; S' x ⟧ ⟹ Q x y ≡ Q' x y"
    shows "comp_PRE R P Q S ≡ comp_PRE R' P' Q' S'"
    using assms
    by (fastforce simp: comp_PRE_def intro!: eq_reflection ext)

  lemma fref_compI_PRE:
    "⟦ (f,g)∈fref P R1 R2; (g,h)∈fref Q S1 S2 ⟧ 
      ⟹ (f,h) ∈ fref (comp_PRE S1 Q (λ_. P) (λ_. True)) (R1 O S1) (R2 O S2)"
    using fref_compI[of P R1 R2 Q S1 S2]   
    unfolding comp_PRE_def
    by auto

  lemma PRE_D1: "(Q x ∧ P x) ⟶ comp_PRE S1 Q (λx _. P x) S x"
    by (auto simp: comp_PRE_def)

  lemma PRE_D2: "(Q x ∧ (∀y. (y,x)∈S1 ⟶ S x ⟶ P x y)) ⟶ comp_PRE S1 Q P S x"
    by (auto simp: comp_PRE_def)

  lemma fref_weaken_pre: 
    assumes "⋀x. P x ⟶ P' x"  
    assumes "(f,h) ∈ fref P' R S"
    shows "(f,h) ∈ fref P R S"
    apply (rule set_rev_mp[OF assms(2) fref_mono])
    using assms(1) by auto
    
  lemma fref_PRE_D1:
    assumes "(f,h) ∈ fref (comp_PRE S1 Q (λx _. P x) X) R S"  
    shows "(f,h) ∈ fref (λx. Q x ∧ P x) R S"
    by (rule fref_weaken_pre[OF PRE_D1 assms])

  lemma fref_PRE_D2:
    assumes "(f,h) ∈ fref (comp_PRE S1 Q P X) R S"  
    shows "(f,h) ∈ fref (λx. Q x ∧ (∀y. (y,x)∈S1 ⟶ X x ⟶ P x y)) R S"
    by (rule fref_weaken_pre[OF PRE_D2 assms])

  lemmas fref_PRE_D = fref_PRE_D1 fref_PRE_D2

  lemma hfref_weaken_pre: 
    assumes "⋀x. P x ⟶ P' x"  
    assumes "(f,h) ∈ hfref P' R S"
    shows "(f,h) ∈ hfref P R S"
    using assms
    by (auto simp: hfref_def)

  lemma hfref_weaken_pre': 
    assumes "⋀x. ⟦P x; rdomp (fst R) x⟧ ⟹ P' x"  
    assumes "(f,h) ∈ hfref P' R S"
    shows "(f,h) ∈ hfref P R S"
    apply (rule hfrefI)
    apply (rule hn_refine_preI)
    using assms
    by (auto simp: hfref_def rdomp_def)

  lemma hfref_weaken_pre_nofail': 
    assumes "(f,g) ∈ [P]a R → S"  
    assumes "⋀x. ⟦nofailT (g x); Q x⟧ ⟹ P x"
    shows "(f,g) ∈ [Q]a R → S"
    apply (rule hfref_weaken_pre[OF _ assms(1)[THEN hfref_weaken_pre_nofail]])
    using assms(2) 
    by blast

  lemma hfref_compI_PRE_aux:
    assumes A: "(f,g) ∈ [P]a RR' → S"
    assumes B: "(g,h) ∈ [Q]f T → ⟨U⟩nrest_rel"
    shows "(f,h) ∈ [comp_PRE T Q (λ_. P) (λ_. True)]a 
      hrp_comp RR' T → hr_comp S U"
    apply (rule hfref_weaken_pre[OF _ hfcomp[OF A B]])
    by (auto simp: comp_PRE_def)


  lemma hfref_compI_PRE:
    assumes A: "(f,g) ∈ [P]a RR' → S"
    assumes B: "(g,h) ∈ [Q]f T → ⟨U⟩nrest_rel"
    shows "(f,h) ∈ [comp_PRE T Q (λx y. P y) (λx. nofailT (h x))]a 
      hrp_comp RR' T → hr_comp S U"
    using hfref_compI_PRE_aux[OF A B, THEN hfref_weaken_pre_nofail]  
    apply (rule hfref_weaken_pre[rotated])
    apply (auto simp: comp_PRE_def)
    done

  lemma hfref_PRE_D1:
    assumes "(f,h) ∈ hfref (comp_PRE S1 Q (λx _. P x) X) R S"  
    shows "(f,h) ∈ hfref (λx. Q x ∧ P x) R S"
    by (rule hfref_weaken_pre[OF PRE_D1 assms])

  lemma hfref_PRE_D2:
    assumes "(f,h) ∈ hfref (comp_PRE S1 Q P X) R S"  
    shows "(f,h) ∈ hfref (λx. Q x ∧ (∀y. (y,x)∈S1 ⟶ X x ⟶ P x y)) R S"
    by (rule hfref_weaken_pre[OF PRE_D2 assms])

  lemma hfref_PRE_D3:
    assumes "(f,h) ∈ hfref (comp_PRE S1 Q P X) R S"  
    shows "(f,h) ∈ hfref (comp_PRE S1 Q P X) R S"
    using assms .

  lemmas hfref_PRE_D = hfref_PRE_D1 hfref_PRE_D3

  subsection ‹Automation›  
  text ‹Purity configuration for constraint solver›
  lemmas [safe_constraint_rules] = pure_pure

  text ‹Configuration for hfref to hnr conversion›
  named_theorems to_hnr_post ‹to_hnr converter: Postprocessing unfold rules›

  lemma uncurry0_add_app_tag: "uncurry0 (RETURNT c) = uncurry0 (RETURNT$c)" by simp

  lemmas [to_hnr_post] = norm_RETURNT_o norm_return_o
    uncurry0_add_app_tag uncurry0_apply uncurry0_APP hn_val_unit_conv_emp
    mult_1[of "x::assn" for x] mult_1_right[of "x::assn" for x]

  named_theorems to_hfref_post ‹to_hfref converter: Postprocessing unfold rules› 
  lemma prod_casesK[to_hfref_post]: "case_prod (λ_ _. k) = (λ_. k)" by auto
  lemma uncurry0_hfref_post[to_hfref_post]: "hfref (uncurry0 True) R S = hfref (λ_. True) R S" 
    apply (fo_rule arg_cong fun_cong)+ by auto


  (* Currently not used, we keep it in here anyway. *)  
  text ‹Configuration for relation normalization after composition›
  named_theorems fcomp_norm_unfold ‹fcomp-normalizer: Unfold theorems›
  named_theorems fcomp_norm_simps ‹fcomp-normalizer: Simplification theorems›
  named_theorems fcomp_norm_init "fcomp-normalizer: Initialization rules"  
  named_theorems fcomp_norm_trans "fcomp-normalizer: Transitivity rules"  
  named_theorems fcomp_norm_cong "fcomp-normalizer: Congruence rules"  
  named_theorems fcomp_norm_norm "fcomp-normalizer: Normalization rules"  
  named_theorems fcomp_norm_refl "fcomp-normalizer: Reflexivity rules"  

  text ‹Default Setup›
  lemmas [fcomp_norm_unfold] = prod_rel_comp nrest_rel_comp Id_O_R R_O_Id
  lemmas [fcomp_norm_unfold] = hr_comp_Id1 hr_comp_Id2
  lemmas [fcomp_norm_unfold] = hr_comp_prod_conv
  lemmas [fcomp_norm_unfold] = prod_hrp_comp hrp_comp_keep hrp_comp_dest hr_comp_pure
  (*lemmas [fcomp_norm_unfold] = prod_casesK uncurry0_hfref_post*)

  lemma [fcomp_norm_simps]: "CONSTRAINT is_pure P ⟹ pure (the_pure P) = P" by simp
  lemmas [fcomp_norm_simps] = True_implies_equals 

  lemmas [fcomp_norm_init] = fcomp_norm_dflt_init
  lemmas [fcomp_norm_trans] = hrp_imp_trans
  lemmas [fcomp_norm_cong] = hrp_comp_cong hrp_prod_cong
  (*lemmas [fcomp_norm_norm] = hrp_comp_dest*)
  lemmas [fcomp_norm_refl] = refl hrp_imp_refl
 
    

  lemma ensure_fref_nresI: "(f,g)∈[P]f R→S ⟹ (RETURNT o f, RETURNT o g)∈[P]f R→⟨S⟩nrest_rel" 
    by (auto intro!: RETURNT_refine nrest_relI simp: fref_def)
     

  lemma ensure_fref_nres_unfold:
    "⋀f. RETURNT o (uncurry0 f) = uncurry0 (RETURNT f)" 
    "⋀f. RETURNT o (uncurry f) = uncurry (RETURNT oo f)"
    "⋀f. (RETURNT ooo uncurry) f = uncurry (RETURNT ooo f)"
    by auto

  text ‹Composed precondition normalizer›  
  named_theorems fcomp_prenorm_simps ‹fcomp precondition-normalizer: Simplification theorems›

  text ‹Support for preconditions of the form @{text "_∈Domain R"}, 
    where @{text R} is the relation of the next more abstract level.›
  declare DomainI[fcomp_prenorm_simps]

  lemma auto_weaken_pre_init_hf: 
    assumes "⋀x. PROTECT P x ⟶ P' x"  
    assumes "(f,h) ∈ hfref P' R S"
    shows "(f,h) ∈ hfref P R S"
    using assms
    by (auto simp: hfref_def)

  lemma auto_weaken_pre_init_f: 
    assumes "⋀x. PROTECT P x ⟶ P' x"  
    assumes "(f,h) ∈ fref P' R S"
    shows "(f,h) ∈ fref P R S"
    using assms
    by (auto simp: fref_def)

  lemmas auto_weaken_pre_init = auto_weaken_pre_init_hf auto_weaken_pre_init_f  

  lemma auto_weaken_pre_uncurry_step:
    assumes "PROTECT f a ≡ f'"
    shows "PROTECT (λ(x,y). f x y) (a,b) ≡ f' b" 
    using assms
    by (auto simp: curry_def dest!: meta_eq_to_obj_eq intro!: eq_reflection)

  lemma auto_weaken_pre_uncurry_finish:  
    "PROTECT f x ≡ f x" by (auto)

  lemma auto_weaken_pre_uncurry_start:
    assumes "P ≡ P'"
    assumes "P'⟶Q"
    shows "P⟶Q"
    using assms by (auto)

  lemma auto_weaken_pre_comp_PRE_I:
    assumes "S x ⟹ P x"
    assumes "⋀y. ⟦(y,x)∈R; P x; S x⟧ ⟹ Q x y"
    shows "comp_PRE R P Q S x"
    using assms by (auto simp: comp_PRE_def)

  lemma auto_weaken_pre_to_imp_nf:
    "(A⟶B⟶C) = (A∧B ⟶ C)"
    "((A∧B)∧C) = (A∧B∧C)"
    by auto

  lemma auto_weaken_pre_add_dummy_imp:
    "P ⟹ True ⟶ P" by simp


  text ‹Synthesis for hfref statements›  
  definition hfsynth_ID_R :: "('a ⇒ _ ⇒ assn) ⇒ 'a ⇒ bool" where
    [simp]: "hfsynth_ID_R _ _ ≡ True"

  lemma hfsynth_ID_R_D:
    fixes I :: "'a itself"
    assumes "hfsynth_ID_R R a"
    assumes "intf_of_assn R I"
    shows "a ::i I"
    by simp

  lemma hfsynth_hnr_from_hfI:
    assumes "∀x xi. P x ∧ hfsynth_ID_R (fst R) x ⟶ hn_refine (emp * hn_ctxt (fst R) x xi) (f$xi) (emp * hn_ctxt (snd R) x xi) S (g$x)"
    shows "(f,g) ∈ [P]a R → S"
    using assms
    unfolding hfref_def 
    by (auto simp: hn_ctxt_def)


  lemma hfsynth_ID_R_uncurry_unfold: 
    "hfsynth_ID_R (to_hnr_prod R S) (a,b) ≡ hfsynth_ID_R R a ∧ hfsynth_ID_R S b" 
    "hfsynth_ID_R (fst (hf_pres R k)) ≡ hfsynth_ID_R R"
    by (auto intro!: eq_reflection)

  ML ‹

    signature SEPREF_RULES = sig
      (* Analysis of relations, both fref and fun_rel *)
      (* "R1→...→Rn→_" / "[_]f ((R1×rR2)...×rRn)"  ↦  "[R1,...,Rn]" *)
      val binder_rels: term -> term list 
      (* "_→...→_→S" / "[_]f _ → S"  ↦  "S" *)
      val body_rel: term -> term 
      (* Map →/fref to (precond,args,res). NONE if no/trivial precond. *)
      val analyze_rel: term -> term option * term list * term 
      (* Make trivial ("λ_. True") precond *)
      val mk_triv_precond: term list -> term 
      (* Make "[P]f ((R1×rR2)...×rRn) → S". Insert trivial precond if NONE. *)
      val mk_rel: term option * term list * term -> term 
      (* Map relation to (args,res) *)
      val strip_rel: term -> term list * term 

      (* Make hfprod (op *a) *)
      val mk_hfprod : term * term -> term
      val mk_hfprods : term list -> term

      (* Determine interface type of refinement assertion, using default fallback
        if necessary. Use named_thms intf_of_assn for configuration. *)
      val intf_of_assn : Proof.context -> term -> typ

      (*
        Convert a parametricity theorem in higher-order form to
        uncurried fref-form. For functions without arguments, 
        a unit-argument is added.

        TODO/FIXME: Currently this only works for higher-order theorems,
          i.e., theorems of the form (f,g)∈R1→…→Rn. 
          
          First-order theorems are silently treated as refinement theorems
          for functions with zero arguments, i.e., a unit-argument is added.
      *)
      val to_fref : Proof.context -> thm -> thm

      (* Convert a parametricity or fref theorem to first order form *)
      val to_foparam : Proof.context -> thm -> thm

      (* Convert schematic hfref goal to hnr-goal *)
      val prepare_hfref_synth_tac : Proof.context -> tactic'

      (* Convert theorem in hfref-form to hnr-form *)
      val to_hnr : Proof.context -> thm -> thm

      (* Convert theorem in hnr-form to hfref-form *)
      val to_hfref: Proof.context -> thm -> thm

      (* Convert theorem to given form, if not yet in this form *)
      val ensure_fref : Proof.context -> thm -> thm
      val ensure_fref_nres : Proof.context -> thm -> thm
      val ensure_hfref : Proof.context -> thm -> thm
      val ensure_hnr : Proof.context -> thm -> thm


      type hnr_analysis = {
        thm: thm,                     (* Original theorem, may be normalized *)
        precond: term,                (* Precondition, abstracted over abs-arguments *)
        prems : term list,            (* Premises not depending on arguments *)
        ahead: term * bool,           (* Abstract function, has leading RETURN *)
        chead: term * bool,           (* Concrete function, has leading return *)
        argrels: (term * bool) list,  (* Argument relations, preserved (keep-flag) *)
        result_rel: term              (* Result relation *)
      }
  
      val analyze_hnr: Proof.context -> thm -> hnr_analysis
      val pretty_hnr_analysis: Proof.context -> hnr_analysis -> Pretty.T
      val mk_hfref_thm: Proof.context -> hnr_analysis -> thm
  
  

      (* Simplify precondition of fref/hfref-theorem *)
      val simplify_precond: Proof.context -> thm -> thm

      (* Normalize hfref-theorem after composition *)
      val norm_fcomp_rule: Proof.context -> thm -> thm

      (* Replace "pure ?A" by "?A'" and is_pure constraint, then normalize *)
      val add_pure_constraints_rule: Proof.context -> thm -> thm

      (* Compose fref/hfref and fref theorem, to produce hfref theorem.
        The input theorems may also be in ho-param or hnr form, and
        are converted accordingly.
      *)
      val gen_compose : Proof.context -> thm -> thm -> thm

      (* FCOMP-attribute *)
      val fcomp_attrib: attribute context_parser
    end

    structure Sepref_Rules: SEPREF_RULES = struct

      local open Refine_Util Relators in
        fun binder_rels @{mpat "?F → ?G"} = F::binder_rels G
          | binder_rels @{mpat "fref _ ?F _"} = strip_prodrel_left F
          | binder_rels _ = []
    
        local 
          fun br_aux @{mpat "_ → ?G"} = br_aux G
            | br_aux R = R
        in    
          fun body_rel @{mpat "fref _ _ ?G"} = G
            | body_rel R = br_aux R
        end
    
        fun strip_rel R = (binder_rels R, body_rel R)   
    
        fun analyze_rel @{mpat "fref (λ_. True) ?R ?S"} = (NONE,strip_prodrel_left R,S)
          | analyze_rel @{mpat "fref ?P ?R ?S"} = (SOME P,strip_prodrel_left R,S)
          | analyze_rel R = let
              val (args,res) = strip_rel R
            in
              (NONE,args,res)
            end
    
        fun mk_triv_precond Rs = absdummy (map rel_absT Rs |> list_prodT_left) @{term True}
    
        fun mk_rel (P,Rs,S) = let 
          val R = list_prodrel_left Rs 
    
          val P = case P of 
              SOME P => P 
            | NONE => mk_triv_precond Rs
    
        in 
          @{mk_term "fref ?P ?R ?S"} 
        end
      end


      fun mk_hfprod (a, b) = @{mk_term "?a*a?b"}
  
      local 
        fun mk_hfprods_rev [] = @{mk_term "unit_assnk"}
          | mk_hfprods_rev [Rk] = Rk
          | mk_hfprods_rev (Rkn::Rks) = mk_hfprod (mk_hfprods_rev Rks, Rkn)
      in
        val mk_hfprods = mk_hfprods_rev o rev
      end


      fun intf_of_assn ctxt t = let
        val orig_ctxt = ctxt
        val (t,ctxt) = yield_singleton (Variable.import_terms false) t ctxt

        val v = TVar (("T",0),Proof_Context.default_sort ctxt ("T",0)) |> Logic.mk_type
        val goal = @{mk_term "Trueprop (intf_of_assn ?t ?v)"}

        val i_of_assn_rls = 
          Named_Theorems_Rev.get ctxt @{named_theorems_rev intf_of_assn}
          @ @{thms intf_of_assn_fallback}

        fun tac ctxt = REPEAT_ALL_NEW (resolve_tac ctxt i_of_assn_rls)

        val thm = Goal.prove ctxt [] [] goal (fn {context,...} => ALLGOALS (tac context))
        val intf = case Thm.concl_of thm of
            @{mpat "Trueprop (intf_of_assn _ (?v ASp TYPE (_)))"} => v 
          | _ => raise THM("Intf_of_assn: Proved a different theorem?",~1,[thm])

        val intf = singleton (Variable.export_terms ctxt orig_ctxt) intf
          |> Logic.dest_type

      in
        intf
      end

      datatype rthm_type = 
        RT_HOPARAM    (* (_,_) ∈ _ → … → _ *)
      | RT_FREF       (* (_,_) ∈ [_]f _ → _ *)
      | RT_HNR        (* hn_refine _ _ _ _ _ *)
      | RT_HFREF      (* (_,_) ∈ [_]a _ → _ *)
      | RT_OTHER

      fun rthm_type thm =
        case Thm.concl_of thm |> HOLogic.dest_Trueprop of
          @{mpat "(_,_) ∈ fref _ _ _"} => RT_FREF
        | @{mpat "(_,_) ∈ hfref _ _ _"} => RT_HFREF
        | @{mpat "hn_refine _ _ _ _ _"} => RT_HNR
        | @{mpat "(_,_) ∈ _"} => RT_HOPARAM (* TODO: Distinction between ho-param and fo-param *)
        | _ => RT_OTHER


      fun to_fref ctxt thm = let
        open Conv
      in  
        case Thm.concl_of thm |> HOLogic.dest_Trueprop of
          @{mpat "(_,_)∈_→_"} =>
            Local_Defs.unfold0 ctxt @{thms fref_param1} thm
            |> fconv_rule (repeat_conv (Refine_Util.ftop_conv (K (rewr_conv @{thm fref_nest})) ctxt))
            |> Local_Defs.unfold0 ctxt @{thms in_CURRY_conv}
        | @{mpat "(_,_)∈_"} => thm RS @{thm fref_param0I}   
        | _ => raise THM ("to_fref: Expected theorem of form (_,_)∈_",~1,[thm])
      end

      fun to_foparam ctxt thm = let
        val unf_thms = @{thms 
          split_tupled_all prod_rel_simp uncurry_apply cnv_conj_to_meta Product_Type.split}
      in
        case Thm.concl_of thm of
          @{mpat "Trueprop ((_,_) ∈ fref _ _ _)"} =>
            (@{thm frefD} OF [thm])
            |> forall_intr_vars
            |> Local_Defs.unfold0 ctxt unf_thms
            |> Variable.gen_all ctxt
        | @{mpat "Trueprop ((_,_) ∈ _)"} =>
            Parametricity.fo_rule thm
        | _ => raise THM("Expected parametricity or fref theorem",~1,[thm])
      end

      fun to_hnr ctxt thm =
        (thm RS @{thm hf2hnr})
        |> Local_Defs.unfold0 ctxt @{thms to_hnr_prod_fst_snd keep_drop_sels} (* Resolve fst and snd over *a and Rk, Rd *)
        |> Local_Defs.unfold0 ctxt @{thms hnr_uncurry_unfold} (* Resolve products for uncurried parameters *)
        |> Local_Defs.unfold0 ctxt @{thms uncurry_apply uncurry_APP assn_one_left split} (* Remove the uncurry modifiers, the emp-dummy, and unfold product cases *)
        |> Local_Defs.unfold0 ctxt @{thms hn_ctxt_ctxt_fix_conv} (* Remove duplicate hn_ctxt tagging *)
        |> Local_Defs.unfold0 ctxt @{thms all_to_meta imp_to_meta HOL.True_implies_equals HOL.implies_True_equals Pure.triv_forall_equality cnv_conj_to_meta} (* Convert to meta-level, remove vacuous condition *)
        |> Local_Defs.unfold0 ctxt (Named_Theorems.get ctxt @{named_theorems to_hnr_post}) (* Post-Processing *)
        |> Goal.norm_result ctxt
        |> Conv.fconv_rule Thm.eta_conversion

      (* Convert schematic hfref-goal to hn_refine goal *)  
      fun prepare_hfref_synth_tac ctxt = let
        val i_of_assn_rls = 
          Named_Theorems_Rev.get ctxt @{named_theorems_rev intf_of_assn}
          @ @{thms intf_of_assn_fallback}

        val to_hnr_post_rls = 
          Named_Theorems.get ctxt @{named_theorems to_hnr_post}

        val i_of_assn_tac = (
          REPEAT' (
            DETERM o dresolve_tac ctxt @{thms hfsynth_ID_R_D}
            THEN' DETERM o SOLVED' (REPEAT_ALL_NEW (resolve_tac ctxt i_of_assn_rls))
          )
        )
      in
        (* Note: To re-use the to_hnr infrastructure, we first work with
          $-tags on the abstract function, which are finally removed.
        *)
        resolve_tac ctxt @{thms hfsynth_hnr_from_hfI} THEN_ELSE' (
          SELECT_GOAL (
            unfold_tac ctxt @{thms to_hnr_prod_fst_snd keep_drop_sels hf_pres_fst} (* Distribute fst,snd over product and hf_pres *)
            THEN unfold_tac ctxt @{thms hnr_uncurry_unfold hfsynth_ID_R_uncurry_unfold} (* Curry parameters *)
            THEN unfold_tac ctxt @{thms uncurry_apply uncurry_APP assn_one_left split} (* Curry parameters (II) and remove emp assertion *)
            (*THEN unfold_tac ctxt @{thms hn_ctxt_ctxt_fix_conv} (* Remove duplicate hn_ctxt (Should not be necessary) *)*)
            THEN unfold_tac ctxt @{thms all_to_meta imp_to_meta HOL.True_implies_equals HOL.implies_True_equals Pure.triv_forall_equality cnv_conj_to_meta} (* Convert precondition to meta-level *)
            THEN ALLGOALS i_of_assn_tac (* Generate _::i_ premises*)
            THEN unfold_tac ctxt to_hnr_post_rls (* Postprocessing *)
            THEN unfold_tac ctxt @{thms APP_def} (* Get rid of $ - tags *)
          )
        ,
          K all_tac
        )
      end


      (************************************)  
      (* Analyze hnr *)
      structure Termtab2 = Table(
        type key = term * term 
        val ord = prod_ord Term_Ord.fast_term_ord Term_Ord.fast_term_ord);
  
      type hnr_analysis = {
        thm: thm,                     
        precond: term,                
        prems : term list,
        ahead: term * bool,           
        chead: term * bool,           
        argrels: (term * bool) list,  
        result_rel: term              
      }
  
    
      fun analyze_hnr (ctxt:Proof.context) thm = let
    
        (* Debug information: Stores string*term pairs, which are pretty-printed on error *)
        val dbg = Unsynchronized.ref []
        fun add_dbg msg ts = (
          dbg := (msg,ts) :: !dbg;
          ()
        )
        fun pretty_dbg (msg,ts) = Pretty.block [
          Pretty.str msg,
          Pretty.str ":",
          Pretty.brk 1,
          Pretty.list "[" "]" (map (Syntax.pretty_term ctxt) ts)
        ]
        fun pretty_dbgs l = map pretty_dbg l |> Pretty.fbreaks |> Pretty.block
    
        fun trace_dbg msg = Pretty.block [Pretty.str msg, Pretty.fbrk, pretty_dbgs (rev (!dbg))] |> Pretty.string_of |> tracing
    
        fun fail msg = (trace_dbg msg; raise THM(msg,~1,[thm])) 
        fun assert cond msg = cond orelse fail msg;
    
    
        (* Heads may have a leading return/RETURN.
          The following code strips off the leading return, unless it has the form
          "return x" for an argument x
        *)
        fun check_strip_leading args t f = (* Handle the case RETURN x, where x is an argument *)
          if Termtab.defined args f then (t,false) else (f,true)
    
        fun strip_leading_RETURN args (t as @{mpat "RETURNT$(?f)"}) = check_strip_leading args t f
          | strip_leading_RETURN args (t as @{mpat "RETURNT ?f"}) = check_strip_leading args t f
          | strip_leading_RETURN _ t = (t,false)
    
        fun strip_leading_return args (t as @{mpat "ureturn$(?f)"}) = check_strip_leading args t f
            | strip_leading_return args (t as @{mpat "ureturn ?f"}) = check_strip_leading args t f
            | strip_leading_return _ t = (t,false)
    
    
        (* The following code strips the arguments of the concrete or abstract
          function. It knows how to handle APP-tags ($), and stops at PR_CONST-tags.
    
          Moreover, it only strips actual arguments that occur in the 
          precondition-section of the hn_refine-statement. This ensures
          that non-arguments, like maxsize, are treated correctly.
        *)    
        fun strip_fun _ (t as @{mpat "PR_CONST _"}) = (t,[])
          | strip_fun s (t as @{mpat "?f$?x"}) = check_arg s t f x
          | strip_fun s (t as @{mpat "?f ?x"}) = check_arg s t f x
          | strip_fun _ f = (f,[])
        and check_arg s t f x = 
            if Termtab.defined s x then
              strip_fun s f |> apsnd (curry op :: x)
            else (t,[])  
    
        (* Arguments in the pre/postcondition are wrapped into hn_ctxt tags. 
          This function strips them off. *)    
        fun dest_hn_ctxt @{mpat "hn_ctxt ?R ?a ?c"} = ((a,c),R)
          | dest_hn_ctxt _ = fail "Invalid hn_ctxt parameter in pre or postcondition"
    
    
        fun dest_hn_refine @{mpat "(hn_refine ?G ?c ?G' ?R ?a)"} = (G,c,G',R,a) 
          | dest_hn_refine _ = fail "Conclusion is not a hn_refine statement"
    
        (*
          Strip separation conjunctions. Special case for "emp", which is ignored. 
        *)  
        fun is_emp @{mpat emp} = true | is_emp _ = false
  
        val strip_star' = Sepref_Basic.strip_star #> filter (not o is_emp)
  
        (* Compare Termtab2s for equality of keys *)  
        fun pairs_eq pairs1 pairs2 = 
                  Termtab2.forall (Termtab2.defined pairs1 o fst) pairs2
          andalso Termtab2.forall (Termtab2.defined pairs2 o fst) pairs1
    
    
        fun atomize_prem @{mpat "Trueprop ?p"} = p
          | atomize_prem _ = fail "Non-atomic premises"
    
        (* Make HOL conjunction list *)  
        fun mk_conjs [] = @{const True}
          | mk_conjs [p] = p
          | mk_conjs (p::ps) = HOLogic.mk_binop @{const_name "HOL.conj"} (p,mk_conjs ps)
    
    
        (***********************)      
        (* Start actual analysis *)
    
        val _ = add_dbg "thm" [Thm.prop_of thm]
        val prems = Thm.prems_of thm
        val concl = Thm.concl_of thm |> HOLogic.dest_Trueprop
        val (G,c,G',R,a) = dest_hn_refine concl
    
        val pre_pairs = G 
          |> strip_star'
          |> tap (add_dbg "precondition")
          |> map dest_hn_ctxt
          |> Termtab2.make
    
        val post_pairs = G' 
          |> strip_star'
          |> tap (add_dbg "postcondition")
          |> map dest_hn_ctxt
          |> Termtab2.make
    
        val _ = assert (pairs_eq pre_pairs post_pairs) 
          "Parameters in precondition do not match postcondition"
    
        val aa_set = pre_pairs |> Termtab2.keys |> map fst |> Termtab.make_set
        val ca_set = pre_pairs |> Termtab2.keys |> map snd |> Termtab.make_set
    
        val (a,leading_RETURN) = strip_leading_RETURN aa_set a
        val (c,leading_return) = strip_leading_return ca_set c
    
        val _ = add_dbg "stripped abstract term" [a]
        val _ = add_dbg "stripped concrete term" [c]
    
        val (ahead,aargs) = strip_fun aa_set a;
        val (chead,cargs) = strip_fun ca_set c;
    
        val _ = add_dbg "abstract head" [ahead]
        val _ = add_dbg "abstract args" aargs
        val _ = add_dbg "concrete head" [chead]
        val _ = add_dbg "concrete args" cargs
    
    
        val _ = assert (length cargs = length aargs) "Different number of abstract and concrete arguments";
    
        val _ = assert (not (has_duplicates op aconv aargs)) "Duplicate abstract arguments"
        val _ = assert (not (has_duplicates op aconv cargs)) "Duplicate concrete arguments"
    
        val argpairs = aargs ~~ cargs
        val ap_set = Termtab2.make_set argpairs
        val _ = assert (pairs_eq pre_pairs ap_set) "Arguments from pre/postcondition do not match operation's arguments"
    
        val pre_rels = map (the o (Termtab2.lookup pre_pairs)) argpairs
        val post_rels = map (the o (Termtab2.lookup post_pairs)) argpairs
    
        val _ = add_dbg "pre-rels" pre_rels
        val _ = add_dbg "post-rels" post_rels

        fun adjust_hf_pres @{mpat "snd (?Rk)"} = R
          | adjust_hf_pres t = t
          
        val post_rels = map adjust_hf_pres post_rels
    
        fun is_invalid R @{mpat "invalid_assn ?R'"} = R aconv R'
          | is_invalid _ @{mpat "snd (_d)"} = true
          | is_invalid _ _ = false
    
        fun is_keep (R,R') =
          if R aconv R' then true
          else if is_invalid R R' then false
          else fail "Mismatch between pre and post relation for argument"
    
        val keep = map is_keep (pre_rels ~~ post_rels)
    
        val argrels = pre_rels ~~ keep

        val aa_set = Termtab.make_set aargs
        val ca_set = Termtab.make_set cargs

        fun is_precond t =
          (exists_subterm (Termtab.defined ca_set) t andalso fail "Premise contains concrete argument")
          orelse exists_subterm (Termtab.defined aa_set) t

        val (preconds, prems) = split is_precond prems  
    
        val precond = 
          map atomize_prem preconds 
          |> mk_conjs
          |> fold lambda aargs
    
        val _ = add_dbg "precond" [precond]
        val _ = add_dbg "prems" prems
    
      in
        {
          thm = thm,
          precond = precond,
          prems = prems,
          ahead = (ahead,leading_RETURN),
          chead = (chead,leading_return),
          argrels = argrels,
          result_rel = R
        }
      end  
    
      fun pretty_hnr_analysis 
        ctxt 
        ({thm,precond,ahead,chead,argrels,result_rel,...}) 
        : Pretty.T =
      let  
        val _ = thm (* Suppress unused warning for thm *)

        fun pretty_argrel (R,k) = Pretty.block [
          Syntax.pretty_term ctxt R,
          if k then Pretty.str "k" else Pretty.str "d"
        ]
    
        val pretty_chead = case chead of 
          (t,false) => Syntax.pretty_term ctxt t 
        | (t,true) => Pretty.block [Pretty.str "return ", Syntax.pretty_term ctxt t]

        val pretty_ahead = case ahead of 
          (t,false) => Syntax.pretty_term ctxt t 
        | (t,true) => Pretty.block [Pretty.str "RETURNT ", Syntax.pretty_term ctxt t]

      in
        Pretty.fbreaks [
          (*Display.pretty_thm ctxt thm,*)
          Pretty.block [ 
            Pretty.enclose "[" "]" [pretty_chead, pretty_ahead],
            Pretty.enclose "[" "]" [Syntax.pretty_term ctxt precond],
            Pretty.brk 1,
            Pretty.block (Pretty.separate " →" (map pretty_argrel argrels @ [Syntax.pretty_term ctxt result_rel]))
          ]
        ] |> Pretty.block
    
      end
    
    
      fun mk_hfref_thm 
        ctxt 
        ({thm,precond,prems,ahead,chead,argrels,result_rel}) = 
      let
    
        fun mk_keep (R,true) = @{mk_term "?Rk"}
          | mk_keep (R,false) = @{mk_term "?Rd"}
    
        (* TODO: Move, this is of general use! *)  
        fun mk_uncurry f = @{mk_term "uncurry ?f"}  
      
        (* Uncurry function for the given number of arguments. 
          For zero arguments, add a unit-parameter.
        *)
        fun rpt_uncurry n t =
          if n=0 then @{mk_term "uncurry0 ?t"}
          else if n=1 then t 
          else funpow (n-1) mk_uncurry t
      
        (* Rewrite uncurried lambda's to λ(_,_). _ form. Use top-down rewriting
          to correctly handle nesting to the left. 
    
          TODO: Combine with abstraction and  uncurry-procedure,
            and mark the deviation about uncurry as redundant 
            intermediate step to be eliminated.
        *)  
        fun rew_uncurry_lambda t = let
          val rr = map (Logic.dest_equals o Thm.prop_of) @{thms uncurry_def uncurry0_def}
          val thy = Proof_Context.theory_of ctxt
        in 
          Pattern.rewrite_term_top thy rr [] t 
        end  
    
        (* Shortcuts for simplification tactics *)
        fun gsimp_only ctxt sec = let
          val ss = put_simpset HOL_basic_ss ctxt |> sec
        in asm_full_simp_tac ss end
    
        fun simp_only ctxt thms = gsimp_only ctxt (fn ctxt => ctxt addsimps thms)
    
    
        (********************************)
        (* Build theorem statement *)
        (* ⟦prems⟧ ⟹ (chead,ahead) ∈ [precond] rels → R *)
    
        (* Uncurry precondition *)
        val num_args = length argrels
        val precond = precond
          |> rpt_uncurry num_args
          |> rew_uncurry_lambda (* Convert to nicer λ((...,_),_) - form*)

        (* Re-attach leading RETURN/return *)
        fun mk_RETURN (t,r) = if r then 
            let
              val T = funpow num_args range_type (fastype_of (fst ahead))
              val tRETURN = Const (@{const_name RETURNT}, T --> Type(@{type_name nrest},[T]))
            in
              Refine_Util.mk_compN num_args tRETURN t
            end  
          else t
    
        fun mk_return (t,r) = if r then 
            let
              val T = funpow num_args range_type (fastype_of (fst chead))
              val tRETURN = Const (@{const_name return}, T --> Type(@{type_name Heap},[T]))
            in
              Refine_Util.mk_compN num_args tRETURN t
            end  
          else t
          
        (* Hrmpf!: Gone for good from 2015→2016. Inserting ctxt-based substitute here. *)  
        fun certify_inst ctxt (instT, inst) =
         (map (apsnd (Thm.ctyp_of ctxt)) instT,
          map (apsnd (Thm.cterm_of ctxt)) inst);

        (*  
        fun mk_RETURN (t,r) = if r then @{mk_term "RETURN o ?t"} else t
        fun mk_return (t,r) = if r then @{mk_term "return o ?t"} else t
        *)
    
        (* Uncurry abstract and concrete function, append leading return *)
        val ahead = ahead |> mk_RETURN |> rpt_uncurry num_args  
        val chead = chead |> mk_return |> rpt_uncurry num_args 
    
        (* Add keep-flags and summarize argument relations to product *)
        val argrel = map mk_keep argrels |> rev (* TODO: Why this rev? *) |> mk_hfprods
    
        (* Produce final result statement *)
        val result = @{mk_term "Trueprop ((?chead,?ahead) ∈ [?precond]a ?argrel → ?result_rel)"}
        val result = Logic.list_implies (prems,result)
    
        (********************************)
        (* Prove theorem *)
    
        (* Create context and import result statement and original theorem *)
        val orig_ctxt = ctxt
        (*val thy = Proof_Context.theory_of ctxt*)
        val (insts, ctxt) = Variable.import_inst true [result] ctxt
        val insts' = certify_inst ctxt insts
        val result = Term_Subst.instantiate insts result
        val thm = Thm.instantiate insts' thm
    
        (* Unfold APP tags. This is required as some APP-tags have also been unfolded by analysis *)
        val thm = Local_Defs.unfold0 ctxt @{thms APP_def} thm
    
        (* Tactic to prove the theorem. 
          A first step uses hfrefI to get a hnr-goal.
          This is then normalized in several consecutive steps, which 
            get rid of uncurrying. Finally, the original theorem is used for resolution,
            where the pre- and postcondition, and result relation are connected with 
            a consequence rule, to handle unfolded hn_ctxt-tags, re-ordered relations,
            and introduced unit-parameters (TODO: 
              Mark artificially introduced unit-parameter specially, it may get confused 
              with intentional unit-parameter, e.g., functional empty_set ()!)
    
          *)
        fun tac ctxt = 
                resolve_tac ctxt @{thms hfrefI}
          THEN' gsimp_only ctxt (fn c => c 
            addsimps @{thms uncurry_def hn_ctxt_def uncurry0_def
                            keep_drop_sels uc_hfprod_sel o_apply
                            APP_def}
            |> Splitter.add_split @{thm prod.split}
          ) 
    
          THEN' TRY o (
            REPEAT_ALL_NEW (match_tac ctxt @{thms allI impI})
            THEN' simp_only ctxt @{thms Product_Type.split prod.inject})
    
          THEN' TRY o REPEAT_ALL_NEW (DETERM o ematch_tac ctxt @{thms conjE})
          THEN' TRY o hyp_subst_tac ctxt
          THEN' simp_only ctxt @{thms triv_forall_equality}
          THEN' (
            resolve_tac ctxt @{thms hn_refine_cons[rotated]} 
            THEN' (resolve_tac ctxt [thm] THEN_ALL_NEW assume_tac ctxt))
          THEN_ALL_NEW simp_only ctxt 
            @{thms hn_ctxt_def entt_refl pure_unit_rel_eq_empty
              mult_ac mult_1 mult_1_right keep_drop_sels}  
    
        (* Prove theorem *)  
        val result = Thm.cterm_of ctxt result
        val rthm = Goal.prove_internal ctxt [] result (fn _ => ALLGOALS (tac ctxt))
    
        (* Export statement to original context *)
        val rthm = singleton (Variable.export ctxt orig_ctxt) rthm
    
        (* Post-processing *)
        val rthm = Local_Defs.unfold0 ctxt (Named_Theorems.get ctxt @{named_theorems to_hfref_post}) rthm

      in
        rthm
      end
  
      fun to_hfref ctxt = analyze_hnr ctxt #> mk_hfref_thm ctxt




      (***********************************)
      (* Composition *)

      local
        fun norm_set_of ctxt = {
          trans_rules = Named_Theorems.get ctxt @{named_theorems fcomp_norm_trans},
          cong_rules = Named_Theorems.get ctxt @{named_theorems fcomp_norm_cong},
          norm_rules = Named_Theorems.get ctxt @{named_theorems fcomp_norm_norm},
          refl_rules = Named_Theorems.get ctxt @{named_theorems fcomp_norm_refl}
        }
    
        fun init_rules_of ctxt = Named_Theorems.get ctxt @{named_theorems fcomp_norm_init}
        fun unfold_rules_of ctxt = Named_Theorems.get ctxt @{named_theorems fcomp_norm_unfold}
        fun simp_rules_of ctxt = Named_Theorems.get ctxt @{named_theorems fcomp_norm_simps}

      in  
        fun norm_fcomp_rule ctxt = let
          open PO_Normalizer Refine_Util
          val norm1 = gen_norm_rule (init_rules_of ctxt) (norm_set_of ctxt) ctxt
          val norm2 = Local_Defs.unfold0 ctxt (unfold_rules_of ctxt)
          val norm3 = Conv.fconv_rule (
            Simplifier.asm_full_rewrite 
              (put_simpset HOL_basic_ss ctxt addsimps simp_rules_of ctxt))
    
          val norm = changed_rule (try_rule norm1 o try_rule norm2 o try_rule norm3)
        in
          repeat_rule norm
        end
      end  

      fun add_pure_constraints_rule ctxt thm = let
        val orig_ctxt = ctxt
    
        val t = Thm.prop_of thm
    
        fun 
          cnv (@{mpat (typs) "pure (mpaq_STRUCT (mpaq_Var ?x _) :: (?'v_c×?'v_a) set)"}) = 
          let
            val T = a --> c --> @{typ assn}
            val t = Var (x,T)
            val t = @{mk_term "(the_pure ?t)"}
          in
            [(x,T,t)]
          end
        | cnv (t$u) = union op= (cnv t) (cnv u)
        | cnv (Abs (_,_,t)) = cnv t  
        | cnv _ = []
    
        val pvars = cnv t
    
        val _ = (pvars |> map #1 |> has_duplicates op=) 
          andalso raise TERM ("Duplicate indexname with different type",[t]) (* This should not happen *)
    
        val substs = map (fn (x,_,t) => (x,t)) pvars
    
        val t' = subst_Vars substs t  
    
        fun mk_asm (x,T,_) = let
          val t = Var (x,T)
          val t = @{mk_term "Trueprop (CONSTRAINT is_pure ?t)"}
        in
          t
        end
    
        val assms = map mk_asm pvars
    
        fun add_prems prems t = let
          val prems' = Logic.strip_imp_prems t
          val concl = Logic.strip_imp_concl t
        in
          Logic.list_implies (prems@prems', concl)
        end
    
        val t' = add_prems assms t'
    
        val (t',ctxt) = yield_singleton (Variable.import_terms true) t' ctxt
    
        val thm' = Goal.prove_internal ctxt [] (Thm.cterm_of ctxt t') (fn _ => 
          ALLGOALS (resolve_tac ctxt [thm] THEN_ALL_NEW assume_tac ctxt))
    
        val thm' = norm_fcomp_rule ctxt thm'

        val thm' = singleton (Variable.export ctxt orig_ctxt) thm'
      in
        thm'
      end  


      val cfg_simp_precond = 
        Attrib.setup_config_bool @{binding fcomp_simp_precond} (K true)

      local
        fun mk_simp_thm ctxt t = let
          val st = t
            |> HOLogic.mk_Trueprop
            |> Thm.cterm_of ctxt
            |> Goal.init
      
          val ctxt = Context_Position.set_visible false ctxt  
          val ctxt = ctxt addsimps (
              refine_pw_simps.get ctxt 
            @ Named_Theorems.get ctxt @{named_theorems fcomp_prenorm_simps}
            @ @{thms split_tupled_all cnv_conj_to_meta}  
            )
          
          val trace_incomplete_transfer_tac =
            COND (Thm.prems_of #> exists (strip_all_body #> Logic.strip_imp_concl #> Term.is_open))
              (print_tac ctxt "Failed transfer from intermediate level:") all_tac
    
          val tac = 
            ALLGOALS (resolve_tac ctxt @{thms auto_weaken_pre_comp_PRE_I} )
            THEN ALLGOALS (Simplifier.asm_full_simp_tac ctxt)
            THEN trace_incomplete_transfer_tac
            THEN ALLGOALS (TRY o filter_prems_tac ctxt (K false))
            THEN Local_Defs.unfold0_tac ctxt [Drule.triv_forall_equality]
      
          val st' = tac st |> Seq.take 1 |> Seq.list_of
          val thm = case st' of [st'] => Goal.conclude st' | _ => raise THM("Simp_Precond: Simp-Tactic failed",~1,[st])
    
          (* Check generated premises for leftover intermediate stuff *)
          val _ = exists (Logic.is_all) (Thm.prems_of thm) 
            andalso raise THM("Simp_Precond: Transfer from intermediate level failed",~1,[thm])
    
          val thm = 
             thm
          (*|> map (Simplifier.asm_full_simplify ctxt)*)
          |> Conv.fconv_rule (Object_Logic.atomize ctxt)
          |> Local_Defs.unfold0 ctxt @{thms auto_weaken_pre_to_imp_nf}
    
          val thm = case Thm.concl_of thm of
            @{mpat "Trueprop (_ ⟶ _)"} => thm
          | @{mpat "Trueprop _"} => thm RS @{thm auto_weaken_pre_add_dummy_imp}  
          | _ => raise THM("Simp_Precond: Generated odd theorem, expected form 'P⟶Q'",~1,[thm])
    
    
        in
          thm
        end
      in  
        fun simplify_precond ctxt thm = let
          val orig_ctxt = ctxt
          val thm = Refine_Util.OF_fst @{thms auto_weaken_pre_init} [asm_rl,thm]
          val thm = 
            Local_Defs.unfold0 ctxt @{thms split_tupled_all} thm
            OF @{thms auto_weaken_pre_uncurry_start}
      
          fun rec_uncurry thm =
            case try (fn () => thm OF @{thms auto_weaken_pre_uncurry_step}) () of
              NONE => thm OF @{thms auto_weaken_pre_uncurry_finish}
            | SOME thm => rec_uncurry thm  
      
          val thm = rec_uncurry thm  
            |> Conv.fconv_rule Thm.eta_conversion
      
          val t = case Thm.prems_of thm of
            t::_ => t | _ => raise THM("Simp-Precond: Expected at least one premise",~1,[thm])
      
          val (t,ctxt) = yield_singleton (Variable.import_terms false) t ctxt
          val ((_,t),ctxt) = Variable.focus NONE t ctxt
          val t = case t of
            @{mpat "Trueprop (_ ⟶ ?t)"} => t | _ => raise TERM("Simp_Precond: Expected implication",[t])
      
          val simpthm = mk_simp_thm ctxt t  
            |> singleton (Variable.export ctxt orig_ctxt)
            
          val thm = thm OF [simpthm]  
          val thm = Local_Defs.unfold0 ctxt @{thms prod_casesK} thm
        in
          thm
        end

        fun simplify_precond_if_cfg ctxt =
          if Config.get ctxt cfg_simp_precond then
            simplify_precond ctxt
          else I

      end  

      (* fref O fref *)
      fun compose_ff ctxt A B = 
          (@{thm fref_compI_PRE} OF [A,B])
        |> norm_fcomp_rule ctxt
        |> simplify_precond_if_cfg ctxt
        |> Conv.fconv_rule Thm.eta_conversion

      (* hfref O fref *)
      fun compose_hf ctxt A B =
          (@{thm hfref_compI_PRE} OF [A,B])
        |> norm_fcomp_rule ctxt
        |> simplify_precond_if_cfg ctxt
        |> Conv.fconv_rule Thm.eta_conversion
        |> add_pure_constraints_rule ctxt
        |> Conv.fconv_rule Thm.eta_conversion

      fun ensure_fref ctxt thm = case rthm_type thm of
        RT_HOPARAM => to_fref ctxt thm
      | RT_FREF => thm
      | _ => raise THM("Expected parametricity or fref theorem",~1,[thm])

      fun ensure_fref_nres ctxt thm = let
        val thm = ensure_fref ctxt thm
      in
        case Thm.concl_of thm of
          @{mpat (typs) "Trueprop (_∈fref _ _ (_::(_ nrest×_)set))"} => thm
        | @{mpat "Trueprop ((_,_)∈fref _ _ _)"} => 
            (thm RS @{thm ensure_fref_nresI}) |> Local_Defs.unfold0 ctxt @{thms ensure_fref_nres_unfold}
        | _ => raise THM("Expected fref-theorem",~1,[thm])
      end

      fun ensure_hfref ctxt thm = case rthm_type thm of
        RT_HNR => to_hfref ctxt thm
      | RT_HFREF => thm
      | _ => raise THM("Expected hnr or hfref theorem",~1,[thm])

      fun ensure_hnr ctxt thm = case rthm_type thm of
        RT_HNR => thm
      | RT_HFREF => to_hnr ctxt thm
      | _ => raise THM("Expected hnr or hfref theorem",~1,[thm])

      fun gen_compose ctxt A B = let
        val rtA = rthm_type A
      in
        if rtA = RT_HOPARAM orelse rtA = RT_FREF then
          compose_ff ctxt (ensure_fref ctxt A) (ensure_fref ctxt B)
        else  
          compose_hf ctxt (ensure_hfref ctxt A) ((ensure_fref_nres ctxt B))
        
      end

      val parse_fcomp_flags = Refine_Util.parse_paren_lists 
        (Refine_Util.parse_bool_config "prenorm" cfg_simp_precond)

      val fcomp_attrib = parse_fcomp_flags |-- Attrib.thm >> (fn B => Thm.rule_attribute [] (fn context => fn A => 
      let
        val ctxt = Context.proof_of context
      in  
        gen_compose ctxt A B
      end))

    end
  ›

  attribute_setup to_fref = {*
    Scan.succeed (Thm.rule_attribute [] (Sepref_Rules.to_fref o Context.proof_of))
  *} "Convert parametricity theorem to uncurried fref-form" 

  attribute_setup to_foparam = {*
      Scan.succeed (Thm.rule_attribute [] (Sepref_Rules.to_foparam o Context.proof_of))
  *} ‹Convert param or fref rule to first order rule›
  (* Overloading existing param_fo - attribute from Parametricity.thy *)
  attribute_setup param_fo = {*
      Scan.succeed (Thm.rule_attribute [] (Sepref_Rules.to_foparam o Context.proof_of))
  *} ‹Convert param or fref rule to first order rule›

  attribute_setup to_hnr = {*
    Scan.succeed (Thm.rule_attribute [] (Sepref_Rules.to_hnr o Context.proof_of))
  *} "Convert hfref-rule to hnr-rule"
  
  attribute_setup to_hfref = ‹Scan.succeed (
      Thm.rule_attribute [] (Context.proof_of #> Sepref_Rules.to_hfref)
    )› ‹Convert hnr to hfref theorem›


  attribute_setup ensure_fref_nres = ‹Scan.succeed (
      Thm.rule_attribute [] (Context.proof_of #> Sepref_Rules.ensure_fref_nres)
    )›

  attribute_setup sepref_dbg_norm_fcomp_rule = ‹Scan.succeed (
      Thm.rule_attribute [] (Context.proof_of #> Sepref_Rules.norm_fcomp_rule)
    )›

  attribute_setup sepref_simplify_precond = ‹Scan.succeed (
      Thm.rule_attribute [] (Context.proof_of #> Sepref_Rules.simplify_precond)
    )› ‹Simplify precondition of fref/hfref-theorem›

  attribute_setup FCOMP = Sepref_Rules.fcomp_attrib "Composition of refinement rules"

end