Theory Sepref_Lin_Ana

theory Sepref_Lin_Ana
imports Sepref_Monadify
header ‹Linearity Analysis›
theory Sepref_Lin_Ana
imports Sepref_Monadify
begin
text {*
  The goal of this phase is to add to each occurrence of a bound variable
  a flag that indicates whether the value stored in this bound variable
  is accessed again (non-linear) or not (linear).

  The intention is that, for linear references to bound variables,
  the content of the variable on the heap may be destroyed.
*}


datatype lin_type -- ‹Type of linearity annotation›
  = LINEAR | NON_LINEAR

definition LIN_ANNOT -- ‹Tag to annotate linearity›
  :: "'a => lin_type => 'a" 
  where [simp]: "LIN_ANNOT x T == x"

abbreviation is_LINEAR ("_L") where "xL == LIN_ANNOT x LINEAR"
abbreviation is_NON_LINEAR ("_N") where "xN == LIN_ANNOT x NON_LINEAR"

text {*
  Internally, this linearity analysis works in two stages. First, a constraint
  system is generated from the program, which is solved in the second stage, to
  obtain the linearity annotations.
*}


text {*
  In the following, we define constants to represent the constraints
*}
type_synonym la_skel = unit

consts 
  la_seq :: "la_skel => la_skel => la_skel" -- "Sequential evaluation"
  la_choice :: "la_skel => la_skel => la_skel" -- "Alternatives"
  la_rec :: "(la_skel => la_skel) => la_skel" -- "Recursion"
  la_rcall :: "la_skel => la_skel" -- "Recursive call"
  la_op :: "'a => la_skel" -- "Primitive operand"
  la_lambda :: "(la_skel => la_skel) => la_skel" -- "Lambda abstraction"
  SKEL :: "'a => la_skel" -- "Tag to indicate progress of constraint system generation"
  UNSKEL :: "la_skel => 'a" -- "Placed on arguments of recursion and abstraction"

definition lin_ana -- ‹Tag to indicate linearity analysis›
  where [simp]: "lin_ana x ≡ True"
lemma lin_anaI: "lin_ana x" by simp

lemma lin_ana_init:
  assumes "lin_ana (SKEL a)"
  assumes "hn_refine Γ c Γ' R a"
  shows "hn_refine Γ c Γ' R a"
  by fact


