Theory Sepref_Definition

section Sepref-Definition Command
theory Sepref_Definition
imports Sepref_Translate "Lib/Term_Synth"
keywords "sepref_definition" :: thy_goal
      and "sepref_def" :: thy_goal
      and "sepref_thm" :: thy_goal
begin
subsection  Setup of Extraction-Tools 
  declare [[cd_patterns "hn_refine _  _ _ _"]]


  subsection Synthesis setup for sepref-definition goals
  (* TODO: The UNSPEC are an ad-hoc hack to specify the synthesis goal *)
  consts UNSPEC::'a  

  abbreviation hfunspec 
    :: "('a  'b  assn)  ('a  'b  assn)×('a  'b  assn)" 
    ("(_?)" [1000] 999)
    where "R?  hf_pres R UNSPEC"

  definition SYNTH :: "('a  'r nres)  (('ai 'ri llM) × ('a  'r nres)) set  bool"
    where "SYNTH f R  True"

  definition [simp]: "CP_UNCURRY _ _  True"
  definition [simp]: "CP_PAT _ _  True"
  definition [simp]: "INTRO_KD _ _  True"
  definition SPEC_RES_ASSN :: "'a  'a  bool" where [simp]: "SPEC_RES_ASSN _ _  True"

  lemma [synth_rules]: "CP_UNCURRY f g" by simp
  lemma [synth_rules]: "CP_UNCURRY (uncurry0 f) (uncurry0 g)" by simp
  lemma [synth_rules]: "CP_UNCURRY f g  CP_UNCURRY (uncurry f) (uncurry g)" by simp

  lemma [synth_rules]: "CP_PAT f g" by simp
  lemma [synth_rules]: "CP_PAT (uncurry0 f) (uncurry0 g)" by simp
  lemma [synth_rules]: "CP_PAT f g  CP_PAT (uncurry f) (uncurry g)" by simp
  
  
  lemma [synth_rules]: "INTRO_KD R1 R1'; INTRO_KD R2 R2'  INTRO_KD (R1*aR2) (R1'*aR2')" by simp
  lemma [synth_rules]: "INTRO_KD (R?) (hf_pres R k)" by simp
  lemma [synth_rules]: "INTRO_KD (Rk) (Rk)" by simp
  lemma [synth_rules]: "INTRO_KD (Rd) (Rd)" by simp

  lemma [synth_rules]: "SPEC_RES_ASSN R R" by simp
  lemma [synth_rules]: "SPEC_RES_ASSN UNSPEC R" by simp
  
  lemma synth_hnrI:
    "CP_UNCURRY f fi; CP_PAT f fpat; INTRO_KD R R'; SPEC_RES_ASSN S S'  SYNTH_TERM (SYNTH f ([P]a [C]c RdS [CP]c)) ((fpat,SDUMMY)SDUMMY,(fi,f)([P]a [C]c R'dS' [CP]c))" 
    by (simp add: SYNTH_def)


ML 
  structure Sepref_Definition = struct
    fun make_hnr_goal t ctxt = let
      val ctxt = Variable.declare_term t ctxt
      val (pat,goal) = case Term_Synth.synth_term @{thms synth_hnrI} ctxt t of
        @{mpat "(?pat,?goal)"} => (pat,goal) | t => raise TERM("Synthesized term does not match",[t])
      val pat = Thm.cterm_of ctxt pat |> Definition_Utils.prepare_cd_pattern ctxt
      val goal = HOLogic.mk_Trueprop goal
    in
      ((pat,goal),ctxt)
    end

    local 
      open Refine_Util
      (*val flags = parse_bool_config' "prep_code" cfg_prep_code
      val parse_flags = parse_paren_list' flags  
      *)

    in       
      val sd_parser = Parse.binding -- Parse.opt_attribs --| @{keyword "is"} 
        -- Parse.opt_attribs -- Parse.term --| @{keyword "::"} -- Parse.term
        
      val sd_dflt_parser = 
          Parse.binding 
        -- Scan.optional Parse.attribs @{attributes [llvm_code]} 
        --| @{keyword "is"} 
        -- Scan.optional Parse.attribs @{attributes [sepref_fr_rules]} 
        -- Parse.term 
        --| @{keyword "::"} 
        -- Parse.term
        
        
    end  

    fun mk_synth_term ctxt t_raw r_raw = let
        val t = Syntax.parse_term ctxt t_raw
        val r = Syntax.parse_term ctxt r_raw
        val t = Const (@{const_name SYNTH},dummyT)$t$r
      in
        Syntax.check_term ctxt t
      end  


    fun sd_cmd ((((name,attribs_def),attribs_ref),t_raw),r_raw) lthy = let
      (*local
        val ctxt = Refine_Util.apply_configs flags lthy
      in
        val flag_prep_code = Config.get ctxt cfg_prep_code
      end
      *)

      val t = mk_synth_term lthy t_raw r_raw

      val ((pat,goal),ctxt) = make_hnr_goal t lthy
      
      fun 
        after_qed [[thm]] ctxt = let
            val thm = singleton (Variable.export ctxt lthy) thm

            val (_,lthy) 
              = Local_Theory.note 
                 ((Definition_Utils.mk_qualified (Binding.name_of name) "refine_raw",[]),[thm]) 
                 lthy;

            val ((dthm,rthm),lthy) = Definition_Utils.define_concrete_fun NONE name attribs_def attribs_ref [] thm [pat] lthy

            val _ = Thm.pretty_thm lthy dthm |> Pretty.string_of |> writeln
            val _ = Thm.pretty_thm lthy rthm |> Pretty.string_of |> writeln
          in
            lthy
          end
        | after_qed thmss _ = raise THM ("After-qed: Wrong thmss structure",~1,flat thmss)

    in
      Proof.theorem NONE after_qed [[ (goal,[]) ]] ctxt
    end



    val _ = Outer_Syntax.local_theory_to_proof @{command_keyword "sepref_definition"}
      "Synthesis of imperative program"
      (sd_parser >> sd_cmd)

    val _ = Outer_Syntax.local_theory_to_proof @{command_keyword "sepref_def"}
      "Synthesis of imperative program (default attributes)"
      (sd_dflt_parser >> sd_cmd)
      
      
    val st_parser = Parse.binding --| @{keyword "is"} -- Parse.term --| @{keyword "::"} -- Parse.term

    fun st_cmd ((name,t_raw),r_raw) lthy = let
      val t = mk_synth_term lthy t_raw r_raw
      val ((_,goal),ctxt) = make_hnr_goal t lthy
      
      fun 
        after_qed [[thm]] ctxt = let
            val thm = singleton (Variable.export ctxt lthy) thm

            val _ = Thm.pretty_thm lthy thm |> Pretty.string_of |> tracing
  
            val (_,lthy) 
              = Local_Theory.note 
                 ((Definition_Utils.mk_qualified (Binding.name_of name) "refine_raw",[]),[thm]) 
                 lthy;

          in
            lthy
          end
        | after_qed thmss _ = raise THM ("After-qed: Wrong thmss structure",~1,flat thmss)

    in
      Proof.theorem NONE after_qed [[ (goal,[]) ]] ctxt
    end

    val _ = Outer_Syntax.local_theory_to_proof @{command_keyword "sepref_thm"}
      "Synthesis of imperative program: Only generate raw refinement theorem"
      (st_parser >> st_cmd)

  end


end