Theory Sepref_Frame

theory Sepref_Frame
imports Sepref_Basic
header ‹Frame Inference›
theory Sepref_Frame
imports Sepref_Basic
begin
  text ‹ In this theory, we provide a specific frame inference tactic
    for Sepref.

    The first tactic, @{text frame_tac}, is a standard frame inference tactic, 
    based on the assumption that only @{const hn_ctxt}-assertions need to be
    matched.

    The second tactic, @{text merge_tac}, resolves entailments of the form
      @{text "F1 ∨A F2 ==>A ?F"}
    that occur during translation of if and case statements.
    It synthesizes a new frame ?F, where refinements of variables 
    with equal refinements in @{text F1} and @{text F2} are preserved,
    and the others are set to @{const hn_invalid}.
    ›

lemma frame_thms:
  "P ==>A P"
  "hn_ctxt R x y ==>A hn_invalid x y"

  "P==>AP' ==> F==>AF' ==> F*P ==>A F'*P'"

  "P==>AP' ==> emp==>AF' ==> P==>AF'*P'"
  apply (blast intro: ent_refl ent_star_mono)
  apply (simp add: hn_ctxt_def)
  apply (erule (1) ent_star_mono)
  by (metis assn_one_left ent_star_mono)

lemma frame_ctxt_dischargeI:
  "R = R' ==> hn_ctxt R x y ==>A hn_ctxt R' x y"
  by simp


lemma hn_merge0:
  "emp ∨A emp ==>A emp"
  by simp

lemma hn_merge1:
  "hn_ctxt R x x' ∨A hn_ctxt R x x' ==>A hn_ctxt R x x'"
  "[| Fl ∨A Fr ==>A F |] ==> 
    Fl * hn_ctxt R x x' ∨A Fr * hn_ctxt R x x' ==>A F * hn_ctxt R x x'"
  apply (rule ent_disjE)
  apply (rule ent_refl)
  apply (rule ent_refl)

  apply (rule ent_disjE)
  apply (rule ent_star_mono[OF _ ent_refl])
  apply (erule ent_disjI1)
  apply (rule ent_star_mono[OF _ ent_refl])
  apply (erule ent_disjI2)
  done

lemma hn_merge2:  
  "hn_ctxt R1 x x' ∨A hn_ctxt R2 x x' ==>A hn_invalid x x'"
  "[| Fl ∨A Fr ==>A F |] ==> 
    Fl * hn_ctxt R1 x x' ∨A Fr * hn_ctxt R2 x x' ==>A F * hn_invalid x x'"
  apply (rule ent_disjE)
  apply (simp add: hn_ctxt_def pure_def)
  apply (simp add: hn_ctxt_def pure_def)

  apply (rule ent_disjE)
  apply (rule ent_star_mono)
  apply (erule ent_disjI1)
  apply (simp add: hn_ctxt_def pure_def)
  apply (rule ent_star_mono)
  apply (erule ent_disjI2)
  apply (simp add: hn_ctxt_def pure_def)
  done

lemmas hn_merge = hn_merge0 hn_merge1 hn_merge2

