Theory LLVM_Examples

section Examples
theory LLVM_Examples
imports 
  "../ds/LLVM_DS_All"
  "../ds/LLVM_DS_Array_List"
begin

text Examples on top of Isabelle-LLVM basic layer. 
  For the verification of more complex algorithms, consider using
  Isabelle-LLVM with the Refinement Framework, and the Sepref tool.
  See, e.g., @{file Bin_Search.thy}.


(* TODO: Parts of this file are incomplete, the examples could me more elaborate! *)

subsection Numeric Algorithms

subsubsection Exponentiation

definition exp :: "'a::len word  'b::len word llM" where [llvm_code]: "exp r  doM {
  a  ll_const (unsigned 1);
  (a,r)  llc_while 
    (λ(a,r). doM { ll_icmp_ult (unsigned 0) r}) 
    (λ(a,r). doM {
      Mreturn (a*unsigned 2,r-unsigned 1)
    })
    (a,r);
  Mreturn a
}"

abbreviation exp32::"32 word  32 word llM" where "exp32  exp"
abbreviation exp64::"64 word  64 word llM" where "exp64  exp"

export_llvm 
  exp32 is "uint32_t exp32 (uint32_t)" 
  exp64 is "uint64_t exp64 (uint64_t)"
  file "code/exp.ll"

lemma exp_aux1: 
  assumes "2 ^ nat k < (N::int)" "t  k" "0 < t" 
  shows "2 * 2 ^ nat (k - t) < N"
proof -
  from assms have "nat (k - t) + 1  nat k" by auto
  with assms have "(2::int) ^ (nat (k - t) + 1)  2 ^ nat k"
    using one_le_numeral power_increasing by blast
  thus ?thesis using assms by simp
qed
  
lemma exp_aux2:  "t  k; 0 < t  nat (1+k-t) = Suc (nat (k-t))" by simp

