Theory Monad

section ‹Nterm-Fail-Exception-State Monad›
theory Monad
imports
  Basic_Imports ELenses
begin



  section ‹Additions to Partial Function›
  context partial_function_definitions
  begin
    lemma monotoneI:
      "(x. mono_body (λf. F f x))  monotone le_fun le_fun F"
      by (auto simp: monotone_def fun_ord_def)

    lemma fp_unfold:
      assumes "f  fixp_fun F"
      assumes "(x. mono_body (λf. F f x))"
      shows "f x = F f x"
      using assms mono_body_fixp[of F] by auto

  end

  lemma fun_ordD: "fun_ord le f g  le (f x) (g x)"
    by (auto simp: fun_ord_def)

  lemma fun_ord_mono_alt: "monotone le (fun_ord le') f  (x. monotone le le' (λy. f y x))"
    by (metis (mono_tags, lifting) fun_ord_def monotone_def)



  method_setup pf_mono_prover = ‹Scan.succeed (SIMPLE_METHOD' o Subgoal.FOCUS_PREMS (fn {context=ctxt,...} => CHANGED (ALLGOALS (Partial_Function.mono_tac ctxt))))

  ML fun discharge_monos ctxt thm = let
      fun aux ctxt thm = let
        val prems = Thm.prems_of thm

        fun prove_simple tac t ctxt = Goal.prove ctxt [] [] t (fn {context=ctxt, ...} => ALLGOALS (tac ctxt))


        (*val mono_tac = Subgoal.FOCUS (fn {context=ctxt,...} => CHANGED (ALLGOALS (Partial_Function.mono_tac ctxt)))*)
        fun mono_tac ctxt = CHANGED o (Partial_Function.mono_tac ctxt)

        fun cinst (t as @{mpat "_. monotone (fun_ord _) _ _"}) = the_default asm_rl (try (prove_simple mono_tac t) ctxt)
          | cinst _ = asm_rl

        val insts = map cinst prems

        val thm = thm OF insts
      in
        thm
      end
    in
      (* Avoid surprises with schematic variables being instantiated *)
      singleton (Variable.trade (map o aux) ctxt) thm
    end

  attribute_setup discharge_monos
    = ‹Scan.succeed (Thm.rule_attribute [] (discharge_monos o Context.proof_of))
    ‹Try to discharge monotonicity premises›




  section ‹Monad Definition›

  subsection ‹Inner Type›
  datatype (discs_sels) ('a,'e,'c,'s,'f) mres = NTERM | FAIL (the_failure: 'f)
    | EXC 'e 'c (the_state: 's) | SUCC 'a 'c (the_state: 's)
  datatype ('a,'e,'c,'s,'f) M = M (run: "'s  ('a,'e,'c,'s,'f) mres")

  abbreviation "map_mres_state f  map_mres id id id f id"
  abbreviation "map_mres_fail f  map_mres id id id id f"

  lemma map_mres_state_invert[simp]:
    (*"map_mres_state f m = NTERM ⟷ m = NTERM"*)
    "map_mres_state f m = FAIL msg  m = FAIL msg"
    "map_mres_state f m = EXC e c s  (ss. s=f ss  m = EXC e c ss)"
    "map_mres_state f m = SUCC x c s  (ss. s=f ss  m = SUCC x c ss)"
    by (cases m; auto; fail)+

  lemma map_mres_fail_invert[simp]:
    (*"map_mres_state f m = NTERM ⟷ m = NTERM"*)
    "map_mres_fail f m = FAIL msg  (msg'. msg = f msg'  m = FAIL msg')"
    "map_mres_fail f m = EXC e c s  m = EXC e c s"
    "map_mres_fail f m = SUCC x c s  m = SUCC x c s"
    by (cases m; auto; fail)+

  lemma M_eqI[intro?]: " s. run m s = run m' s   m=m'"
    by (cases m; cases m'; auto)


  subsection ‹Ordering Structure›
  abbreviation "mres_ord  flat_ord NTERM"
  abbreviation "mres_lub  flat_lub NTERM"
  abbreviation "mres_mono  monotone (fun_ord mres_ord) mres_ord"
  abbreviation "M_ord  img_ord run (fun_ord mres_ord)"
  abbreviation "M_lub  img_lub run M (fun_lub mres_lub)"
  abbreviation "M_mono  monotone (fun_ord M_ord) M_ord"

  interpretation M:
    partial_function_definitions "M_ord" "M_lub"
    apply (intro partial_function_image partial_function_lift flat_interpretation)
    by (auto simp: M.expand)


  fun addcost :: "'a  ('b, 'c, 'a::{monoid_add}, 'd, 'e) mres  ('b, 'c, 'a, 'd, 'e) mres"  where 
    "addcost _ NTERM = NTERM"
  | "addcost _ (FAIL v) = (FAIL v)"
  | "addcost c' (EXC e c s) = (EXC e (c'+c) s)"
  | "addcost c' (SUCC x c s) = (SUCC x (c'+c) s)"

  subsection ‹Basic Combinators›
  
  definition REC where "REC  M.fixp_fun"
  definition internal_nterm where "internal_nterm  M (λ_. NTERM)"
  definition fail where "fail msg  M (λ_. FAIL msg)"
  definition return where "return x  M (SUCC x 0)"

  definition bind where "bind m f  M (λs. case run m s of SUCC x c s  addcost c (run (f x) s)
                     | NTERM  NTERM | FAIL msg  FAIL msg | EXC e c s  EXC e c s)"
  definition get where "get  M (λs. SUCC s 0 s)"
  definition set where "set s  M (λ_. SUCC () 0 s)"
  definition raise where "raise e  M (EXC e 0)"
  definition handle where "handle m h  M (λs. case run m s of EXC e c s  addcost c (run (h e) s)
         | SUCC x c s  SUCC x c s | NTERM  NTERM | FAIL msg  FAIL msg)"

  definition consume where "consume t = M (λs. SUCC () t s)"

  definition mblock where "mblock begin end m  M (map_mres_state end o run m o begin)"
  definition mfail where "mfail f m  M (map_mres_fail f o run m)"


  text ‹Derived, but required for some laws.›
  definition "map_state f  M (λs. SUCC () 0 (f s))"
  lemma map_state_bind:
    "((map_state f)::(unit, 'a, 'c::{monoid_add}, 'b, 'd) M)  bind get (set o f)" 
    by(auto simp: map_state_def set_def bind_def get_def)


  section ‹Syntax›
  
  abbreviation (do_notation) bind_doI where "bind_doI  bind"
  abbreviation (do_notation) then_doI where "then_doI m f  bind_doI m (λ_. f)"

  notation bind_doI (infixr "" 54)

  notation then_doI (infixr "" 54)

  nonterminal doI_binds and doI_bind
  syntax
    "_doI_block" :: "doI_binds  'a" ("doM {//(2  _)//}" [12] 62)
    "_doI_bind"  :: "[pttrn, 'a]  doI_bind" ("(2_ / _)" 13)
    "_doI_let" :: "[pttrn, 'a]  doI_bind" ("(2let _ =/ _)" [1000, 13] 13)
    "_doI_then" :: "'a  doI_bind" ("_" [14] 13)
    "_doI_final" :: "'a  doI_binds" ("_")
    "_doI_cons" :: "[doI_bind, doI_binds]  doI_binds" ("_;//_" [13, 12] 12)
    (*"_thenM" :: "['a, 'b] ⇒ 'c" (infixr "⪢" 54)*)

  syntax (ASCII)
    "_doI_bind" :: "[pttrn, 'a]  doI_bind" ("(2_ <-/ _)" 13)
    (*"_thenM" :: "['a, 'b] ⇒ 'c" (infixr ">>" 54)*)

  translations
    "_doI_block (_doI_cons (_doI_then t) (_doI_final e))"
       "CONST then_doI t e"

    "_doI_block (_doI_cons (_doI_bind p t) (_doI_final e))"
       "CONST bind_doI t (λp. e)"

    "_doI_block (_doI_cons (_doI_let p t) bs)"
       "let p = t in _doI_block bs"

    "_doI_block (_doI_cons b (_doI_cons c cs))"
       "_doI_block (_doI_cons b (_doI_final (_doI_block (_doI_cons c cs))))"

    "_doI_cons (_doI_let p t) (_doI_final s)"
       "_doI_final (let p = t in s)"

    "_doI_block (_doI_final e)"  "e"

  section ‹Monad Laws›

  lemma map_state_laws[simp]:
    "map_state (λx. x) = return ()"
    "map_state id = return ()"
    "map_state (λ_. c) = set c"
    unfolding return_def bind_def map_state_def get_def set_def
    by (auto split: mres.splits)

  lemma addcost_zero[simp]: "addcost 0 m = m"
    by(cases m; auto)
  lemma addcost_assoc[simp]: "addcost c1 (addcost c2 m) = addcost (c1 + c2) m"
    by (cases m; auto simp: add.assoc)

  lemma bind_laws[simp]:
    fixes m :: "('a, 'b, 'c::{monoid_add},'d, 'e) M"
    shows
      "bind m return = m"
      "bind (return x) f = f x"
      "bind (bind m (λx. f x)) g = bind m (λx. bind (f x) g)"
      "bind (fail msg) f = fail msg"
      "bind (internal_nterm) f = internal_nterm"
      "bind (raise e) f = raise e"
      "bind (consume c1) (λ_. consume c2) = consume (c1+c2)"
      unfolding bind_def return_def fail_def raise_def internal_nterm_def consume_def
      by (cases m; auto split: mres.split)+

  (*
  lemma bind_laws_comm:
    fixes m :: "('a, 'b, 'c::{comm_monoid_add},'d, 'e) M"
    shows
        "bind (consume c) (λ_. m) = bind m (λx. bind (consume c) (λ_. return x))" *)

  lemma handle_laws[simp]:
    "handle (return x) h = return x"
    "handle (consume c) h = consume c"
    "handle (fail msg) h = fail msg"
    "handle (internal_nterm) h = internal_nterm"
    "handle (raise e) h = h e"
    "handle m raise = m"
    "handle (handle m (λe. h e)) i = handle m (λe. handle (h e) i)"
    unfolding handle_def return_def fail_def raise_def internal_nterm_def consume_def
    by ((auto split: mres.split | (cases m; auto split: mres.split)) [])+

  lemma state_laws[simp]:
    "bind get set = return ()"
    "bind get (λs. bind (set s) (f s)) = bind get (λs. f s ())" (* From Lars Hupel's HOL-Library.State_Monad *)
    "bind (set s) (λ_. set s') = set s'"
    "bind get (λ_. m) = m"
    "bind (set s) (λ_. get) = bind (set s) (λ_. return s)"

    "handle get h = get"
    "handle (set s) h = set s"
    unfolding handle_def return_def bind_def get_def set_def
    by (auto)

  lemma mblock_laws[simp]:
    "mblock begin end (return x) = doM {map_state (end o begin); return x}"
    "mblock begin end (raise e) = doM {map_state (end o begin); raise e}"
    "mblock begin end (fail msg) = fail msg"
    "mblock begin end (internal_nterm) = internal_nterm"
    "mblock begin end (get) = doM { sget; map_state (end o begin); return (begin s) }"
    "mblock begin end (set s) = set (end s)"
    unfolding return_def fail_def raise_def mblock_def bind_def map_state_def get_def set_def internal_nterm_def
    by (auto split: mres.splits del: ext intro!: ext)

  lemma mfail_laws[simp]:
    "mfail f (return x) = return x"
    "mfail f (raise e) = raise e"
    "mfail f (fail msg) = fail (f msg)"
    "mfail f (internal_nterm) = internal_nterm"
    "mfail f (get) = get"
    "mfail f (set s) = set s"
    unfolding return_def fail_def raise_def mfail_def bind_def map_state_def get_def set_def internal_nterm_def
    by (auto split: mres.splits del: ext intro!: ext)

  lemma m_injects[simp]: 
    "return x = return x'  x=x'"
    "raise e = raise e'  e=e'"
    "fail msg = fail msg'  msg=msg'"
    unfolding return_def fail_def raise_def
    by (auto dest: fun_cong)
    
  
  section ‹Recursion Setup›

  subsection ‹Fixed-Point Induction›
  lemma M_admissible_aux:
    assumes "x s. PQ x s NTERM"
    shows "M.admissible (λf. x s. PQ x s (run (f x) s))"
    apply (rule admissible_fun)
    apply unfold_locales
    apply (rule admissible_image)
    apply (rule partial_function_lift)
    apply (rule flat_interpretation)
    using assms
    apply (simp add: comp_def)
    apply (smt ccpo.admissibleI chain_fun flat_lub_in_chain fun_lub_def mem_Collect_eq)
    apply (auto simp: M.expand)
    done

  lemma M_lub_fun_empty[simp]: "M.lub_fun {} x = M (λ_. NTERM)"
    by (auto simp: img_lub_def fun_lub_def flat_lub_def)


  lemma REC_unfold:
    assumes DEF: "f  REC F"
    assumes MONO: "x. M.mono_body (λfa. F fa x)"
    shows "f = F f"
    by (metis DEF M.mono_body_fixp MONO REC_def)

  lemma REC_partial_rule:
    fixes PQ :: "'a  'b  ('c, 'd, 'cc, 'b,'f) mres  bool"
      and F :: "('a  ('c, 'd, 'cc, 'b,'f) M)  'a  ('c, 'd, 'cc, 'b,'f) M"
    assumes "f  REC F"
        and "x. M.mono_body (λfa. F fa x)"
        and "x s. PQ x s NTERM"
        and "f x s. x' s'. PQ x' s' (run (f x') s')  PQ x s (run (F f x) s)"
    shows "PQ x s (run (f x) s)"
    using ccpo.fixp_induct[OF M.ccpo M_admissible_aux M.monotoneI, simplified]
    using assms
    unfolding REC_def
    by blast

  declaration Partial_Function.init "ners" @{term M.fixp_fun}
    @{term M.mono_body} @{thm M.fixp_rule_uc} @{thm M.fixp_induct_uc}
    (NONE)


  subsection ‹Well-Founded Induction›
  lemma REC_total_rule:
    fixes PQ :: "'a  'b  ('c, 'd,'cc, 'b,'f) mres  bool"
      and F :: "('a  ('c, 'd,'cc, 'b,'f) M)  'a  ('c, 'd,'cc, 'b,'f) M"
    assumes DEF: "f  REC F"
        and MONO: "x. M.mono_body (λfa. F fa x)"
        and WF: "wf R"
        and STEP: "f x s. x' s'. ((x',s'),(x,s))R  PQ x' s' (run (f x') s')  PQ x s (run (F f x) s)"
    shows "PQ x s (run (f x) s)"
    using WF
    apply (induction "(x,s)" arbitrary: x s)
    by (metis DEF M.mono_body_fixp MONO REC_def STEP)


  subsection ‹Monotonicity Reasoner Setup›

  lemma M_bind_mono[partial_function_mono]:
    assumes mf: "M_mono B" and mg: "y. M_mono (λf. C y f)"
    shows "M_mono (λf. bind (B f) (λy. C y f))"
    apply (rule monotoneI)
    using monotoneD[OF mf] monotoneD[OF mg]
    unfolding bind_def img_ord_def fun_ord_def
    apply (auto simp: flat_ord_def run_def split!: M.splits mres.splits)
    apply (smt M.collapse M.sel mres.distinct(1) run_def)
    apply (smt M.collapse M.sel mres.distinct(1) mres.inject(1) run_def)
    apply (smt M.collapse M.sel mres.distinct(1) mres.distinct(7) run_def)
    apply (smt M.collapse M.sel mres.distinct(1) mres.distinct(9) run_def)
    apply (smt M.collapse M.sel mres.distinct(3) run_def)
    apply (smt M.collapse M.sel mres.distinct(3) mres.distinct(7) run_def)
    apply (smt M.collapse M.sel mres.distinct(3) mres.inject(2) run_def)
    apply (smt M.collapse M.sel mres.distinct(3) mres.inject(2) run_def)
    apply (smt M.collapse M.sel mres.distinct(3) mres.inject(2) run_def)
    apply (smt M.collapse M.sel mres.distinct(11) mres.distinct(3) run_def)
    apply (smt M.collapse M.sel mres.distinct(5) run_def)
    apply (smt M.collapse M.sel mres.distinct(5) mres.distinct(9) run_def)
    apply (smt M.collapse M.sel mres.distinct(11) mres.distinct(5) run_def)
    apply (smt M.collapse M.sel addcost.simps(1) mres.distinct(5) mres.sel(5) mres.sel(6) mres.sel(7) run_def)
    done

  lemma M_handle_mono[partial_function_mono]:
    assumes mf: "M_mono B" and mg: "y. M_mono (λf. C y f)"
    shows "M_mono (λf. handle (B f) (λy. C y f))"
    apply (rule monotoneI)
    using monotoneD[OF mf] monotoneD[OF mg]
    unfolding handle_def img_ord_def fun_ord_def
    apply (auto simp: flat_ord_def run_def split!: M.splits mres.splits)
    apply (smt M.collapse M.sel mres.distinct(1) run_def)
    apply (smt M.collapse M.sel mres.distinct(1) mres.inject(1) run_def)
    apply (smt M.collapse M.sel mres.distinct(1) mres.distinct(7) run_def)
    apply (smt M.collapse M.sel mres.distinct(1) mres.distinct(9) run_def)
    apply (smt M.collapse M.sel mres.distinct(3) run_def)
    apply (smt M.collapse M.sel mres.distinct(3) mres.distinct(7) run_def)
    apply (smt M.collapse M.sel addcost.simps(1) mres.distinct(3) mres.sel(2) mres.sel(3) mres.sel(4) run_def)
    apply (smt M.collapse M.sel mres.distinct(11) mres.distinct(3) run_def)
    apply (smt M.collapse M.sel mres.distinct(5) run_def)
    apply (smt M.collapse M.sel mres.distinct(5) mres.distinct(9) run_def)
    apply (smt M.collapse M.sel mres.distinct(11) mres.distinct(5) run_def)
    apply (smt M.collapse M.sel mres.distinct(5) mres.inject(3) run_def)
    apply (smt M.collapse M.sel mres.distinct(5) mres.inject(3) run_def)
    apply (smt M.collapse M.sel mres.distinct(5) mres.inject(3) run_def)
    done

  lemma mblock_mono[partial_function_mono]:
    assumes "M_mono (λfa. m fa)"
    shows "M_mono (λfa. mblock begin end (m fa))"
    apply (rule monotoneI)
    using monotoneD[OF assms]
    unfolding mblock_def
    unfolding flat_ord_def fun_ord_def img_ord_def
    by simp metis

  lemma mfail_mono[partial_function_mono]:
    assumes "M_mono (λfa. m fa)"
    shows "M_mono (λfa. mfail f (m fa))"
    apply (rule monotoneI)
    using monotoneD[OF assms]
    unfolding mfail_def
    unfolding flat_ord_def fun_ord_def img_ord_def
    by simp metis


  (*
    TODO: Make this proof generic, in partial_function_definitions or so.
  *)
  lemma REC_mono_aux:
    assumes MONO: "D. monotone M.le_fun M.le_fun (B D)"
    assumes 1: "monotone M.le_fun (fun_ord M.le_fun) B"
    shows "monotone M.le_fun M.le_fun (λD. REC (B D))"
    unfolding REC_def
    apply (rule monotoneI)
    apply (rule ccpo.fixp_lowerbound[OF M.ccpo MONO])
    apply (subst (2) ccpo.fixp_unfold[OF M.ccpo MONO])
    supply R=fun_ordD[of M.le_fun "B x" "B y" for x y]
    apply (rule R)
    apply (rule monotoneD[OF 1])
    .

  lemma REC_mono[partial_function_mono]:
    assumes MONO: "D x. M.mono_body (λE. B D E x)"
    assumes 1: "E x. M_mono (λD. B D E x)"
    shows "M_mono (λD. REC (B D) x)"
    using assms REC_mono_aux fun_ord_mono_alt by metis




section ‹Reasoning Setup›

  subsection ‹Simplifier Based›
  named_theorems run_simps

  definition "mwp m N F E S  case_mres N F E S m"

  lemma mwp_simps[simp]:
    "mwp NTERM N F E S = N"
    "mwp (FAIL msg) N F E S = F msg"
    "mwp (EXC e c s) N F E S = E e c s"
    "mwp (SUCC x c s) N F E S = S x c s"
    by (auto simp: mwp_def)

  lemma mwp_cong[cong]: "m=m'  mwp m N F E S = mwp m' N F E S" by simp

  lemma mwp_eq_cases:
    assumes "mwp m N F E S = r"
    assumes "m = NTERM  r = N  thesis"
    assumes "e. m = FAIL e  r = F e  thesis"
    assumes "e c s. m = EXC e c s  r = E e c s  thesis"
    assumes "v c s. m = SUCC v c s  r = S v c s  thesis"
    shows thesis
    using assms unfolding mwp_def by (auto split: mres.splits)

  lemma mwp_invert[simp]:
    "mwp (mwp m N F E S) N' F' E' S' =
      (mwp m
        (mwp N N' F' E' S')
        (λx. mwp (F x) N' F' E' S')
        (λe c s. mwp (E e c s) N' F' E' S')
        (λx c s. mwp (S x c s) N' F' E' S')
      )"
    by (auto simp: mwp_def split: mres.splits)

  lemma mwp_eqI[intro!]:
    assumes "m=NTERM  N=N'"
    assumes "f. m=FAIL f  F f = F' f"
    assumes "e c s. m=EXC e c s  E e c s = E' e c s"
    assumes "x c s. m=SUCC x c s  S x c s = S' x c s"
    shows "mwp m N F E S = mwp m N' F' E' S'"
    using assms by (cases m) auto


  lemma mwp_cons:
    assumes "mwp r N' F' E' S'"
    assumes "N'N"
    assumes "msg. F' msg  F msg"
    assumes "e c s. E' e c s  E e c s"
    assumes "x c s. S' x c s  S x c s"
    shows "mwp r N F E S"
    using assms by (auto simp: mwp_def split: mres.split)

  lemma mwp_map_mres_state[simp]: "mwp (map_mres_state f s) N F E S = mwp s N F (λe c s. E e c (f s)) (λr c s. S r c (f s))"
    by (cases s) auto

  lemma mwp_triv[simp]: 
    "mwp m top top top top"
    "mwp m True (λ_. True) (λ_ _ _. True) (λ_ _ _. True)"
    by (cases m; auto; fail)+
  
  lemma mwp_trivI: "N; f. F f; e c s. E e c s; x c s. S x c s   mwp m N F E S"
    by (cases m; auto)
    
    
    
  lemma flip_run_eq[simp]:
    "SUCC r c s' = run m s  run m s = SUCC r c s'"
    "EXC e c s' = run m s  run m s = EXC e c s'"
    by auto

  lemma flip_mwp_eq[simp]:
    "SUCC r c s' = mwp m N F E S  mwp m N F E S = SUCC r c s'"
    "EXC e c s' = mwp m N F E S  mwp m N F E S = EXC e c s'"
    by auto

  lemma basic_run_simps[run_simps]:
    "s. run (return x) s = SUCC x 0 s"
    "s. run (fail msg) s = FAIL msg"
    "s. run (internal_nterm) s = NTERM"
    "s. run (raise e) s = EXC e 0 s"
    "s. run (get) s = SUCC s 0 s"
    "s. run (set s') s = SUCC () 0 s'"
    by (auto simp: return_def fail_def raise_def get_def set_def internal_nterm_def)

  lemma run_Let[run_simps]: "run (let x=v in f x) s = run (f v) s" by auto

  lemma run_bind[run_simps]: "run (bind m f) s
    = (mwp (run m s) NTERM (λx. FAIL x) (λe c s. EXC e c s) (λx c s.  addcost c (run (f x) s)))"
    unfolding bind_def mwp_def by simp

  lemma run_handle[run_simps]: "run (handle m h) s
    = (mwp (run m s) NTERM (λmsg. FAIL msg) (λe c s.  addcost c (run (h e) s)) (λx c s. SUCC x c s))"
    unfolding handle_def mwp_def by simp

  lemma run_mblock[run_simps]: "run (mblock b e m) s = map_mres_state e (run m (b s))"
    unfolding mblock_def by simp

  lemma run_mfail[run_simps]: "run (mfail f m) s = map_mres_fail f (run m s)"
    unfolding mfail_def by simp

  lemma run_map_state[run_simps]: "run (map_state f) s = SUCC () 0 (f s)"
    unfolding map_state_def
    by (simp add: run_simps)

  lemma lrmwpe_REC_partial:
    assumes "f  REC F"
        and "run (f x) s = r"
        and "x. M.mono_body (λfa. F fa x)"
        and "x s. P x s NTERM"
        and "f x s r. x' s' r'. run (f x') s' = r'  P x' s' r'; run (F f x) s = r   P x s r"
    shows "P x s r"
  proof -
    note A = assms
    show ?thesis
      apply (rule REC_partial_rule[OF A(1,3), where PQ=P, of x s, unfolded A(2)])
      apply fact
      by (rule A(5)) auto
  qed

  lemma lrmwpe_REC_total:
    assumes "f  REC F"
        and "run (f x) s = r"
        and "x. M.mono_body (λfa. F fa x)"
        and "wf R"
        and "f x s r. x' s' r'. run (f x') s' = r'; ((x',s'), (x,s))R  P x' s' r'; run (F f x) s = r   P x s r"
    shows "P x s r"
  proof -
    note A = assms
    show ?thesis
      apply (rule REC_total_rule[OF A(1,3,4), where PQ=P, of x s, unfolded A(2)])
      by (rule A(5)) auto
  qed


  lemma mwp_inductI:
    assumes "r. run m s = r  mwp r N F E S"
    shows "mwp (run m s) N F E S"
    using assms by auto

  subsection ‹Simulation›


  definition "sim m m'  s. mwp (run m s) top top (λe c s'. run m' s = EXC e c s') (λx c s'. run m' s = SUCC x c s')"

  named_theorems sim_rules


  lemma sim_refl[intro!,simp]: "sim m m"
    unfolding sim_def mwp_def by (auto split: mres.split)

  lemma sim_fail[sim_rules]: "sim (fail msg) m'"
    unfolding sim_def fail_def by auto

  lemma sim_internal_nterm[sim_rules]: "sim (internal_nterm) m'"
    unfolding sim_def internal_nterm_def by auto

  lemma sim_return[sim_rules]: "x=x'  sim (return x) (return x')"
    by (auto simp: sim_def run_simps)

  lemma sim_get[sim_rules]: "sim get get"
    by (auto simp: sim_def run_simps)

  lemma sim_set[sim_rules]: "s=s'  sim (set s) (set s')"
    by (auto simp: sim_def run_simps)


  lemma sim_REC:
    assumes DEF: "f  REC F"
    assumes DEF': "f'  REC F'"
    assumes MONO: "x. M.mono_body (λf. F f x)" "x. M.mono_body (λf. F' f x)"
    assumes SIM: "f f' x. (x. sim (f x) (f' x))  sim (F f x) (F' f' x)"
    shows "sim (f x) (f' x)"
    unfolding sim_def apply clarify
  proof (rule mwp_inductI)
    fix s r
    assume "run (f x) s = r"
    then show "mwp r top top (λe c s'. run (f' x) s = EXC e c s') (λxa c s'. run (f' x) s = SUCC xa c s')"
    proof (induction rule: lrmwpe_REC_partial[OF DEF _ MONO(1), consumes 1, case_names nterm step])
      case (nterm x s)
      then show ?case by simp
    next
      case (step f x s r)
      then show ?case
        apply (clarsimp)
        apply (subst REC_unfold[OF DEF' MONO(2)])
        apply (subst (2) REC_unfold[OF DEF' MONO(2)])
        using SIM[of f f' x]
        apply (auto simp: sim_def)
        done
    qed
  qed

  lemma addcost_EXC_D: "addcost c1 m = EXC e c s  (c2. m = EXC e c2 s  c=c1+c2)"
    apply(cases m) by auto

  lemma addcost_SUCC_D: "addcost c1 m = SUCC a c s  (c2. m = SUCC a c2 s  c=c1+c2)"
    apply(cases m) by auto
  lemma addcost_SUCC_sym_D: "SUCC a c s = addcost c1 m  (c2. m = SUCC a c2 s  c=c1+c2)"
    apply(cases m) by auto

  lemma sim_bind[sim_rules]:
    assumes "sim m m'" assumes "x. sim (f x) (f' x)"
    shows "sim (bind m f) (bind m' f')"
    using assms
    unfolding sim_def
    by (fastforce simp: run_simps mwp_def split: mres.splits
                      dest!: addcost_EXC_D addcost_SUCC_D)

  lemma sim_handle[sim_rules]:
    assumes "sim m m'" assumes "x. sim (h x) (h' x)"
    shows "sim (handle m h) (handle m' h')"
    using assms
    unfolding sim_def
    by (fastforce simp: run_simps mwp_def split: mres.splits
                      dest!: addcost_EXC_D addcost_SUCC_D)

  lemma sim_mblock[sim_rules]:
    "sim m m'  sim (mblock begin end m) (mblock begin end m')"
    unfolding sim_def
    by (auto simp: run_simps mwp_def split: mres.splits)

  lemma sim_mfail[sim_rules]:
    "sim m m'  sim (mfail fm m) (mfail fm m')"
    unfolding sim_def
    by (auto simp: run_simps mwp_def split: mres.splits)

section ‹Integration of Lenses›

subsection ‹Monadic mblock›
definition "mmblock begin end m  doM {
  s'  begin;
  s  get;
  (x,s')  handle (
    mblock (λ_. s') (λ_. s) (
      handle
        (doM { xm; s'get; return (x,s') })
        (λe. doM { s'get; raise (e,s') })
    )
  ) (λ(e,s'). doM { end s'; raise e });

  end s';
  return x
}"

lemma run_mmblock[run_simps]:
  "run (mmblock begin end m) s = mwp (run begin s) NTERM FAIL EXC
    (λs' c1 s. mwp (run m s')
      NTERM
      FAIL
      (λe c2 s'. mwp (run (end s') s) NTERM FAIL (λe c3 s. EXC e (c1+c2+c3) s)
                                                 (λ_ c3 s. EXC e (c1+c2+c3) s))
      (λx c2 s'. mwp (run (end s') s) NTERM FAIL (λe c3 s. EXC e (c1+c2+c3) s)
                                                 (λ_ c3 s. SUCC x (c1+c2+c3) s))
    )"
  by (auto simp add: add.assoc mmblock_def run_simps mwp_def cong del: mwp_cong split: prod.splits mres.splits)

