Theory DF_Solver

theory DF_Solver
imports Refine_Lib
header {* Depth-First Search Solver with Forward Constraints *}
theory DF_Solver
imports "../Automatic_Refinement/Lib/Refine_Lib"
begin

text ‹
  This solver tries to solve a subgoal by repeatedly applying a tactic,
  backtracking in a depth-first manner.
  Apart from normal subgoals, the tactic may also produce constraint subgoals,
  which pose constraints on terms. These constraints are solved recursively by
  a special set of rules, unless the term is a schematic variable. In this case,
  solving is delayed until the schematic variable is instantiated, or until
  all other constraints are solved.
›

definition CONSTRAINT where [simp]: "CONSTRAINT P x ≡ P x"
definition SOLVED where [simp]: "SOLVED ≡ True"

lemma SOLVED_I_eq:
  "PROP P == (SOLVED ==> PROP P)"
  unfolding SOLVED_def by simp

lemma is_SOLVED: "SOLVED ==> SOLVED" .

lemma SOLVED_I: "SOLVED" by simp

lemma CONSTRAINT_D:
  assumes "CONSTRAINT (P::'a => bool) x"
  shows "P x"
  using assms unfolding CONSTRAINT_def by simp

lemma CONSTRAINT_I:
  assumes "P x"
  shows "CONSTRAINT (P::'a => bool) x"
  using assms unfolding CONSTRAINT_def by simp

lemma is_CONSTRAINT_rl: "CONSTRAINT P x ==> CONSTRAINT P x" .