lemma exp_correct:
  assumes "LENGTH('b::len)  2"
  shows "llvm_htriple 
    (uint.assn k (ki::'a::len word) ** (2^nat k  uints LENGTH('b))) 
    (exp ki) 
    (λr::'b word. uint.assn (2^nat k) r ** uint.assn k ki)"
  unfolding exp_def
  apply (rewrite annotate_llc_while[where 
    I="λ(ai,ri) t. EXS a r. uint.assn a ai ** uint.assn r ri ** d( 0r  rk  a = 2^nat (k-r) ) ** !(t = r)"
    and R="measure nat"
    ])
  apply vcg_monadify  
  apply (vcg'; (clarsimp simp: algebra_simps)?)
  using assms
  apply (simp_all add: exp_aux1 exp_aux2)
  done

(* TODO: can we restore executability?
text ‹Executability of semantics inside Isabelle›
value "run (exp64 32) llvm_empty_memory"
*)

subsubsection Euclid's Algorithm

                       
definition [llvm_code]: "euclid (a::'a::len word) b  doM {
  (a,b)  llc_while 
    (λ(a,b)  ll_cmp (a  b))
    (λ(a,b)  if (ab) then Mreturn (a,b-a) else Mreturn (a-b,b))
    (a,b);
  Mreturn a
}"
  
export_llvm (debug) (*no_while*) 
  "euclid :: 64 word  64 word  64 word llM" is "uint64_t euclid (uint64_t, uint64_t)"
  file "code/euclid.ll"

  
lemma gcd_diff1': "gcd (a::int) (b-a) = gcd a b"
  by (metis gcd.commute gcd_diff1)   
  

lemma "llvm_htriple 
  (uint.assn a0 ai ** uint.assn b0 bi ** d(0<a0  0<b0)) 
  (euclid ai bi) 
  (λri. uint.assn (gcd a0 b0) ri)"
  unfolding euclid_def
  apply (rewrite annotate_llc_while[where 
    I="λ(ai,bi) t. EXS a b. uint.assn a ai ** uint.assn b bi 
        ** a(t=a+b) ** d(0<a  0<b  gcd a b = gcd a0 b0)" 
    and R="measure nat"  
  ])
  apply vcg_monadify
  apply (vcg'; clarsimp?)
  apply (simp_all add: gcd_diff1 gcd_diff1')
  done

subsubsection Fibonacci Numbers

definition fib :: "'n::len word  'n word llM" where [llvm_code]: "fib n  REC (λfib' n. 
  if nunsigned 1 then Mreturn n 
  else doM { 
    n1  fib' (n-unsigned 1); 
    n2  fib' (n-unsigned 2); 
    Mreturn (n1+n2)     
  }) n"

abbreviation fib64 :: "64 word  64 word llM" where "fib64  fib"
export_llvm thms: fib64
  
(* TODO: Arbitrary fixed-point reasoning not yet supported in VCG!
  set up a rule with pre and post consequence rule, 
  and seplogic-assertions

lemma
  assumes MONO: "⋀x. M.mono_body (λfa. F fa x)"
  assumes "P x s m"
  assumes "wf R"
  assumes "⋀D x s m. ⟦ P x s m; ⋀x' s' m'. ⟦ P x' s' m'; (m',m)∈R ⟧ ⟹ wp (D x') Q s' ⟧ ⟹ wp (F D x) Q s"
  shows "wp (REC F x) Q s"
  using assms(3,2)
  apply (induction m arbitrary: x s rule: wf_induct_rule)
  apply (subst REC_unfold) apply simp apply (rule MONO)
  using assms(4) by simp
  
  

lemma "llvm_htriple (↿uint.assn n ni) (fib ni) (λri. ↿uint.assn x ri)"
  unfolding fib_def
  apply vcg_monadify
  apply vcg
  find_theorems wp REC
*)
  
prepare_code_thms (LLVM) [code] fib_def  (* Set up code equation. Required to execute semantics in Isabelle. *)

(*
term "Abs_memory []"

value "map (λn. run (fib64 n) (Abs_memory [])) [0,1,2,3]"
*)

(*
lemmas [named_ss llvm_inline cong] = refl[of "numeral _"]
*)

definition test :: "64 word  64 word  _ llM"
where [llvm_code]: "test a b  doM {

  Mreturn (a,b) 
}"

ML_val 
  local open LLC_Preprocessor
    val ctxt = @{context}
  in

    val thm = @{thm test_def}
      |> cthm_inline ctxt
      |> cthm_monadify ctxt
  
  end



find_theorems llc_while

lemma "foo (test)"
  unfolding test_def
  apply (simp named_ss llvm_pre_simp:)
  oops

export_llvm test


subsubsection Distance between two Points (double)

context begin

  (*
    TODO: Generalize monadification/preprocessor to push nanize into operations!
      Otherwise, we have to flatten by hand!
  *)
  lemma plus_nan_double1[simp]:
    "is_nan_double a  is_nan_double (a+b)"
    apply transfer
    unfolding plus_float_def fadd_def
    by simp

  lemma plus_nan_double2[simp]:
    "is_nan_double b  is_nan_double (a+b)"
    apply transfer
    unfolding plus_float_def fadd_def
    by simp
    
  lemma [simp]: "is_nan_double  bot"  
    using is_nan_double.abs_eq by force
    
  lemma [simp]: "is_nan_single  bot"  
    using is_nan_single.abs_eq by force
  
  lemma pw_nan_double[pw_simp]:
    "run ndet_nan_double s  failne"  
    "is_res (run ndet_nan_double s) (x,i,s')  is_nan_double x  i=0  s'=s"
    unfolding ndet_nan_double_def
    by pw+
    
  lemma "doM {
    a  nanize_double a;
    b  nanize_double b;
    nanize_double (a + b)
  } = nanize_double (a + b)"
    unfolding nanize_double_def
    apply pw' 
    apply fastforce
    done

  definition ddist :: "double × double  double × double  double llM"
    where [llvm_code]: "ddist p1 p2  doM {
    let (x1,y1) = p1;
    let (x2,y2) = p2;
    dx  nanize_double (x1 - x2);
    dy  nanize_double (y1 - y2);
    dx2  nanize_double (dx*dx);
    dy2  nanize_double (dy*dy);
    dxy2  nanize_double (dx2+dy2);
    nanize_double (dsqrt dxy2)
  }"
  
  export_llvm ddist
  
  interpretation llvm_prim_arith_setup .

  (* There's not much we can prove without defined rounding mode. At least not in current setup! *)
  lemma "llvm_htriple  (ddist p1 p2) (λ_. )"
    unfolding ddist_def 
    apply (simp split: prod.split add: Let_def)
    unfolding nanize_double_def ndet_nan_double_def
    apply vcg
    done

  (* TODO: Prove upper and lower bounds. This needs an infrastructure to be thought of! *)  

  
  definition fdist :: "single × single  single × single  single llM"
    where [llvm_code]: "fdist p1 p2  doM {
    let (x1,y1) = p1;
    let (x2,y2) = p2;
    dx  nanize_single (x1 - x2);
    dy  nanize_single (y1 - y2);
    dx2  nanize_single (dx*dx);
    dy2  nanize_single (dy*dy);
    dxy2  nanize_single (dx2+dy2);
    nanize_single (ssqrt dxy2)
  }"
  
  export_llvm fdist
  
  interpretation llvm_prim_arith_setup .

  (* There's not much we can prove without defined rounding mode. At least not in current setup! *)
  lemma "llvm_htriple  (fdist p1 p2) (λ_. )"
    unfolding fdist_def 
    apply (simp split: prod.split add: Let_def)
    unfolding nanize_single_def ndet_nan_single_def
    apply vcg
    done

  (* TODO: Prove upper and lower bounds. This needs an infrastructure to be thought of! *)  
  
  
end

subsection Unions

declare [[llc_compile_union=true]]

datatype ('a,'b) ll_sum = is_Zero: Zero | is_Inl: Inl (the_left: 'a) | is_Inr: Inr (the_right: 'b)
hide_const (open) 
  ll_sum.Zero ll_sum.Inl ll_sum.Inr 
  ll_sum.is_Zero ll_sum.is_Inl ll_sum.is_Inr 
  ll_sum.the_left ll_sum.the_right

instantiation ll_sum :: (llvm_rep,llvm_rep) llvm_rep
begin

  fun to_val_ll_sum :: "('a,'b) ll_sum  llvm_val" where
    "to_val_ll_sum ll_sum.Zero = LL_UNION (UN_ZERO_INIT [struct_of TYPE('a),struct_of TYPE('b)])"
  | "to_val_ll_sum (ll_sum.Inl l) = LL_UNION (UN_SEL [] (to_val l) [struct_of TYPE('b)])"
  | "to_val_ll_sum (ll_sum.Inr r) = LL_UNION (UN_SEL [struct_of TYPE('a)] (to_val r) [])"

  fun from_val_ll_sum :: "llvm_val  ('a,'b) ll_sum" where
    "from_val_ll_sum (LL_UNION (UN_ZERO_INIT _)) = ll_sum.Zero"
  | "from_val_ll_sum (LL_UNION (UN_SEL [] l [_])) = ll_sum.Inl (from_val l)"
  | "from_val_ll_sum (LL_UNION (UN_SEL [_] r [])) = ll_sum.Inr (from_val r)"
  | "from_val_ll_sum _ = undefined"  

  definition struct_of_ll_sum :: "('a,'b) ll_sum itself  llvm_struct" where 
    [simp]: "struct_of_ll_sum _ = VS_UNION [struct_of TYPE('a), struct_of TYPE('b)]"
    
  definition init_ll_sum :: "('a,'b) ll_sum" where [simp]: "init_ll_sum = ll_sum.Zero"  
   
  instance
    apply standard
    apply (all (clarsimp simp: comp_def fun_eq_iff)?)
    subgoal for x by (cases x) auto  
    subgoal for v by (cases v rule: from_val_ll_sum.cases) auto
    subgoal for x by (cases x) auto
    done

end

lemma struct_of_ll_sum[ll_struct_of]: "struct_of TYPE(('a::llvm_rep, 'b::llvm_rep) ll_sum) = VS_UNION [struct_of TYPE('a), struct_of TYPE('b)]"
  by simp


definition ll_sum_mk_left :: "'l  ('l::llvm_rep, 'r::llvm_rep) ll_sum llM" where 
  [llvm_code,llvm_inline]: "ll_sum_mk_left x  ll_make_union TYPE(('l,'r) ll_sum) x 0"

definition ll_sum_mk_right :: "'r  ('l::llvm_rep, 'r::llvm_rep) ll_sum llM" where 
  [llvm_code,llvm_inline]: "ll_sum_mk_right x  ll_make_union TYPE(('l,'r) ll_sum) x 1"

definition ll_sum_extr_left :: "('l::llvm_rep, 'r::llvm_rep) ll_sum  'l llM" where 
  [llvm_code,llvm_inline]: "ll_sum_extr_left x  ll_dest_union x 0"

definition ll_sum_extr_right :: "('l::llvm_rep, 'r::llvm_rep) ll_sum  'r llM" where 
  [llvm_code,llvm_inline]: "ll_sum_extr_right x  ll_dest_union x 1"
  
  
export_llvm 
  "ll_sum_mk_left :: 32 word  (32 word, double) ll_sum llM"
  "ll_sum_mk_right :: double  (32 word, double) ll_sum llM"
  "ll_sum_extr_left :: (32 word, double) ll_sum  32 word llM"
  "ll_sum_extr_right :: (32 word, double) ll_sum  double llM"
  file "../../regression/gencode/test_basic_union.ll"
  
  
lemma ll_sum_mk_simps[vcg_normalize_simps]:
  "ll_sum_mk_left l = Mreturn (ll_sum.Inl l)"
  "ll_sum_mk_right r = Mreturn (ll_sum.Inr r)"
  unfolding ll_sum_mk_left_def ll_sum_mk_right_def
  by (simp_all add: ll_make_union_def checked_from_val_def llvm_make_union_def
    llvm_union_can_make_def llvm_union_make_def)

lemma ll_sum_extr_simps:
  "ll_sum.is_Inl x  ll_sum_extr_left x = Mreturn (ll_sum.the_left x)"
  "ll_sum.is_Inr x  ll_sum_extr_right x = Mreturn (ll_sum.the_right x)"
  unfolding ll_sum_extr_left_def ll_sum_extr_right_def
  apply (cases x; simp_all add: ll_dest_union_def checked_from_val_def llvm_dest_union_def)
  apply (cases x; simp_all add: ll_dest_union_def checked_from_val_def llvm_dest_union_def)
  done
    
      
lemma ll_sum_extr_rules[vcg_rules]:
  "llvm_htriple ((ll_sum.is_Inl x)) (ll_sum_extr_left x) (λr. (r=ll_sum.the_left x))"    
  "llvm_htriple ((ll_sum.is_Inr x)) (ll_sum_extr_right x) (λr. (r=ll_sum.the_right x))"    
  supply [vcg_normalize_simps] = ll_sum_extr_simps
  by (vcg)

(* TODO: Test this VCG setup *)      


text Example and Regression Tests using LLVM-VCG directly, 
i.e., without Refinement Framework

subsection Custom and Named Structures
typedef ('a,'b) my_pair = "UNIV :: ('a::llvm_rep × 'b::llvm_rep) set" by simp

lemmas my_pair_bij[simp] = Abs_my_pair_inverse[simplified] Rep_my_pair_inverse

instantiation my_pair :: (llvm_rep,llvm_rep)llvm_rep
begin
  definition "from_val_my_pair  Abs_my_pair o from_val"
  definition "to_val_my_pair  to_val o Rep_my_pair"
  definition [simp]: "struct_of_my_pair (_:: ('a,'b)my_pair itself)  struct_of TYPE('a × 'b)"
  definition "init_my_pair  Abs_my_pair init"

  instance
    apply standard
    unfolding from_val_my_pair_def to_val_my_pair_def struct_of_my_pair_def init_my_pair_def
    apply (auto simp: to_val_word_def init_zero)
    done

end

definition "my_sel_fst  fst o Rep_my_pair"
definition "my_sel_snd  snd o Rep_my_pair"

lemma my_pair_struct_of[ll_struct_of]: "struct_of TYPE(('a::llvm_rep,'b::llvm_rep) my_pair) = VS_STRUCT [struct_of TYPE('a), struct_of TYPE('b)]"
  by simp

(*lemma my_pair_to_val[ll_to_val]: "to_val x = LL_STRUCT [to_val (my_sel_fst x), to_val (my_sel_snd x)]"
  by (auto simp: my_sel_fst_def my_sel_snd_def to_val_my_pair_def to_val_prod)
*)  


definition my_fst :: "('a::llvm_rep,'b::llvm_rep)my_pair  'a llM" where [llvm_inline]: "my_fst x  ll_extract_value x 0"
definition my_snd :: "('a::llvm_rep,'b::llvm_rep)my_pair  'b llM" where [llvm_inline]: "my_snd x  ll_extract_value x 1"
definition my_ins_fst :: "('a::llvm_rep,'b::llvm_rep)my_pair  'a  ('a,'b)my_pair llM" where [llvm_inline]: "my_ins_fst x a  ll_insert_value x a 0"
definition my_ins_snd :: "('a::llvm_rep,'b::llvm_rep)my_pair  'b  ('a,'b)my_pair llM" where [llvm_inline]: "my_ins_snd x a  ll_insert_value x a 1"
(*definition my_gep_fst :: "('a::llvm_rep,'b::llvm_rep)my_pair ptr ⇒ 'a ptr llM" where [llvm_inline]: "my_gep_fst x ≡ ll_gep_struct x 0"
definition my_gep_snd :: "('a::llvm_rep,'b::llvm_rep)my_pair ptr ⇒ 'b ptr llM" where [llvm_inline]: "my_gep_snd x ≡ ll_gep_struct x 1"
*)

definition [llvm_code]: "add_add (a::_ word)  doM {
  x  ll_add a a;
  x  ll_add x x;
  Mreturn x
}"

