Theory Sepref_Monadify

section Monadify
theory Sepref_Monadify
imports Sepref_Basic Sepref_Id_Op
begin


text 
  In this phase, a monadic program is converted to complete monadic form,
  that is, computation of compound expressions are made visible as top-level 
  operations in the monad.

  The monadify process is separated into 2 steps.
  \begin{enumerate}
    \item In a first step, eta-expansion is used to add missing operands 
      to operations and combinators. This way, operators and combinators
      always occur with the same arity, which simplifies further processing.

    \item In a second step, computation of compound operands is flattened,
      introducing new bindings for the intermediate values. 
  \end{enumerate}



definition SP ― ‹Tag to protect content from further application of arity 
  and combinator equations
  where [simp]: "SP x  x"
lemma SP_cong[cong]: "SP x  SP x" by simp
lemma PR_CONST_cong[cong]: "PR_CONST x  PR_CONST x" by simp

definition RCALL ― ‹Tag that marks recursive call
  where [simp]: "RCALL D  D"
definition EVAL ― ‹Tag that marks evaluation of plain expression for monadify phase
  where [simp]: "EVAL x  RETURN x"

text 
  Internally, the package first applies rewriting rules from 
  @{text sepref_monadify_arity}, which use eta-expansion to ensure that
  every combinator has enough actual parameters. Moreover, this phase will
  mark recursive calls by the tag @{const RCALL}.

  Next, rewriting rules from @{text sepref_monadify_comb} are used to
  add @{const EVAL}-tags to plain expressions that should be evaluated
  in the monad. The @{const EVAL} tags are flattened using a default simproc 
  that generates left-to-right argument order.


lemma monadify_simps: 
  "Refine_Basic.bind$(RETURN$x)$(λ2x. f x) = f x" 
  "EVAL$x  RETURN$x"
  by simp_all

definition [simp]: "PASS  RETURN"
  ― ‹Pass on value, invalidating old one

lemma remove_pass_simps:
  "Refine_Basic.bind$(PASS$x)$(λ2x. f x)  f x" 
  "Refine_Basic.bind$m$(λ2x. PASS$x)  m"
  by simp_all


definition COPY :: "'a  'a" 
  ― ‹Marks required copying of parameter
  where [simp]: "COPY x  x"
lemma RET_COPY_PASS_eq: "RETURN$(COPY$p) = PASS$p" by simp


named_theorems_rev sepref_monadify_arity "Sepref.Monadify: Arity alignment equations"
named_theorems_rev sepref_monadify_comb "Sepref.Monadify: Combinator equations"