ML {*
  signature DF_SOLVER = sig
    val add_constraint_rule: thm -> Context.generic -> Context.generic
    val del_constraint_rule: thm -> Context.generic -> Context.generic
    val get_constraint_rules: Proof.context -> thm list

    val add_forced_constraint_rule: thm -> Context.generic -> Context.generic
    val del_forced_constraint_rule: thm -> Context.generic -> Context.generic
    val get_forced_constraint_rules: Proof.context -> thm list

    val check_constraints_tac: Proof.context -> tactic
    val constraint_tac: Proof.context -> tactic'

    val force_constraints_tac: Proof.context -> tactic
    val force_constraint_tac: Proof.context -> tactic'

    val defer_constraints: tactic' -> tactic'

    val DF_SOLVE_FWD_C: bool -> tactic' -> Proof.context -> tactic'

    val is_constraint_tac: tactic'

    val setup: theory -> theory

  end

  structure DF_Solver :DF_SOLVER = struct 
    local
      fun prepare_constraint_conv ctxt = let
        open Conv 
        fun CONSTRAINT_conv ct = case term_of ct of
          @{mpat "Trueprop (_ _)"} => 
            HOLogic.Trueprop_conv 
              (rewr_conv @{thm CONSTRAINT_def[symmetric]}) ct
          | _ => raise CTERM ("CONSTRAINT_conv", [ct])

        fun rec_conv ctxt ct = (
          CONSTRAINT_conv
          else_conv 
          implies_conv (rec_conv ctxt) (rec_conv ctxt)
          else_conv
          forall_conv (rec_conv o #2) ctxt
        ) ct

        (*
        fun add_solved_conv ct = case term_of ct of 
          @{mpat "_==>_"} => all_conv ct
        | _ => rewr_conv @{thm SOLVED_I_eq} ct
        *)

      in
        rec_conv ctxt (*then_conv add_solved_conv*)
      end
    in
      structure constraint_rules = Named_Sorted_Thms (
        val name = @{binding constraint_rules}
        val description = "Constraint rules"
        val sort = K I
        fun transform context thm = let
          open Conv
          val ctxt = Context.proof_of context
        in
          case try (fconv_rule (prepare_constraint_conv ctxt)) thm of
            NONE => raise THM ("Invalid constraint rule",~1,[thm])
          | SOME thm => [thm]
        end
      )

      structure forced_constraint_rules = Named_Sorted_Thms (
        val name = @{binding forced_constraint_rules}
        val description = "Forced Constraint rules"
        val sort = K I
        fun transform context thm = let
          open Conv
          val ctxt = Context.proof_of context
        in
          case try (fconv_rule (prepare_constraint_conv ctxt)) thm of
            NONE => raise THM ("Invalid constraint rule",~1,[thm])
          | SOME thm => [thm]
        end
      )
    end

    val add_constraint_rule = constraint_rules.add_thm
    val del_constraint_rule = constraint_rules.del_thm
    val get_constraint_rules = constraint_rules.get

    val add_forced_constraint_rule = forced_constraint_rules.add_thm
    val del_forced_constraint_rule = forced_constraint_rules.del_thm
    val get_forced_constraint_rules = forced_constraint_rules.get

    fun constraint_tac_aux thms no_ctac i st = 
      case Logic.concl_of_goal (prop_of st) i |> Envir.beta_eta_contract of
        @{mpat "Trueprop (CONSTRAINT _ ?t)"} => 
          if (is_Var (head_of t)) then
            Seq.single st
          else (
              resolve_tac thms THEN_ALL_NEW_FWD constraint_tac_aux thms (K no_tac)
            ) i st
      | @{mpat "Trueprop SOLVED"} => Seq.single st
      | _ => no_ctac i st

    fun check_constraints_tac ctxt = DETERM (ALLGOALS (
      constraint_tac_aux (constraint_rules.get ctxt) (K all_tac)
    ))

    fun constraint_tac ctxt = DETERM o (
      constraint_tac_aux (constraint_rules.get ctxt) (K all_tac)
    )

    fun force_constraint_tac_aux ctxt no_ctac i st = 
      case Logic.concl_of_goal (prop_of st) i |> Envir.beta_eta_contract of
        @{mpat "Trueprop (CONSTRAINT _ ?t)"} => 
          (
            (
              if (is_Var (head_of t)) then
                resolve_tac (forced_constraint_rules.get ctxt)
              else
                constraint_tac_aux (constraint_rules.get ctxt) (K no_tac)
            ) THEN_ALL_NEW_FWD force_constraint_tac_aux ctxt (K no_tac)
          ) i st
      | @{mpat "Trueprop SOLVED"} => Seq.single st
      | _ => no_ctac i st

    fun force_constraint_tac ctxt = 
      force_constraint_tac_aux ctxt (K all_tac)

    fun force_constraints_tac ctxt = ALLGOALS (force_constraint_tac ctxt)

    fun is_CONSTRAINT_goal t = case Logic.strip_assums_concl t of
      @{mpat "Trueprop (CONSTRAINT _ _)"} => true
    | _ => false

    (* Defer constraints produced by tac *)
    local
      fun dc_int l u st = 
        if l>u then 
          Seq.single st 
        else if is_CONSTRAINT_goal (Logic.get_goal (prop_of st) l) then
          (defer_tac l THEN dc_int l (u-1)) st
        else
          dc_int (l+1) u st

    in
      fun defer_constraints tac i st = (
        tac i THEN 
        (fn st' => dc_int i (i + nprems_of st' - nprems_of st) st')
      ) st

    end


    local
      fun c_nprems_of st = prems_of st 
        |> filter (not o is_CONSTRAINT_goal) 
        |> length

      (* Apply tactic to subgoals in interval, in a forward manner, 
         skipping over emerging subgoals *)
      fun INTERVAL_FWD tac l u st =
        if l>u then all_tac st 
        else (tac l THEN (fn st' => let
            val ofs = c_nprems_of st' - c_nprems_of st;
          in
            if ofs < ~1 then raise THM (
              "INTERVAL_FWD: Tac solved more than one goal",~1,[st,st'])
            else INTERVAL_FWD tac (l+1+ofs) (u+ofs) st'
          end)) st;

      (* Apply tac2 to all subgoals emerged from tac1, in forward manner. *)
      fun (tac1 THEN_ALL_NEW_FWD tac2) i st =
        (tac1 i 
          THEN (fn st' => 
            INTERVAL_FWD tac2 i (i+c_nprems_of st'-c_nprems_of st) st')
        ) st;

    in

      fun DF_SOLVE_FWD_C dbg tac ctxt = let
        val stuck_list_ref = Unsynchronized.ref []

        fun stuck_tac _ st = if dbg then (
          stuck_list_ref := st :: !stuck_list_ref;
          Seq.empty
        ) else Seq.empty

        fun tac' i =  defer_constraints tac i THEN check_constraints_tac ctxt

        fun is_CONSTRAINT i st = case Logic.concl_of_goal (prop_of st) i of
          @{mpat "Trueprop (CONSTRAINT _ _)"} => Seq.single st
        | _ => Seq.empty

        fun rec_tac i st = (
            rtac @{thm SOLVED_I} ORELSE'
            is_CONSTRAINT ORELSE'
            (tac' THEN_ELSE_COMB' (op THEN_ALL_NEW_FWD, rec_tac, stuck_tac))

            (* (tac' THEN_ALL_NEW_FWD (rec_tac))
            ORELSE' stuck_tac *)
          ) i st

        fun fail_tac _ _ = if dbg then
          Seq.of_list (rev (!stuck_list_ref))
        else Seq.empty
      in
        (rec_tac ORELSE' fail_tac) THEN' (K (ALLGOALS (TRY o (force_constraint_tac ctxt))))
      end

    end

    val is_constraint_tac = rtac @{thm is_CONSTRAINT_rl}

    val setup = 
         constraint_rules.setup 
      #> forced_constraint_rules.setup
  end

*}

setup DF_Solver.setup

method_setup trace_constraints = {* Scan.succeed (fn ctxt => SIMPLE_METHOD (
  fn st => let
    fun is_CONSTRAINT_goal t = case Logic.strip_assums_concl t of
      @{mpat "Trueprop (CONSTRAINT _ _)"} => true
    | _ => false

    val cgoals = prems_of st 
      |> filter is_CONSTRAINT_goal


    val _ = case cgoals of
      [] => tracing "No constraints"
    | _ => ( cgoals
      |> map (Syntax.pretty_term ctxt)
      |> Pretty.fbreaks
      |> Pretty.block
      |> Pretty.string_of |> tracing )

  in
    Seq.single st
  end
  ))
  *}

end