Theory Sepref_Basic

theory Sepref_Basic
imports Sep_Main Refine_Dflt Id_Op
header ‹Basic Definitions›
theory Sepref_Basic
imports 
  "../Separation_Logic_Imperative_HOL/Sep_Main"
  "../Collections/Refine_Dflt"
  Id_Op
begin

text ‹
  In this theory, we define the basic concept of refinement 
  from a nondeterministic program specified in the 
  Isabelle Refinement Framework to an imperative deterministic one 
  specified in Imperative/HOL.
›

subsection {* Values on Heap *}
text ‹We tag every refinement assertion with the tag @{text hn_ctxt}, to
  avoid higher-order unification problems when the refinement assertion 
  is schematic.›
definition hn_ctxt :: "('a=>'c=>assn) => 'a => 'c => assn" 
  -- {* Tag for refinement assertion *}
  where
  "hn_ctxt P a c ≡ P a c"

definition pure :: "('b × 'a) set => 'a => 'b => assn"
  -- {* Pure binding, not involving the heap *}
  where "pure R ≡ (λa c. \<up>((c,a)∈R))"

abbreviation "hn_val R ≡ hn_ctxt (pure R)"

lemma hn_val_unfold: "hn_val R a b = \<up>((b,a)∈R)"
  by (simp add: hn_ctxt_def pure_def)

abbreviation hn_invalid 
  -- "Vacuous refinement assertion for invalidated variables"
  where "hn_invalid a c ≡ hn_ctxt (λ_ _. true) a c"

lemma fr_invalidate: "A==>AB ==> A==>AB*hn_invalid x x'"
  apply (simp add: hn_ctxt_def)
  by (rule ent_true_drop)

subsection ‹Heap-Nres Refinement Calculus›

text {* Predicate that expresses refinement. Given a heap
  @{text "Γ"}, program @{text "c"} produces a heap @{text "Γ'"} and
  a concrete result that is related with predicate @{text "R"} to some
  abstract result from @{text "m"}*}
definition "hn_refine Γ c Γ' R m ≡ nofail m -->
  <Γ> c <λr. Γ' * (∃Ax. R x r * \<up>(RETURN x ≤ m)) >t"

lemma hn_refineI[intro?]:
  assumes "nofail m 
    ==> <Γ> c <λr. Γ' * (∃Ax. R x r * \<up>(RETURN x ≤ m)) >t"
  shows "hn_refine Γ c Γ' R m"
  using assms unfolding hn_refine_def by blast

lemma hn_refineD:
  assumes "hn_refine Γ c Γ' R m"
  assumes "nofail m"
  shows "<Γ> c <λr. Γ' * (∃Ax. R x r * \<up>(RETURN x ≤ m)) >t"
  using assms unfolding hn_refine_def by blast

lemma hn_refine_false[simp]: "hn_refine false c Γ' R m"
  by rule auto

lemma hn_refine_fail[simp]: "hn_refine Γ c Γ' R FAIL"
  by rule auto

lemma hn_refine_frame:
  assumes "hn_refine P' c Q' R m"
  assumes "P ==>A F * P'"
  shows "hn_refine P c (F * Q') R m"
  using assms
  unfolding hn_refine_def
  apply clarsimp
  apply (erule cons_pre_rule)
  apply (rule cons_post_rule)
  apply (erule frame_rule_left)
  apply (simp only: star_aci)
  apply simp
  done

