Theory Sep_Lift

theory Sep_Lift
imports Sep_Generic_Wp "HOL-Library.Rewrite"
begin
    
  locale sep_lifting_scheme = 
    fixes lift :: "'asmall::unique_zero_sep_algebra  'abig::unique_zero_sep_algebra"  
      and project :: "'abig  'asmall"
      and carve :: "'abig  'abig"
      and splittable :: "'abig  bool"
    assumes lift_pres_zero[simp]: "lift 0 = 0" 
        and lift_add_distrib[simp]: "a1##a2  lift (a1+a2) = lift a1 + lift a2"
        and lift_disj_distrib[simp]: "lift a1 ## lift a2  a1##a2"
        
    assumes splittable_lift[simp]: "splittable (lift a)"    
        and splittable_same_struct: "b1##b2; splittable b1; project b10  splittable b2"
        (*and splittable_same_struct: "⟦b1##b2; b1≠0; b2≠0⟧ ⟹ splittable b2 = splittable b1"*)
        and splittable_add_distrib[simp]: "b1##b2  splittable (b1 + b2)  splittable b1  splittable b2"
        
    assumes project_pres_zero[simp]: "project 0 = 0"
    assumes project_add_distrib[simp]: 
          "b1##b2; splittable b1; splittable b2  project (b1 + b2) = project b1 + project b2"
        and project_disj_distrib[simp]: "b1##b2; splittable b1; splittable b2  project b1 ## project b2"     
    assumes carve_pres_zero[simp]: "carve 0 = 0"
    assumes carve_add_distrib[simp]: "b1##b2; splittable b1; splittable b2  carve (b1 + b2) = carve b1 + carve b2"
        and carve_disj_distrib[simp]: "b1##b2; splittable b1; splittable b2  carve b1 ## carve b2"
        
    assumes project_lift_id[simp]: "project (lift a) = a"
    assumes project_carve_Z[simp]: "splittable b  project (carve b) = 0"
    assumes carve_lift_Z[simp]: "carve (lift a) = 0"
    
    assumes disj_project_imp_lift: "splittable b  a ## project b  lift a ## b"
    assumes carve_disj_lift1: "splittable b  carve b ## lift a"
                
    assumes complete_split: "splittable b  lift (project b) + carve b = b"
    
    assumes carve_not_splittable[simp]: "¬ splittable b  carve b = b"
      ― ‹We force carve to be identity outside splittable range. 
        This is an arbitrary fixation of an otherwise unspecified value, which makes 
        some properties simpler to prove (e.g. assoc of composition)
        ›
  begin
    lemma splittableZ[simp]: "splittable 0"
      by (metis lift_pres_zero splittable_lift)
        
    lemma splittable_carve_if[simp]: " splittable p  splittable (carve p)"
      by (metis carve_disj_lift1 complete_split sep_add_commute splittable_add_distrib) 
  
    lemma carve_disj_lift[simp]: 
      assumes "splittable b"
      shows "carve b ## lift a" "lift a ## carve b" 
      using carve_disj_lift1 carve_disj_lift1[THEN sep_disj_commuteI] assms by auto
  
    lemma lift_eqZ_iffZ[simp]: "lift a = 0  a=0"
      by (metis lift_pres_zero project_lift_id)
    
    lemma carve_smaller1: "splittable p  p ## f  carve p ## f"  
      by (metis carve_disj_lift(2) sep_add_disjD complete_split)

    lemmas carve_smaller[simp] = carve_smaller1 carve_smaller1[THEN sep_disj_commuteI]
      
    lemma disj_project_eq_lift1: 
      "splittable b  a ## project b  lift a ## b"
      using disj_project_imp_lift project_disj_distrib by fastforce
      
    lemma disj_project_eq_lift2: 
      "splittable b  project b ## a  b ## lift a"
      by (meson disj_project_eq_lift1 sep_disj_commuteI)

    lemmas disj_project_eq_lift = disj_project_eq_lift1 disj_project_eq_lift2  
        
    (*  
    lemma proj_Z_imp_disj1:
      assumes "project b = 0"
      shows "lift a ## b"
      by (simp add: assms disj_project_imp_lift)
      
    lemmas proj_Z_imp_disj[simp] = proj_Z_imp_disj1 proj_Z_imp_disj1[THEN sep_disj_commuteI]  
    *)
      
    lemmas split_complete = complete_split[symmetric]

    
    lemma projectZ_add_distrib[simp]: 
      "b1##b2; splittable b1; splittable b2  project (b1 + b2) = 0  project b1=0  project b2=0"  
      by (auto simp: )
      
    lemma project_add_invE:
      assumes "project b = a1 + a2" "splittable b" "a1##a2" 
      obtains b1 b2 where "b1##b2" "b=b1+b2" "project b1 = a1" "project b2 = a2" "splittable b1" "splittable b2"
      using assms
      by (smt lift_add_distrib lift_disj_distrib project_add_distrib project_carve_Z project_lift_id sep_add_assoc sep_lifting_scheme.carve_disj_lift(2) sep_lifting_scheme.disj_project_eq_lift1 sep_lifting_scheme_axioms split_complete splittable_add_distrib)
  
    lemma carve_project_Z_imp_Z: "carve x = 0; project x = 0  x = 0"  
      by (metis carve_not_splittable carve_pres_zero complete_split project_pres_zero splittableZ)
      
      
  end        
    
    
  record ('asmall,'csmall,'tsmall,'abig,'cbig,'tbig) sep_lifter =
    lift :: "'asmall::unique_zero_sep_algebra  'abig::unique_zero_sep_algebra"  
    project :: "'abig  'asmall"
    carve :: "'abig  'abig"
    splittable :: "'abig  bool"
    
    L :: "'csmall  'cbig"
    αb :: "'cbig  'abig"
    αs :: "'csmall  'asmall"
    
    tyb :: "'cbig  'tbig"
    tys :: "'csmall  'tsmall"
    
  hide_const (open) lift project carve splittable L αb αs tyb tys

  locale pre_sep_lifter = 
    fixes LFT :: "('asmall::unique_zero_sep_algebra,'csmall,'tsmall,'abig::unique_zero_sep_algebra,'cbig,'tbig) sep_lifter"
  begin
    abbreviation lift where "lift  sep_lifter.lift LFT"
    abbreviation project where "project  sep_lifter.project LFT"
    abbreviation carve where "carve  sep_lifter.carve LFT"
    abbreviation splittable where "splittable  sep_lifter.splittable LFT"
    abbreviation L where "L  sep_lifter.L LFT"
    abbreviation αb where "αb  sep_lifter.αb LFT"
    abbreviation αs where "αs  sep_lifter.αs LFT"
    abbreviation tyb where "tyb  sep_lifter.tyb LFT"
    abbreviation tys where "tys  sep_lifter.tys LFT"
    
    
    definition "lift_assn P b  splittable b  carve b = 0  (P (project b))"

    
  end  
  
      
  locale sep_lifter = pre_sep_lifter LFT + sep_lifting_scheme lift project carve splittable
    for LFT :: "('asmall::unique_zero_sep_algebra,'csmall,'tsmall,'abig::unique_zero_sep_algebra,'cbig,'tbig) sep_lifter" +
    
    assumes lensL[simp, intro!]: "hlens L"  
    assumes precond: "splittable (αb cb); project (αb cb)0  pre_get L cb  pre_put L cb"
    assumes get_xfer[simp]: "splittable (αb cb); project (αb cb)0  αs (get' L cb) = project (αb cb)"
    assumes put_xfer[simp]: "splittable (αb cb); project (αb cb)0; tys x = tys (get' L cb) 
       αb (put' L x cb) = lift (αs x) + carve (αb cb)"

    assumes ty_put_xfer[simp]: "pre_get L cb; tys x = tys (get' L cb)   tyb (put' L x cb) = tyb cb"
      
  begin
  
    declare sep_lifter_axioms[simp, intro!]
  

    lemma lift_assn_distrib[simp]: "lift_assn (P**Q) = (lift_assn P ** lift_assn Q)"
      apply (rule ext)
      apply (auto simp: sep_conj_def lift_assn_def)
      apply (erule project_add_invE; auto)
      apply force
      done
      
    lemma lift_assn_EXACT_eq[simp]: "lift_assn (EXACT v) = EXACT (lift v)"
      apply (rule ext)
      apply (auto simp: lift_assn_def)
      using split_complete by fastforce

    lemma lift_assn_pure[simp]: "lift_assn (Φ) = Φ"
      apply (rule ext)
      unfolding lift_assn_def
      by (auto simp: pred_lift_extract_simps intro: carve_project_Z_imp_Z)

    lemma lift_assn_empty[simp]: "lift_assn  = "  
      apply (rule ext)
      unfolding lift_assn_def
      by (auto simp: sep_algebra_simps intro: carve_project_Z_imp_Z)
            
    lemma lift_ty:   
      assumes PRESTY: "s. wlp c (λ_ s'. tys s' = tys s) s"
      shows "wlp (zoom (lift_lens e L) c) (λ_ s'. tyb s' = tyb s) s"
      using PRESTY
      apply (auto simp: mwp_def wlp_def run_simps split: option.splits)
      apply (drule meta_spec[where x="get' L s"])
      apply (auto split: mres.split)
      done
      
    lemma infer_pre_get_with_frame:
      assumes "lift_assn P p" "p ## f"  "αb s = p + f"
      assumes NZ: "¬P 0"
      shows "pre_get L s"
      apply (rule precond[THEN conjunct1])
      using assms splittable_same_struct unfolding lift_assn_def 
      by fastforce+

    lemma lift_operation:
      assumes NZ: "¬P 0"
      assumes PRESTY: "s. wlp c (λ_ s'. tys s' = tys s) s"
      assumes HT: "notime.htriple αs P c Q"
      shows "notime.htriple αb (lift_assn P) (zoom (lift_lens e L) c) (λr. lift_assn (Q r))"
    proof (rule notime.htripleI'; clarsimp simp: lift_assn_def)
      fix p s f
      assume A: "p ## f" and [simp]: "αb s = p + f" "splittable p" "carve p = 0" and "P (project p)"
      hence [simp]: "project p0" "p0" "p##f" "f##p" using NZ by (auto simp: sep_algebra_simps)
      
      have [simp]: "splittable f"
        using A ‹project p  0 ‹splittable p splittable_same_struct by blast
      
      from A have OFR: "αb s = lift (project (p + f)) + carve (p+f)"
        apply (rewrite in "_ = " complete_split) 
        by simp_all
        

      have [simp]: "project (αb s)  0" using A by simp 
        
      note HT'= notime.htripleD'[OF HT _ _ P (project p), where f="project f" and s="get' L s", simplified]
      
      from PRESTY HT' have HT': 
        "wpn c (λr s'. tys s' = tys (get' L s)  (p'. p' ## project f  αs s' = p' + project f  Q r p')) (get' L s)"
        apply (auto simp: wlp_def wpn_def mwp_def split: mres.splits) 
        apply (drule meta_spec[where x="get' L s"]; simp) (* FIXME: Why can't we solve that goal in one line? *)
        done
      
      note HT' = HT'[unfolded wpn_def]

      (*from precond have [simp]: "pre_get L s" "pre_put L s" by auto*)
      note [simp] = precond
      
      show "wpn (zoom (lift_lens e L) c) (λr s'. p'. p' ## f  αb s' = p' + f  splittable p'  carve p' = 0  Q r (project p')) s"
        apply (auto simp: wpn_def run_simps split: option.splits)
        apply (rule mwp_cons[OF HT'])
        apply (clarsimp_all simp: )
        subgoal for x s' p'
          apply (intro exI[where x="lift p' + carve p"])
          apply (auto simp: disj_project_eq_lift sep_algebra_simps)
          apply (rewrite in "_=" split_complete[of f])
          apply (auto simp: disj_project_eq_lift sep_algebra_simps)
          done
        done    
    qed
    
  end


  definition "listα α l i  if i<length l then α (l!i) else (0::_::unique_zero_sep_algebra)"  
  
  lemma listα_Nil[simp]: "listα α [] = 0"
    by (auto simp: listα_def)
  
  
  lemma listα_upd[simp]: 
    "i<length xs  listα α (xs[i := x]) = (listα α xs)(i:=α x)"
    by (rule ext) (auto simp: listα_def)

  definition "idx_lifter tys αs i   
    sep_lifter.lift = fun_upd (λx. 0) i,
    sep_lifter.project = (λf. f i),
    sep_lifter.carve = (λf. f(i:=0)),
    sep_lifter.splittable = (λf. True),
    sep_lifter.L = idxL i,
    sep_lifter.αb = listα αs,
    sep_lifter.αs = αs,
    sep_lifter.tyb = map tys,
    sep_lifter.tys = tys
        "

  lemma idx_lifter_simps[simp]:      
    "sep_lifter.lift (idx_lifter tys αs i) = fun_upd (λx. 0) i"
    "sep_lifter.project (idx_lifter tys αs i) = (λf. f i)"
    "sep_lifter.carve (idx_lifter tys αs i) = (λf. f(i:=0))"
    "sep_lifter.splittable (idx_lifter tys αs i) = (λf. True)"
    "sep_lifter.L (idx_lifter tys αs i) = idxL i"
    "sep_lifter.αb (idx_lifter tys αs i) = listα αs"
    "sep_lifter.αs (idx_lifter tys αs i) = αs"
    "sep_lifter.tyb (idx_lifter tys αs i) = map tys"
    "sep_lifter.tys (idx_lifter tys αs i) = tys"
    unfolding idx_lifter_def by auto    
            
  lemma idx_lifter[simp, intro!]: "sep_lifter (idx_lifter tys αs i)"
    apply unfold_locales
    apply (simp_all add: idx_lifter_def)
    apply (auto simp: sep_algebra_simps)
    apply (auto simp: sep_disj_fun_def sep_algebra_simps intro!: exI[where x="f(i:=0::'a)" for f])
    apply (auto simp: listα_def[abs_def] nth_list_update map_upd_eq split: if_splits)
    done
    
  fun optionα :: "('c  'a)  'c option  'a::unique_zero_sep_algebra" 
    where "optionα α None = 0" | "optionα α (Some x) = α x"
    
  lemma optionα_alt: "optionα α x = (case x of None  0 | Some y  α y)" by (cases x) auto 
    
  definition "option_lifter tys αs  
    sep_lifter.lift = id,
    sep_lifter.project = id,
    sep_lifter.carve = λ_. 0,
    sep_lifter.splittable = (λf. True),
    sep_lifter.L = theL,
    sep_lifter.αb = optionα αs,
    sep_lifter.αs = αs,
    sep_lifter.tyb = map_option tys,
    sep_lifter.tys = tys
  "  
    
  lemma option_lifter_simps[simp]:
    "sep_lifter.lift (option_lifter tys αs) = id"
    "sep_lifter.project (option_lifter tys αs) = id"
    "sep_lifter.carve (option_lifter tys αs) = (λ_. 0)"
    "sep_lifter.splittable (option_lifter tys αs) = (λf. True)"
    "sep_lifter.L (option_lifter tys αs) = theL"
    "sep_lifter.αb (option_lifter tys αs) = optionα αs"
    "sep_lifter.αs (option_lifter tys αs) = αs"
    "sep_lifter.tyb (option_lifter tys αs) = map_option tys"
    "sep_lifter.tys (option_lifter tys αs) = tys"
    unfolding option_lifter_def by auto
  
  
  
  lemma option_lifter[simp, intro!]: "sep_lifter (option_lifter tys αs)"
    apply unfold_locales
    apply (auto simp: option_lifter_def)
    apply (auto simp: optionα_alt split: option.splits)
    done
  
  definition "compose_splittable l1 l2  (λx. sep_lifter.splittable l2 x  sep_lifter.splittable l1 (sep_lifter.project l2 x))"  
    
  definition "compose_carve l1 l2  λx. 
    if compose_splittable l1 l2 x then
      sep_lifter.carve l2 x + sep_lifter.lift l2 ( sep_lifter.carve l1 (sep_lifter.project l2 x))
    else x"  
    
  definition 
    compose_lifter :: 
    "('as::unique_zero_sep_algebra,'cs,'ts,'ab::unique_zero_sep_algebra,'cb,'tb) sep_lifter 
       ('ab,'cb,'tb,'al::unique_zero_sep_algebra,'cl,'tl) sep_lifter
       ('as, 'cs, 'ts, 'al, 'cl, 'tl) sep_lifter"
    (infixl "lft" 80)  
  where "compose_lifter l1 l2  
    sep_lifter.lift = sep_lifter.lift l2 o sep_lifter.lift l1,
    sep_lifter.project = sep_lifter.project l1 o sep_lifter.project l2,
    sep_lifter.carve = compose_carve l1 l2,
    sep_lifter.splittable = compose_splittable l1 l2,
    sep_lifter.L = sep_lifter.L l2 L sep_lifter.L l1,
    sep_lifter.αb = sep_lifter.αb l2,
    sep_lifter.αs = sep_lifter.αs l1,
    sep_lifter.tyb = sep_lifter.tyb l2,
    sep_lifter.tys = sep_lifter.tys l1
  "  

  lemma compose_lifter_simps:
    "sep_lifter.lift (l1 lft l2) = sep_lifter.lift l2 o sep_lifter.lift l1"
    "sep_lifter.project (l1 lft l2) = sep_lifter.project l1 o sep_lifter.project l2"
    "sep_lifter.carve (l1 lft l2) = compose_carve l1 l2"
    "sep_lifter.splittable (l1 lft l2) = compose_splittable l1 l2"
    "sep_lifter.L (l1 lft l2) = sep_lifter.L l2 L sep_lifter.L l1"
    "sep_lifter.αb (l1 lft l2) = sep_lifter.αb l2"
    "sep_lifter.αs (l1 lft l2) = sep_lifter.αs l1"
    "sep_lifter.tyb (l1 lft l2) = sep_lifter.tyb l2"
    "sep_lifter.tys (l1 lft l2) = sep_lifter.tys l1"
    unfolding compose_lifter_def by auto
                                              
  lemma compose_sep_lifter[simp, intro!]:
    fixes l1 :: "('as::unique_zero_sep_algebra, 'cs, 'ts, 'ab::unique_zero_sep_algebra, 'cb, 'tb) sep_lifter" 
      and l2 :: "('ab, 'cb, 'tb, 'al::unique_zero_sep_algebra, 'cl, 'tl) sep_lifter"
    assumes "sep_lifter l1"  
    assumes "sep_lifter l2"  
    assumes αEQ: "sep_lifter.αs l2 = sep_lifter.αb l1"
    assumes tEQ: "sep_lifter.tys l2 = sep_lifter.tyb l1"
    shows "sep_lifter (compose_lifter l1 l2)"
  proof -
    interpret l1: sep_lifter l1 by fact
    interpret l2: sep_lifter l2 by fact

    show ?thesis
      apply unfold_locales
      apply (clarsimp_all 
        simp: compose_lifter_def compose_carve_def compose_splittable_def
        simp: sep_algebra_simps 
        )
      subgoal for b1 b2
        using l1.splittable_same_struct l2.splittable_same_struct by force
     
      subgoal for b1 b2 by auto
      subgoal by (clarsimp simp: l1.disj_project_eq_lift l2.disj_project_eq_lift sep_algebra_simps)
      subgoal 
        apply (rewrite at "_=" l2.split_complete, assumption)
        apply (rewrite in "_=" l1.split_complete, assumption)
        by (auto simp: sep_algebra_simps)
      proof goal_cases  
        fix cb
        assume A: 
          "l1.project (l2.project (l2.αb cb))  0"
          "l2.splittable (l2.αb cb)"
          "l1.splittable (l2.project (l2.αb cb))"
        then show "pre_get l2.L cb  pre_get l1.L (get' l2.L cb)"
          using αEQ l2.precond l1.sep_lifter_axioms l2.sep_lifter_axioms sep_lifter.get_xfer sep_lifter.precond 
          by fastforce
          
          
        hence [simp]: "pre_get l2.L cb" "pre_get l1.L (get' l2.L cb)" by auto
          
        from A show "l1.αs (get' (l2.L L l1.L) cb) = l1.project (l2.project (l2.αb cb))"
          using αEQ l2.sep_lifter_axioms sep_lifter.get_xfer by fastforce
        
        fix x  
        assume "l1.tys x = l1.tys (get' (l2.L L l1.L) cb)"
        with A show "l2.αb (put' (l2.L L l1.L) x cb) =
            l2.carve (l2.αb cb) 
          + (l2.lift (l1.lift (l1.αs x)) 
              + l2.lift (l1.carve (l2.project (l2.αb cb))))"  
          apply (auto)
          apply (subst l2.put_xfer) apply (auto simp: tEQ) [3]
          apply (subst l1.put_xfer[folded αEQ]; (subst l2.get_xfer)?; auto) 
          apply (auto simp: sep_algebra_simps)
          done
          
      qed (simp add: tEQ)
      
  qed    
    
    
  lemma compose_lifter_assoc[simp]:
    assumes "sep_lifter l1" "sep_lifter l2" "sep_lifter l3"
    shows "l1 lft l2 lft l3 = l1 lft (l2 lft l3)"
  proof -
    interpret l1: sep_lifter l1 by fact
    interpret l2: sep_lifter l2 by fact
    interpret l3: sep_lifter l3 by fact
  
    show ?thesis
      unfolding compose_lifter_def compose_carve_def compose_splittable_def
      by (auto del: ext intro!: ext simp: sep_algebra_simps)
      
  qed
  
  lemma lift_assn_compose: 
    assumes "sep_lifter l1"
    assumes "sep_lifter l2"
    shows "pre_sep_lifter.lift_assn (l1 lft l2) P = pre_sep_lifter.lift_assn l2 (pre_sep_lifter.lift_assn l1 P)"  
  proof -
    interpret l1: sep_lifter l1 by fact
    interpret l2: sep_lifter l2 by fact
  
    show ?thesis  
      apply (rule ext)
      apply (auto simp: pre_sep_lifter.lift_assn_def)
      apply (auto simp: compose_lifter_def compose_splittable_def compose_carve_def)
      done
  qed
    
  definition "id_lifter ty α  
    sep_lifter.lift = id,
    sep_lifter.project = id,
    sep_lifter.carve = λx. 0,
    sep_lifter.splittable = (λf. True),
    sep_lifter.L = idL,
    sep_lifter.αb = α,
    sep_lifter.αs = α,
    sep_lifter.tyb = ty,
    sep_lifter.tys = ty
  "
  
  lemma id_lifter_simps[simp]:
    "sep_lifter.lift (id_lifter ty α) = id"
    "sep_lifter.project (id_lifter ty α) = id"
    "sep_lifter.carve (id_lifter ty α) = (λx. 0)"
    "sep_lifter.splittable (id_lifter ty α) = (λf. True)"
    "sep_lifter.L (id_lifter ty α) = idL"
    "sep_lifter.αb (id_lifter ty α) = α"
    "sep_lifter.αs (id_lifter ty α) = α"
    "sep_lifter.tyb (id_lifter ty α) = ty"
    "sep_lifter.tys (id_lifter ty α) = ty"
    unfolding id_lifter_def by auto
    
  
  
  lemma id_lifter[simp, intro!]: "sep_lifter (id_lifter ty α)"
    apply unfold_locales
    apply (auto simp: id_lifter_def)
    done
  
  lemma compose_lifter_id_left[simp]:
    assumes "sep_lifter l" shows "id_lifter (sep_lifter.tys l) (sep_lifter.αs l) lft l = l"  
  proof -
    interpret sep_lifter l by fact
    
    have [simp]: "compose_carve (id_lifter (sep_lifter.tys l) (sep_lifter.αs l)) l = carve"
      by (rule ext) (auto simp: compose_carve_def compose_splittable_def id_lifter_def)
    
    show ?thesis 
      apply (auto simp: compose_lifter_def)
      apply (auto simp: compose_splittable_def id_lifter_def)
      done
  qed    

  lemma compose_lifter_id_right[simp]:
    assumes "sep_lifter l" shows "l lft (id_lifter (sep_lifter.tyb l) (sep_lifter.αb l)) = l"  
  proof -
    interpret sep_lifter l by fact
    
    have [simp]: "compose_carve l (id_lifter (sep_lifter.tyb l) (sep_lifter.αb l)) = carve"
      by (rule ext) (auto simp: compose_carve_def compose_splittable_def id_lifter_def)
    
    show ?thesis
      apply (auto simp: compose_lifter_def)
      apply (auto simp: compose_splittable_def id_lifter_def)
      done
      
  qed    
    
end