Theory Sepref_Monadify

theory Sepref_Monadify
imports Sepref_Basic
header ‹Monadify›
theory Sepref_Monadify
imports Sepref_Basic Id_Op
begin


text {*
  In this phase, a monadic program is converted to complete monadic form,
  that is, computation of compound expressions are made visible as top-level 
  operations in the monad.

  The monadify process is separated into 2 steps.
  \begin{enumerate}
    \item In a first step, eta-expansion is used to add missing operands 
      to operations and combinators. This way, operators and combinators
      always occur with the same arity, which simplifies further processing.

    \item In a second step, computation of compound operands is flattened,
      introducing new bindings for the intermediate values. 
  \end{enumerate}
*}


definition SP -- ‹Tag to protect content from further application of arity 
  and combinator equations›
  where [simp]: "SP x ≡ x"
lemma SP_cong[cong]: "SP x ≡ SP x" by simp

definition RCALL -- ‹Tag that marks recursive call›
  where [simp]: "RCALL D ≡ D"
definition EVAL -- ‹Tag that marks evaluation of plain expression for monadify phase›
  where [simp]: "EVAL x ≡ RETURN x"

text {*
  Internally, the package first applies rewriting rules from 
  @{text sepref_monadify_arity}, which use eta-expansion to ensure that
  every combinator has enough actual parameters. Moreover, this phase will
  mark recursive calls by the tag @{const RCALL}.

  Next, rewriting rules from @{text sepref_monadify_comb} are used to
  add @{const EVAL}-tags to plain expressions that should be evaluated
  in the monad. 

  Finally, the expressions inside the eval-tags are flattened. In this step,
  rewrite rules from @{text sepref_monadify_evalcomb} are applied, 
  in conjunction with a default rule that evaluates the arguments of each 
  function from left to right.
*}

lemma monadify_simps: 
  "bind$(RETURN$x)$(λ2x. f x) = f x" 
  "EVAL$x ≡ RETURN$x"
  by simp_all

definition [simp]: "PASS ≡ RETURN"
  -- "Pass on value, invalidating old one"

lemma remove_pass_simps:
  "bind$(PASS$x)$(λ2x. f x) ≡ f x" 
  "bind$m$(λ2x. PASS$x) ≡ m"
  by simp_all