lemma hn_refine_cons:
  assumes I: "P==>AP'"
  assumes R: "hn_refine P' c Q R m"
  assumes I': "Q==>AQ'"
  shows "hn_refine P c Q' R m"
  using R unfolding hn_refine_def
  apply clarsimp
  apply (rule cons_pre_rule[OF I])
  apply (erule cons_post_rule)
  apply (rule ent_star_mono ent_refl I' ent_ex_preI ent_ex_postI)+
  done

lemma hn_refine_cons_pre:
  assumes I: "P==>AP'"
  assumes R: "hn_refine P' c Q R m"
  shows "hn_refine P c Q R m"
  using assms
  by (rule hn_refine_cons[OF _ _ ent_refl])

lemma hn_refine_cons_post:
  assumes R: "hn_refine P c Q R m"
  assumes I: "Q==>AQ'"
  shows "hn_refine P c Q' R m"
  using assms
  by (rule hn_refine_cons[OF ent_refl])

lemma hn_refine_ref:
  assumes LE: "m≤m'"
  assumes R: "hn_refine P c Q R m"
  shows "hn_refine P c Q R m'"
  apply rule
  apply (rule cons_post_rule)
  apply (rule hn_refineD[OF R])
  using LE apply (simp add: pw_le_iff)
  apply (sep_auto intro: order_trans[OF _ LE])
  done

lemma hn_refine_cons_complete:
  assumes I: "P==>AP'"
  assumes R: "hn_refine P' c Q R m"
  assumes I': "Q==>AQ'"
  assumes LE: "m≤m'"
  shows "hn_refine P c Q' R m'"
  apply (rule hn_refine_ref[OF LE])
  apply (rule hn_refine_cons[OF I R I'])
  done
  
subsection "Convenience Lemmas"

lemma hn_refine_guessI:
  assumes "hn_refine P f P' R f'"
  assumes "f=f_conc"
  shows "hn_refine P f_conc P' R f'"
  -- ‹To prove a refinement, first synthesize one, and then prove equality›
  using assms by simp


lemma imp_correctI:
  assumes R: "hn_refine Γ c Γ' R a"
  assumes C: "a ≤ SPEC Φ"
  shows "<Γ> c <λr'. ∃Ar. Γ' * R r r' * \<up>(Φ r)>t"
  apply (rule cons_post_rule)
  apply (rule hn_refineD[OF R])
  apply (rule le_RES_nofailI[OF C])
  apply (sep_auto dest: order_trans[OF _ C])
  done

subsubsection ‹Return›
lemma hnr_RETURN_pass:
  "hn_refine (hn_ctxt R x p) (return p) (hn_invalid x p) R (RETURN x)"
  -- ‹Pass on a value from the heap as return value›
  by rule (sep_auto simp: hn_ctxt_def)

lemma hnr_RETURN_pure:
  assumes "(c,a)∈R"
  shows "hn_refine emp (return c) emp (pure R) (RETURN a)"
  -- ‹Return pure value›
  unfolding hn_refine_def using assms
  by (sep_auto simp: pure_def)
  
subsubsection ‹Assertion›
lemma hnr_FAIL[simp, intro!]: "hn_refine Γ c Γ' R FAIL"
  unfolding hn_refine_def
  by simp

lemma hnr_ASSERT:
  assumes "Φ ==> hn_refine Γ c Γ' R c'"
  shows "hn_refine Γ c Γ' R (do { ASSERT Φ; c'})"
  using assms
  apply (cases Φ)
  by auto

subsubsection ‹Bind›
lemma bind_det_aux: "[| RETURN x ≤ m; RETURN y ≤ f x |] ==> RETURN y ≤ m »= f"
  apply (rule order_trans[rotated])
  apply (rule bind_mono)
  apply assumption
  apply (rule order_refl)
  apply simp
  done

lemma hnr_bind:
  assumes D1: "hn_refine Γ m' Γ1 Rh m"
  assumes D2: 
    "!!x x'. hn_refine (Γ1 * hn_ctxt Rh x x') (f' x') (Γ2 x x') R (f x)"
  assumes IMP: "!!x x'. Γ2 x x' ==>A Γ' * hn_ctxt Rx x x'"
  shows "hn_refine Γ (m'»=f') Γ' R (m»=f)"
  using assms
  unfolding hn_refine_def
  apply (clarsimp simp add: pw_bind_nofail)
  apply (rule Hoare_Triple.bind_rule)
  apply assumption
  apply (clarsimp intro!: normalize_rules simp: hn_ctxt_def)
proof -
  fix x' x
  assume 1: "RETURN x ≤ m" 
    and "nofail m" "∀x. inres m x --> nofail (f x)"
  hence "nofail (f x)" by (auto simp: pw_le_iff)
  moreover assume "!!x x'.
           nofail (f x) --> <Γ1 * Rh x x'> f' x'
           <λr'. ∃Ar. Γ2 x x' * R r r' * true * \<up> (RETURN r ≤ f x)>"
  ultimately have "!!x'. <Γ1 * Rh x x'> f' x'
           <λr'. ∃Ar. Γ2 x x' * R r r' * true * \<up> (RETURN r ≤ f x)>"
    by simp
  also have "!!r'. ∃Ar. Γ2 x x' * R r r' * true * \<up> (RETURN r ≤ f x) ==>AAr. Γ' * R r r' * true * \<up> (RETURN r ≤ f x)"
    apply sep_auto
    apply (rule ent_frame_fwd[OF IMP])
    apply frame_inference
    apply (solve_entails)
    done
  finally (cons_post_rule) have 
    R: "<Γ1 * Rh x x'> f' x' 
        <λr'. ∃Ar. Γ' * R r r' * true * \<up>(RETURN r ≤ f x)>"
    .
  show "<Γ1 * Rh x x' * true> f' x'
          <λr'. ∃Ar. Γ' * R r r' * true * \<up> (RETURN r ≤ m »= f)>"
    by (sep_auto heap: R intro: bind_det_aux[OF 1])
qed

subsubsection ‹Recursion›

definition "hn_rel P m ≡ λr. ∃Ax. P x r * \<up>(RETURN x ≤ m)"

lemma hn_refine_alt: "hn_refine Fpre c Fpost P m ≡ nofail m -->
  <Fpre> c <λr. hn_rel P m r * Fpost>t"
  apply (rule eq_reflection)
  unfolding hn_refine_def hn_rel_def
  apply (simp add: hn_ctxt_def)
  apply (simp only: star_aci)
  done

lemma wit_swap_forall:
  assumes W: "<P> c <λ_. true>"
  assumes T: "(∀x. A x --> <P> c <Q x>)"
  shows "<P> c <λr. ¬A (∃Ax. \<up>(A x) * ¬A Q x r)>"
  (* TODO: Clean up this mess! *)
  unfolding hoare_triple_def Let_def
  apply (intro conjI impI allI)
  apply (elim conjE)
  apply (rule hoare_tripleD[OF W], assumption+) []
  defer
  apply (elim conjE)
  apply (rule hoare_tripleD[OF W], assumption+) []
  apply (elim conjE)
  apply (rule hoare_tripleD[OF W], assumption+) []

  apply (clarsimp, intro conjI allI)
  apply (rule models_in_range)
  apply (rule hoare_tripleD[OF W], assumption+) []

  apply (simp only: disj_not2, intro impI)
  apply (drule spec[OF T, THEN mp])
  apply (drule (2) hoare_tripleD(2))
  .

lemma hn_admissible:
  assumes PREC: "precise Ry"
  assumes E: "∀f∈A. nofail (f x) --> <P> c <λr. hn_rel Ry (f x) r * F>"
  assumes NF: "nofail (INF f:A. f x)"
  shows "<P> c <λr. hn_rel Ry (INF f:A. f x) r * F>"
proof -
  from NF obtain f where "f∈A" and "nofail (f x)"
    by (simp only: refine_pw_simps INF_def) blast

  with E have "<P> c <λr. hn_rel Ry (f x) r * F>" by blast
  hence W: "<P> c <λ_. true>" by (rule cons_post_rule, simp)

  from E have 
    E': "∀f. f∈A ∧ nofail (f x) --> <P> c <λr. hn_rel Ry (f x) r * F>"
    by blast
  from wit_swap_forall[OF W E'] have 
    E'': "<P> c
     <λr. ¬A (∃Axa. \<up> (xa ∈ A ∧ nofail (xa x)) *
                ¬A (hn_rel Ry (xa x) r * F))>" .
  
  thus ?thesis
    apply (rule cons_post_rule)
    unfolding entails_def hn_rel_def
    apply clarsimp
  proof -
    fix h as p
    assume A: "∀f. f∈A --> (∃a.
      ((h, as) \<Turnstile> Ry a p * F ∧ RETURN a ≤ f x)) ∨ ¬ nofail (f x)"
    with `f∈A` and `nofail (f x)` obtain a where 
      1: "(h, as) \<Turnstile> Ry a p * F" and "RETURN a ≤ f x"
      by blast
    have
      "∀f∈A. nofail (f x) --> (h, as) \<Turnstile> Ry a p * F ∧ RETURN a ≤ f x"
    proof clarsimp
      fix f'
      assume "f'∈A" and "nofail (f' x)"
      with A obtain a' where 
        2: "(h, as) \<Turnstile> Ry a' p * F" and "RETURN a' ≤ f' x"
        by blast

      moreover note preciseD'[OF PREC 1 2] 
      ultimately show "(h, as) \<Turnstile> Ry a p * F ∧ RETURN a ≤ f' x" by simp
    qed
    hence "RETURN a ≤ (INF f:A. f x)"
      by (metis (mono_tags) le_INF_iff le_nofailI)
    with 1 show "∃a. (h, as) \<Turnstile> Ry a p * F ∧ RETURN a ≤ (INF f:A. f x)"
      by blast
  qed
qed

lemma hn_admissible':
  assumes PREC: "precise Ry"
  assumes E: "∀f∈A. nofail (f x) --> <P> c <λr. hn_rel Ry (f x) r * F>t"
  assumes NF: "nofail (INF f:A. f x)"
  shows "<P> c <λr. hn_rel Ry (INF f:A. f x) r * F>t"
  apply (rule hn_admissible[OF PREC, where F="F*true", simplified])
  apply simp
  by fact+

lemma hnr_RECT:
  assumes S: "!!cf af ax px. [|
    !!ax px. hn_refine (hn_ctxt Rx ax px * F) (cf px) (F' ax px) Ry (af ax)|] 
    ==> hn_refine (hn_ctxt Rx ax px * F) (cB cf px) (F' ax px) Ry (aB af ax)"
  assumes M: "(!!x. mono_Heap (λf. cB f x))"
  assumes PREC: "precise Ry"
  shows "hn_refine 
    (hn_ctxt Rx ax px * F) (heap.fixp_fun cB px) (F' ax px) Ry (RECT aB ax)"
  unfolding RECT_gfp_def
proof (simp, intro conjI impI)
  assume "trimono aB"
  hence "mono aB" by (simp add: trimonoD)
  have "∀ax px. 
    hn_refine (hn_ctxt Rx ax px * F) (heap.fixp_fun cB px) (F' ax px) Ry 
      (gfp aB ax)"
    apply (rule gfp_cadm_induct[OF _ _ `mono aB`])

    apply rule
    apply (auto simp: hn_refine_alt intro: hn_admissible'[OF PREC]) []

    apply (auto simp: hn_refine_alt) []

    apply clarsimp
    apply (subst heap.mono_body_fixp[of cB, OF M])
    apply (rule S)
    apply blast
    done
  thus "hn_refine (hn_ctxt Rx ax px * F)
     (ccpo.fixp (fun_lub Heap_lub) (fun_ord Heap_ord) cB px) (F' ax px) Ry
     (gfp aB ax)" by simp
qed

lemma hnr_If:
  assumes P: "Γ ==>A Γ1 * hn_val bool_rel a a'"
  assumes RT: "a ==> hn_refine (Γ1 * hn_val bool_rel a a') b' Γ2b R b"
  assumes RE: "¬a ==> hn_refine (Γ1 * hn_val bool_rel a a') c' Γ2c R c"
  assumes IMP: "Γ2b ∨A Γ2c ==>A Γ'"
  shows "hn_refine Γ (if a' then b' else c') Γ' R (if a then b else c)"
  apply rule
  apply (cases a)
    apply (rule cons_pre_rule[OF P])
    apply vcg
    apply (frule RT[unfolded hn_refine_def])
    apply (simp add: pure_def)
    apply (erule cons_post_rule)
    apply (sep_auto intro: ent_star_mono ent_disjI1[OF IMP] ent_refl)
    apply (simp add: pure_def hn_ctxt_def)

    apply (rule cons_pre_rule[OF P])
    apply vcg
    apply (simp add: hn_ctxt_def pure_def)

    apply (frule RE[unfolded hn_refine_def])
    apply (simp add: hn_ctxt_def pure_def)
    apply (erule cons_post_rule)
    apply (sep_auto intro: ent_star_mono ent_disjI2[OF IMP] ent_refl)
  done



subsection ‹ML-Level Utilities›
ML {*
  signature SEPREF_BASIC = sig
    (* Conversion for hn_refine - term*)
    val hn_refine_conv : conv -> conv -> conv -> conv -> conv -> conv

    (* Conversion on abstract value (last argument) of hn_refine - term *)
    val hn_refine_conv_a : conv -> conv

    (* Conversion on abstract value of hn_refine term in conclusion of theorem *)
    val hn_refine_concl_conv_a: (Proof.context -> conv) -> Proof.context -> conv

    (* Make certified == *)
    val mk_cequals : cterm * cterm -> cterm
    (* Make ==>A *)
    val mk_entails : term * term -> term

    (* Make separation conjunction *)
    val mk_star : term * term -> term
    (* Make separation conjunction from list *)
    val list_star : term list -> term
    (* Decompose separation conjunction *)
    val strip_star : term -> term list

    (* Check if term is hn_ctxt-assertion *)
    val is_hn_ctxt : term -> bool 
    (* Decompose hn_ctxt-assertion *)
    val dest_hn_ctxt : term -> term * term * term
    (* Decompose hn_ctxt-assertion, NONE if term has wrong format *)
    val dest_hn_ctxt_opt : term -> (term * term * term) option

    (* Decompose function application, return constructor to rebuild it *)
    val dest_APPc : term -> (term * term) * (term * term -> term)
    (* Get argument of function application, return constructor to exchange argument *)
    val dest_APP_argc : term -> term * (term -> term)
    (* Get arguments, return constructor to exchange arguments *)
    val strip_APP_argc : term -> term list * (term list -> term)
    (* Get function and arguments, return constructor to exchange *)
    val strip_APPc :
       term -> (term * term list) * (term * term list -> term)

  end

  structure Sepref_Basic : SEPREF_BASIC = struct
    local open Conv in
      fun hn_refine_conv c1 c2 c3 c4 c5 ct = case term_of ct of
        @{mpat "hn_refine _ _ _ _ _"} => let
          val cc = combination_conv
        in
          cc (cc (cc (cc (cc all_conv c1) c2) c3) c4) c5 ct
        end
      | _ => raise CTERM ("hn_refine_conv",[ct])
  
      val hn_refine_conv_a = hn_refine_conv all_conv all_conv all_conv all_conv
  
      fun hn_refine_concl_conv_a conv ctxt = Refine_Util.HOL_concl_conv 
        (fn ctxt => hn_refine_conv_a (conv ctxt)) ctxt
  
    end

    (* FIXME: Strange dependency! *)
    val mk_cequals = uncurry SMT_Utils.mk_cequals
  
    val mk_entails = HOLogic.mk_binrel @{const_name "entails"}
  
    val mk_star = HOLogic.mk_binop @{const_name "Groups.times_class.times"}

    fun list_star [] = @{term "emp::assn"}
      | list_star [a] = a
      | list_star (a::l) = mk_star (list_star l,a)

    fun strip_star @{mpat "?a*?b"} = strip_star a @ strip_star b
      | strip_star t = [t]

  
    fun is_hn_ctxt @{mpat "hn_ctxt _ _ _"} = true | is_hn_ctxt _ = false
    fun dest_hn_ctxt @{mpat "hn_ctxt ?R ?a ?p"} = (R,a,p) 
      | dest_hn_ctxt t = raise TERM("dest_hn_ctxt",[t])
  
    fun dest_hn_ctxt_opt @{mpat "hn_ctxt ?R ?a ?p"} = SOME (R,a,p) 
      | dest_hn_ctxt_opt _ = NONE
  
      
    fun 
      dest_APPc (Const(@{const_name "APP"},T)$f$x) =
        ((f,x),fn (f,x) => Const(@{const_name "APP"},T)$f$x)
    | dest_APPc t = raise TERM("dest_APPc",[t])
  
    local
      fun
        strip_APPc_aux (Const(@{const_name "APP"},T)$f$x) = let
          val ((f',l), c) = strip_APPc_aux f
          val l' = x::l
          fun c' (f,x::l) = Const(@{const_name "APP"},T)$c (f,l)$x
            | c' _ = error "strip_APPc (constructor): Too few args"
  
        in
          ((f',l'),c')
        end
      | strip_APPc_aux t = ((t,[]), 
          fn (t,[]) => t | _ => error "strip_APPc (constructor): Extra #args")
    in
      fun strip_APPc t = let
        val ((f,l),c) = strip_APPc_aux t
      in
        ((f,rev l), fn (f,l) => c (f, rev l))
      end
  
    end
  
    fun dest_APP_argc t = let
      val ((f,x),c) = dest_APPc t
    in
      (x,curry c f)
    end
  
    fun strip_APP_argc t = let
      val ((f,l),c) = strip_APPc t
    in
      (l,curry c f)
    end


  end
*}

end