header ‹Linearity Analysis› theory Sepref_Lin_Ana imports Sepref_Monadify begin text {* The goal of this phase is to add to each occurrence of a bound variable a flag that indicates whether the value stored in this bound variable is accessed again (non-linear) or not (linear). The intention is that, for linear references to bound variables, the content of the variable on the heap may be destroyed. *} datatype lin_type -- ‹Type of linearity annotation› = LINEAR | NON_LINEAR definition LIN_ANNOT -- ‹Tag to annotate linearity› :: "'a => lin_type => 'a" where [simp]: "LIN_ANNOT x T == x" abbreviation is_LINEAR ("_⇧L") where "x⇧L == LIN_ANNOT x LINEAR" abbreviation is_NON_LINEAR ("_⇧N") where "x⇧N == LIN_ANNOT x NON_LINEAR" text {* Internally, this linearity analysis works in two stages. First, a constraint system is generated from the program, which is solved in the second stage, to obtain the linearity annotations. *} text {* In the following, we define constants to represent the constraints *} type_synonym la_skel = unit consts la_seq :: "la_skel => la_skel => la_skel" -- "Sequential evaluation" la_choice :: "la_skel => la_skel => la_skel" -- "Alternatives" la_rec :: "(la_skel => la_skel) => la_skel" -- "Recursion" la_rcall :: "la_skel => la_skel" -- "Recursive call" la_op :: "'a => la_skel" -- "Primitive operand" la_lambda :: "(la_skel => la_skel) => la_skel" -- "Lambda abstraction" SKEL :: "'a => la_skel" -- "Tag to indicate progress of constraint system generation" UNSKEL :: "la_skel => 'a" -- "Placed on arguments of recursion and abstraction" definition lin_ana -- ‹Tag to indicate linearity analysis› where [simp]: "lin_ana x ≡ True" lemma lin_anaI: "lin_ana x" by simp lemma lin_ana_init: assumes "lin_ana (SKEL a)" assumes "hn_refine Γ c Γ' R a" shows "hn_refine Γ c Γ' R a" by fact ML {* structure Sepref_Lin_Ana = struct structure skel_eqs = Named_Thms ( val name = @{binding sepref_la_skel} val description = "Sepref.Linearity-Analysis: Skeleton equations" ) local fun add_annot_vars t = let val prefix = let val context = Name.make_context (Term.add_var_names t [] |> map #1) in (Name.variant "a" context |> #1) ^ "_" end fun f (e,i) (t1$t2) = let val (i,t1) = f (e,i) t1 val (i,t2) = f (e,i) t2 in (i,t1$t2) end | f (e,i) (Abs(x,T,t)) = let val (i,t) = f (T::e,i) t in (i,Abs (x,T,t)) end | f (e,i) (t as Bound _) = let val a = Var ((prefix^string_of_int i,0),@{typ lin_type}) val t = @{mk_term e: "LIN_ANNOT ?t ?a"} in (i+1,t) end | f (_,i) t = (i,t) in f ([],0) t |> #2 end in (* Add schematic linearity annotation to each bound variable *) fun add_annot_vars_conv ctxt = Refine_Util.f_tac_conv ctxt (add_annot_vars) (simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms LIN_ANNOT_def}) 1) end local fun fin_annot_vars (t as @{mpat "_⇧L"}) = t | fin_annot_vars (t as @{mpat "_⇧N"}) = t | fin_annot_vars (@{mpat "LIN_ANNOT ?x _"}) = x | fin_annot_vars (t1$t2) = fin_annot_vars t1 $ fin_annot_vars t2 | fin_annot_vars (Abs (x,T,t)) = Abs (x,T,fin_annot_vars t) | fin_annot_vars t = t in (* Remove all unfinished linearity annotations *) fun fin_annot_vars_conv ctxt = Refine_Util.f_tac_conv ctxt (fin_annot_vars) (simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms LIN_ANNOT_def}) 1) end local datatype env = Val of bool | Rec of int list fun set_used (Val _) = Val true | set_used x = x fun merge_env [] [] = [] | merge_env (Val b1::r1) (Val b2::r2) = Val (b1 orelse b2) :: merge_env r1 r2 | merge_env (Rec l1::r1) (Rec _::r2) = Rec l1 :: merge_env r1 r2 | merge_env _ _ = error "merge_env: Unequal length or rec/val mismatch" fun lin_ana (env : env list ) @{mpat "la_seq ?s ?t"} : env list * (term * term) list = let val (env,s1) = lin_ana env t val (env,s2) = lin_ana env s in (env,s1@s2) end | lin_ana env @{mpat "la_choice ?s ?t"} = let val (env1,s1) = lin_ana env s val (env2,s2) = lin_ana env t in (merge_env env1 env2, s1@s2) end | lin_ana env @{mpat "la_rec (λ_. ?f)"} = let val f_used = add_loose_bnos (f,1,[]) |> map (curry op + 1) val env = Rec f_used :: env val (env,s) = lin_ana env f in (tl env,s) end | lin_ana env @{mpat "la_rcall (mpaq_STRUCT (mpaq_Bound ?i))"} = let val used = case nth env i of Rec used => used | _ => raise TERM ("lin_ana: rcall rec/val mismatch",[Bound i]) val used = map (curry op + i) used val env = map_index (fn (i,e) => if member op= used i then set_used e else e) env in (env,[]) end | lin_ana env @{mpat "la_lambda (λ_. ?f)"} = let val (env,s) = lin_ana (Val false::env) f in (tl env,s) end | lin_ana env @{mpat "la_op ?t"} = let (* Collect loose bound vars with their annotations *) fun collect n @{mpat "LIN_ANNOT (mpaq_STRUCT (mpaq_Bound ?i)) ?a"} = if i>=n then [(i-n,a)] else [] | collect n (t1$t2) = collect n t1 @ collect n t2 | collect n (Abs (_,_,t)) = collect (n+1) t | collect _ _ = [] val used = collect 0 t (* Check whether they are used in env … add subst to result *) val s = map (fn (i,a) => case nth env i of Val false => (a,@{const LINEAR}) | Val true => (a,@{const NON_LINEAR}) | _ => raise TERM ("lin_ana: Invalid occurence of recursion var",[t]) ) used (* Mark them as used in env *) val used = map #1 used val env = map_index (fn (i,e) => if member op= used i then set_used e else e) env in (env,s) end | lin_ana _ t = raise TERM ("lin_ana: Invalid",[t]) fun lin_ana_trans t = let val (_,s) = lin_ana [] t val res = subst_atomic s t in res end in (* Solve linearity constraint system: As conversion*) fun lin_ana_conv ctxt = Refine_Util.f_tac_conv ctxt (lin_ana_trans) (simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms LIN_ANNOT_def}) 1) (* Solve linearity constraint system: As tactic *) fun lin_ana_inst_tac i st = case Logic.concl_of_goal (prop_of st) i of @{mpat "Trueprop (lin_ana ?t)"} => let val thy = theory_of_thm st val cert = cterm_of thy val (_,s) = lin_ana [] t val s = map (pairself cert) s in ( rtac @{thm lin_anaI} i THEN PRIMITIVE (Thm.instantiate ([],s)) ) st end | _ => Seq.empty end (* TODO: Move *) fun ex_aterm P (t1$t2) = ex_aterm P t1 orelse ex_aterm P t2 | ex_aterm P (Abs (_,_,t)) = ex_aterm P t | ex_aterm P t = P t val contains_skel = ex_aterm (fn @{mpat "SKEL"} => true | _ => false) (* Perform linearity analysis *) fun lin_ana_tac ctxt = let fun err_tac i st = let val g = Logic.get_goal (prop_of st) i val _ = Pretty.block [ Pretty.str "Unresolved combinators remain:", Pretty.brk 1, Syntax.pretty_term ctxt g ] |> Pretty.string_of |> tracing in Seq.empty end open Sepref_Basic in (* Add schematic annotations *) CONVERSION (hn_refine_concl_conv_a (K (add_annot_vars_conv ctxt)) ctxt) (* Generate constraint system *) THEN' rtac @{thm lin_ana_init} THEN' simp_tac (put_simpset HOL_basic_ss ctxt addsimps skel_eqs.get ctxt) THEN' ( COND' contains_skel THEN_ELSE' ( err_tac, (* CS not fully generated*) lin_ana_inst_tac (* Solve CS*) THEN' CONVERSION (hn_refine_concl_conv_a (K (fin_annot_vars_conv ctxt)) ctxt) )) end val setup = skel_eqs.setup end *} setup Sepref_Lin_Ana.setup lemma dflt_skel_eqs[sepref_la_skel]: "!!a b. SKEL (bind$a$b) ≡ la_seq (SKEL a) (SKEL b)" "!!a. SKEL (RETURN$a) ≡ la_op a" "!!a. SKEL (PASS$a) ≡ la_op a" "!!f x. SKEL (RECT$(λ⇩2D. f D)$x) ≡ la_seq (la_op x) (la_rec (λD. SKEL (f (UNSKEL D))))" "!!D x. SKEL (RCALL$(UNSKEL D)$x) ≡ la_seq (la_op x) (la_rcall D)" "!!D a. la_rcall (LIN_ANNOT D a) ≡ la_rcall D" "!!f. SKEL (λ⇩2x. f x) ≡ la_lambda (λx. SKEL (f (UNSKEL x)))" "!!v a. LIN_ANNOT (UNSKEL v) a = UNSKEL (LIN_ANNOT v a)" "!!x. la_op (UNSKEL x) = la_op x" "!!f p. SKEL (case_prod$f$p) ≡ la_seq (la_op p) (SKEL f)" "!!fn fc l. SKEL (case_list$fn$fc$l) ≡ la_seq (la_op l) (la_choice (SKEL fn) (SKEL fc))" "!!fn fs ov. SKEL (case_option$fn$fs$ov) ≡ la_seq (la_op ov) (la_choice (SKEL fn) (SKEL fs))" "!!v f. SKEL (Let$v$f) ≡ la_seq (la_op v) (SKEL f)" "!!b t e. SKEL (If$b$t$e) ≡ la_seq (la_op b) (la_choice (SKEL t) (SKEL e))" by simp_all end