ML 
  structure Sepref_Monadify = struct
    local
      fun cr_var (i,T) = ("v"^string_of_int i, Free ("__v"^string_of_int i,T))

      fun lambda2_name n t = let
        val t = @{mk_term "PROTECT2 ?t DUMMY"}
      in
        Term.lambda_name n t
      end


      fun 
        bind_args exp0 [] = exp0
      | bind_args exp0 ((x,m)::xms) = let
          val lr = bind_args exp0 xms 
            |> incr_boundvars 1 
            |> lambda2_name x
        in @{mk_term "Refine_Basic.bind$?m$?lr"} end

      fun monadify t = let
        val (f,args) = Autoref_Tagging.strip_app t
        val _ = not (is_Abs f) orelse 
          raise TERM ("monadify: higher-order",[t])

        val argTs = map fastype_of args
        (*val args = map monadify args*)
        val args = map (fn a => @{mk_term "EVAL$?a"}) args

        (*val fT = fastype_of f
        val argTs = binder_types fT*)
  
        val argVs = tag_list 0 argTs
          |> map cr_var

        val res0 = let
          val x = Autoref_Tagging.list_APP (f,map #2 argVs)
        in 
          @{mk_term "SP (RETURN$?x)"}
        end

        val res = bind_args res0 (argVs ~~ args)
      in
        res
      end

      fun monadify_conv_aux ctxt ct = case Thm.term_of ct of
        @{mpat "EVAL$_"} => let
          val ctxt = put_simpset HOL_basic_ss ctxt
          val ctxt = (ctxt addsimps @{thms monadify_simps SP_def})
          fun tac ctxt = (simp_tac ctxt 1)
        in (*Refine_Util.monitor_conv "monadify"*) (
          Refine_Util.f_tac_conv ctxt (dest_comb #> #2 #> monadify) tac) ct
        end
      | t => raise TERM ("monadify_conv",[t])

      (*fun extract_comb_conv ctxt = Conv.rewrs_conv 
        (Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_monadify_evalcomb})
      *)  
    in
      (*
      val monadify_conv = Conv.top_conv 
        (fn ctxt => 
          Conv.try_conv (
            extract_comb_conv ctxt else_conv monadify_conv_aux ctxt
          )
        )
      *)  

      val monadify_simproc = 
        Simplifier.make_simproc @{context} "monadify_simproc"
         {lhss =
          [Logic.varify_global @{term "EVAL$a"}],
          proc = K (try o monadify_conv_aux)};

    end

    local
      open Sepref_Basic
      fun mark_params t = let
        val (P,c,Q,R,CP,a) = dest_hn_refine t
        val pps = strip_star P |> map_filter (dest_hn_ctxt_opt #> map_option #2)

        fun tr env (t as @{mpat "RETURN$?x"}) = 
              if is_Bound x orelse member (op aconv) pps x then
                @{mk_term env: "PASS$?x"}
              else t
          | tr env (t1$t2) = tr env t1 $ tr env t2
          | tr env (Abs (x,T,t)) = Abs (x,T,tr (T::env) t)
          | tr _ t = t

        val a = tr [] a
      in
        mk_hn_refine (P,c,Q,R,CP,a)
      end

    in  
    fun mark_params_conv ctxt = Refine_Util.f_tac_conv ctxt 
      (mark_params) 
      (fn ctxt => simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms PASS_def}) 1)

    end  

    local

      open Sepref_Basic

      fun dp ctxt (@{mpat "Refine_Basic.bind$(PASS$?p)$(?t' ASp (λ_. PROTECT2 _ DUMMY))"}) = 
          let
            val (t',ps) = let
                val ((t',rc),ctxt) = dest_lambda_rc ctxt t'
                val f = case t' of @{mpat "PROTECT2 ?f _"} => f | _ => raise Match 
                val (f,ps) = dp ctxt f
                val t' = @{mk_term "PROTECT2 ?f DUMMY"}
                val t' = rc t'
              in
                (t',ps)
              end
  
            val dup = member (op aconv) ps p
            val t = if dup then
              @{mk_term "Refine_Basic.bind$(RETURN$(COPY$?p))$?t'"}
            else
              @{mk_term "Refine_Basic.bind$(PASS$?p)$?t'"}
          in
            (t,p::ps)
          end
        | dp ctxt (t1$t2) = (#1 (dp ctxt t1) $ #1 (dp ctxt t2),[])
        | dp ctxt (t as (Abs _)) = (apply_under_lambda (#1 oo dp) ctxt t,[])
        | dp _ t = (t,[])

      fun dp_conv ctxt = Refine_Util.f_tac_conv ctxt 
        (#1 o dp ctxt) 
        (fn ctxt => ALLGOALS (simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms RET_COPY_PASS_eq}))) 


    in
      fun dup_tac ctxt = CONVERSION (Sepref_Basic.hn_refine_concl_conv_a dp_conv ctxt)
    end


    fun arity_tac ctxt = let
      val arity1_ss = put_simpset HOL_basic_ss ctxt 
        addsimps ((Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_monadify_arity}))
        |> Simplifier.add_cong @{thm SP_cong}
        |> Simplifier.add_cong @{thm PR_CONST_cong}

      val arity2_ss = put_simpset HOL_basic_ss ctxt 
        addsimps @{thms beta SP_def}
    in
      simp_tac arity1_ss THEN' simp_tac arity2_ss
    end

    fun comb_tac ctxt = let
      val comb1_ss = put_simpset HOL_basic_ss ctxt 
        addsimps (Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_monadify_comb})
        (*addsimps (Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_monadify_evalcomb})*)
        addsimprocs [monadify_simproc]
        |> Simplifier.add_cong @{thm SP_cong}
        |> Simplifier.add_cong @{thm PR_CONST_cong}

      val comb2_ss = put_simpset HOL_basic_ss ctxt 
        addsimps @{thms SP_def}
    in
      simp_tac comb1_ss THEN' simp_tac comb2_ss
    end

    (*fun ops_tac ctxt = CONVERSION (
      Sepref_Basic.hn_refine_concl_conv_a monadify_conv ctxt)*)

    fun mark_params_tac ctxt = CONVERSION (
      Refine_Util.HOL_concl_conv (mark_params_conv) ctxt)

    fun contains_eval @{mpat "Trueprop (hn_refine _ _ _ _ _ ?a)"} =   
      Term.exists_subterm (fn @{mpat EVAL} => true | _ => false) a
    | contains_eval t = raise TERM("contains_eval",[t]);  

    fun remove_pass_tac ctxt = 
      simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms remove_pass_simps})

    fun monadify_tac dbg ctxt = let
      open Sepref_Basic
    in
      PHASES' [
        ("arity", arity_tac, 0),
        ("comb", comb_tac, 0),
        (*("ops", ops_tac, 0),*)
        ("check_EVAL", K (CONCL_COND' (not o contains_eval)), 0),
        ("mark_params", mark_params_tac, 0),
        ("dup", dup_tac, 0),
        ("remove_pass", remove_pass_tac, 0)
      ] (flag_phases_ctrl ctxt dbg) ctxt
    end

  end


lemma dflt_arity[sepref_monadify_arity]:
  "RETURN  λ2x. SP RETURN$x" 
  "RECT  λ2B x. SP RECT$(λ2D x. B$(λ2x. RCALL$D$x)$x)$x" 
  "case_list  λ2fn fc l. SP case_list$fn$(λ2x xs. fc$x$xs)$l" 
  "case_prod  λ2fp p. SP case_prod$(λ2a b. fp$a$b)$p" 
  "case_option  λ2fn fs ov. SP case_option$fn$(λ2x. fs$x)$ov" 
  "If  λ2b t e. SP If$b$t$e" 
  "Let  λ2x f. SP Let$x$(λ2x. f$x)"
  by (simp_all only: SP_def APP_def PROTECT2_def RCALL_def)


lemma dflt_comb[sepref_monadify_comb]:
  "B x. RECT$B$x  Refine_Basic.bind$(EVAL$x)$(λ2x. SP (RECT$B$x))"
  "D x. RCALL$D$x  Refine_Basic.bind$(EVAL$x)$(λ2x. SP (RCALL$D$x))"
  "fn fc l. case_list$fn$fc$l  Refine_Basic.bind$(EVAL$l)$(λ2l. (SP case_list$fn$fc$l))"
  "fp p. case_prod$fp$p  Refine_Basic.bind$(EVAL$p)$(λ2p. (SP case_prod$fp$p))"
  "fn fs ov. case_option$fn$fs$ov 
     Refine_Basic.bind$(EVAL$ov)$(λ2ov. (SP case_option$fn$fs$ov))"
  "b t e. If$b$t$e  Refine_Basic.bind$(EVAL$b)$(λ2b. (SP If$b$t$e))"
  "x. RETURN$x  Refine_Basic.bind$(EVAL$x)$(λ2x. SP (RETURN$x))"
  "x f. Let$x$f  Refine_Basic.bind$(EVAL$x)$(λ2x. (SP Let$x$f))"
  by (simp_all)


lemma dflt_plain_comb[sepref_monadify_comb]:
  "EVAL$(If$b$t$e)  Refine_Basic.bind$(EVAL$b)$(λ2b. If$b$(EVAL$t)$(EVAL$e))"
  "EVAL$(case_list$fn$(λ2x xs. fc x xs)$l)  
    Refine_Basic.bind$(EVAL$l)$(λ2l. case_list$(EVAL$fn)$(λ2x xs. EVAL$(fc x xs))$l)"
  "EVAL$(case_prod$(λ2a b. fp a b)$p)  
    Refine_Basic.bind$(EVAL$p)$(λ2p. case_prod$(λ2a b. EVAL$(fp a b))$p)"
  "EVAL$(case_option$fn$(λ2x. fs x)$ov)  
    Refine_Basic.bind$(EVAL$ov)$(λ2ov. case_option$(EVAL$fn)$(λ2x. EVAL$(fs x))$ov)"
  "EVAL $ (Let $ v $ (λ2x. f x))  (⤜) $ (EVAL $ v) $ (λ2x. EVAL $ (f x))"
  apply (rule eq_reflection, simp split: list.split prod.split option.split)+
  done

lemma evalcomb_PR_CONST[sepref_monadify_comb]:
  "EVAL$(PR_CONST x)  SP (RETURN$(PR_CONST x))"
  by simp


end