Theory Id_Op

theory Id_Op
imports Autoref_Tagging
header ‹Operation Identification Phase›
theory Id_Op
imports 
  Main 
  "../Automatic_Refinement/Lib/Refine_Lib"
  "../Automatic_Refinement/Tool/Autoref_Tagging"
begin

text ‹
  The operation identification phase is adapted from the Autoref tool.
  The basic idea is to have a type system, which works on so called 
  interface types (also called conceptual types). Each conceptual type
  denotes an abstract data type, e.g., set, map, priority queue.
  
  Each abstract operation, which must be a constant applied to its arguments,
  is assigned a conceptual type. Additionally, there is a set of 
  {\emph pattern rewrite rules},
  which are applied to subterms before type inference takes place, and 
  which may be backtracked over. 
  This way, encodings of abstract operations in Isabelle/HOL, like 
  @{term [source] "λ_. None"} for the empty map, 
  or @{term [source] "fun_upd m k (Some v)"} for map update, can be rewritten
  to abstract operations, and get properly typed.
›

subsection "Proper Protection of Term"
text ‹ The following constants are meant to encode abstraction and 
  application as proper HOL-constants, and thus avoid strange effects with
  HOL's higher-order unification heuristics and automatic 
  beta and eta-contraction.

  The first step of operation identification is to protect the term
  by replacing all function applications and abstractions be 
  the constants defined below.
›

definition [simp]: "PROTECT2 x (y::prop) ≡ x"
consts DUMMY :: "prop"

abbreviation PROTECT2_syn ("'(#_#')") where "PROTECT2_syn t ≡ PROTECT2 t DUMMY"

