Theory Sepref_Tool

theory Sepref_Tool
imports Sepref_Translate
header ‹Sepref Tool›
theory Sepref_Tool
imports Sepref_Translate
begin

text ‹In this theory, we set up the sepref tool.›

definition [simp]: "CNV x y ≡ x=y"

lemma ID_init: "[|ID a a' TYPE('T); hn_refine Γ c Γ' R a'|] 
  ==> hn_refine Γ c Γ' R a" by simp

lemma TRANS_init: "[| hn_refine Γ c Γ' R a; CNV c c' |] 
  ==> hn_refine Γ c' Γ' R a"
  by simp

lemma CNV_I: "CNV x x" by simp

ML {*
  structure Sepref = struct
    structure sepref_opt_simps = Named_Thms (
      val name = @{binding sepref_opt_simps}
      val description = "Sepref: Post-Translation optimizations, phase 1"
    )

    structure sepref_opt_simps2 = Named_Thms (
      val name = @{binding sepref_opt_simps2}
      val description = "Sepref: Post-Translation optimizations, phase 2"
    )

    fun id_tac ctxt = 
      rtac @{thm ID_init} 
      THEN' CONVERSION Thm.eta_conversion
      THEN' Id_Op.id_tac Id_Op.Normal ctxt

    fun id_param_tac ctxt = CONVERSION (Refine_Util.HOL_concl_conv 
      (K (Sepref_Param.id_param_conv ctxt)) ctxt)

    fun monadify_tac ctxt = Sepref_Monadify.monadify_tac ctxt

    fun lin_ana_tac ctxt = Sepref_Lin_Ana.lin_ana_tac ctxt

    fun trans_tac ctxt = 
      rtac @{thm TRANS_init} 
      THEN' Sepref_Translate.trans_tac ctxt

    fun opt_tac ctxt = let 
      val opt1_ss = put_simpset HOL_basic_ss ctxt
        addsimps sepref_opt_simps.get ctxt
        addsimprocs [@{simproc "HOL.let_simp"}]
      |> Simplifier.add_cong @{thm SP_cong}
      val unsp_ss = put_simpset HOL_basic_ss ctxt addsimps @{thms SP_def}

      val opt2_ss = put_simpset HOL_basic_ss ctxt
        addsimps sepref_opt_simps2.get ctxt
        addsimprocs [@{simproc "HOL.let_simp"}]

    in 
      simp_tac opt1_ss THEN' simp_tac unsp_ss THEN'
      simp_tac opt2_ss THEN' simp_tac unsp_ss
    end

    fun PHASES [] _ = K all_tac
      | PHASES (tac::tacs) ctxt = 
          IF_EXGOAL (tac ctxt)
          THEN_ELSE' (PHASES tacs ctxt, K all_tac)

    fun sepref_tac ctxt = 
      PHASES 
        [ id_param_tac, 
          id_tac,
          monadify_tac,
          lin_ana_tac,
          trans_tac,
          opt_tac, 
          K (CONVERSION Thm.eta_conversion),
          K (rtac @{thm CNV_I})
        ]
        ctxt

    val setup = sepref_opt_simps.setup #> sepref_opt_simps2.setup
  end
*}

setup Sepref.setup

method_setup sepref = ‹Scan.succeed (fn ctxt =>
  SIMPLE_METHOD (DETERM (SOLVED' (IF_EXGOAL (
      Sepref.sepref_tac ctxt  
    )) 1)))›
  ‹Automatic refinement to Imperative/HOL›

method_setup sepref_keep = ‹Scan.succeed (fn ctxt =>
  SIMPLE_METHOD (IF_EXGOAL (Sepref.sepref_tac ctxt) 1))›
  ‹Automatic refinement to Imperative/HOL›


subsubsection ‹Debugging Methods›
ML ‹
  fun SIMPLE_METHOD_NOPARAM' tac = Scan.succeed (fn ctxt => SIMPLE_METHOD' (tac ctxt))
  fun SIMPLE_METHOD_NOPARAM tac = Scan.succeed (fn ctxt => SIMPLE_METHOD (tac ctxt))
›
method_setup sepref_dbg_id_param = ‹SIMPLE_METHOD_NOPARAM' Sepref.id_param_tac›
  ‹Sepref debug: Identify parameters phase›
method_setup sepref_dbg_id = ‹SIMPLE_METHOD_NOPARAM' Sepref.id_tac›
  ‹Sepref debug: Identify operations phase›