ML {*
  structure Sepref_Lin_Ana = struct
    structure skel_eqs = Named_Thms (
      val name = @{binding sepref_la_skel}
      val description = "Sepref.Linearity-Analysis: Skeleton equations"
    )

    local
      fun add_annot_vars t = let
        val prefix = 
          let
            val context = Name.make_context (Term.add_var_names t [] |> map #1)
          in (Name.variant "a" context |> #1) ^ "_" end

        fun f (e,i) (t1$t2) = let
              val (i,t1) = f (e,i) t1
              val (i,t2) = f (e,i) t2
            in (i,t1$t2) end
          | f (e,i) (Abs(x,T,t)) = let
              val (i,t) = f (T::e,i) t
            in (i,Abs (x,T,t)) end
          | f (e,i) (t as Bound _) = let
              val a = Var ((prefix^string_of_int i,0),@{typ lin_type})
              val t = @{mk_term e: "LIN_ANNOT ?t ?a"}
            in (i+1,t) end
          | f (_,i) t = (i,t)

      in
        f ([],0) t |> #2
      end
    in 
      (* Add schematic linearity annotation to each bound variable *)
      fun add_annot_vars_conv ctxt = Refine_Util.f_tac_conv ctxt 
        (add_annot_vars) 
        (simp_tac 
          (put_simpset HOL_basic_ss ctxt addsimps @{thms LIN_ANNOT_def}) 1)

    end

    local
      fun fin_annot_vars (t as @{mpat "_L"}) = t
        | fin_annot_vars (t as @{mpat "_N"}) = t
        | fin_annot_vars (@{mpat "LIN_ANNOT ?x _"}) = x
        | fin_annot_vars (t1$t2) = fin_annot_vars t1 $ fin_annot_vars t2
        | fin_annot_vars (Abs (x,T,t)) = Abs (x,T,fin_annot_vars t)
        | fin_annot_vars t = t
    in
      (* Remove all unfinished linearity annotations *)
      fun fin_annot_vars_conv ctxt = Refine_Util.f_tac_conv ctxt 
        (fin_annot_vars) 
        (simp_tac 
          (put_simpset HOL_basic_ss ctxt addsimps @{thms LIN_ANNOT_def}) 1)
    end

    local
      datatype env = Val of bool | Rec of int list

      fun set_used (Val _) = Val true | set_used x = x

      fun 
        merge_env [] [] = []
      | merge_env (Val b1::r1) (Val b2::r2) 
        = Val (b1 orelse b2) :: merge_env r1 r2
      | merge_env (Rec l1::r1) (Rec _::r2) = Rec l1 :: merge_env r1 r2
      | merge_env _ _ = error "merge_env: Unequal length or rec/val mismatch"

      fun 
        lin_ana (env : env list )
          @{mpat "la_seq ?s ?t"}
          : env list * (term * term) list
        = 
        let
          val (env,s1) = lin_ana env t
          val (env,s2) = lin_ana env s
        in 
          (env,s1@s2)
        end
      | lin_ana env @{mpat "la_choice ?s ?t"} = let
          val (env1,s1) = lin_ana env s
          val (env2,s2) = lin_ana env t
        in (merge_env env1 env2, s1@s2) end
      | lin_ana env @{mpat "la_rec (λ_. ?f)"} = let
          val f_used = add_loose_bnos (f,1,[]) |> map (curry op + 1)
          val env = Rec f_used :: env
          val (env,s) = lin_ana env f
        in (tl env,s) end
      | lin_ana env @{mpat "la_rcall (mpaq_STRUCT (mpaq_Bound ?i))"} = let
          val used = case nth env i of Rec used => used 
            | _ => raise TERM ("lin_ana: rcall rec/val mismatch",[Bound i])
          val used = map (curry op + i) used

          val env = map_index 
            (fn (i,e) => if member op= used i then set_used e else e)
            env

        in
          (env,[])
        end
      | lin_ana env @{mpat "la_lambda (λ_. ?f)"} = let
          val (env,s) = lin_ana (Val false::env) f
        in
          (tl env,s)
        end
      | lin_ana env @{mpat "la_op ?t"} = let
          (* Collect loose bound vars with their annotations *)
          fun collect n @{mpat "LIN_ANNOT (mpaq_STRUCT (mpaq_Bound ?i)) ?a"} = 
                if i>=n then [(i-n,a)] else []
            | collect n (t1$t2) = collect n t1 @ collect n t2
            | collect n (Abs (_,_,t)) = collect (n+1) t
            | collect _ _ = []

          val used = collect 0 t

          (* Check whether they are used in env … add subst to result *)
          val s = map (fn (i,a) => case nth env i of
              Val false => (a,@{const LINEAR})
            | Val true => (a,@{const NON_LINEAR})
            | _ => 
                raise TERM ("lin_ana: Invalid occurence of recursion var",[t])
          ) used

          (* Mark them as used in env *)
          val used = map #1 used
          val env = map_index (fn (i,e) => 
            if member op= used i then set_used e else e) env

        in
          (env,s)
        end
      | lin_ana _ t = raise TERM ("lin_ana: Invalid",[t])


      fun lin_ana_trans t = let
        val (_,s) = lin_ana [] t
        val res = subst_atomic s t
      in
        res
      end

    in
      (* Solve linearity constraint system: As conversion*)
      fun lin_ana_conv ctxt = Refine_Util.f_tac_conv ctxt 
        (lin_ana_trans) 
        (simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms LIN_ANNOT_def}) 1)

      (* Solve linearity constraint system: As tactic *)  
      fun lin_ana_inst_tac i st = case Logic.concl_of_goal (prop_of st) i of
        @{mpat "Trueprop (lin_ana ?t)"} => let
          val thy = theory_of_thm st
          val cert = cterm_of thy
          val (_,s) = lin_ana [] t
          val s = map (pairself cert) s
        in
          ( rtac @{thm lin_anaI} i 
            THEN PRIMITIVE (Thm.instantiate ([],s))
          ) st
        end
      | _ => Seq.empty
  
    end
  
    (* TODO: Move *)
    fun ex_aterm P (t1$t2) = ex_aterm P t1 orelse ex_aterm P t2
      | ex_aterm P (Abs (_,_,t)) = ex_aterm P t
      | ex_aterm P t = P t
  
    val contains_skel = ex_aterm (fn @{mpat "SKEL"} => true | _ => false)
  
    (* Perform linearity analysis *)
    fun lin_ana_tac ctxt = let
      fun err_tac i st = let
        val g = Logic.get_goal (prop_of st) i
        val _ = Pretty.block [
          Pretty.str "Unresolved combinators remain:", Pretty.brk 1,
          Syntax.pretty_term ctxt g
        ] |> Pretty.string_of |> tracing
      in
        Seq.empty
      end

      open Sepref_Basic
    in
      (* Add schematic annotations *)
      CONVERSION (hn_refine_concl_conv_a (K (add_annot_vars_conv ctxt)) ctxt)

      (* Generate constraint system *)
      THEN' rtac @{thm lin_ana_init}
      THEN' simp_tac (put_simpset HOL_basic_ss ctxt addsimps skel_eqs.get ctxt)
      THEN' (
        COND' contains_skel
        THEN_ELSE' 
        ( err_tac, (* CS not fully generated*)
          lin_ana_inst_tac (* Solve CS*)
          THEN' 
            CONVERSION (hn_refine_concl_conv_a (K (fin_annot_vars_conv ctxt)) ctxt)
        ))
    end
  
    val setup = skel_eqs.setup
  end