lemma mmblock_mono[partial_function_mono]:
  "monotone M.le_fun M_ord m  monotone M.le_fun M_ord (λf. mmblock begin end (m f))"
  unfolding mmblock_def
  by pf_mono_prover



subsection ‹Lifting from Sum-Type›
definition "lift_sum m  case m of Inl f  fail f | Inr x  return x"

lemma lift_sum_simps[simp]:
  "lift_sum (Inl f) = fail f"
  "lift_sum (Inr x) = return x"
  by (auto simp: lift_sum_def)

lemma run_lift_sum[run_simps]:
  "run (lift_sum m) s = (case m of Inl f  FAIL f | Inr x  SUCC x 0 s)"
  by (auto simp: lift_sum_def run_simps split: sum.splits)

subsection ‹Lifting Lenses›

definition "mget L s  lift_sum (eget L s)"
definition "mput L x s  lift_sum (eput L x s)"

definition "use L  doM { sget; mget L s }"
definition assign (infix "::=" 51) where "assign L x  doM { sget; smput L x s; set s }"

(*
definition "eget_cases L a f1 f2 ≡ case eget L a of Inr b ⇒ f1 b | Inl e ⇒ f2 e"

lemma eget_cases_split:
  "P (eget_cases L a f1 f2) ⟷ (epre_get L a ⟶ P (f1 (eget' L a))) ∧ (∀e. eget L a = Inl e ⟶ P (f2 e))"
  unfolding eget_cases_def by (auto split: sum.split)

lemma eget_cases_split_asm:
  "P (eget_cases L a f1 f2) ⟷ ¬ ((epre_get L a ∧ ¬P (f1 (eget' L a))) ∨ (∃e. eget L a = Inl e ∧ ¬ P (f2 e)))"
  apply (subst eget_cases_split[of P]) by blast

definition "eput_cases L b a f1 f2 ≡ case eput L b a of Inr a ⇒ f1 a | Inl e ⇒ f2 e"

lemma eput_cases_split:
  "P (eput_cases L b a f1 f2) ⟷ (epre_put_single_point L b a ⟶ P (f1 (eput' L b a))) ∧ (∀e. eput L b a = Inl e ⟶ P (f2 e))"
  unfolding eput_cases_def
  by (auto split: sum.split simp: eput_Inr_conv_sp)

lemma eput_cases_split_asm:
  "P (eput_cases L b a f1 f2) ⟷ ¬((epre_put_single_point L b a ∧ ¬ P (f1 (eput' L b a))) ∨ (∃e. eput L b a = Inl e ∧ ¬ P (f2 e)))"
  apply (subst eput_cases_split[of P]) by blast

lemmas epg_splits = eget_cases_split eget_cases_split_asm eput_cases_split eput_cases_split_asm
*)