definition [llvm_code]: "test_named (a::32 word) (b::64 word)  doM {
  a  add_add a;
  b  add_add b;
  let n = (init::(32 word,64 word)my_pair);
  a  my_fst n;
  b  my_snd n;
  n  my_ins_fst n init;
  n  my_ins_snd n init;
  
  Mreturn b
}"

lemma my_pair_id_struct[ll_identified_structures]: "ll_is_identified_structure ''my_pair'' TYPE((_,_)my_pair)"
  unfolding ll_is_identified_structure_def
  apply (simp add: )
  done

thm ll_identified_structures



(*lemma [ll_is_pair_type_thms]: "ll_is_pair_type False TYPE(my_pair) TYPE(64 word) TYPE(32 word)"
  unfolding ll_is_pair_type_def
  by auto
*)  

export_llvm (debug) test_named file "code/test_named.ll"

definition test_foo :: "(64 word × 64 word ptr) ptr  64 word  64 word llM" 
  where [llvm_code]:
  "test_foo a b  Mreturn 0"

  export_llvm test_foo is int64_t test_foo(larray_t*, elem_t) 
  defines 
    typedef uint64_t elem_t;
    typedef struct {
      int64_t len;
      elem_t *data;
    } larray_t;
  


subsubsection Linked List

datatype 'a list_cell = CELL (data: 'a) ("next": "'a list_cell ptr")