ML {*
  structure Sepref_Monadify = struct
    structure arity_eqs = Named_Thms (
      val name = @{binding sepref_monadify_arity}
      val description = "Sepref.Monadify: Arity alignment equations"
    )

    structure comb_eqs = Named_Thms (
      val name = @{binding sepref_monadify_comb}
      val description = "Sepref.Monadify: Combinator equations"
    )

    structure eval_comb_eqs = Named_Thms (
      val name = @{binding sepref_monadify_evalcomb}
      val description = "Sepref.Monadify: Eval-Combinator equations"
    )

    local
      fun cr_var (i,T) = ("v"^string_of_int i, Free ("__v"^string_of_int i,T))

      fun lambda2_name n t = let
        val t = @{mk_term "PROTECT2 ?t DUMMY"}
      in
        Term.lambda_name n t
      end


      fun 
        bind_args exp0 [] = exp0
      | bind_args exp0 ((x,m)::xms) = let
          val lr = bind_args exp0 xms 
            |> incr_boundvars 1 
            |> lambda2_name x
        in @{mk_term "bind$?m$?lr"} end

      fun monadify t = let
        val (f,args) = Autoref_Tagging.strip_app t
        val _ = not (is_Abs f) orelse 
          raise TERM ("monadify: higher-order",[t])

        val argTs = map fastype_of args
        (*val args = map monadify args*)
        val args = map (fn a => @{mk_term "EVAL$?a"}) args

        (*val fT = fastype_of f
        val argTs = binder_types fT*)
  
        val argVs = tag_list 0 argTs
          |> map cr_var

        val res0 = let
          val x = Autoref_Tagging.list_APP (f,map #2 argVs)
        in 
          @{mk_term "RETURN$?x"}
        end

        val res = bind_args res0 (argVs ~~ args)
      in
        res
      end

      fun monadify_conv_aux ctxt ct = case term_of ct of
        @{mpat "EVAL$_"} => let
          val ss = ctxt
          val ss = (ss addsimps @{thms monadify_simps})
          val tac = (simp_tac ss 1)
        in (*Refine_Util.monitor_conv "monadify"*) (
          Refine_Util.f_tac_conv ctxt (dest_comb #> #2 #> monadify) tac) ct
        end
      | t => raise TERM ("monadify_conv",[t])

      fun extract_comb_conv ctxt = Conv.rewrs_conv (eval_comb_eqs.get ctxt)
    in

      val monadify_conv = Conv.top_conv 
        (fn ctxt => 
          Conv.try_conv (
            extract_comb_conv ctxt else_conv monadify_conv_aux ctxt
          )
        )

    end

    fun mark_params env @{mpat "RETURN$(?x ASs mpaq_Bound _)"} = 
          @{mk_term env: "PASS$?x"}
      | mark_params env (t1$t2) = mark_params env t1 $ mark_params env t2
      | mark_params env (Abs (x,T,t)) = Abs (x,T,mark_params (T::env) t)
      | mark_params _ t = t

    fun mark_params_conv ctxt = Refine_Util.f_tac_conv ctxt 
      (mark_params []) 
      (simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms PASS_def}) 1)

    fun monadify_tac ctxt = let
      val arity1_ss = put_simpset HOL_basic_ss ctxt 
        addsimps arity_eqs.get ctxt
        |> Simplifier.add_cong @{thm SP_cong}

      val arity2_ss = put_simpset HOL_basic_ss ctxt 
        addsimps @{thms beta SP_def}

      val arity_tac = simp_tac arity1_ss THEN' simp_tac arity2_ss

      val comb1_ss = put_simpset HOL_basic_ss ctxt 
        addsimps comb_eqs.get ctxt
        addsimps eval_comb_eqs.get ctxt
        |> Simplifier.add_cong @{thm SP_cong}

      val comb2_ss = put_simpset HOL_basic_ss ctxt 
        addsimps @{thms SP_def}

      val comb_tac = simp_tac comb1_ss THEN' simp_tac comb2_ss

      open Sepref_Basic
    in
      arity_tac 
      THEN' comb_tac
      THEN' CONVERSION (hn_refine_concl_conv_a monadify_conv ctxt)
      THEN' CONVERSION (hn_refine_concl_conv_a (K (mark_params_conv ctxt)) ctxt)
      THEN' simp_tac 
        (put_simpset HOL_basic_ss ctxt addsimps @{thms remove_pass_simps})
    end


    val setup = I 
      #> arity_eqs.setup
      #> comb_eqs.setup
      #> eval_comb_eqs.setup
  end
*}

setup Sepref_Monadify.setup

lemma dflt_arity[sepref_monadify_arity]:
  "RECT ≡ λ2B x. SP RECT$(λ2D x. B$(λ2x. RCALL$D$x)$x)$x" 
  "case_list ≡ λ2fn fc l. SP case_list$fn$(λ2x xs. fc$x$xs)$l" 
  "case_prod ≡ λ2fp p. SP case_prod$(λ2a b. fp$a$b)$p" 
  "If ≡ λ2b t e. SP If$b$t$e" 
  "Let ≡ λ2x f. SP Let$x$(λ2x. f$x)"
  by (simp_all only: SP_def APP_def PROTECT2_def RCALL_def)


lemma dflt_comb[sepref_monadify_comb]:
  "!!B x. RECT$B$x ≡ bind$(EVAL$x)$(λ2x. SP (RECT$B$x))"
  "!!D x. RCALL$D$x ≡ bind$(EVAL$x)$(λ2x. SP (RCALL$D$x))"
  "!!fn fc l. case_list$fn$fc$l ≡ bind$(EVAL$l)$(λ2l. (SP case_list$fn$fc$l))"
  "!!fp p. case_prod$fp$p ≡ bind$(EVAL$p)$(λ2p. (SP case_prod$fp$p))"
  "!!fn fs ov. case_option$fn$fs$ov 
    ≡ bind$(EVAL$ov)$(λ2ov. (SP case_option$fn$fs$ov))"
  "!!b t e. If$b$t$e ≡ bind$(EVAL$b)$(λ2b. (SP If$b$t$e))"
  "!!x. RETURN$x ≡ bind$(EVAL$x)$(λ2x. SP (RETURN$x))"
  "!!x f. Let$x$f ≡ bind$(EVAL$x)$(λ2x. (SP Let$x$f))"
  by (simp_all)


lemma dflt_plain_comb[sepref_monadify_comb]:
  "EVAL$(If$b$t$e) ≡ bind$(EVAL$b)$(λ2b. If$b$(EVAL$t)$(EVAL$e))"
  "EVAL$(case_list$fn$(λ2x xs. fc x xs)$l) ≡ 
    bind$(EVAL$l)$(λ2l. case_list$(EVAL$fn)$(λ2x xs. EVAL$(fc x xs))$l)"
  "EVAL$(case_prod$(λ2a b. fp a b)$p) ≡ 
    bind$(EVAL$p)$(λ2p. case_prod$(λ2a b. EVAL$(fp a b))$p)"
  "EVAL$(case_option$fn$(λ2x. fs x)$ov) ≡ 
    bind$(EVAL$ov)$(λ2ov. case_option$(EVAL$fn)$(λ2x. EVAL$(fs x))$ov)"
  apply (rule eq_reflection, simp split: list.split prod.split option.split)+
  done

lemma evalcomb_PR_CONST[sepref_monadify_evalcomb]:
  "EVAL$(PR_CONST x) ≡ RETURN$(PR_CONST x)"
  by simp


end