abbreviation (input) "eget_cases L s f1 f2  case epre_get L s of None  f1 (eget' L s) | Some e  f2 e"
abbreviation (input) "eput_cases L x s f1 f2  case epre_put L x s of None  f1 (eput' L x s) | Some e  f2 e"

lemma run_mget[run_simps]:
  "run (mget L s) xx = (eget_cases L s (λx. SUCC x 0 xx) (FAIL))"
  by (auto simp: mget_def run_simps split: sum.splits option.splits)

lemma run_mput[run_simps]:
  "elens L  run (mput L x s) xx = (eput_cases L x s (λx. SUCC x 0 xx) FAIL)"
  by (auto simp: mput_def run_simps split: sum.splits option.splits)

lemma run_use[run_simps]:
  "elens L  run (use L) s = (eget_cases L s (λx. SUCC x 0 s) FAIL)"
  by (auto simp: use_def run_simps)

lemma run_assign[run_simps]:
  "elens L  run (assign L x) s = eput_cases L x s (SUCC () 0) FAIL"
  by (auto simp: assign_def run_simps split: option.splits)



definition "zoom L m  mmblock (use L) (assign L) m"

lemma run_zoom[run_simps]:
  assumes [simp]: "elens L"
  shows
  "run (zoom L m) s = (
    eget_cases L s
      (λss. mwp (run m ss) NTERM FAIL (λe c ss. EXC e c (eput' L ss s))
                                      (λx c ss. SUCC x c (eput' L ss s)))
      FAIL
    )"
  by (auto simp: zoom_def run_simps split: option.splits)


lemma zoom_mono[partial_function_mono]:
  "monotone M.le_fun M_ord m  monotone M.le_fun M_ord (λf. zoom L (m f))"
  unfolding zoom_def
  by pf_mono_prover

lemma zoom_get_is_use[simp]: "elens L  zoom L get = use L"
  apply (rule)
  apply (auto simp: run_simps split: option.split)
  done

lemma zoom_set_is_assign[simp]: "ehlens L  zoom L (set x) = (L ::= x)"
  apply (rule)
  apply (auto simp: run_simps split: option.split)
  done

lemma zoom_comp_eq[simp]: "elens L1; elens L2  zoom (L1  L2) f = zoom L1 (zoom L2 f)"
  apply rule
  apply (auto simp: run_simps split: option.split)
  done

(* TODO: Move 
   TODO/FIXME: Simplifier should derive this on its own! *)  
lemma eget_put_pre: "elens L  epre_put L x s = None  epre_get L (eput' L x s) = None"
  by (metis (mono_tags, lifting) LENS_downstage(1) epre_get_def lens.simp_rls(4) lower_epre_put' lower_get_def lower_invert(1) lower_lens_def not_None_eq pre_eq_conv(2))
  
lemma zoom_return: "elens L  zoom L (return x) = use Lreturn x"
  apply (rule M_eqI)
  apply (auto simp: run_simps eget_put_pre split: option.split)
  done

lemma addcost_mwp: "(addcost c N = N)
   (x. addcost c (F x) = F x)
   (e c' s. addcost c (E e c' s) = E e (c + c') s)
   (a c' s. addcost c (S a c' s) = S a (c + c') s)
   addcost c (mwp m N F E S) = mwp (addcost c m) N F E S"
  apply(cases m) by auto

lemma zoom_bind: "elens L  zoom L (mf) = zoom L m  zoom L o f"
  apply (rule M_eqI)
  apply (auto simp: addcost_mwp run_simps eget_put_pre split: option.split)
  done
  
  
  

definition "ap_state f  doM {sget; set (f s)}"
definition ap_lens (infix "%=" 51) where "ap_lens L f  zoom L (ap_state f)"

lemma run_ap_state[run_simps]: "run (ap_state f) s = SUCC () 0 (f s)"
  by (auto simp: ap_state_def run_simps)

lemma run_ap_lens[run_simps]: "elens L
   run (L%=f) s = (eget_cases L s (λss. SUCC () 0 (eput' L (f ss) s)) FAIL)"
  by (auto simp: ap_lens_def run_simps split: option.splits)


definition "map_lens L f s  doM {
  x  mget L s;
  x  f x;
  mput L x s
}"

thm run_simps

lemma run_map_lens[run_simps]:
  "elens L  run (map_lens L f a) s = (
    eget_cases L a (λb.
      mwp (run (f b) s) NTERM FAIL EXC (λb s. SUCC (eput' L b a) s))
    ) FAIL"
  by (auto simp add: map_lens_def run_simps split: option.splits)


(* For presentation in paper *)
  
definition "noexc m  s. run m s  NTERM  ¬is_EXC (run m s)"

lemma "elens L  use L = zoom L get" by simp
lemma "ehlens L  (L ::= x) = (zoom L (set x))" by simp

  
  
  
  
section ‹Derived Constructs›

subsection ‹While›

  definition "mwhile b f  REC (λmwhile σ. doM { ctd  b σ; if ctd then doM {σf σ; mwhile σ } else return σ })"
  abbreviation "mwhile' b f  mwhile (λ_::unit. b) (λ_. f) ()"

  lemma sim_mwhile[sim_rules]:
    "σ. sim (b σ) (b' σ); σ. sim (f σ) (f' σ)  sim (mwhile b f σ) (mwhile b' f' σ)"
    by (auto intro!: sim_rules sim_REC[OF mwhile_def mwhile_def, discharge_monos])

  lemma mwhile_mono[partial_function_mono]:
    assumes "x. M_mono (λf. b f x)"
    assumes "x. M_mono (λf. c f x)"
    shows "M_mono (λD. mwhile (b D) (c D) σ)"
    supply assms[partial_function_mono]
    unfolding mwhile_def
    by pf_mono_prover

  lemmas mwhile_unfold[code] = REC_unfold[OF mwhile_def, discharge_monos]



subsection ‹Check›
  definition "fcheck e φ  if φ then return () else fail e"

  lemma fcheck_laws[simp]:
    "fcheck e True = return ()"
    "fcheck e False = fail e"
    by (auto simp: fcheck_def)

  lemma run_fcheck[run_simps]: "run (fcheck f Φ) s = (if Φ then SUCC () 0 s else FAIL f)"
    by (auto simp: fcheck_def run_simps)




subsection ‹Fold›
  fun mfold where
    "mfold f [] s = return s"
  | "mfold f (x#xs) s = doM {
      s  f x s;
      mfold f xs s
  }"

  abbreviation "mfold' f xs  mfold (λx _. f x) xs ()"

  lemma mfold_sim[sim_rules]:
    assumes [sim_rules]: "x s. sim (f x s) (f' x s)"
    shows "sim (mfold f xs s) (mfold f' xs s)"
    apply (induction xs arbitrary: s)
    apply (auto intro!: sim_rules)
    done

  lemma mfold_mono[partial_function_mono]:
    assumes [partial_function_mono]: "a σ. M_mono (λfa. f fa a σ)"
    shows "M_mono (λD. mfold (f D) l σ)"
  proof (induction l arbitrary: σ)
    case Nil
    then show ?case by simp pf_mono_prover
  next
    case [partial_function_mono]: (Cons a l)
    show ?case
      by simp pf_mono_prover
  qed


subsection ‹Map›

fun mmap where
  "mmap _ [] = return []"
| "mmap f (x#xs) = doM { xf x; xsmmap f xs; return (x#xs) }"

lemma mmap_sim[sim_rules]:
  assumes "x. xlist.set xs  sim (f x) (f' x)"
  shows "sim (mmap f xs) (mmap f' xs)"
  using assms
  apply (induction xs)
  by (auto intro!: sim_rules)


lemma mmap_mono[partial_function_mono]:
  assumes [partial_function_mono]: "a. M_mono (λfa. f fa a)"
  shows "M_mono (λD. mmap (f D) xs)"
proof (induction xs)
  case Nil
  then show ?case by simp pf_mono_prover
next
  case [partial_function_mono]: (Cons a xs)
  show ?case by simp pf_mono_prover
qed


lemma run_mmap_unit_state_idxD:
  assumes "run (mmap f xs) () = SUCC ys c ()"
  assumes "i<length xs"
  shows "c. run (f (xs!i)) () = SUCC (ys!i) c ()"
  using assms apply (induction xs arbitrary: i ys c)
  by (auto simp: run_simps nth_Cons split: nat.splits
            elim!: mwp_eq_cases dest!: addcost_SUCC_sym_D)


lemma run_mmap_length_eq:
  assumes "run (mmap f xs) s = SUCC ys c s'"
  shows "length ys = length xs"
  using assms apply (induction xs arbitrary: ys c s)
  by (auto simp: run_simps elim!: mwp_eq_cases dest!: addcost_SUCC_sym_D)


lemma run_mmap_unit_state_elemD:
  assumes "run (mmap f xs) () = SUCC ys c ()"
  assumes "xList.set xs"
  shows "yList.set ys. c. run (f x) () = SUCC y c ()"
  using assms
  by (auto simp: in_set_conv_nth Bex_def run_mmap_unit_state_idxD run_mmap_length_eq)

lemma run_mmap_append[run_simps]:
  "run (mmap f (xs@ys)) s = mwp (run (mmap f xs) s) NTERM FAIL EXC (λrxs cxs s.
  mwp (run (mmap f ys) s) NTERM FAIL (λa cys s. EXC a (cxs+cys) s) (λrys cys s. SUCC (rxs@rys) (cxs+cys) s))"
  apply (induction xs arbitrary: s)
  apply (auto simp: run_simps mwp_def add.assoc split: mres.splits)
  done


(* TODO: What are good rules for mmap ? *)



subsection ‹Lookup›

  definition "lookup f m s  case m s of None  fail f | Some x  return x"

  lemma run_lookup[run_simps]:
    "run (lookup f m k) s = (case m k of None  FAIL f | Some v  SUCC v 0 s)"
    by (auto simp: lookup_def run_simps split: option.splits)

  lemma lookup_sim[sim_rules]:
    assumes "π m π'"
    shows "sim (lookup f π x2) (lookup f' π' x2)"
    using map_leD[OF assms]
    by (auto simp: sim_def run_simps split: option.split)


subsection ‹Hiding too generic Names›    
    
hide_const (open) get set M.M
    
end