instantiation list_cell :: (llvm_rep)llvm_rep
begin
  definition "to_val_list_cell  λCELL a b  LL_STRUCT [to_val a, to_val b]"
  definition "from_val_list_cell p  case llvm_val.the_fields p of [a,b]  CELL (from_val a) (from_val b)"
  definition [simp]: "struct_of_list_cell (_::(('a) list_cell) itself)  VS_STRUCT [struct_of TYPE('a), struct_of TYPE('a list_cell ptr)]"
  definition [simp]: "init_list_cell ::('a) list_cell  CELL init init"
  
  instance
    apply standard
    unfolding from_val_list_cell_def to_val_list_cell_def struct_of_list_cell_def init_list_cell_def
    (* TODO: Clean proof here, not breaking abstraction barriers! *)
    apply (auto simp: to_val_word_def init_zero fun_eq_iff split: list_cell.splits)
    subgoal for v v1 v2 by (cases v) (auto)
    subgoal by (simp add: LLVM_Shallow.null_def to_val_ptr_def)
    done

end

lemma struct_of_list_cell[ll_struct_of]: 
  "struct_of TYPE('a::llvm_rep list_cell) = VS_STRUCT [struct_of (TYPE('a)), struct_of (TYPE('a list_cell ptr))]"
  by simp

  (*
lemma to_val_list_cell[ll_to_val]: "to_val x = LL_STRUCT [to_val (data x), to_val (next x)]"
  apply (cases x)
  apply (auto simp: to_val_list_cell_def)
  done
  *)

lemma [ll_identified_structures]: "ll_is_identified_structure ''list_cell'' TYPE(_ list_cell)"  
  unfolding ll_is_identified_structure_def
  by (simp)

  
find_theorems "prod_insert_fst"

lemma cell_insert_value:
  "ll_insert_value (CELL x n) x' 0 = Mreturn (CELL x' n)"
  "ll_insert_value (CELL x n) n' (Suc 0) = Mreturn (CELL x n')"

  apply (simp_all add: ll_insert_value_def llvm_insert_value_def Let_def checked_from_val_def 
                to_val_list_cell_def from_val_list_cell_def)
  done

lemma cell_extract_value:
  "ll_extract_value (CELL x n) 0 = Mreturn x"  
  "ll_extract_value (CELL x n) (Suc 0) = Mreturn n"  
  apply (simp_all add: ll_extract_value_def llvm_extract_value_def Let_def checked_from_val_def 
                to_val_list_cell_def from_val_list_cell_def)
  done
  
find_theorems "ll_insert_value"

lemma inline_return_cell[llvm_pre_simp]: "Mreturn (CELL a x) = doM {
    r  ll_insert_value init a 0;
    r  ll_insert_value r x 1;
    Mreturn r
  }"
  apply (auto simp: cell_insert_value)
  done

lemma inline_cell_case[llvm_pre_simp]: "(case x of (CELL a n)  f a n) = doM {
  a  ll_extract_value x 0;
  n  ll_extract_value x 1;
  f a n
}"  
  apply (cases x)
  apply (auto simp: cell_extract_value)
  done
  
lemma inline_return_cell_case[llvm_pre_simp]: "doM {Mreturn (case x of (CELL a n)  f a n)} = doM {
  a  ll_extract_value x 0;
  n  ll_extract_value x 1;
  Mreturn (f a n)
}"  
  apply (cases x)
  apply (auto simp: cell_extract_value)
  done

