section ‹Sepref Tool›
theory Sepref_Tool
imports Sepref_Translate Sepref_Definition Sepref_Combinator_Setup Sepref_Intf_Util
begin
text ‹In this theory, we set up the sepref tool.›
subsection ‹Sepref Method›
lemma CONS_init:
assumes "hn_refine Γ c Γ' R a"
assumes "Γ' ⟹⇩t Γc'"
assumes "⋀a c. hn_ctxt R a c ⟹⇩t hn_ctxt Rc a c"
shows "hn_refine Γ c Γc' Rc a"
apply (rule hn_refine_cons)
apply (rule entt_refl)
apply (rule assms[unfolded hn_ctxt_def])+
done
lemma ID_init: "⟦ID a a' TYPE('T); hn_refine Γ c Γ' R a'⟧
⟹ hn_refine Γ c Γ' R a" by simp
lemma TRANS_init: "⟦ hn_refine Γ c Γ' R a; CNV c c' ⟧
⟹ hn_refine Γ c' Γ' R a"
by simp
lemma infer_post_triv: "P ⟹⇩t P" by (rule entt_refl)
ML {*
structure Sepref = struct
structure sepref_preproc_simps = Named_Thms (
val name = @{binding sepref_preproc}
val description = "Sepref: Preprocessor simplifications"
)
structure sepref_opt_simps = Named_Thms (
val name = @{binding sepref_opt_simps}
val description = "Sepref: Post-Translation optimizations, phase 1"
)
structure sepref_opt_simps2 = Named_Thms (
val name = @{binding sepref_opt_simps2}
val description = "Sepref: Post-Translation optimizations, phase 2"
)
fun cons_init_tac ctxt = Sepref_Frame.weaken_post_tac ctxt THEN' resolve_tac ctxt @{thms CONS_init}
fun cons_solve_tac dbg ctxt = let
val dbgSOLVED' = if dbg then I else SOLVED'
in
dbgSOLVED' (
resolve_tac ctxt @{thms infer_post_triv}
ORELSE' Sepref_Translate.side_frame_tac ctxt
)
end
fun preproc_tac ctxt = let
val ctxt = put_simpset HOL_basic_ss ctxt
val ctxt = ctxt addsimps (sepref_preproc_simps.get ctxt)
in
Sepref_Rules.prepare_hfref_synth_tac ctxt THEN'
Simplifier.simp_tac ctxt
end
fun id_tac ctxt =
resolve_tac ctxt @{thms ID_init}
THEN' CONVERSION Thm.eta_conversion
THEN' DETERM o Id_Op.id_tac Id_Op.Normal ctxt
fun id_init_tac ctxt =
resolve_tac ctxt @{thms ID_init}
THEN' CONVERSION Thm.eta_conversion
THEN' Id_Op.id_tac Id_Op.Init ctxt
fun id_step_tac ctxt =
Id_Op.id_tac Id_Op.Step ctxt
fun id_solve_tac ctxt =
Id_Op.id_tac Id_Op.Solve ctxt
(*fun id_param_tac ctxt = CONVERSION (Refine_Util.HOL_concl_conv
(K (Sepref_Param.id_param_conv ctxt)) ctxt)*)
fun monadify_tac ctxt = Sepref_Monadify.monadify_tac ctxt
(*fun lin_ana_tac ctxt = Sepref_Lin_Ana.lin_ana_tac ctxt*)
fun trans_tac ctxt = Sepref_Translate.trans_tac ctxt
fun opt_tac ctxt = let
val opt1_ss = put_simpset HOL_basic_ss ctxt
addsimps sepref_opt_simps.get ctxt
addsimprocs [@{simproc "HOL.let_simp"}]
|> Simplifier.add_cong @{thm SP_cong}
|> Simplifier.add_cong @{thm PR_CONST_cong}
val unsp_ss = put_simpset HOL_basic_ss ctxt addsimps @{thms SP_def}
val opt2_ss = put_simpset HOL_basic_ss ctxt
addsimps sepref_opt_simps2.get ctxt
addsimprocs [@{simproc "HOL.let_simp"}]
in
simp_tac opt1_ss THEN' simp_tac unsp_ss THEN'
simp_tac opt2_ss THEN' simp_tac unsp_ss THEN'
CONVERSION Thm.eta_conversion THEN'
resolve_tac ctxt @{thms CNV_I}
end
fun sepref_tac dbg ctxt =
(K Sepref_Constraints.ensure_slot_tac)
THEN'
Sepref_Basic.PHASES'
[
("preproc",preproc_tac,0),
("cons_init",cons_init_tac,2),
("id",id_tac,0),
("monadify",monadify_tac false,0),
("opt_init",fn ctxt => resolve_tac ctxt @{thms TRANS_init},1),
("trans",trans_tac,~1),
("opt",opt_tac,~1),
("cons_solve1",cons_solve_tac false,~1),
("cons_solve2",cons_solve_tac false,~1),
("constraints",fn ctxt => K (Sepref_Constraints.solve_constraint_slot ctxt THEN Sepref_Constraints.remove_slot_tac),~1)
] (Sepref_Basic.flag_phases_ctrl dbg) ctxt
val setup = I
#> sepref_preproc_simps.setup
#> sepref_opt_simps.setup
#> sepref_opt_simps2.setup
end
*}
setup Sepref.setup
method_setup sepref = ‹Scan.succeed (fn ctxt =>
SIMPLE_METHOD (DETERM (SOLVED' (IF_EXGOAL (
Sepref.sepref_tac false ctxt
)) 1)))›
‹Automatic refinement to Imperative/HOL›
method_setup sepref_dbg_keep = ‹Scan.succeed (fn ctxt => let
(*val ctxt = Config.put Id_Op.cfg_id_debug true ctxt*)
in
SIMPLE_METHOD (IF_EXGOAL (Sepref.sepref_tac true ctxt) 1)
end)›
‹Automatic refinement to Imperative/HOL, debug mode›
subsubsection ‹Default Optimizer Setup›
lemma return_bind_eq_let: "do { x←ureturn v; f x } = do { let x=v; f x }" by simp
lemmas [sepref_opt_simps] = return_bind_eq_let bind_return bind_bind id_def
text ‹We allow the synthesized function to contain tagged function applications.
This is important to avoid higher-order unification problems when synthesizing
generic algorithms, for example the to-list algorithm for foreach-loops.›
lemmas [sepref_opt_simps] = Autoref_Tagging.APP_def
text {* Revert case-pulling done by monadify *}
lemma case_prod_return_opt[sepref_opt_simps]:
"case_prod (λa b. return (f a b)) p = return (case_prod f p)"
by (simp split: prod.split)
lemma case_option_return_opt[sepref_opt_simps]:
"case_option (return fn) (λs. return (fs s)) v = return (case_option fn fs v)"
by (simp split: option.split)
lemma case_list_return[sepref_opt_simps]:
"case_list (return fn) (λx xs. return (fc x xs)) l = return (case_list fn fc l)"
by (simp split: list.split)
lemma if_return[sepref_opt_simps]:
"If b (return t) (return e) = return (If b t e)" by simp
text {* In some cases, pushing in the returns is more convenient *}
lemma case_prod_opt2[sepref_opt_simps2]:
"(λx. return (case x of (a,b) ⇒ f a b))
= (λ(a,b). return (f a b))"
by auto
subsection ‹Debugging Methods›
ML ‹
fun SIMPLE_METHOD_NOPARAM' tac = Scan.succeed (fn ctxt => SIMPLE_METHOD' (IF_EXGOAL (tac ctxt)))
fun SIMPLE_METHOD_NOPARAM tac = Scan.succeed (fn ctxt => SIMPLE_METHOD (tac ctxt))
›
method_setup sepref_dbg_preproc = ‹SIMPLE_METHOD_NOPARAM' (fn ctxt => K (Sepref_Constraints.ensure_slot_tac) THEN' Sepref.preproc_tac ctxt)›
‹Sepref debug: Preprocessing phase›
method_setup sepref_dbg_cons_init = ‹SIMPLE_METHOD_NOPARAM' Sepref.cons_init_tac›
‹Sepref debug: Initialize consequence reasoning›
method_setup sepref_dbg_id = ‹SIMPLE_METHOD_NOPARAM' (Sepref.id_tac)›
‹Sepref debug: Identify operations phase›
method_setup sepref_dbg_id_keep = ‹SIMPLE_METHOD_NOPARAM' (Config.put Id_Op.cfg_id_debug true #> Sepref.id_tac)›
‹Sepref debug: Identify operations phase. Debug mode, keep intermediate subgoals on failure.›
method_setup sepref_dbg_monadify = ‹SIMPLE_METHOD_NOPARAM' (Sepref.monadify_tac false)›
‹Sepref debug: Monadify phase›
method_setup sepref_dbg_monadify_keep = ‹SIMPLE_METHOD_NOPARAM' (Sepref.monadify_tac true)›
‹Sepref debug: Monadify phase›
method_setup sepref_dbg_monadify_arity = ‹SIMPLE_METHOD_NOPARAM' (Sepref_Monadify.arity_tac)›
‹Sepref debug: Monadify phase: Arity phase›
method_setup sepref_dbg_monadify_comb = ‹SIMPLE_METHOD_NOPARAM' (Sepref_Monadify.comb_tac)›
‹Sepref debug: Monadify phase: Comb phase›
method_setup sepref_dbg_monadify_check_EVAL = ‹SIMPLE_METHOD_NOPARAM' (K (CONCL_COND' (not o Sepref_Monadify.contains_eval)))›
‹Sepref debug: Monadify phase: check_EVAL phase›
method_setup sepref_dbg_monadify_mark_params = ‹SIMPLE_METHOD_NOPARAM' (Sepref_Monadify.mark_params_tac)›
‹Sepref debug: Monadify phase: mark_params phase›
method_setup sepref_dbg_monadify_dup = ‹SIMPLE_METHOD_NOPARAM' (Sepref_Monadify.dup_tac)›
‹Sepref debug: Monadify phase: dup phase›
method_setup sepref_dbg_monadify_remove_pass = ‹SIMPLE_METHOD_NOPARAM' (Sepref_Monadify.remove_pass_tac)›
‹Sepref debug: Monadify phase: remove_pass phase›
method_setup sepref_dbg_opt_init = ‹SIMPLE_METHOD_NOPARAM' (fn ctxt => resolve_tac ctxt @{thms TRANS_init})›
‹Sepref debug: Translation phase initialization›
method_setup sepref_dbg_trans = ‹SIMPLE_METHOD_NOPARAM' Sepref.trans_tac›
‹Sepref debug: Translation phase›
method_setup sepref_dbg_opt = ‹SIMPLE_METHOD_NOPARAM' (fn ctxt =>
Sepref.opt_tac ctxt
THEN' CONVERSION Thm.eta_conversion
THEN' TRY o resolve_tac ctxt @{thms CNV_I}
)›
‹Sepref debug: Optimization phase›
method_setup sepref_dbg_cons_solve = ‹SIMPLE_METHOD_NOPARAM' (Sepref.cons_solve_tac false)›
‹Sepref debug: Solve post-consequences›
method_setup sepref_dbg_cons_solve_keep = ‹SIMPLE_METHOD_NOPARAM' (Sepref.cons_solve_tac true)›
‹Sepref debug: Solve post-consequences, keep intermediate results›
method_setup sepref_dbg_constraints = ‹SIMPLE_METHOD_NOPARAM' (fn ctxt => IF_EXGOAL (K (
Sepref_Constraints.solve_constraint_slot ctxt
THEN Sepref_Constraints.remove_slot_tac
)))›
‹Sepref debug: Solve accumulated constraints›
method_setup sepref_dbg_id_init = ‹SIMPLE_METHOD_NOPARAM' Sepref.id_init_tac›
‹Sepref debug: Initialize operation identification phase›
method_setup sepref_dbg_id_step = ‹SIMPLE_METHOD_NOPARAM' Sepref.id_step_tac›
‹Sepref debug: Single step operation identification phase›
method_setup sepref_dbg_id_solve = ‹SIMPLE_METHOD_NOPARAM' Sepref.id_solve_tac›
‹Sepref debug: Complete current operation identification goal›
method_setup sepref_dbg_trans_keep = ‹SIMPLE_METHOD_NOPARAM' Sepref_Translate.trans_keep_tac›
‹Sepref debug: Translation phase, stop at failed subgoal›
method_setup sepref_dbg_trans_step = ‹SIMPLE_METHOD_NOPARAM' Sepref_Translate.trans_step_tac›
‹Sepref debug: Translation step›
method_setup sepref_dbg_trans_step_keep = ‹SIMPLE_METHOD_NOPARAM' Sepref_Translate.trans_step_keep_tac›
‹Sepref debug: Translation step, keep unsolved subgoals›
method_setup sepref_dbg_side = ‹SIMPLE_METHOD_NOPARAM' (fn ctxt => REPEAT_ALL_NEW_FWD (Sepref_Translate.side_cond_dispatch_tac false (K no_tac) ctxt))›
method_setup sepref_dbg_side_unfold = ‹SIMPLE_METHOD_NOPARAM' (Sepref_Translate.side_unfold_tac)›
method_setup sepref_dbg_side_keep = ‹SIMPLE_METHOD_NOPARAM' (fn ctxt => REPEAT_ALL_NEW_FWD (Sepref_Translate.side_cond_dispatch_tac true (K no_tac) ctxt))›
method_setup sepref_dbg_prepare_frame = ‹SIMPLE_METHOD_NOPARAM' Sepref_Frame.prepare_frame_tac›
‹Sepref debug: Prepare frame inference›
method_setup sepref_dbg_frame = ‹SIMPLE_METHOD_NOPARAM' (Sepref_Frame.frame_tac (Sepref_Translate.side_fallback_tac))›
‹Sepref debug: Frame inference›
method_setup sepref_dbg_merge = ‹SIMPLE_METHOD_NOPARAM' (Sepref_Frame.merge_tac (Sepref_Translate.side_fallback_tac))›
‹Sepref debug: Frame inference, merge›
method_setup sepref_dbg_frame_step = ‹SIMPLE_METHOD_NOPARAM' (Sepref_Frame.frame_step_tac (Sepref_Translate.side_fallback_tac) false)›
‹Sepref debug: Frame inference, single-step›
method_setup sepref_dbg_frame_step_keep = ‹SIMPLE_METHOD_NOPARAM' (Sepref_Frame.frame_step_tac (Sepref_Translate.side_fallback_tac) true)›
‹Sepref debug: Frame inference, single-step, keep partially solved side conditions›
subsection ‹Utilities›
subsubsection ‹Manual hfref-proofs›
method_setup sepref_to_hnr = ‹SIMPLE_METHOD_NOPARAM' (fn ctxt =>
Sepref.preproc_tac ctxt THEN' Sepref_Frame.weaken_post_tac ctxt)›
‹Sepref: Convert to hnr-goal and weaken postcondition›
method_setup sepref_to_hoare = ‹
let
fun sepref_to_hoare_tac ctxt = let
val ss = put_simpset HOL_basic_ss ctxt
addsimps @{thms hn_ctxt_def pure_def}
in
Sepref.preproc_tac ctxt
THEN' Sepref_Frame.weaken_post_tac ctxt
THEN' TRY o (FIRST' [ resolve_tac ctxt @{thms hn_refineI0},
resolve_tac ctxt @{thms hn_refineI} THEN' asm_full_simp_tac ss])
THEN' asm_full_simp_tac ss
end
in
SIMPLE_METHOD_NOPARAM' sepref_to_hoare_tac
end
› ‹Sepref: Convert to hoare-triple›
subsubsection ‹Copying of Parameters›
lemma fold_COPY: "x = COPY x" by simp
sepref_register COPY
text ‹Copy is treated as normal operator, and one can just declare rules for it! ›
lemma hnr_pure_COPY[sepref_fr_rules]:
"CONSTRAINT is_pure R ⟹ (ureturn, RETURNT o COPY) ∈ R⇧k →⇩a R"
by (sep_auto simp: is_pure_conv pure_def intro!: hfrefI hn_refineI0)
subsubsection ‹Short-Circuit Boolean Evaluation›
text ‹Convert boolean operators to short-circuiting.
When applied before monadify, this will generate a short-circuit execution.›
lemma short_circuit_conv:
"(a ∧ b) ⟷ (if a then b else False)"
"(a ∨ b) ⟷ (if a then True else b)"
"(a⟶b) ⟷ (if a then b else True)"
by auto
subsubsection ‹Eliminating higher-order›
lemma ho_prod_move[sepref_preproc]: "case_prod (λa b x. f x a b) = (λp x. case_prod (f x) p)"
by (auto intro!: ext)
declare o_apply[sepref_preproc]
subsubsection ‹Precision Proofs›
text ‹
We provide a method that tries to extract equalities from
an assumption of the form
@{text "_ ⊨ P1 * … * Pn ∧⇩A P1' * … * Pn'"},
if it find a precision rule for Pi and Pi'.
The precision rules are extracted from the constraint rules.
TODO: Extracting the precision rules from the constraint rules
is not a clean solution. It might be better to collect precision rules
separately, and feed them into the constraint solver.
›
definition "prec_spec h Γ Γ' ≡ h ⊨ Γ * true ∧⇩A Γ' * true"
lemma prec_specI: "h ⊨ Γ ∧⇩A Γ' ⟹ prec_spec h Γ Γ'"
unfolding prec_spec_def
by (auto simp: mod_and_dist mod_star_trueI)
lemma prec_split1_aux: "A*B*true ⟹⇩A A*true"
by (simp add: ent_true_drop(1) entt_refl')
lemma prec_split2_aux: "A*B*true ⟹⇩A B*true"
by (simp add: ent_true_drop(1) ent_true_drop_fst entt_refl')
lemma prec_spec_splitE:
assumes "prec_spec h (A*B) (C*D)"
obtains "prec_spec h A C" "prec_spec h B D"
apply (thin_tac "⟦_;_⟧ ⟹ _")
apply (rule that)
using assms
apply -
unfolding prec_spec_def
apply (erule entailsD[rotated])
apply (rule ent_conjI)
apply (rule ent_conjE1)
apply (rule prec_split1_aux)
apply (rule ent_conjE2)
apply (rule prec_split1_aux)
apply (erule entailsD[rotated])
apply (rule ent_conjI)
apply (rule ent_conjE1)
apply (rule prec_split2_aux)
apply (rule ent_conjE2)
apply (rule prec_split2_aux)
done
lemma prec_specD:
assumes "precise R"
assumes "prec_spec h (R a p) (R a' p)"
shows "a=a'"
using assms unfolding precise_def prec_spec_def CONSTRAINT_def by blast
ML {*
fun prec_extract_eqs_tac ctxt = let
fun is_precise thm = case Thm.concl_of thm of
@{mpat "Trueprop (precise _)"} => true
| _ => false
val thms = Sepref_Constraints.get_constraint_rules ctxt
@ Sepref_Constraints.get_safe_constraint_rules ctxt
val thms = thms
|> filter is_precise
val thms = @{thms snga_prec sngr_prec} @ thms
val thms = map (fn thm => thm RS @{thm prec_specD}) thms
val thin_prec_spec_rls = @{thms thin_rl[Pure.of "prec_spec a b c" for a b c]}
val tac =
forward_tac ctxt @{thms prec_specI}
THEN' REPEAT_ALL_NEW (ematch_tac ctxt @{thms prec_spec_splitE})
THEN' REPEAT o (dresolve_tac ctxt thms)
THEN' REPEAT o (eresolve_tac ctxt thin_prec_spec_rls )
in tac end
*}
method_setup prec_extract_eqs = ‹SIMPLE_METHOD_NOPARAM' prec_extract_eqs_tac›
‹Extract equalities from "_ |= _ & _" assumption, using precision rules›
subsubsection ‹Combinator Rules›
lemma split_merge: "⟦A ∨⇩A B ⟹⇩t X; X ∨⇩A C ⟹⇩t D⟧ ⟹ (A ∨⇩A B ∨⇩A C ⟹⇩t D)"
proof -
assume a1: "X ∨⇩A C ⟹⇩t D"
assume "A ∨⇩A B ⟹⇩t X"
then have "A ∨⇩A B ⟹⇩A D * true"
using a1 by (meson ent_disjI1_direct ent_frame_fwd enttD entt_def_true)
then show ?thesis
using a1 by (meson ent_disjI1 entailst_def entt_disjD2 entt_disjE)
qed
ML ‹
fun prep_comb_rule thm = let
fun mrg t = case Logic.strip_assums_concl t of
@{mpat "Trueprop (_ ∨⇩A _ ∨⇩A _ ⟹⇩t _)"} => (@{thm split_merge},true)
| @{mpat "Trueprop (hn_refine _ _ ?G _ _)"} => (
if not (is_Var (head_of G)) then (@{thm hn_refine_cons_post}, true)
else (asm_rl,false)
)
| _ => (asm_rl,false)
val inst = Thm.prems_of thm |> map mrg
in
if exists snd inst then
prep_comb_rule (thm OF (map fst inst))
else
thm |> zero_var_indexes
end
›
attribute_setup sepref_prep_comb_rule = ‹Scan.succeed (Thm.rule_attribute [] (K prep_comb_rule))›
‹Preprocess combinator rule: Split merge-rules and add missing frame rules›
end