Theory Sepref_Monadify

theory Sepref_Monadify
imports Sepref_Basic
section ‹Monadify›
theory Sepref_Monadify
imports Sepref_Basic Sepref_Id_Op
begin


text {*
  In this phase, a monadic program is converted to complete monadic form,
  that is, computation of compound expressions are made visible as top-level 
  operations in the monad.

  The monadify process is separated into 2 steps.
  \begin{enumerate}
    \item In a first step, eta-expansion is used to add missing operands 
      to operations and combinators. This way, operators and combinators
      always occur with the same arity, which simplifies further processing.

    \item In a second step, computation of compound operands is flattened,
      introducing new bindings for the intermediate values. 
  \end{enumerate}
*}


definition SP ― ‹Tag to protect content from further application of arity 
  and combinator equations›
  where [simp]: "SP x ≡ x"
lemma SP_cong[cong]: "SP x ≡ SP x" by simp
lemma PR_CONST_cong[cong]: "PR_CONST x ≡ PR_CONST x" by simp

definition RCALL ― ‹Tag that marks recursive call›
  where [simp]: "RCALL D ≡ D"
definition EVAL ― ‹Tag that marks evaluation of plain expression for monadify phase›
  where [simp]: "EVAL x ≡ RETURNT x"

text {*
  Internally, the package first applies rewriting rules from 
  @{text sepref_monadify_arity}, which use eta-expansion to ensure that
  every combinator has enough actual parameters. Moreover, this phase will
  mark recursive calls by the tag @{const RCALL}.

  Next, rewriting rules from @{text sepref_monadify_comb} are used to
  add @{const EVAL}-tags to plain expressions that should be evaluated
  in the monad. The @{const EVAL} tags are flattened using a default simproc 
  that generates left-to-right argument order.
*}

lemma monadify_simps: 
  "bindT$(RETURNT$x)$(λ2x. f x) = f x" 
  "EVAL$x ≡ RETURNT$x"
  by simp_all

definition [simp]: "PASS ≡ RETURNT"
  ― ‹Pass on value, invalidating old one›

lemma remove_pass_simps:
  "bindT$(PASS$x)$(λ2x. f x) ≡ f x" 
  "bindT$m$(λ2x. PASS$x) ≡ m"
  by simp_all


definition COPY :: "'a ⇒ 'a" 
  ― ‹Marks required copying of parameter›
  where [simp]: "COPY x ≡ x"
lemma RET_COPY_PASS_eq: "RETURNT$(COPY$p) = PASS$p" by simp


named_theorems_rev sepref_monadify_arity "Sepref.Monadify: Arity alignment equations"
named_theorems_rev sepref_monadify_comb "Sepref.Monadify: Combinator equations"