definition [llvm_code]: "llist_append x l  Mreturn (CELL x l)"
definition [llvm_code]: "llist_split l  doM {
  c  ll_load l;
  Mreturn (case c of CELL x n  (x,n))
}"  

export_llvm 
  "llist_append::1 word 1 word list_cell ptr  _ llM"
  file "code/list_cell.ll"

  
subsection Array List Examples

definition [llvm_code]: "cr_big_al (n::64 word)  doM {
  a  arl_new TYPE(64 word) TYPE(64);
  (_,a)  llc_while 
    (λ(n,a). ll_icmp_ult (signed_nat 0) n) 
    (λ(n,a). doM { a  arl_push_back a n; n  ll_sub n (signed_nat 1); Mreturn (n,a) }) 
    (n,a);
  
  (_,s)  llc_while 
    (λ(n,s). ll_icmp_ult (signed_nat 0) n) 
    (λ(n,s). doM { n  ll_sub n (signed_nat 1); x  arl_nth a n; sll_add x s; Mreturn (n,s) }) 
    (n,signed_nat 0);
    
  Mreturn s    
}"

declare Let_def[llvm_pre_simp]
export_llvm (debug) cr_big_al is "cr_big_al" file "code/cr_big_al.ll"


subsection Sorting

definition [llvm_inline]: "llc_for_range l h c s  doM {
  (_,s)  llc_while (λ(i,s). ll_cmp (i<h)) (λ(i,s). doM { 
    sc i s; 
    i  ll_add i 1; 
    Mreturn (i,s)}
  ) (l,s);
  Mreturn s
}"

lemma llc_for_range_rule:
  assumes [vcg_rules]: "i ii si. llvm_htriple 
      (snat.assn i ii ** d(loi  i<hi) ** I i si) 
      (c ii si) 
      (λsi. I (i+1) si)"
  shows "llvm_htriple
      (snat.assn lo loi ** snat.assn hi hii ** (lohi) ** I lo si)
      (llc_for_range loi hii c si)
      (λsi. I hi si)"
  unfolding llc_for_range_def
  apply (rewrite at 1 to "signed_nat 1" signed_nat_def[symmetric])
  apply (rewrite annotate_llc_while[where 
    I="λ(ii,si) t. EXS i. snat.assn i ii ** (loi  ihi) ** !(t=hi-i) ** I i si" 
    and R="measure id"])
  apply vcg_monadify
  apply vcg'
  done
  