abbreviation (input)ABS2 :: "('a=>'b)=>'a=>'b" (binder 2" 10)
  where "ABS2 f ≡ (λx. PROTECT2 (f x) DUMMY)"

lemma beta: "(λ2x. f x)$x ≡ f x" by simp

text ‹
  Another version of @{const "APP"}. Treated like @{const APP} by our tool.
  Required to avoid infinite pattern rewriting in some cases, e.g., map-lookup.
›

definition APP' (infixl "$''" 900) where [simp, autoref_tag_defs]: "f$'a ≡ f a"

text ‹
  Sometimes, whole terms should be protected from being processed by our tool.
  For example, our tool should not look into numerals. For this reason,
  the @{text "PR_CONST"} tag indicates terms that our tool shall handle as
  atomic constants, an never look into them.

  The special form @{text "UNPROTECT"} can be used inside pattern rewrite rules.
  It has the effect to revert the protection from its argument, and then wrap
  it into a @{text "PR_CONST"}.
›
definition [simp, autoref_tag_defs]: "PR_CONST x ≡ x" -- "Tag to protect constant"
definition [simp, autoref_tag_defs]: "UNPROTECT x ≡ x" -- {* Gets 
  converted to @{term PR_CONST}, after unprotecting its content*}


subsection {* Operation Identification *}

text ‹ Indicator predicate for conceptual typing of a constant ›
definition intf_type :: "'a => 'b itself => bool" (infix "::i" 10) where
  [simp]: "c::iI ≡ True"

lemma itypeI: "c::iI" by simp

text ‹ Wrapper predicate for an conceptual type inference ›
definition ID :: "'a => 'a => 'c itself => bool" 
  where [simp]: "ID t t' T ≡ t=t'"

subsubsection {* Conceptual Typing Rules *}

lemma ID_unfold_vars: "ID x y T ==> x≡y" by simp
lemma ID_PR_CONST_trigger: "ID (PR_CONST x) y T ==> ID (PR_CONST x) y T" .

lemma pat_rule:
  "[| p≡p'; ID p' t' T |] ==> ID p t' T" by simp

lemma app_rule:
  "[| ID f f' TYPE('a=>'b); ID x x' TYPE('a)|] ==> ID (f$x) (f'$x') TYPE('b)"
  by simp

lemma app'_rule:
  "[| ID f f' TYPE('a=>'b); ID x x' TYPE('a)|] ==> ID (f$'x) (f'$x') TYPE('b)"
  by simp

lemma abs_rule:
  "[| !!x x'. ID x x' TYPE('a) ==> ID (t x) (t' x') TYPE('b) |] ==>
    ID (λ2x. t x) (λ2x'. t' x') TYPE('a=>'b)"
  by simp

lemma id_rule: "c::iI ==> ID c c I" by simp

lemma fallback_rule:
  "ID (c::'a) c TYPE('c)"
  by simp

lemma unprotect_rl1: "ID (PR_CONST x) t T ==> ID (UNPROTECT x) t T"
  by simp

subsection ‹ ML-Level code ›
ML {*
infix 0 THEN_ELSE_COMB'

signature ID_OP_TACTICAL = sig
  val SOLVE_FWD: tactic' -> tactic'
  val DF_SOLVE_FWD: bool -> tactic' -> tactic'
end

structure Id_Op_Tactical :ID_OP_TACTICAL = struct

  fun SOLVE_FWD tac i st = SOLVED' (
    tac 
    THEN_ALL_NEW_FWD (SOLVE_FWD tac)) i st


  (* Search for solution with DFS-strategy. If dbg-flag is given,
    return sequence of stuck states if no solution is found.
  *)
  fun DF_SOLVE_FWD dbg tac = let
    val stuck_list_ref = Unsynchronized.ref []

    fun stuck_tac _ st = if dbg then (
      stuck_list_ref := st :: !stuck_list_ref;
      Seq.empty
    ) else Seq.empty

    fun rec_tac i st = (
        (tac THEN_ALL_NEW_FWD (SOLVED' rec_tac))
        ORELSE' stuck_tac
      ) i st

    fun fail_tac _ _ = if dbg then
      Seq.of_list (rev (!stuck_list_ref))
    else Seq.empty
  in
    rec_tac ORELSE' fail_tac    
  end

end
*}


ML {*

  structure Id_Op = struct

    fun id_a_conv cnv ct = case term_of ct of
      @{mpat "ID _ _ _"} => Conv.fun_conv (Conv.fun_conv (Conv.arg_conv cnv)) ct
    | _ => raise CTERM("id_a_conv",[ct])

    fun 
      protect env (t1$t2) = let
        val t1 = protect env t1
        val t2 = protect env t2
      in
        @{mk_term env: "?t1.0 $ ?t2.0"}
      end
    | protect env (Abs (x,T,t)) = let
        val t = protect (T::env) t
      in
        @{mk_term env: "λv_x::?'v_T. PROTECT2 ?t DUMMY"}
      end
    | protect _ t = t

    fun protect_conv ctxt = Refine_Util.f_tac_conv ctxt
      (protect []) 
      (simp_tac 
        (put_simpset HOL_basic_ss ctxt addsimps @{thms PROTECT2_def APP_def}) 1)

    fun unprotect_conv ctxt
      = Simplifier.rewrite (put_simpset HOL_basic_ss ctxt 
        addsimps @{thms PROTECT2_def APP_def})

    fun do_unprotect_tac ctxt =
      rtac @{thm unprotect_rl1} THEN'
      CONVERSION (Refine_Util.HOL_concl_conv (fn ctxt => id_a_conv (unprotect_conv ctxt)) ctxt)

    val cfg_id_debug = 
      Attrib.setup_config_bool @{binding id_debug} (K false)

    val cfg_id_trace_fallback = 
      Attrib.setup_config_bool @{binding id_trace_fallback} (K false)

    fun dest_id_rl thm = case concl_of thm of
      @{mpat (typs) "Trueprop (?c::iTYPE(?'v_T))"} => (c,T)
    | _ => raise THM("dest_id_rl",~1,[thm])

    
    structure id_rules = Named_Thms (
      val name = @{binding id_rules};
      val description = "Operation identification rules"
    )

    structure pat_rules = Named_Thms (
      val name = @{binding pat_rules};
      val description = "Operation pattern rules"
    )

    structure def_pat_rules = Named_Thms (
      val name = @{binding def_pat_rules};
      val description = "Definite operation pattern rules (not backtracked over)"
    )

    datatype id_tac_mode = Init | Step | Normal | Solve

    fun id_tac ss ctxt = let
      open Id_Op_Tactical
      val thy = Proof_Context.theory_of ctxt
      val certT = ctyp_of thy
      val cert = cterm_of thy

      val id_rules = id_rules.get ctxt
      val pat_rules = pat_rules.get ctxt
      val def_pat_rules = def_pat_rules.get ctxt

      val rl_net = Tactic.build_net (
        (pat_rules |> map (fn thm => thm RS @{thm pat_rule})) 
        @ @{thms app_rule app'_rule abs_rule} 
        @ (id_rules |> map (fn thm => thm RS @{thm id_rule}))
      )

      val def_rl_net = Tactic.build_net (
        (def_pat_rules |> map (fn thm => thm RS @{thm pat_rule}))
      )  

      val id_pr_const_rename_tac = 
          rtac @{thm ID_PR_CONST_trigger} THEN'
          Subgoal.FOCUS (fn { context=ctxt, prems, ... } => 
            let
              fun is_ID @{mpat "Trueprop (ID _ _ _)"} = true | is_ID _ = false
              val prems = filter (prop_of #> is_ID) prems
              val eqs = map (fn thm => thm RS @{thm ID_unfold_vars}) prems
              val conv = Conv.rewrs_conv eqs
              val conv = fn ctxt => (Conv.top_sweep_conv (K conv) ctxt)
              val conv = fn ctxt => Conv.fun2_conv (Conv.arg_conv (conv ctxt))
              val conv = Refine_Util.HOL_concl_conv conv ctxt
            in CONVERSION conv 1 end 
          ) ctxt THEN'
          rtac @{thm id_rule} THEN'
          resolve_tac id_rules 

      val ityping = id_rules 
        |> map dest_id_rl
        |> filter (is_Const o #1)
        |> map (apfst (#1 o dest_Const))
        |> Symtab.make_list

      val has_type = Symtab.defined ityping

      val fallback_tac = IF_EXGOAL (fn i => fn st =>
        case Logic.concl_of_goal (prop_of st) i of
          @{mpat "Trueprop (ID ?c _ _)"} => ( case c of
            Const (name,cT) => ( 
              case try (Sign.the_const_constraint thy) name of
                SOME T => 
                  if not (has_type name) then 
                    let
                      val thm = @{thm fallback_rule} 
                        |> Drule.instantiate' 
                             [SOME (certT cT), SOME (certT T)] 
                             [SOME (cert c)]
                      val _ = Config.get ctxt cfg_id_trace_fallback       
                        andalso let 
                          open Pretty
                          val p = block [str "ID_OP: Applying fallback rule: ", Display.pretty_thm ctxt thm]
                        in 
                          string_of p |> tracing; 
                          false
                        end
                    in
                      rtac thm i st
                    end
                  else Seq.empty
              | _ => Seq.empty
            )
          | _ => Seq.empty
          )
        | _ => Seq.empty
      )

      val init_tac = CONVERSION (
        Refine_Util.HOL_concl_conv (fn ctxt => (id_a_conv (protect_conv ctxt))) 
          ctxt
      )

      val step_tac = (FIRST' [
        atac, 
        resolve_from_net_tac def_rl_net, 
        resolve_from_net_tac rl_net, 
        id_pr_const_rename_tac,
        do_unprotect_tac ctxt, 
        fallback_tac])

      val solve_tac = DF_SOLVE_FWD (Config.get ctxt cfg_id_debug) step_tac  

    in
      case ss of
        Init => init_tac 
      | Step => step_tac 
      | Normal => init_tac THEN' solve_tac
      | Solve => solve_tac

    end

    val setup = I
      #> id_rules.setup
      #> pat_rules.setup
      #> def_pat_rules.setup
  end

*}

setup Id_Op.setup


subsection ‹Default Setup›

subsubsection ‹Maps›
typedecl ('k,'v) i_map

definition [simp]: "op_map_empty ≡ Map.empty"
definition [simp]: "op_map_is_empty m ≡ m = Map.empty"
definition [simp]: "op_map_update k v m ≡ m(k\<mapsto>v)"
definition [simp]: "op_map_delete k m ≡ m|`(-{k})"
definition [simp]: "op_map_lookup k m ≡ m k::'a option"

lemma pat_map_empty[pat_rules]: 2_. None ≡ op_map_empty" by simp

lemma pat_map_is_empty[pat_rules]: 
  "op =$m$(λ2_. None) ≡ op_map_is_empty$m" 
  "op =$(λ2_. None)$m ≡ op_map_is_empty$m" 
  "op =$(dom$m)${} ≡ op_map_is_empty$m"
  "op =${}$(dom$m) ≡ op_map_is_empty$m"
  unfolding atomize_eq
  by auto

lemma pat_map_update[pat_rules]: 
  "fun_upd$m$k$(Some$v) ≡ op_map_update$'k$'v$'m"
  by simp
lemma pat_map_lookup[pat_rules]: "m$k ≡ op_map_lookup$'k$'m"
  by simp
lemma op_map_delete_pat[pat_rules]: 
  "op |` $ m $ (uminus $ (insert $ k $ {})) ≡ op_map_delete$'k$'m"
  by simp

lemma id_map_empty[id_rules]: "op_map_empty ::i TYPE(('k,'v) i_map)"
  by simp

lemma id_map_is_empty[id_rules]: "op_map_is_empty ::i TYPE(('k,'v) i_map => bool)"
  by simp

lemma id_map_update[id_rules]: 
  "op_map_update ::i TYPE('k => 'v => ('k,'v) i_map => ('k,'v) i_map)"
  by simp

lemma id_map_lookup[id_rules]: 
  "op_map_lookup ::i TYPE('k => ('k,'v) i_map => 'v option)"
  by simp

lemma id_map_delete[id_rules]: 
  "op_map_delete ::i TYPE('k => ('k,'v) i_map => ('k,'v) i_map)"
  by simp

subsubsection ‹Numerals› 
lemma pat_numeral[def_pat_rules]: "numeral$x ≡ UNPROTECT (numeral$x)" by simp

lemma id_nat_const[id_rules]: "(PR_CONST (a::nat)) ::i TYPE(nat)" by simp
lemma id_int_const[id_rules]: "(PR_CONST (a::int)) ::i TYPE(int)" by simp

subsection ‹Example›
schematic_lemma 
  "ID (λa b. (b(1::int\<mapsto>2::nat) |`(-{3})) a, Map.empty, λa. case a of None => Some a | Some _ => None) (?c) (?T::?'d itself)"
  (*"TERM (?c,?T)"*)
  using [[id_debug]]
  by (tactic {* Id_Op.id_tac Id_Op.Normal @{context} 1  *})  

end