Theory Sepref_Translate

theory Sepref_Translate
imports Sepref_Lin_Ana DF_Solver Sepref_Frame Pf_Mono_Prover
header ‹Translation›
theory Sepref_Translate
imports Sepref_Lin_Ana DF_Solver Sepref_Frame Pf_Mono_Prover
begin

(* TODO: Move *)
lemma bind_ASSERT_eq_if: "do { ASSERT Φ; m } = (if Φ then m else FAIL)"
  by auto


text ‹
  This theory defines the translation phase.
  
  The main functionality of the translation phase is to
  apply refinement rules. Thereby, the linearity information is
  exploited to create copies of parameters that are still required, but
  would be destroyed by a synthesized operation.
  These \emph{frame-based} rules are in the named theorem collection
  @{text sepref_fr_rules}, and the collection @{text sepref_copy_rules}
  contains rules to handle copying of parameters.

  Apart from the frame-based rules described above, there is also a set of
  rules for combinators, in the collection @{text sepref_comb_rules}, 
  where no automatic copying of parameters is applied.

  Moreover, this theory contains 
  \begin{itemize}
    \item A setup for the  basic monad combinators and recursion.
    \item A tool to import parametricity theorems.
    \item Some setup to identify pure refinement relations, i.e., those not
      involving the heap.
    \item A preprocessor that identifies parameters in refinement goals,
      and flags them with a special tag, that allows their correct handling.
  \end{itemize}
›

subsection ‹Basic Translation Tool›  
definition COPY -- "Copy operation"
   where [simp]: "COPY ≡ RETURN" 

lemma tagged_nres_monad1: "bind$(RETURN$x)$(λ2x. f x) = f x" by simp

text ‹The PREPARED-tag is used internally, to flag a refinement goal
  with the index of the refinement rule to be used›
definition PREPARED_TAG :: "'a => nat => 'a"
  where [simp]: "PREPARED_TAG x i == x"
lemma PREPARED_TAG_I: 
  "hn_refine Γ c Γ' R a ==> hn_refine Γ c Γ' R (PREPARED_TAG a i)"
  by simp

lemmas prepare_refine_simps = tagged_nres_monad1 COPY_def LIN_ANNOT_def 
  PREPARED_TAG_def