definition llc_for_range_annot :: "(nat  'b::llvm_rep  ll_assn)
   'a::len word  'a word  ('a word  'b  'b llM)  'b  'b llM"
  where [llvm_inline]: "llc_for_range_annot I  llc_for_range"  
declare [[vcg_const "llc_for_range_annot I"]]
  
lemmas annotate_llc_for_range = llc_for_range_annot_def[symmetric]

lemmas llc_for_range_annot_rule[vcg_rules] 
  = llc_for_range_rule[where I=I, unfolded annotate_llc_for_range[of I]] for I


(* TODO: Move *)
lemma sep_red_idx_setI:  
  assumes "I I'. II'={}  A (II') = (A I ** A I')"
  shows "is_sep_red (A (I-I')) (A (I'-I)) (A I) (A I')"
proof -
  define I1 where "I1  I-I'"
  define I2 where "I2  I'-I"
  define C where "C  II'"

  have S1: "I = I1  C" "I'=I2  C" and S2: "I-I' = I1" "I'-I=I2" and DJ: "I1C={}" "I2C={}"
    unfolding I1_def I2_def C_def by auto

  show ?thesis  
    apply (rule is_sep_redI)
    apply (simp only: S2; simp only: S1)
    apply (auto simp: DJ assms)
    by (simp add: conj_entails_mono sep_conj_left_commute)
    
qed    

lemma sep_set_img_reduce:
  "is_sep_red (⋃*iI-I'. f i) (⋃*iI'-I. f i) (⋃*iI. f i) (⋃*iI'. f i)"
  by (rule sep_red_idx_setI) simp

(* TODO: Move *)  
  
lemma is_sep_red_false[simp]: "is_sep_red P' Q' sep_false Q"
  by (auto simp: is_sep_red_def)

  
(* TODO: Move *)  
lemma entails_pre_pure[sep_algebra_simps]: 
  "(Φ  Q)  (Φ  Q)"  
  "(Φ**P  Q)  (Φ  PQ)"  
  by (auto simp: entails_def sep_algebra_simps )
  
  
  
definition "lstr_assn A I  mk_assn (λas cs. (length cs = length as  (iI. i<length as)) ** (⋃*iI. A (as!i) (cs!i)))"

lemma lstr_assn_union: "II'={}  
  (lstr_assn A (II')) as cs = ((lstr_assn A I) as cs ** (lstr_assn A I') as cs)"
  by (auto simp: lstr_assn_def sep_algebra_simps )

  
lemma lstr_assn_red: "is_sep_red 
  ((lstr_assn A (I-I')) as cs) ((lstr_assn A (I'-I)) as cs)
  ((lstr_assn A I) as cs) ((lstr_assn A I') as cs)"  
  by (rule sep_red_idx_setI) (simp add: lstr_assn_union)

lemma lstr_assn_red': "PRECOND (SOLVE_AUTO (II'{}))  is_sep_red 
  ((lstr_assn A (I-I')) as cs) ((lstr_assn A (I'-I)) as cs)
  ((lstr_assn A I) as cs) ((lstr_assn A I') as cs)"  
  by (rule sep_red_idx_setI) (simp add: lstr_assn_union)
  
    
lemma lstr_assn_singleton: "(lstr_assn A {i}) as cs = ((length cs = length as  i<length as) ** A (as!i) (cs!i))"  
  by (auto simp: lstr_assn_def sep_algebra_simps)
  
lemma lstr_assn_empty: "(lstr_assn A {}) as cs = (length cs = length as)"  
  by (auto simp: lstr_assn_def sep_algebra_simps)
    
lemma lstr_assn_out_of_range: 
  "¬(length cs = length as  (iI. i<length as))  (lstr_assn A I) as cs = sep_false"  
  "iI  ¬i<length as  (lstr_assn A I) as cs = sep_false"  
  "iI  ¬i<length cs  (lstr_assn A I) as cs = sep_false"  
  "length cs  length as  (lstr_assn A I) as cs = sep_false"  
  by (auto simp: lstr_assn_def sep_algebra_simps)
  
  
  
lemma lstr_assn_idx_left[fri_red_rules]:
  assumes "PRECOND (SOLVE_AUTO (length cs = length as  iI  i<length as))"
  shows "is_sep_red  ((lstr_assn A (I-{i})) as cs) (A ai (cs!i)) ((lstr_assn A I) (as[i:=ai]) cs)"
proof -

  from assms have [simp]: "{i} - I = {}" "length cs = length as" "i<length as" and "iI" 
    unfolding vcg_tag_defs by auto

  have "(⋃*iI - {i}. A (as ! i) (cs ! i)) 
    = (⋃*iaI - {i}. A (as[i := ai] ! ia) (cs ! ia))"
    by (rule sep_set_img_cong) auto
  then have 1: "(lstr_assn A (I-{i})) as cs = (lstr_assn A (I-{i})) (as[i:=ai]) cs"
    by (auto simp: lstr_assn_def sep_algebra_simps)
  
  show ?thesis
    using lstr_assn_red[of A "{i}" I "as[i:=ai]" cs]
    by (simp add: 1 lstr_assn_singleton lstr_assn_empty sep_algebra_simps)
    
qed
  
lemma lstr_assn_idx_right[fri_red_rules]:
  assumes "PRECOND (SOLVE_AUTO (iI))"
  shows "is_sep_red ((lstr_assn A (I-{i})) as cs)  ((lstr_assn A I) as cs) (A (as!i) (cs!i))"
proof -  
  from assms have [simp]: "{i} - I = {}" "iI" 
    unfolding vcg_tag_defs by auto
  
  show ?thesis
    using lstr_assn_red[of A I "{i}" "as" cs]
    apply (cases "length cs = length as  (iI. i<length as )"; simp add: lstr_assn_out_of_range)
    apply (simp add: lstr_assn_singleton lstr_assn_empty sep_algebra_simps)
    done
qed  
  
(* TODO: Move *)
lemma is_pure_lst_assn[is_pure_rule]: "is_pure A  is_pure (lstr_assn A I)"
  unfolding lstr_assn_def is_pure_def
  by (auto simp: sep_is_pure_assn_conjI sep_is_pure_assn_imgI)
  
lemma vcg_prep_lstr_assn: (* TODO: Need mechanism to recursively prepare pure parts of A! *)
  "pure_part ((lstr_assn A I) as cs)  length cs = length as  (iI. i<length as)"
  by (auto simp: lstr_assn_def sep_algebra_simps 
    simp del: pred_lift_extract_simps
    dest: pure_part_split_conj)


(* TODO: Move *)  
lemma pure_fri_auto_rule: "PRECOND (SOLVE_AUTO (pA a c))    pA a c"
  using pure_fri_rule
  unfolding vcg_tag_defs .


lemma pure_part_prepD: "pure_part (⋃*iI. f i)  iI. pure_part (f i)"
  by (metis Set.set_insert pure_part_split_conj sep_set_img_insert)

lemma pure_part_imp_pure_assn: "is_pure A  pure_part (A a c)  pA a c"
  by (simp add: extract_pure_assn)  
  
  
    
definition "aa_assn A  mk_assn (λas p. EXS cs. 
  array_assn cs p ** (is_pure A  list_all2 (pA) as cs))"  

   
lemma aa_nth_rule[vcg_rules]: "llvm_htriple 
  ((aa_assn A) as p ** snat.assn i ii ** d(i<length as))
  (array_nth p ii)
  (λc. (aa_assn A) as p ** A (as!i) c)"
  unfolding aa_assn_def
  apply (clarsimp simp: list_all2_conv_all_nth)
  supply pure_fri_auto_rule[fri_rules]
  apply vcg
  done  

lemma aa_upd_rule[vcg_rules]: "llvm_htriple 
  ((aa_assn A) as p ** snat.assn i ii ** A a c ** d(i<length as))
  (array_upd p ii c)
  (λc. (aa_assn A) (as[i:=a]) p)"
proof (cases "is_pure A")
  case [is_pure_rule,simp]: True
  (*note thin_dr_pure[vcg_prep_external_drules del]*)
  note [simp] = nth_list_update pure_part_imp_pure_assn
  
  show ?thesis
    unfolding aa_assn_def list_all2_conv_all_nth
    supply pure_fri_auto_rule[fri_rules]
    apply vcg
    done
qed (clarsimp simp: aa_assn_def)      




definition [llvm_inline]: "qs_swap A i j  doM {
  llc_if (ll_cmp' (ij)) (doM {
    x  array_nth A i;
    y  array_nth A j;
    array_upd A i y;
    array_upd A j x;
    Mreturn ()
  }) (Mreturn ())
}"

definition [llvm_code]: "qs_partition A lo hi  doM {
  hi  ll_sub hi (signed_nat 1);
  pivot  array_nth A hi;
  let i = lo;
  
  i  llc_for_range lo hi (λj i. doM {
    Aj  array_nth A j;
    if Aj < pivot then doM {
      qs_swap A i j;
      i  ll_add i (signed_nat 1);
      Mreturn i
    } else Mreturn i
  }) i;
  
  qs_swap A i hi;
  Mreturn i
}"