method_setup sepref_dbg_monadify = ‹SIMPLE_METHOD_NOPARAM' Sepref.monadify_tac›
  ‹Sepref debug: Monadify phase›
method_setup sepref_dbg_lin_ana = ‹SIMPLE_METHOD_NOPARAM' Sepref.lin_ana_tac›
  ‹Sepref debug: Linearity analysis phase›
method_setup sepref_dbg_trans = ‹SIMPLE_METHOD_NOPARAM' Sepref.trans_tac›
  ‹Sepref debug: Translation phase›
method_setup sepref_dbg_opt = ‹SIMPLE_METHOD_NOPARAM' Sepref.opt_tac›
  ‹Sepref debug: Optimization phase›

method_setup sepref_dbg_trans_step = ‹SIMPLE_METHOD_NOPARAM' (Sepref_Translate.cstep_tac)›
  ‹Sepref debug: Translation phase single step›

method_setup sepref_dbg_prepare_frame = ‹SIMPLE_METHOD_NOPARAM' Sepref_Frame.prepare_frame_tac›
  ‹Sepref debug: Prepare frame inference›

method_setup sepref_dbg_frame = ‹SIMPLE_METHOD_NOPARAM' Sepref_Frame.frame_tac›
  ‹Sepref debug: Frame inference›


lemmas [sepref_opt_simps] = return_bind bind_return bind_bind id_def

text ‹We allow the synthesized function to contain tagged function applications.
  This is important to avoid higher-order unification problems when synthesizing
  generic algorithms, for example the to-list algorithm for foreach-loops.›
lemmas [sepref_opt_simps] = Autoref_Tagging.APP_def


text {* Revert case-pulling done by monadify *}
lemma case_prod_return_opt[sepref_opt_simps]:
  "case_prod (λa b. return (f a b)) p = return (case_prod f p)"
  by (simp split: prod.split)

lemma case_option_return_opt[sepref_opt_simps]:
  "case_option (return fn) (λs. return (fs s)) v = return (case_option fn fs v)"
  by (simp split: option.split)

lemma case_list_return[sepref_opt_simps]:
  "case_list (return fn) (λx xs. return (fc x xs)) l = return (case_list fn fc l)"
  by (simp split: list.split)

lemma if_return[sepref_opt_simps]:
  "If b (return t) (return e) = return (If b t e)" by simp

text {* In some cases, pushing in the returns is more convenient *}
lemma case_prod_opt2[sepref_opt_simps2]:
  "(λx. return (case x of (a,b) => f a b)) 
  = (λ(a,b). return (f a b))"
  by auto

subsection {* Setup of Extraction-Tools *}
  declare [[cd_patterns "hn_refine _ ?f _ _ _"]]

  (* TODO: Move *)
  definition [simp, code_unfold]: "TRIV_EXTRACTION x ≡ x"

  lemma TRIV_EXTRACTION_codegen:
    assumes DEF: "f ≡ TRIV_EXTRACTION B"
    shows "f = B"
    using assms by simp

  lemma TRIV_extraction_cong: 
    "TRIV_EXTRACTION x ≡ TRIV_EXTRACTION x" 
    by simp

  setup {*
    Refine_Automation.add_extraction "trivial" {
      pattern = term_of @{cpat "TRIV_EXTRACTION _"},
      gen_thm = @{thm TRIV_EXTRACTION_codegen},
      gen_tac = (K (K all_tac))
    }
  *}

  lemma heap_fixp_codegen:
    assumes DEF: "f ≡ heap.fixp_fun cB"
    assumes M: "(!!x. mono_Heap (λf. cB f x))"
    shows "f x = cB f x"
    unfolding DEF
    apply (rule fun_cong[of _ _ x])
    apply (rule heap.mono_body_fixp)
    apply fact
    done

  setup {*
    Refine_Automation.add_extraction "heap" {
      pattern = term_of @{cpat "heap.fixp_fun _"},
      gen_thm = @{thm heap_fixp_codegen},
      gen_tac = (fn ctxt => 
        (* TODO: This is a bit hacky: We should better handle those stuff in
          the mono-prover *)
        simp_tac 
          (put_simpset HOL_basic_ss ctxt addsimps @{thms TRIV_EXTRACTION_def}) 
        THEN'
        Pf_Mono_Prover.mono_tac ctxt
      )
    }
  *}

end