ML {*
  structure Sepref_Monadify = struct
    local
      fun cr_var (i,T) = ("v"^string_of_int i, Free ("__v"^string_of_int i,T))

      fun lambda2_name n t = let
        val t = @{mk_term "PROTECT2 ?t DUMMY"}
      in
        Term.lambda_name n t
      end


      fun 
        bind_args exp0 [] = exp0
      | bind_args exp0 ((x,m)::xms) = let
          val lr = bind_args exp0 xms 
            |> incr_boundvars 1 
            |> lambda2_name x
        in @{mk_term "bindT$?m$?lr"} end

      fun monadify t = let
        val (f,args) = Autoref_Tagging.strip_app t
        val _ = not (is_Abs f) orelse 
          raise TERM ("monadify: higher-order",[t])

        val argTs = map fastype_of args
        (*val args = map monadify args*)
        val args = map (fn a => @{mk_term "EVAL$?a"}) args

        (*val fT = fastype_of f
        val argTs = binder_types fT*)
  
        val argVs = tag_list 0 argTs
          |> map cr_var

        val res0 = let
          val x = Autoref_Tagging.list_APP (f,map #2 argVs)
        in 
          @{mk_term "SP (RETURNT$?x)"}
        end

        val res = bind_args res0 (argVs ~~ args)
      in
        res
      end

      fun monadify_conv_aux ctxt ct = case Thm.term_of ct of
        @{mpat "EVAL$_"} => let
          val ss = put_simpset HOL_basic_ss ctxt
          val ss = (ss addsimps @{thms monadify_simps SP_def})
          val tac = (simp_tac ss 1)
        in (*Refine_Util.monitor_conv "monadify"*) (
          Refine_Util.f_tac_conv ctxt (dest_comb #> #2 #> monadify) tac) ct
        end
      | t => raise TERM ("monadify_conv",[t])

      (*fun extract_comb_conv ctxt = Conv.rewrs_conv 
        (Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_monadify_evalcomb})
      *)  
    in
      (*
      val monadify_conv = Conv.top_conv 
        (fn ctxt => 
          Conv.try_conv (
            extract_comb_conv ctxt else_conv monadify_conv_aux ctxt
          )
        )
      *)  

      val monadify_simproc = 
        Simplifier.make_simproc @{context} "monadify_simproc"
         {lhss =
          [Logic.varify_global @{term "EVAL$a"}],
          proc = K (try o monadify_conv_aux)};

    end

    local
      open Sepref_Basic
      fun mark_params t = let
        val (P,c,Q,R,a) = dest_hn_refine t
        val pps = strip_star P |> map_filter (dest_hn_ctxt_opt #> map_option #2)

        fun tr env (t as @{mpat "RETURNT$?x"}) = 
              if is_Bound x orelse member (aconv) pps x then
                @{mk_term env: "PASS$?x"}
              else t
          | tr env (t1$t2) = tr env t1 $ tr env t2
          | tr env (Abs (x,T,t)) = Abs (x,T,tr (T::env) t)
          | tr _ t = t

        val a = tr [] a
      in
        mk_hn_refine (P,c,Q,R,a)
      end

    in  
    fun mark_params_conv ctxt = Refine_Util.f_tac_conv ctxt 
      (mark_params) 
      (simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms PASS_def}) 1)

    end  

    local

      open Sepref_Basic

      fun dp ctxt (@{mpat "bindT$(PASS$?p)$(?t' ASp (λ_. PROTECT2 _ DUMMY))"}) = 
          let
            val (t',ps) = let
                val ((t',rc),ctxt) = dest_lambda_rc ctxt t'
                val f = case t' of @{mpat "PROTECT2 ?f _"} => f | _ => raise Match 
                val (f,ps) = dp ctxt f
                val t' = @{mk_term "PROTECT2 ?f DUMMY"}
                val t' = rc t'
              in
                (t',ps)
              end
  
            val dup = member (aconv) ps p
            val t = if dup then
              @{mk_term "bindT$(RETURNT$(COPY$?p))$?t'"}
            else
              @{mk_term "bindT$(PASS$?p)$?t'"}
          in
            (t,p::ps)
          end
        | dp ctxt (t1$t2) = (#1 (dp ctxt t1) $ #1 (dp ctxt t2),[])
        | dp ctxt (t as (Abs _)) = (apply_under_lambda (#1 oo dp) ctxt t,[])
        | dp _ t = (t,[])

      fun dp_conv ctxt = Refine_Util.f_tac_conv ctxt 
        (#1 o dp ctxt) 
        (ALLGOALS (simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms RET_COPY_PASS_eq}))) 


    in
      fun dup_tac ctxt = CONVERSION (Sepref_Basic.hn_refine_concl_conv_a dp_conv ctxt)
    end


    fun arity_tac ctxt = let
      val arity1_ss = put_simpset HOL_basic_ss ctxt 
        addsimps ((Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_monadify_arity}))
        |> Simplifier.add_cong @{thm SP_cong}
        |> Simplifier.add_cong @{thm PR_CONST_cong}

      val arity2_ss = put_simpset HOL_basic_ss ctxt 
        addsimps @{thms beta SP_def}
    in
      simp_tac arity1_ss THEN' simp_tac arity2_ss
    end

    fun comb_tac ctxt = let
      val comb1_ss = put_simpset HOL_basic_ss ctxt 
        addsimps (Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_monadify_comb})
        (*addsimps (Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_monadify_evalcomb})*)
        addsimprocs [monadify_simproc]
        |> Simplifier.add_cong @{thm SP_cong}
        |> Simplifier.add_cong @{thm PR_CONST_cong}

      val comb2_ss = put_simpset HOL_basic_ss ctxt 
        addsimps @{thms SP_def}
    in
      simp_tac comb1_ss THEN' simp_tac comb2_ss
    end

    (*fun ops_tac ctxt = CONVERSION (
      Sepref_Basic.hn_refine_concl_conv_a monadify_conv ctxt)*)

    fun mark_params_tac ctxt = CONVERSION (
      Refine_Util.HOL_concl_conv (K (mark_params_conv ctxt)) ctxt)

    fun contains_eval @{mpat "Trueprop (hn_refine _ _ _ _ ?a)"} =   
      Term.exists_subterm (fn @{mpat EVAL} => true | _ => false) a
    | contains_eval t = raise TERM("contains_eval",[t]);  

    fun remove_pass_tac ctxt = 
      simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms remove_pass_simps})

    fun monadify_tac dbg ctxt = let
      open Sepref_Basic
    in
      PHASES' [
        ("arity", arity_tac, 0),
        ("comb", comb_tac, 0),
        (*("ops", ops_tac, 0),*)
        ("check_EVAL", K (CONCL_COND' (not o contains_eval)), 0),
        ("mark_params", mark_params_tac, 0),
        ("dup", dup_tac, 0),
        ("remove_pass", remove_pass_tac, 0)
      ] (flag_phases_ctrl dbg) ctxt
    end

  end
*}

lemma dflt_arity[sepref_monadify_arity]:
  "RETURNT ≡ λ2x. SP RETURNT$x" 
  "RECT ≡ λ2B x. SP RECT$(λ2D x. B$(λ2x. RCALL$D$x)$x)$x" 
  "case_list ≡ λ2fn fc l. SP case_list$fn$(λ2x xs. fc$x$xs)$l" 
  "case_prod ≡ λ2fp p. SP case_prod$(λ2a b. fp$a$b)$p" 
  "case_option ≡ λ2fn fs ov. SP case_option$fn$(λ2x. fs$x)$ov" 
  "If ≡ λ2b t e. SP If$b$t$e" 
  "Let ≡ λ2x f. SP Let$x$(λ2x. f$x)"
  by (simp_all only: SP_def APP_def PROTECT2_def RCALL_def)


lemma dflt_comb[sepref_monadify_comb]:
  "⋀B x. RECT$B$x ≡ bindT$(EVAL$x)$(λ2x. SP (RECT$B$x))"
  "⋀D x. RCALL$D$x ≡ bindT$(EVAL$x)$(λ2x. SP (RCALL$D$x))"
  "⋀fn fc l. case_list$fn$fc$l ≡ bindT$(EVAL$l)$(λ2l. (SP case_list$fn$fc$l))"
  "⋀fp p. case_prod$fp$p ≡ bindT$(EVAL$p)$(λ2p. (SP case_prod$fp$p))"
  "⋀fn fs ov. case_option$fn$fs$ov 
    ≡ bindT$(EVAL$ov)$(λ2ov. (SP case_option$fn$fs$ov))"
  "⋀b t e. If$b$t$e ≡ bindT$(EVAL$b)$(λ2b. (SP If$b$t$e))"
  "⋀x. RETURNT$x ≡ bindT$(EVAL$x)$(λ2x. SP (RETURNT$x))"
  "⋀x f. Let$x$f ≡ bindT$(EVAL$x)$(λ2x. (SP Let$x$f))"
  by (simp_all)


lemma dflt_plain_comb[sepref_monadify_comb]:
  "EVAL$(If$b$t$e) ≡ bindT$(EVAL$b)$(λ2b. If$b$(EVAL$t)$(EVAL$e))"
  "EVAL$(case_list$fn$(λ2x xs. fc x xs)$l) ≡ 
    bindT$(EVAL$l)$(λ2l. case_list$(EVAL$fn)$(λ2x xs. EVAL$(fc x xs))$l)"
  "EVAL$(case_prod$(λ2a b. fp a b)$p) ≡ 
    bindT$(EVAL$p)$(λ2p. case_prod$(λ2a b. EVAL$(fp a b))$p)"
  "EVAL$(case_option$fn$(λ2x. fs x)$ov) ≡ 
    bindT$(EVAL$ov)$(λ2ov. case_option$(EVAL$fn)$(λ2x. EVAL$(fs x))$ov)"
  "EVAL $ (Let $ v $ (λ2x. f x)) ≡ (⤜) $ (EVAL $ v) $ (λ2x. EVAL $ (f x))"
  apply (rule eq_reflection, simp split: list.split prod.split option.split)+
  done

lemma evalcomb_PR_CONST[sepref_monadify_comb]:
  "EVAL$(PR_CONST x) ≡ SP (RETURNT$(PR_CONST x))"
  by simp


end