definition [llvm_code]: "qs_quicksort A lo hi  doM {
  REC (λquicksort (lo,hi). doM {
    if lo < hi then doM {
      p  qs_partition A lo hi;
      quicksort (lo, p-1);
      quicksort (p+1,hi)
    } else
      Mreturn ()
  
  }) (lo,hi);
  Mreturn ()
}"

(* TODO: Prepare-code-thms after inlining! *)
(* prepare_code_thms  qs_partition_def[unfolded llc_for_range_def] *)


(*prepare_code_thms [llvm_code] qs_quicksort_def*)


llvm_deps foo: "qs_quicksort :: 64 word ptr  64 word  64 word  unit llM"


export_llvm "qs_quicksort :: 64 word ptr  64 word  64 word  unit llM" is "qs_quicksort"
  file code/qs_quicksort.ll

  
lemma qs_swap_aa_rule[vcg_rules]: "llvm_htriple 
  ((aa_assn A) xs p ** snat.assn i ii ** snat.assn j ji ** d(i<length xs  j<length xs))
  (qs_swap p ii ji)
  (λ_. (aa_assn A) (swap xs i j) p)"  
  unfolding qs_swap_def swap_def
  apply vcg_monadify
  apply vcg'
  done
  
lemma qs_swap_rule[vcg_rules]: "llvm_htriple 
  (array_assn xs A ** snat.assn i ii ** snat.assn j ji ** d(i<length xs  j<length xs))
  (qs_swap A ii ji)
  (λ_. array_assn (swap xs i j) A)"  
  unfolding qs_swap_def swap_def
  apply vcg_monadify
  apply vcg'
  done
  

  
    
fun at_idxs :: "'a list  nat list  'a list" (infixl "¡" 100) where
  "at_idxs xs [] = []"
| "at_idxs xs (i#is) = xs!i # at_idxs xs is"  
  
lemma at_idxs_eq_map_nth: "at_idxs xs is = map (nth xs) is"
  by (induction "is") auto

lemma at_idxs_append[simp]: "at_idxs xs (is1@is2) = at_idxs xs is1 @ at_idxs xs is2"  
  by (induction is1) auto
  
lemma at_idxs_ran_zero: "hilength xs  at_idxs xs [0..<hi] = take hi xs"  
  by (induction hi) (auto simp: take_Suc_conv_app_nth)
  
lemma at_idxs_slice: "hilength xs  at_idxs xs [lo..<hi] = Misc.slice lo hi xs"
  apply (induction lo)
  apply (auto simp: Misc.slice_def at_idxs_ran_zero)
  by (simp add: at_idxs_eq_map_nth drop_take map_nth_upt_drop_take_conv)

(* TODO: Move *)     
lemma pure_part_split_img:
  assumes "pure_part (⋃*iI. f i)"  
  shows "(iI. pure_part (f i))"  
proof (cases "finite I")
  assume "finite I"
  then show ?thesis using assms
    by (induction) (auto dest: pure_part_split_conj)
next
  assume "infinite I" with assms show ?thesis by simp    
qed

  
lemma "pure_part ((lstr_assn A I) as cs)  (length cs = length as)  (iI. i<length as  pure_part (A (as!i) (cs!i)))"
  by (auto simp: lstr_assn_def is_pure_def list_all2_conv_all_nth sep_algebra_simps 
    simp del: pred_lift_extract_simps
    dest!: pure_part_split_conj pure_part_split_img)

(* TODO: Move *)    
lemma lstr_assn_insert: "iI  (lstr_assn A (insert i I)) as cs = ((i < length as) ** A (as!i) (cs!i) ** (lstr_assn A I) as cs)"
  by (auto simp: lstr_assn_def sep_algebra_simps)
    

lemma fri_lstr_pure_rl[fri_rules]:
  "PRECOND (SOLVE_ASM (p(lstr_assn A I) as cs))  PRECOND (SOLVE_AUTO (iI))    pA (as!i) (cs!i)"
  unfolding vcg_tag_defs
  by (auto simp: dr_assn_pure_asm_prefix_def lstr_assn_insert dr_assn_pure_prefix_def
    simp: sep_algebra_simps
    elim!: Set.set_insert dest!: pure_part_split_conj)
  

lemma length_swap[simp]: "length (swap xs i j) = length xs"
  by (auto simp: swap_def)    

  
lemma at_idxs_cong:
  assumes "i. iList.set I  xs!i = ys!i"
  shows "xs¡I = ys¡I"
  using assms 
  apply (induction I)
  apply auto
  done
    
