header ‹Monadify› theory Sepref_Monadify imports Sepref_Basic Id_Op begin text {* In this phase, a monadic program is converted to complete monadic form, that is, computation of compound expressions are made visible as top-level operations in the monad. The monadify process is separated into 2 steps. \begin{enumerate} \item In a first step, eta-expansion is used to add missing operands to operations and combinators. This way, operators and combinators always occur with the same arity, which simplifies further processing. \item In a second step, computation of compound operands is flattened, introducing new bindings for the intermediate values. \end{enumerate} *} definition SP -- ‹Tag to protect content from further application of arity and combinator equations› where [simp]: "SP x ≡ x" lemma SP_cong[cong]: "SP x ≡ SP x" by simp definition RCALL -- ‹Tag that marks recursive call› where [simp]: "RCALL D ≡ D" definition EVAL -- ‹Tag that marks evaluation of plain expression for monadify phase› where [simp]: "EVAL x ≡ RETURN x" text {* Internally, the package first applies rewriting rules from @{text sepref_monadify_arity}, which use eta-expansion to ensure that every combinator has enough actual parameters. Moreover, this phase will mark recursive calls by the tag @{const RCALL}. Next, rewriting rules from @{text sepref_monadify_comb} are used to add @{const EVAL}-tags to plain expressions that should be evaluated in the monad. Finally, the expressions inside the eval-tags are flattened. In this step, rewrite rules from @{text sepref_monadify_evalcomb} are applied, in conjunction with a default rule that evaluates the arguments of each function from left to right. *} lemma monadify_simps: "bind$(RETURN$x)$(λ⇩2x. f x) = f x" "EVAL$x ≡ RETURN$x" by simp_all definition [simp]: "PASS ≡ RETURN" -- "Pass on value, invalidating old one" lemma remove_pass_simps: "bind$(PASS$x)$(λ⇩2x. f x) ≡ f x" "bind$m$(λ⇩2x. PASS$x) ≡ m" by simp_all ML {* structure Sepref_Monadify = struct structure arity_eqs = Named_Thms ( val name = @{binding sepref_monadify_arity} val description = "Sepref.Monadify: Arity alignment equations" ) structure comb_eqs = Named_Thms ( val name = @{binding sepref_monadify_comb} val description = "Sepref.Monadify: Combinator equations" ) structure eval_comb_eqs = Named_Thms ( val name = @{binding sepref_monadify_evalcomb} val description = "Sepref.Monadify: Eval-Combinator equations" ) local fun cr_var (i,T) = ("v"^string_of_int i, Free ("__v"^string_of_int i,T)) fun lambda2_name n t = let val t = @{mk_term "PROTECT2 ?t DUMMY"} in Term.lambda_name n t end fun bind_args exp0 [] = exp0 | bind_args exp0 ((x,m)::xms) = let val lr = bind_args exp0 xms |> incr_boundvars 1 |> lambda2_name x in @{mk_term "bind$?m$?lr"} end fun monadify t = let val (f,args) = Autoref_Tagging.strip_app t val _ = not (is_Abs f) orelse raise TERM ("monadify: higher-order",[t]) val argTs = map fastype_of args (*val args = map monadify args*) val args = map (fn a => @{mk_term "EVAL$?a"}) args (*val fT = fastype_of f val argTs = binder_types fT*) val argVs = tag_list 0 argTs |> map cr_var val res0 = let val x = Autoref_Tagging.list_APP (f,map #2 argVs) in @{mk_term "RETURN$?x"} end val res = bind_args res0 (argVs ~~ args) in res end fun monadify_conv_aux ctxt ct = case term_of ct of @{mpat "EVAL$_"} => let val ss = ctxt val ss = (ss addsimps @{thms monadify_simps}) val tac = (simp_tac ss 1) in (*Refine_Util.monitor_conv "monadify"*) ( Refine_Util.f_tac_conv ctxt (dest_comb #> #2 #> monadify) tac) ct end | t => raise TERM ("monadify_conv",[t]) fun extract_comb_conv ctxt = Conv.rewrs_conv (eval_comb_eqs.get ctxt) in val monadify_conv = Conv.top_conv (fn ctxt => Conv.try_conv ( extract_comb_conv ctxt else_conv monadify_conv_aux ctxt ) ) end fun mark_params env @{mpat "RETURN$(?x AS⇩s mpaq_Bound _)"} = @{mk_term env: "PASS$?x"} | mark_params env (t1$t2) = mark_params env t1 $ mark_params env t2 | mark_params env (Abs (x,T,t)) = Abs (x,T,mark_params (T::env) t) | mark_params _ t = t fun mark_params_conv ctxt = Refine_Util.f_tac_conv ctxt (mark_params []) (simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms PASS_def}) 1) fun monadify_tac ctxt = let val arity1_ss = put_simpset HOL_basic_ss ctxt addsimps arity_eqs.get ctxt |> Simplifier.add_cong @{thm SP_cong} val arity2_ss = put_simpset HOL_basic_ss ctxt addsimps @{thms beta SP_def} val arity_tac = simp_tac arity1_ss THEN' simp_tac arity2_ss val comb1_ss = put_simpset HOL_basic_ss ctxt addsimps comb_eqs.get ctxt addsimps eval_comb_eqs.get ctxt |> Simplifier.add_cong @{thm SP_cong} val comb2_ss = put_simpset HOL_basic_ss ctxt addsimps @{thms SP_def} val comb_tac = simp_tac comb1_ss THEN' simp_tac comb2_ss open Sepref_Basic in arity_tac THEN' comb_tac THEN' CONVERSION (hn_refine_concl_conv_a monadify_conv ctxt) THEN' CONVERSION (hn_refine_concl_conv_a (K (mark_params_conv ctxt)) ctxt) THEN' simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms remove_pass_simps}) end val setup = I #> arity_eqs.setup #> comb_eqs.setup #> eval_comb_eqs.setup end *} setup Sepref_Monadify.setup lemma dflt_arity[sepref_monadify_arity]: "RECT ≡ λ⇩2B x. SP RECT$(λ⇩2D x. B$(λ⇩2x. RCALL$D$x)$x)$x" "case_list ≡ λ⇩2fn fc l. SP case_list$fn$(λ⇩2x xs. fc$x$xs)$l" "case_prod ≡ λ⇩2fp p. SP case_prod$(λ⇩2a b. fp$a$b)$p" "If ≡ λ⇩2b t e. SP If$b$t$e" "Let ≡ λ⇩2x f. SP Let$x$(λ⇩2x. f$x)" by (simp_all only: SP_def APP_def PROTECT2_def RCALL_def) lemma dflt_comb[sepref_monadify_comb]: "!!B x. RECT$B$x ≡ bind$(EVAL$x)$(λ⇩2x. SP (RECT$B$x))" "!!D x. RCALL$D$x ≡ bind$(EVAL$x)$(λ⇩2x. SP (RCALL$D$x))" "!!fn fc l. case_list$fn$fc$l ≡ bind$(EVAL$l)$(λ⇩2l. (SP case_list$fn$fc$l))" "!!fp p. case_prod$fp$p ≡ bind$(EVAL$p)$(λ⇩2p. (SP case_prod$fp$p))" "!!fn fs ov. case_option$fn$fs$ov ≡ bind$(EVAL$ov)$(λ⇩2ov. (SP case_option$fn$fs$ov))" "!!b t e. If$b$t$e ≡ bind$(EVAL$b)$(λ⇩2b. (SP If$b$t$e))" "!!x. RETURN$x ≡ bind$(EVAL$x)$(λ⇩2x. SP (RETURN$x))" "!!x f. Let$x$f ≡ bind$(EVAL$x)$(λ⇩2x. (SP Let$x$f))" by (simp_all) lemma dflt_plain_comb[sepref_monadify_comb]: "EVAL$(If$b$t$e) ≡ bind$(EVAL$b)$(λ⇩2b. If$b$(EVAL$t)$(EVAL$e))" "EVAL$(case_list$fn$(λ⇩2x xs. fc x xs)$l) ≡ bind$(EVAL$l)$(λ⇩2l. case_list$(EVAL$fn)$(λ⇩2x xs. EVAL$(fc x xs))$l)" "EVAL$(case_prod$(λ⇩2a b. fp a b)$p) ≡ bind$(EVAL$p)$(λ⇩2p. case_prod$(λ⇩2a b. EVAL$(fp a b))$p)" "EVAL$(case_option$fn$(λ⇩2x. fs x)$ov) ≡ bind$(EVAL$ov)$(λ⇩2ov. case_option$(EVAL$fn)$(λ⇩2x. EVAL$(fs x))$ov)" apply (rule eq_reflection, simp split: list.split prod.split option.split)+ done lemma evalcomb_PR_CONST[sepref_monadify_evalcomb]: "EVAL$(PR_CONST x) ≡ RETURN$(PR_CONST x)" by simp end