*}

setup Sepref_Lin_Ana.setup

lemma dflt_skel_eqs[sepref_la_skel]:
  "!!a b. SKEL (bind$a$b) ≡ la_seq (SKEL a) (SKEL b)" 
  "!!a. SKEL (RETURN$a) ≡ la_op a"
  "!!a. SKEL (PASS$a) ≡ la_op a"
  "!!f x. SKEL (RECT$(λ2D. f D)$x) 
    ≡ la_seq (la_op x) (la_rec (λD. SKEL (f (UNSKEL D))))"
  "!!D x. SKEL (RCALL$(UNSKEL D)$x) ≡ la_seq (la_op x) (la_rcall D)"
  "!!D a. la_rcall (LIN_ANNOT D a) ≡ la_rcall D"
  "!!f. SKEL (λ2x. f x) ≡ la_lambda (λx. SKEL (f (UNSKEL x)))"
  "!!v a. LIN_ANNOT (UNSKEL v) a = UNSKEL (LIN_ANNOT v a)"
  "!!x. la_op (UNSKEL x) = la_op x"
  "!!f p. SKEL (case_prod$f$p) ≡ la_seq (la_op p) (SKEL f)"
  "!!fn fc l. SKEL (case_list$fn$fc$l) 
    ≡ la_seq (la_op l) (la_choice (SKEL fn) (SKEL fc))"
  "!!fn fs ov. SKEL (case_option$fn$fs$ov) 
    ≡ la_seq (la_op ov) (la_choice (SKEL fn) (SKEL fs))"
  "!!v f. SKEL (Let$v$f) ≡ la_seq (la_op v) (SKEL f)"
  "!!b t e. SKEL (If$b$t$e) ≡ la_seq (la_op b) (la_choice (SKEL t) (SKEL e))"
  by simp_all

end