lemma at_idxs_upd_out[simp]: "iList.set I  xs[i:=x] ¡ I = xs¡I"
  by (auto intro: at_idxs_cong simp: nth_list_update')
  
lemma at_idxs_swap_out[simp]: "iList.set I  jList.set I  (swap xs i j)¡I = xs¡I"  
  unfolding swap_def
  by auto

lemma mset_swap'[simp]: "i<length xs; j<length xs  mset (swap xs i j) = mset xs"
  unfolding swap_def
  apply (auto simp: mset_swap)
  done  
  
  
find_theorems at_idxs Misc.slice  
find_theorems mset nth    


        
lemma "llvm_htriple 
  ((aa_assn snat.assn) as A ** snat.assn lo loi  ** snat.assn hi hii 
    ** d(lo<hi  hilength as)) 
  (qs_partition A loi hii)
  (λpi. EXS as' p. (aa_assn snat.assn) as' A ** snat.assn p pi 
    ** ( lop  p<hi 
         length as' = length as
         as'¡[0..<lo] = as¡[0..<lo]     
         as'¡[hi..<length as] = as¡[hi..<length as]
         mset (as') = mset (as)
         (i{lo..<p}. as!i  as!p)
         (i{p..<hi}. as!p  as!i)
         ))"
  unfolding qs_partition_def
  apply (rewrite annotate_llc_for_range[where 
    I="λj ii. EXS i as'. snat.assn i ii ** (aa_assn snat.assn) as' A 
      ** (length as'=length as 
         loi  i<hi
         as'¡[0..<lo] = as¡[0..<lo]     
         as'¡[hi..<length as] = as¡[hi..<length as]
         mset (as') = mset (as)
      )
    
    "])
  apply vcg_monadify
  apply vcg'
  apply clarsimp_all
  apply auto
  prefer 2
  apply (subst at_idxs_swap_out)
  apply simp 
  apply simp
  apply linarith
  apply simp
  oops 
(*  
xxx, ctd here: sharpen invariant!
  
    
  xxx, try "arr_assn A ≡ array o lst A"
  try to set up rules for nth and upd, using a set of externalized indexes (and their intermediate values).
    supplement frame inference by internalize/externalize rules
  
  
  
  
  apply vcg_try_solve
  apply vcg_try_solve
  
  apply vcg_rl back back
  apply vcg_try_solve
  apply (fri_dbg_step) back
  apply vcg_try_solve
  
  
  
  oops
  xxx, ctd here: Intro-trule for pure lstr-assn
  
  oops
  xxx, ctd here: The array itself contains data, which needs to be abstracted over!
    we will need to relate xs!i to some abstract value!
  

  oops
  
  
  
  
  xxx, integrate reduction rules into frame inference!
  xxx: simplify the resulting set differences during frame inference!
    Most important: Elimination of empty sets!
    
      


  xxx, ctd here: Integrate into frame inference  
    "cut" is a bad name for this concept
        
        
  find_theorems sep_set_img  
    
  ML_val ‹@{term ‹⋃*x∈y. p›}›  
    
  lemma
    assumes "↿(lstr_assn A (I-I')) as cs ⊢ ↿(lstr_assn A (I'-I)) as cs"  
    shows "↿(lstr_assn A I) as cs ⊢ ↿(lstr_assn A I') as cs"
    
    
    oops
  xxx, ctd here: do list_assn, with index set. 
  
  derive rules to split/join those assertions. also rules for pure-case.
  in practice, let the lstr-assertions fragment, until some rule/frame forces a re-union.
    
    
    
      
    
      
  thm vcg_frame_erules
  
  apply vcg_rl
         
         
  
term "xs¡[2..<5]"

find_consts "nat ⇒ nat ⇒ _ list ⇒ _ list"  
  
*)  
  

subsection More Floating Point

abbreviation "rm_tmpl f (rmi::64 word)  
  if rmi=unsigned 0 then f AVX512_FROUND_TO_NEAREST_NO_EXC
  else if rmi=unsigned 1 then f AVX512_FROUND_TO_POS_INF_NO_EXC
  else if rmi=unsigned 2 then f AVX512_FROUND_TO_NEG_INF_NO_EXC
  else f AVX512_FROUND_TO_ZERO_NO_EXC
"  
  
context
  notes [llvm_pre_simp] = if_distribR
  notes [[llc_compile_avx512f=true]]
begin

definition [llvm_code]: "avx512_64_add   x a b = rm_tmpl ll_x86_avx512_add_sd_round x a b"
definition [llvm_code]: "avx512_64_sub   x a b = rm_tmpl ll_x86_avx512_sub_sd_round x a b"
definition [llvm_code]: "avx512_64_mul   x a b = rm_tmpl ll_x86_avx512_mul_sd_round x a b"
definition [llvm_code]: "avx512_64_div   x a b = rm_tmpl ll_x86_avx512_div_sd_round x a b"
definition [llvm_code]: "avx512_64_sqrt  x a = rm_tmpl ll_x86_avx512_sqrt_sd x a"
definition [llvm_code]: "avx512_64_fmadd x a b c = rm_tmpl ll_x86_avx512_vfmadd_f64 x a b c"

definition [llvm_code]: "avx512_32_add   x a b = rm_tmpl ll_x86_avx512_add_ss_round x a b"
definition [llvm_code]: "avx512_32_sub   x a b = rm_tmpl ll_x86_avx512_sub_ss_round x a b"
definition [llvm_code]: "avx512_32_mul   x a b = rm_tmpl ll_x86_avx512_mul_ss_round x a b"
definition [llvm_code]: "avx512_32_div   x a b = rm_tmpl ll_x86_avx512_div_ss_round x a b"
definition [llvm_code]: "avx512_32_sqrt  x a = rm_tmpl ll_x86_avx512_sqrt_ss x a"
definition [llvm_code]: "avx512_32_fmadd x a b c = rm_tmpl ll_x86_avx512_vfmadd_f32 x a b c"

export_llvm 
  avx512_64_add    is "avx512_64_add  "
  avx512_64_sub    is "avx512_64_sub  "
  avx512_64_mul    is "avx512_64_mul  "
  avx512_64_div    is "avx512_64_div  "
  avx512_64_sqrt   is "avx512_64_sqrt "
  avx512_64_fmadd  is "avx512_64_fmadd"
  avx512_32_add    is "avx512_32_add  "
  avx512_32_sub    is "avx512_32_sub  "
  avx512_32_mul    is "avx512_32_mul  "
  avx512_32_div    is "avx512_32_div  "
  avx512_32_sqrt   is "avx512_32_sqrt "
  avx512_32_fmadd  is "avx512_32_fmadd"     
  file "../../regression/gencode/test_avx512f_ops.ll"
  
  
       
end  
  
  
  





  

end