lemma is_merge: "P1∨AP2==>AP ==> P1∨AP2==>AP" .
ML {*
signature SEPREF_FRAME = sig
  (* Check if subgoal is a frame obligation *)
  val is_frame : term -> bool 
  (* Check if subgoal is a merge obligation *)
  val is_merge : term -> bool
  (* Perform frame inference *)
  val frame_tac : Proof.context -> int -> tactic
  (* Perform merging *)
  val merge_tac : Proof.context -> int -> tactic
  (* Reorder frame, used for debugging *)
  val prepare_frame_tac : Proof.context -> int -> tactic

  val add_normrel_eq : thm -> Context.generic -> Context.generic
  val del_normrel_eq : thm -> Context.generic -> Context.generic
  val get_normrel_eqs : Proof.context -> thm list

  val setup: theory -> theory
end


structure Sepref_Frame : SEPREF_FRAME = struct

  structure normrel_eqs = Named_Thms (
    val name = @{binding sepref_normrel_eqs}
    val description = "Equations to normalize relations before frame matching"
  )

  val add_normrel_eq = normrel_eqs.add_thm
  val del_normrel_eq = normrel_eqs.del_thm
  val get_normrel_eqs = normrel_eqs.get

  local
    open Sepref_Basic Refine_Util Conv
  
    fun assn_ord p = case pairself dest_hn_ctxt_opt p of
        (NONE,NONE) => EQUAL
      | (SOME _, NONE) => LESS
      | (NONE, SOME _) => GREATER
      | (SOME (_,a,_), SOME (_,a',_)) => Term_Ord.fast_term_ord (a,a')

  in
    fun reorder_ctxt_conv ctxt ct = let
      val cert = cterm_of (theory_of_cterm ct)

      val new_ct = term_of ct 
        |> strip_star
        |> sort assn_ord
        |> list_star
        |> cert

      val thm = Goal.prove_internal ctxt [] (mk_cequals (ct,new_ct)) 
        (fn _ => simp_tac 
          (put_simpset HOL_basic_ss ctxt addsimps @{thms star_aci}) 1)

    in
      thm
    end
  
    fun prepare_fi_conv ctxt ct = case term_of ct of
      @{mpat "?P ==>A ?Q"} => let
        val cert = cterm_of (theory_of_cterm ct)
  
        (* Build table from abs-vars to ctxt *)
        val (Qm, Qum) = strip_star Q |> List.partition is_hn_ctxt
        val Qtab = (
          Qm |> map (fn x => (#2 (dest_hn_ctxt x),(NONE,x))) 
          |> Termtab.make
        ) handle
            e as (Termtab.DUP _) => (
              tracing ("Dup heap: " ^ PolyML.makestring ct); reraise e)
        
        (* Go over entries in P and try to find a partner *)
        val (Qtab,Pum) = fold (fn a => fn (Qtab,Pum) => 
          case dest_hn_ctxt_opt a of
            NONE => (Qtab,a::Pum)
          | SOME (_,p,_) => ( case Termtab.lookup Qtab p of
              SOME (NONE,tg) => (Termtab.update (p,(SOME a,tg)) Qtab, Pum)
            | _ => (Qtab,a::Pum)
            )
        ) (strip_star P) (Qtab,[])

        (* Read out information from Qtab *)
        val (pairs,Qum2) = Termtab.dest Qtab |> map #2 
          |> List.partition (is_some o #1)
          |> apfst (map (apfst the))
          |> apsnd (map #2)
  
        (* Build reordered terms *)
        val P' = map fst pairs @ Pum |> list_star
        val Q' = map snd pairs @ Qum2 @ Qum |> list_star
        
        val new_ct = mk_entails (P',Q') |> cert
  
        (*
          val _ = pairs |> map (pairself cert) |> PolyML.makestring |> tracing
        *)
  
        val thm = Goal.prove_internal ctxt [] (mk_cequals (ct,new_ct)) 
          (fn _ => simp_tac 
            (put_simpset HOL_basic_ss ctxt addsimps @{thms star_aci}) 1)
  
      in 
        thm
      end
    | _ => no_conv ct
  
  end

  fun is_merge @{mpat "Trueprop (_ ∨A _ ==>A _)"} = true | is_merge _ = false
  fun 
    is_frame @{mpat "Trueprop (?P ==>A _)"} = let
      open Sepref_Basic
      val Ps = strip_star P
  
      fun is_atomic (Const (_,@{typ "assn=>assn=>assn"})$_$_) = false
        | is_atomic _ = true
  
    in 
      forall is_atomic Ps
    end
  | is_frame _ = false

  fun prepare_frame_tac ctxt = let
    open Refine_Util Conv
    val frame_ss = put_simpset HOL_basic_ss ctxt addsimps 
      @{thms mult_1_right[where 'a=assn] mult_1_left[where 'a=assn]}
  in
    CONVERSION Thm.eta_conversion THEN'
    CONCL_COND' is_frame THEN'
    simp_tac frame_ss THEN'
    CONVERSION (HOL_concl_conv (fn _ => prepare_fi_conv ctxt) ctxt)
  end    

  fun frame_ctxt_discharge_tac ctxt = let
    val ss = put_simpset HOL_basic_ss ctxt addsimps normrel_eqs.get ctxt
  in
    rtac @{thm frame_ctxt_dischargeI}
    THEN' SOLVED' (full_simp_tac ss)
  end

  fun frame_tac ctxt = let
    open Refine_Util Conv
    val frame_thms = @{thms frame_thms}
  in
    prepare_frame_tac ctxt THEN' 
    TRY o REPEAT_ALL_NEW (DETERM o 
      ( resolve_tac frame_thms 
        ORELSE' frame_ctxt_discharge_tac ctxt))
  end

  fun merge_tac ctxt = let
    open Refine_Util Conv
    val merge_conv = arg1_conv (binop_conv (reorder_ctxt_conv ctxt))
    val merge_thms = @{thms hn_merge}
  in
    CONVERSION Thm.eta_conversion THEN'
    CONCL_COND' is_merge THEN'
    simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms star_aci}) THEN'
    CONVERSION (HOL_concl_conv (fn _ => merge_conv) ctxt) THEN'
    TRY o REPEAT_ALL_NEW (resolve_tac merge_thms)
  end

  val setup = normrel_eqs.setup
end
*}

setup Sepref_Frame.setup

end