ML {*
structure Sepref_Translate = struct

  structure sepref_fr_rules = Named_Thms (
    val name = @{binding "sepref_fr_rules"}
    val description = "Sepref: Frame-based rules"
  )

  structure sepref_comb_rules = Named_Thms (
    val name = @{binding "sepref_comb_rules"}
    val description = "Sepref: Combinator rules"
  )

  structure sepref_copy_rules = Named_Thms (
    val name = @{binding "sepref_copy_rules"}
    val description = "Sepref: Copy rules"
  )


  local 
    open Autoref_Tagging Sepref_Basic
    fun dest_arg (Var (x,_)) = x
      | dest_arg trm = raise TERM("Argument must be variable",[trm])

    fun is_valid_head (Const _) = true
      | is_valid_head (Free _) = true
      | is_valid_head @{mpat "PR_CONST _"} = true
      | is_valid_head _ = false

    fun 
      dest_opr @{mpat "RETURN$?f"} = dest_opr f
    | dest_opr t = let
        val (f,args) = strip_app t
        val _ = is_valid_head f orelse 
          raise TERM("get_args: Expected constant head",[t])
  
      in (f,map dest_arg args) end
   
    fun valid_pair @{mpat "hn_invalid (mpaq_STRUCT(mpaq_Var ?x _)) _"} = (x,NONE)
      | valid_pair @{mpat "hn_ctxt ?R (mpaq_STRUCT(mpaq_Var ?x _)) _"} = (x,SOME R)
      | valid_pair t = raise TERM("Invalid assertion in heap",[t])
    
    fun is_emp @{mpat emp} = true | is_emp _ = false
  
  in
    (* Given a frame theorem, return the constant 
      and a list of refinement relations for the arguments, NONE if the
      argument is not preserved *)
    fun 
      analyze_args thm = (
        case concl_of thm of
          @{mpat "Trueprop (hn_refine ?G _ ?G' _ ?a)"} => let
            val in_args = (
                 strip_star G 
              |> filter (not o is_emp)
              |> map valid_pair
              |> Vartab.make
            ) handle
                Vartab.DUP _ => raise THM ("analyze_args: Dup in-args",~1,[thm])
      
            val out_args = (
                 strip_star G'
              |> filter (not o is_emp)
              |> map valid_pair
              |> Vartab.make
            ) handle
                Vartab.DUP _ => raise THM ("analyze_args: Dup out-args",~1,[thm])

            val (f,formal_args) = dest_opr a 
      
            (* Check that no parameters are dropped or invented *)
            val _ = Vartab.forall (fn (x,_) => Vartab.defined out_args x) in_args
              orelse raise THM ("analyze_args: Dropped parameters",~1,[thm])
            val _ = Vartab.forall (fn (x,_) => Vartab.defined in_args x) out_args
              orelse raise THM ("analyze_args: Invanted parameters",~1,[thm])

            (* Check that precisely the parameters are present *)
            val _ = forall (fn x => Vartab.defined in_args x) formal_args
              orelse raise THM ("analyze_args: Missing parameters",~1,[thm])
            val _ = Vartab.forall (fn (x,_) => member op= formal_args x) in_args
              orelse raise THM ("analyze_args: Extra parameters",~1,[thm])
                  
            val preserved = map (the o Vartab.lookup out_args) formal_args
      
          in
            (f,preserved)
          end
        | _ => raise THM("Invalid hn_refine theorem",~1,[thm])
      )
  
  end

  local
    open Autoref_Tagging Sepref_Basic 

    fun is_valid_head (Const _) = true
      | is_valid_head (Free _) = true
      | is_valid_head @{mpat "PR_CONST _"} = true
      | is_valid_head _ = false

    fun dest_arg @{mpat "(?x ASs (mpaq_Free _ _))L"} = (x,true)
      | dest_arg @{mpat "(?x ASs (mpaq_Free _ _))N"} = (x,false)
      | dest_arg t = raise TERM("Malformed argument",[t])

    fun 
      dest_opr (t as @{mpat "RETURN$_"}) = let
        val (f,c) = dest_APP_argc t
        val (res,c') = dest_opr f
      in
        (res,c o c')
      end
    | dest_opr t = let
        val ((f,args),c) = strip_APPc t
        val _ = is_valid_head f orelse
          raise TERM("get_args: Expected constant head",[t])
      in ((f,map dest_arg args),c) end


  in
    (* Analyze the arguments in the actual refinement goal.
      Return the head constant, a list of 
        argument × linear × refinement-relation option
      and a function to reconstruct the term.
    *)
    fun 
      analyze_actual_args @{mpat "hn_refine ?G _ _ _ ?a"} = let
        val ((f,args),c) = dest_opr a

        fun dest_hn_ctxt @{mpat "hn_ctxt ?R ?a _"} = SOME (a,R)
          | dest_hn_ctxt _ = NONE

        val on_heap = 
          strip_star G
          |> map_filter dest_hn_ctxt
          |> Termtab.make

        val args = map (fn (x,l) => (x,l,Termtab.lookup on_heap x)) args
      in 
        ((f,args),c)
      end
    | analyze_actual_args t = raise TERM("No hn-refine subgoal",[t])
  end

  local
    open Conv

    fun lambda2_name n t = Term.lambda_name n @{mk_term "PROTECT2 ?t DUMMY"}

    fun prepare_refine thy ((f,largs),mk_a) (i,(f',is_pres)) = let
      val _ = Pattern.matches thy (f',f) orelse raise TERM("No match",[f',f])

      (* Quick frame check *)
      fun check_rel (_,_,SOME R) (SOME R') = Term.could_unify (R,R')
        | check_rel _ NONE = true (* Strange case, should not happen! *)
        | check_rel _ _  = false

      val _ = forall2 check_rel largs is_pres 
        orelse raise TERM("No frame match",[f',f])

      val args_nc
        = map2 (fn (x,l,_) => fn r => (x,not l andalso is_none r)) 
          largs is_pres

      (* Aliasing check: Avoid aliasing due to duplicate arguments *)
      fun dups [] = ([],Termtab.empty)
        | dups ((x,_,_)::l) = let
            val (l,tab) = dups l
            val d = Termtab.defined tab x
            val tab = Termtab.update (x,()) tab
          in
            (d::l, tab)
          end
  
      val is_dup = #1 (dups largs)
  
      val args_nc = map2 (fn (x,nc) => fn d => (x,nc orelse d)) args_nc is_dup

      fun prep ((x,c)::l) args i = 
        if c then let
          val name = "v" ^ string_of_int i
          val fv = Free ("__prep__"^name,fastype_of x)
          val r = prep l (fv::args) (i+1)
            |> lambda2_name (name,fv)
        in 
          @{mk_term "bind$(COPY$?x)$?r"}
        end
        else 
          prep l (x::args) (i+1)
      | prep [] args _ = let
          val i = HOLogic.mk_number @{typ nat} i
          val a = mk_a (f,rev args)
        in @{mk_term "PREPARED_TAG ?a ?i"} end

    in
      prep args_nc [] 1
    end
  in
    (* Try to prepare refinement with the specified index×theorem pair *)
    fun prepare_refine_conv (i,thm) ctxt ct = let
      val thy = Proof_Context.theory_of ctxt
      val aa = analyze_actual_args (term_of ct)
      val ta = analyze_args thm
      val a' = prepare_refine thy aa (i,ta)
      val tac = 
        ALLGOALS (simp_tac 
          (put_simpset HOL_basic_ss ctxt addsimps @{thms prepare_refine_simps}))
      open Sepref_Basic    
    in
      hn_refine_conv_a (Refine_Util.f_tac_conv ctxt (K a') tac) ct
    end
  
  end  

  (* Refine with the specified theorem *)
  fun try_refine_tac (i,thm) ctxt =
    CONVERSION Thm.eta_conversion THEN'
    CONVERSION (Refine_Util.HOL_concl_conv (prepare_refine_conv (i,thm)) ctxt)
    THEN' 
    simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms LIN_ANNOT_def}) 
    (*THEN'
    rtac @{thm hn_refine_frame} THEN'
    rtac thm THEN'
    SOLVED' (DETERM o Sepref_Frame.frame_tac ctxt)*)

  (* Refine with all matching theorems, allow backtracking *)  
  fun refine_fr_tac ctxt = let
    val fr_rules = sepref_fr_rules.get ctxt
      |> tag_list 0
    val tacs = map (fn ixthm => try_refine_tac ixthm ctxt) fr_rules
  in
    APPEND_LIST' tacs
  end


  local
    (* Combine rule with frame-thm,
      move frame-premise before first non-PREFER premise *)
    fun prep_fr_rule thm = let
      val thm = thm RS @{thm hn_refine_frame}
      val prems = prems_of thm
      val fix = find_index (fn @{mpat "Trueprop (PREFER_tag _)"} => false | _ => true) prems
    in
      Thm.permute_prems fix ~1 thm
    end

    (* Discharge premises *)
    fun discharge_rprem ctxt = 
    DF_Solver.is_constraint_tac  (* Keep constraints *)
    ORELSE' SOLVED' (
    FIRST' [
      resolve_tac @{thms PREFER_tagI DEFER_tagI}
      THEN' CONVERSION (Refine_Util.HOL_concl_conv (K (Id_Op.unprotect_conv ctxt)) ctxt)
      THEN' Tagged_Solver.solve_tac ctxt
    ,
      DETERM o Sepref_Frame.frame_tac ctxt
    ,
      CONVERSION (Refine_Util.HOL_concl_conv (K (Id_Op.unprotect_conv ctxt)) ctxt)
      THEN' Tagged_Solver.solve_tac ctxt
    ,
      DETERM o Indep_Vars.indep_tac
    ])

  in    
    (* Solve prepared frame *)  
    fun prepared_tac ctxt i st = 
      case Logic.concl_of_goal (prop_of st) i of
        @{mpat "Trueprop (hn_refine _ _ _ _ (PREPARED_TAG _ ?n))"} => let
          val n = #2 (HOLogic.dest_number n)
          val thm = nth (sepref_fr_rules.get ctxt) n |> prep_fr_rule
        in 
          rtac @{thm PREPARED_TAG_I} THEN'
          (rtac thm THEN_ALL_NEW_FWD (discharge_rprem ctxt))
        end i st
      | _ => Seq.empty

  end  

  (* Single translation step *)  
  fun trans_step_tac ctxt = let
    val combinator_rules = sepref_comb_rules.get ctxt

    val rcall_tac = 
      rtac @{thm hn_refine_frame[where m="RCALL$D$(xL)" for D x]}
      THEN' DETERM o rprems_tac ctxt
      THEN' SOLVED' (DETERM o Sepref_Frame.frame_tac ctxt)

    val copy_tac = 
      resolve_tac (sepref_copy_rules.get ctxt)
      THEN' SOLVED' (DETERM o Sepref_Frame.frame_tac ctxt)

    val misc_tac = FIRST' [
      SOLVED' (DETERM o Sepref_Frame.frame_tac ctxt),
      SOLVED' (DETERM o Indep_Vars.indep_tac),
      SOLVED' (DETERM o Sepref_Frame.merge_tac ctxt),
      SOLVED' (Pf_Mono_Prover.mono_tac ctxt),
      SOLVED' (
        CONVERSION (Refine_Util.HOL_concl_conv (K (Id_Op.unprotect_conv ctxt)) ctxt)
        THEN' (TRY o resolve_tac @{thms PREFER_tagI DEFER_tagI})
        THEN' Tagged_Solver.solve_tac ctxt)
    ]

    val fr_tac = refine_fr_tac ctxt
  in
    (*(K (print_tac "trans_step_tac")) THEN'*)
    (
    FIRST' [ 
      resolve_tac combinator_rules,
      rcall_tac,
      copy_tac,
      prepared_tac ctxt,
      fr_tac,
      misc_tac
    ] (*THEN' (K (print_tac "yields"))*)
    (*ORELSE' (K (print_tac "FAILED"))*)
    )
  end

  (* Do translation *)
  fun trans_tac ctxt = 
    DF_Solver.DF_SOLVE_FWD_C true (trans_step_tac ctxt) ctxt
    THEN_ALL_NEW (TRY o rtac @{thm SOLVED_I})

  (* Single translation step *)  
  fun cstep_tac ctxt = IF_EXGOAL (
    DF_Solver.defer_constraints (trans_step_tac ctxt)
    THEN' K (DF_Solver.check_constraints_tac ctxt)
    THEN' (CONVERSION Thm.eta_conversion))

  val setup = I
    #> sepref_fr_rules.setup
    #> sepref_comb_rules.setup
    #> sepref_copy_rules.setup

end
*}

setup Sepref_Translate.setup

subsubsection ‹Basic Setup›
lemma hn_pass[sepref_fr_rules]:
  shows "hn_refine (hn_ctxt P x x') (return x') (hn_invalid x x') P (PASS$x)"
  by rule (sep_auto simp: hn_ctxt_def)

lemma hn_bind[sepref_comb_rules]:
  assumes D1: "hn_refine Γ m' Γ1 Rh m"
  assumes D2: 
    "!!x x'. hn_refine (Γ1 * hn_ctxt Rh x x') (f' x') (Γ2 x x') R (f x)"
  assumes IMP: "!!x x'. Γ2 x x' ==>A Γ' * hn_ctxt Rx x x'"
  shows "hn_refine Γ (m'»=f') Γ' R (bind$m$(λ2x. f x))"
  using assms
  unfolding APP_def PROTECT2_def
  by (rule hnr_bind)

lemma hn_RECT'[sepref_comb_rules]:
  assumes "INDEP Ry" "INDEP Rx" "INDEP Rx'"
  assumes FR: "P ==>A hn_ctxt Rx ax px * F"
  assumes S: "!!cf af ax px. [|
    !!ax px. hn_refine (hn_ctxt Rx ax px * F) (cf px) (hn_ctxt Rx' ax px * F) Ry 
      (RCALL$af$(axL))|] 
    ==> hn_refine (hn_ctxt Rx ax px * F) (cB cf px) (F' ax px) Ry 
          (aB af ax)"
  assumes FR': "!!ax px. F' ax px ==>A hn_ctxt Rx' ax px * F"
  assumes M: "(!!x. mono_Heap (λf. cB f x))"
  assumes PREC[unfolded CONSTRAINT_def]: "CONSTRAINT precise Ry"
  shows "hn_refine 
    (P) (heap.fixp_fun cB px) (hn_ctxt Rx' ax px * F) Ry 
        (RECT$(λ2D x. aB D x)$(axL))"
  unfolding APP_def PROTECT2_def LIN_ANNOT_def
  apply (rule hn_refine_cons_pre[OF FR])
  apply (rule hnr_RECT)

  apply (rule hn_refine_cons_post[OF _ FR'])
  apply (rule S[unfolded RCALL_def APP_def LIN_ANNOT_def])
  apply assumption
  apply fact+
  done

lemma hn_RECT_nl[sepref_comb_rules]: 
  assumes "hn_refine 
    (P) t (hn_ctxt Rx' ax px * F) Ry 
        (bind$(COPY$ax)$(λ2ax. RECT$(λ2D x. aB D x)$(axL)))"
  shows "hn_refine 
    (P) t (hn_ctxt Rx' ax px * F) Ry 
        (RECT$(λ2D x. aB D x)$(axN))"
  using assms by simp

lemma hn_RCALL_nl[sepref_comb_rules]:
  assumes "hn_refine Γ c Γ' R (bind$(COPY$x)$(λ2x. RCALL$D$(xL)))"
  shows "hn_refine Γ c Γ' R (RCALL$D$(xN))"
  using assms by simp

definition "monadic_WHILEIT I b f s ≡ do {
  RECT (λD s. do {
    ASSERT (I s);
    bv \<leftarrow> b s;
    if bv then do {
      s \<leftarrow> f s;
      D s
    } else do {RETURN s}
  }) s
}"

definition "heap_WHILET b f s ≡ do {
  heap.fixp_fun (λD s. do {
    bv \<leftarrow> b s;
    if bv then do {
      s \<leftarrow> f s;
      D s
    } else do {return s}
  }) s
}"

lemma heap_WHILET_unfold[code]: "heap_WHILET b f s = 
  do {
    bv \<leftarrow> b s;
    if bv then do {
      s \<leftarrow> f s;
      heap_WHILET b f s
    } else
      return s
  }"
  unfolding heap_WHILET_def
  apply (subst heap.mono_body_fixp)
  apply (tactic {* Pf_Mono_Prover.mono_tac @{context} 1 *})
  apply simp
  done



lemma WHILEIT_to_monadic: "WHILEIT I b f s = monadic_WHILEIT I (λs. RETURN (b s)) f s"
  unfolding WHILEIT_def monadic_WHILEIT_def
  unfolding WHILEI_body_def bind_ASSERT_eq_if
  by (simp cong: if_cong)

lemma WHILEIT_pat[def_pat_rules]:
  "WHILEIT$I ≡ UNPROTECT (WHILEIT I)"
  "WHILET ≡ PR_CONST (WHILEIT (λ_. True))"
  by (simp_all add: WHILET_def)

lemma id_WHILEIT[id_rules]: 
  "PR_CONST (WHILEIT I) ::i TYPE(('a => bool) => ('a => 'a nres) => 'a => 'a nres)"
  by simp

lemma WHILE_arities[sepref_monadify_arity]:
  (*"WHILET ≡ WHILEIT$(λ2_. True)"*)
  "PR_CONST (WHILEIT I) ≡ λ2b f s. SP (PR_CONST (WHILEIT I))$(λ2s. b$s)$(λ2s. f$s)$s"
  by (simp_all add: WHILET_def)

lemma WHILEIT_comb[sepref_monadify_comb]:
  "PR_CONST (WHILEIT I)$(λ2x. b x)$f$s ≡ 
    bind$(EVAL$s)$(λ2s. 
      SP (PR_CONST (monadic_WHILEIT I))$(λ2x. (EVAL$(b x)))$f$s
    )"
  by (simp_all add: WHILEIT_to_monadic)

lemma [sepref_la_skel]: "SKEL (PR_CONST (monadic_WHILEIT I)$b$f$x) ≡ 
  la_seq 
    (la_op x) 
    (la_rec (λD. la_seq (SKEL b) (la_seq (SKEL f) (la_rcall D))))"
  by simp

(* TODO: Integrate into merge? *)
lemma merge4:
  "[|Fl ∨A Fr ==>A F|] ==> Fl * hn_val R x x' ∨A Fr ==>A F"
  "[|Fl ∨A Fr ==>A F|] ==> Fl ∨A Fr * hn_val R x x' ==>A F"
  apply (rule ent_disjE)
  apply (drule ent_disjI1)
  apply (sep_auto simp: hn_ctxt_def pure_def)
  apply (erule ent_disjI2)

  apply (rule ent_disjE)
  apply (erule ent_disjI1)
  apply (drule ent_disjI2)
  apply (sep_auto simp: hn_ctxt_def pure_def)
  done
  

lemma hn_monadic_WHILE_aux:
  assumes FR: "P ==>A Γ * hn_ctxt Rs s' s"
  assumes b_ref: "!!s s'. hn_refine 
    (Γ * hn_ctxt Rs s' s)
    (b s)
    (Γb s' s)
    (pure bool_rel)
    (b' s')"
  assumes b_fr: "!!s' s. Γb s' s ==>A Γ * hn_ctxt Rs s' s"

  assumes f_ref: "!!s' s. hn_refine
    (Γ * hn_ctxt Rs s' s)
    (f s)
    (Γf s' s)
    Rs
    (f' s')"
  assumes f_fr: "!!s' s. Γf s' s ==>A Γ * hn_invalid s' s"
  assumes PREC: "precise Rs"
  shows "hn_refine (P) (heap_WHILET b f s) (Γ * hn_invalid s' s) Rs (monadic_WHILEIT I b' f' s')"
  unfolding monadic_WHILEIT_def heap_WHILET_def
  apply (rule hn_refine_cons_pre[OF FR])
  apply (rule hn_refine_cons_pre[OF _ hnr_RECT])
  apply (subst mult_ac(2)[of Γ]) apply (rule ent_refl)
  apply (rule hnr_ASSERT)
  apply (rule hnr_bind)
  apply (rule hn_refine_cons[OF _ b_ref b_fr])
  apply sep_auto []
  apply (rule hnr_If)
  apply sep_auto []
  apply (rule hnr_bind)
  apply (rule hn_refine_cons[OF _ f_ref f_fr])
  apply (sep_auto simp: hn_ctxt_def pure_def) []
  apply (rule hn_refine_frame)
  apply rprems
  apply (tactic {* Sepref_Frame.frame_tac @{context} 1*})
  apply sep_auto []
  apply (rule hn_refine_frame)
  apply (rule hnr_RETURN_pass)
  apply (tactic {* Sepref_Frame.frame_tac @{context} 1*})
  apply (tactic {* Sepref_Frame.merge_tac @{context} 1*})
  apply (rule hn_merge merge4)
  apply sep_auto []
  apply (rule fr_invalidate)
  apply simp
  apply (tactic {* Pf_Mono_Prover.mono_tac @{context} 1 *})
  apply (rule PREC)
  done

lemma hn_monadic_WHILE_lin[sepref_comb_rules]:
  assumes "INDEP Rs"
  assumes FR: "P ==>A Γ * hn_ctxt Rs s' s"
  assumes b_ref: "!!s s'. hn_refine 
    (Γ * hn_ctxt Rs s' s)
    (b s)
    (Γb s' s)
    (pure bool_rel)
    (b' s')"
  assumes b_fr: "!!s' s. TERM (monadic_WHILEIT,''cond'') ==> Γb s' s ==>A Γ * hn_ctxt Rs s' s"

  assumes f_ref: "!!s' s. hn_refine
    (Γ * hn_ctxt Rs s' s)
    (f s)
    (Γf s' s)
    Rs
    (f' s')"
  assumes f_fr: "!!s' s. TERM (monadic_WHILEIT,''body'') ==> Γf s' s ==>A Γ * hn_invalid s' s"
  assumes "CONSTRAINT precise Rs"
  shows "hn_refine 
    P 
    (heap_WHILET b f s) 
    (Γ * hn_invalid s' s) 
    Rs 
    (PR_CONST (monadic_WHILEIT I)$(λ2s'. b' s')$(λ2s'. f' s')$(s'L))"
  using assms(2-)
  unfolding APP_def PROTECT2_def LIN_ANNOT_def CONSTRAINT_def PR_CONST_def
  by (rule hn_monadic_WHILE_aux)

definition [simp]: "op_ASSERT_bind I m ≡ bind (ASSERT I) (λ_. m)"
lemma pat_ASSERT_bind[def_pat_rules]:
  "bind$(ASSERT$I)$(λ2_. m) ≡ UNPROTECT (op_ASSERT_bind I)$m"
  by simp

term "PR_CONST (op_ASSERT_bind I)"
lemma id_op_ASSERT_bind[id_rules]: 
  "PR_CONST (op_ASSERT_bind I) ::i TYPE('a nres => 'a nres)"
  by simp

lemma arity_ASSERT_bind[sepref_monadify_arity]:
  "PR_CONST (op_ASSERT_bind I) ≡ λ2m. SP (PR_CONST (op_ASSERT_bind I))$m"
  apply (rule eq_reflection)
  by auto

lemma skel_ASSERT_bind[sepref_la_skel]: 
  "SKEL (PR_CONST (op_ASSERT_bind I)$m) = SKEL m"
  by simp

lemma hn_ASSERT_bind[sepref_comb_rules]: 
  assumes "I ==> hn_refine Γ c Γ' R m"
  shows "hn_refine Γ c Γ' R (PR_CONST (op_ASSERT_bind I)$m)"
  using assms
  apply (cases I)
  apply auto
  done

subsection "Import of Parametricity Theorems"
lemma pure_hn_refineI:
  assumes "Q --> (c,a)∈R"
  shows "hn_refine (\<up>Q) (return c) (\<up>Q) (pure R) (RETURN a)"
  unfolding hn_refine_def using assms
  by (sep_auto simp: pure_def)

lemma pure_hn_refineI_no_asm:
  assumes "(c,a)∈R"
  shows "hn_refine emp (return c) emp (pure R) (RETURN a)"
  unfolding hn_refine_def using assms
  by (sep_auto simp: pure_def)

lemma import_param_1: 
  "(P==>Q) ≡ Trueprop (P-->Q)"
  "(P-->Q-->R) <-> (P∧Q --> R)"
  "(a,c)∈Rel ∧ PREFER_tag P <-> PREFER_tag P ∧ (a,c)∈Rel"
  apply (rule, simp+)+
  done

lemma import_param_2:
  "Trueprop (PREFER_tag P ∧ Q --> R) ≡ (PREFER_tag P ==> Q-->R)"
  "Trueprop (DEFER_tag P ∧ Q --> R) ≡ (DEFER_tag P ==> Q-->R)"
  apply (rule, simp+)+
  done

lemma import_param_3:
  "\<up>(P ∧ Q) = \<up>P*\<up>Q"
  "\<up>((c,a)∈R) = hn_val R a c"
  by (simp_all add: hn_ctxt_def pure_def)

ML {*
structure Sepref_Import_Param = struct

  structure sepref_import_rewrite = Named_Thms (
    val name = @{binding "sepref_import_rewrite"}
    val description = "Rewrite rules on importing parametricity theorems"
  )

  fun import ctxt thm = let
    open Sepref_Basic
    val thm = Parametricity.fo_rule thm
      |> Local_Defs.unfold ctxt @{thms import_param_1}
      |> Local_Defs.unfold ctxt @{thms import_param_2}

    val thm = case concl_of thm of
      @{mpat "Trueprop (_-->_)"} => thm RS @{thm pure_hn_refineI}
    | _ => thm RS @{thm pure_hn_refineI_no_asm}

    val thm = Local_Defs.unfold ctxt @{thms import_param_3} thm
      |> Conv.fconv_rule (hn_refine_concl_conv_a (K (Id_Op.protect_conv ctxt)) ctxt)

    val thm = Local_Defs.unfold ctxt (sepref_import_rewrite.get ctxt) thm
  in
    thm
  end

  val import_attr = Scan.succeed (Thm.mixed_attribute (fn (context,thm) =>
    let
      val thm = import (Context.proof_of context) thm
      val context = Sepref_Translate.sepref_fr_rules.add_thm thm context
    in (context,thm) end
  ))

  val setup = I
    #> sepref_import_rewrite.setup 
    #> Attrib.setup @{binding sepref_import_param} import_attr
        "Sepref: Import parametricity rule"

end
*}

setup Sepref_Import_Param.setup


subsection "Purity"
definition "is_pure P ≡ ∃P'. ∀x x'. P x x'=\<up>(P' x x')"
lemma is_pureI[intro?]: 
  assumes "!!x x'. P x x' = \<up>(P' x x')"
  shows "is_pure P"
  using assms unfolding is_pure_def by blast

lemma is_pureE:
  assumes "is_pure P"
  obtains P' where "!!x x'. P x x' = \<up>(P' x x')"
  using assms unfolding is_pure_def by blast

lemma pure_pure[constraint_rules, simp]: "is_pure (pure P)"
  unfolding pure_def by rule blast
lemma pure_hn_ctxt[constraint_rules, intro!]: "is_pure P ==> is_pure (hn_ctxt P)"
  unfolding hn_ctxt_def[abs_def] .


definition "the_pure P ≡ THE P'. ∀x x'. P x x'=\<up>((x',x)∈P')"

(* TODO: Move *)
lemma assn_basic_inequalities[simp, intro!]:
  "true ≠ emp" "emp ≠ true"
  "false ≠ emp" "emp ≠ false"
  "true ≠ false" "false ≠ true"
proof -
  def neh  "((| arrays = undefined, refs=undefined, lim = 1 |)),, {0::nat})"
  have [simp]: "in_range neh" unfolding neh_def 
    by (simp add: in_range.simps)

  have "neh \<Turnstile> true" by simp
  moreover have "¬(neh \<Turnstile> false)" by simp
  moreover have "¬(neh \<Turnstile> emp)" by (simp add: mod_emp neh_def)
  moreover have "h \<Turnstile> emp" by simp
  moreover have "¬(h \<Turnstile> false)" by simp
  ultimately show 
    "true ≠ emp" "emp ≠ true"
    "false ≠ emp" "emp ≠ false"
    "true ≠ false" "false ≠ true"
    by metis+
qed

lemma pure_assn_eq_conv[simp]: "\<up>P = \<up>Q <-> P=Q"
  apply (cases P, simp_all)
  apply (cases Q, simp_all)
  apply (cases Q, simp_all)
  done


lemma the_pure_pure[simp]: "the_pure (pure R) = R"
  unfolding pure_def the_pure_def
  by (rule theI2[where a=R]) auto

lemma is_pure_alt_def: "is_pure R <-> (∃Ri. ∀x y. R x y = \<up>((y,x)∈Ri))"
  unfolding is_pure_def
  apply auto
  apply (rename_tac P')
  apply (rule_tac x="{(x,y). P' y x}" in exI)
  apply auto
  done

lemma pure_the_pure[simp]: "is_pure R ==> pure (the_pure R) = R"
  unfolding is_pure_alt_def pure_def the_pure_def
  apply (intro ext)
  apply clarsimp
  apply (rename_tac a c Ri)
  apply (rule_tac a=Ri in theI2)
  apply auto
  done
  
definition "import_rel1 R ≡ λA c ci. \<up>(is_pure A ∧ (ci,c)∈⟨the_pure A⟩R)"
definition "import_rel2 R ≡ λA B c ci. \<up>(is_pure A ∧ is_pure B ∧ (ci,c)∈⟨the_pure A, the_pure B⟩R)"
  
lemma import_rel1_pure_conv: "import_rel1 R (pure A) = pure (⟨A⟩R)"
  unfolding import_rel1_def
  apply simp
  apply (simp add: pure_def)
  done

lemma import_rel2_pure_conv: "import_rel2 R (pure A) (pure B) = pure (⟨A,B⟩R)"
  unfolding import_rel2_def
  apply simp
  apply (simp add: pure_def)
  done

lemma hn_pure_copy_complete: 
  assumes F: "Γ ==>A F * hn_ctxt P x x'"
  assumes P: "CONSTRAINT is_pure P"
  shows "hn_refine Γ (return x') (F * hn_ctxt P x x') (λxc xc'. P xc xc' * \<up>(xc = x)) 
    (COPY$x)"
  apply (rule is_pureE[OF P[simplified]])
  apply (rule hn_refine_frame[OF _ F])
  apply rule
  apply (sep_auto simp: hn_ctxt_def)
  done

(* TODO: We need some measures not to decouple
  variables from their assertions on copying.

  One possibility would be to add the extra equality 
  information to the copy-rule somehow. However, this needs 
  support in the translation process to extract the equality 
  and move it to the meta-premises.
*)
lemma hn_pure_copy[sepref_copy_rules]: 
  assumes F: "Γ ==>A F * hn_ctxt P x x'"
  assumes P: "CONSTRAINT is_pure P"
  shows "hn_refine Γ (return x') (F * hn_ctxt P x x') P 
    (COPY$x)"
  apply (rule is_pureE[OF P[simplified]])
  apply (rule hn_refine_frame[OF _ F])
  apply rule
  apply (sep_auto simp: hn_ctxt_def)
  done


lemma precise_pure[constraint_rules]: "single_valued R ==> precise (pure R)"
  unfolding precise_def pure_def
  by (auto dest: single_valuedD)

(* TODO: Integrate with tagged-solver! *)
lemmas [constraint_rules] = single_valued_Id list_set_rel_sv
(* TODO: Also use a constraint-principle that stops at schematics in autoref! *)


subsection {* Parameters *}
definition PARAM :: "'a => ('a => 'b nres) => 'b nres" 
  where [simp]: "PARAM s f == f s"
lemma skel_param[sepref_la_skel]: 
  "SKEL (PARAM$s$f) = la_seq (la_op s) (SKEL f)"
  by simp

lemma hn_param_bind[sepref_comb_rules]:
  fixes f :: "'a => 'b nres"
  assumes F1: "Γ ==>A Γ1 * hn_ctxt R x x'"
  assumes R: "!!x x'. hn_refine 
    (Γ1 * hn_ctxt R x x')
    (c x') (Γ2 x x') S (f x)"
  assumes F2: "!!x x'. Γ2 x x' ==>A Γ' * hn_ctxt R' x x'"
  shows "hn_refine Γ (c x') (Γ' * hn_ctxt R' x x') S (PARAM$x$(λ2x. f x))"
  apply (simp)
  apply (rule hn_refine_cons_pre[OF F1])
  apply (rule hn_refine_cons_post[OF _ F2])
  apply (rule R)
  done
  

ML {*
  structure Sepref_Param = struct
    fun 
      id_param @{mpat "hn_refine ?P ?c ?P' ?R ?a"} = let
        open Sepref_Basic
        val used_params = Term.add_frees a [] |> map Free
          |> Termtab.make_set

        val params = strip_star P 
          |> map_filter (dest_hn_ctxt_opt) |> map #2
          |> filter (Termtab.defined used_params)
          |> rev

        fun abs_p (p as (Free (x,_))) a = let
            val b = Term.lambda_name (x,p) a
            val res = @{mk_term "PARAM ?p ?b"}
          in res end
        | abs_p _ _ = 
            error ("id_param: Internal error: expected only frees in params")

        val new_a = fold abs_p params a
      in
       @{mk_term "hn_refine ?P ?c ?P' ?R ?new_a"}
      end
    | id_param t = raise TERM ("id_param",[t])

    fun id_param_conv ctxt = Refine_Util.f_tac_conv ctxt
      (id_param) 
      (simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms PARAM_def}) 1)
  
  end
*}

end