Theory LLVM_DS_Array

section Arrays
theory LLVM_DS_Array
imports LLVM_DS_Arith
begin

text Implementing Lists by Arrays

context begin

  interpretation llvm_prim_mem_setup .

  definition "array_assn  mk_assn (λxs p. 
    (ll_range {0..<int (length xs)}) ((!) xs o nat) p ** ll_malloc_tag (int (length xs)) p
  )"
  
  lemma array_assn_not_null[simp]: "array_assn xs null = sep_false"
    by (auto simp: array_assn_def)
  
  definition [llvm_inline]: "array_new TYPE('a::llvm_rep) n  ll_malloc TYPE('a) n"
  definition [llvm_inline]: "array_free a  ll_free a"
  definition [llvm_inline]: "array_nth a i  doM { p  ll_ofs_ptr a i; ll_load p }"
  definition [llvm_inline]: "array_upd a i x  doM { p  ll_ofs_ptr a i; ll_store x p; Mreturn a }"


  lemma ll_range_cong: "I=I'  (i. iI'  f i = f' i)  p=p' 
     (ll_range I) f p = (ll_range I') f' p'"
    unfolding ll_range_def 
    by simp
  
  lemma array_assn_cnv_range_malloc: 
    "array_assn (replicate n init) p = ((ll_range {0..<int n}) (λ_. init) p ** ll_malloc_tag (int n) p)"  
    unfolding array_assn_def
    apply (simp add: sep_algebra_simps)
    apply (rule sep_conj_trivial_strip2)
    apply (rule ll_range_cong)
    by auto
  
  lemma array_assn_cnv_range_upd:  
    "array_assn (xs[i:=x]) p = ((ll_range {0..<int (length xs)}) (((!) xs  nat)(int i := x)) p ** ll_malloc_tag (int (length xs)) p)"
    unfolding array_assn_def
    apply (simp add: sep_algebra_simps)
    apply (rule sep_conj_trivial_strip2)
    apply (rule ll_range_cong)
    by auto
    

  lemma pos_sint_to_uint: "0  sint i  sint i = uint i"  
    by (smt Suc_n_not_le_n Suc_pred bintrunc_mod2p int_mod_eq' len_gt_0 power_increasing_iff sint_range' uint_sint)
    
  lemma array_new_rule_sint[vcg_rules]: "llvm_htriple 
    (sint.assn n ni ** d(n>0)) 
    (array_new TYPE('a::llvm_rep) ni) 
    (array_assn (replicate (nat n) init))"
    unfolding array_new_def array_assn_cnv_range_malloc sint.assn_def
    supply [simp] = pos_sint_to_uint
    by vcg

  lemma array_new_rule_uint[vcg_rules]: "llvm_htriple 
    (uint.assn n ni ** d(n>0)) 
    (array_new TYPE('a::llvm_rep) ni) 
    (array_assn (replicate (nat n) init))"
    unfolding array_new_def array_assn_cnv_range_malloc uint.assn_def
    by vcg

  lemma array_new_rule_unat[vcg_rules]: "llvm_htriple 
    (unat.assn n ni ** d(n>0)) 
    (array_new TYPE('a::llvm_rep) ni) 
    (array_assn (replicate n init))"
    unfolding array_new_def array_assn_cnv_range_malloc unat.assn_def
    apply (simp add: )
    by vcg

  lemma array_new_rule_snat[vcg_rules]: "llvm_htriple 
    (snat.assn n ni ** d(n>0)) 
    (array_new TYPE('a::llvm_rep) ni) 
    (array_assn (replicate n init))"
    unfolding array_new_def array_assn_cnv_range_malloc snat.assn_def
    supply [simp] = cnv_snat_to_uint and [simp del] = nat_uint_eq
    by vcg
    
      
  lemma array_free_rule[vcg_rules]: "llvm_htriple (array_assn xs p) (array_free p) (λ_. )"
    unfolding array_free_def array_assn_def
    by vcg

  lemma array_cast_index: 
    assumes "uint (ii::'a::len word) < max_sint LENGTH('a)"  
    shows "sint ii = uint ii" "nat (uint ii) < n  uint ii < int n"
      "unat ii < n  uint ii < int n"
    using assms                                                                          
    by (simp_all add: max_sint_def msb_uint_big sint_eq_uint unat_def nat_less_iff del: nat_uint_eq)
    
  abbreviation (input) "in_range_nat i (ii::'a::len word) xs  i<length xs  int i<max_sint LENGTH('a)"  
  abbreviation (input) "in_range_uint i (ii::'a::len word) xs  i<int (length xs)  i<max_sint LENGTH('a)"

  lemma array_nth_rule_sint[vcg_rules]: "llvm_htriple 
    (array_assn xs p ** sint.assn i ii ** d(0i  i<int (length xs)))
    (array_nth p ii)
    (λr. (r = xs!nat i) ** array_assn xs p)"
    unfolding array_nth_def array_assn_def sint.assn_def
    by vcg

  lemma array_nth_rule_uint[vcg_rules]: "llvm_htriple 
    (array_assn xs p ** uint.assn i ii ** d(in_range_uint i ii xs))
    (array_nth p ii)
    (λr. (r = xs!nat i) ** array_assn xs p)"
    unfolding array_nth_def array_assn_def uint.assn_def
    supply [simp] = array_cast_index
    by vcg
      
  lemma array_nth_rule_unat[vcg_rules]: "llvm_htriple 
    (array_assn xs p ** unat.assn i ii ** d(in_range_nat i ii xs))
    (array_nth p ii)
    (λr. (r = xs!i) ** array_assn xs p)"
    unfolding array_nth_def array_assn_def unat.assn_def unat_def
    supply [simp] = array_cast_index
    by vcg

  lemma array_nth_rule_snat[vcg_rules]: "llvm_htriple 
    (array_assn xs p ** snat.assn i ii ** d(i<length xs))
    (array_nth p ii)
    (λr. (r = xs!i) ** array_assn xs p)"
    unfolding array_nth_def array_assn_def snat.assn_def
    supply [simp] = cnv_snat_to_uint and [simp del] = nat_uint_eq
    by vcg
    
  lemma array_upd_rule_sint[vcg_rules]: "llvm_htriple
    (array_assn xs p ** sint.assn i ii ** d(0i  i < int (length xs)))
    (array_upd p ii x)
    (λr. (r=p) ** array_assn (xs[nat i:=x]) p)"
    unfolding array_assn_cnv_range_upd
    unfolding array_upd_def array_assn_def sint.assn_def
    supply [fri_rules] = fri_abs_cong_rl
    by vcg

  lemma array_upd_rule_uint[vcg_rules]: "llvm_htriple
    (array_assn xs p ** uint.assn i ii ** din_range_uint i ii xs)
    (array_upd p ii x)
    (λr. (r=p) ** array_assn (xs[nat i:=x]) p)"
    unfolding array_assn_cnv_range_upd
    unfolding array_upd_def array_assn_def uint.assn_def
    supply [simp] = array_cast_index
    supply [fri_rules] = fri_abs_cong_rl
    by vcg
        
  lemma array_upd_rule_nat[vcg_rules]: "llvm_htriple
    (array_assn xs p ** unat.assn i ii ** din_range_nat i ii xs)
    (array_upd p ii x)
    (λr. (r=p) ** array_assn (xs[i:=x]) p)"
    unfolding array_assn_cnv_range_upd
    unfolding array_upd_def array_assn_def unat.assn_def unat_def
    supply [simp] = array_cast_index
    supply [fri_rules] = fri_abs_cong_rl
    by vcg
    
  lemma array_upd_rule_snat[vcg_rules]: "llvm_htriple
    (array_assn xs p ** snat.assn i ii ** d(i<length xs))
    (array_upd p ii x)
    (λr. (r=p) ** array_assn (xs[i:=x]) p)"
    unfolding array_assn_cnv_range_upd
    unfolding array_upd_def array_assn_def snat.assn_def
    supply [simp] = cnv_snat_to_uint and [simp del] = nat_uint_eq
    supply [fri_rules] = fri_abs_cong_rl
    apply vcg
    done

end    
    
subsection Basic Algorithms

subsubsection Array-Copy
definition "arraycpy dst src (n::'a::len2 word)  
  doM {
    llc_while 
      (λi. ll_icmp_ult i n) 
      (λi. doM { 
        xarray_nth src i;
        array_upd dst i x;
        ill_add i (signed_nat 1);
        Mreturn i
      }) (signed_nat 0);
    Mreturn ()
  }"

declare arraycpy_def[llvm_code]
  
export_llvm "arraycpy :: 8 word ptr  _  64 word  _" is "arraycpy"

(* TODO: Move / REMOVE?*)
lemma unat_not_msb_imp_less_max_sint: "x  unat ` {w::'a::len word. ¬ msb w}  int x < max_sint LENGTH('a)"
  by (auto simp: unat_def[abs_def] msb_uint_big max_sint_def simp del: nat_uint_eq)


lemma arraycpy_rule_snat[vcg_rules]: 
  "llvm_htriple 
    (array_assn dst dsti ** array_assn src srci ** snat.assn n ni ** d(nlength src  nlength dst))
    (arraycpy dsti srci ni)
    (λ_. array_assn (take n src @ drop n dst) dsti ** array_assn src srci)"
  unfolding arraycpy_def
  apply (rewrite annotate_llc_while[where 
    I="λii t. EXS i dst'. snat.assn i ii ** array_assn dst' dsti ** array_assn src srci
      ** d(0i  in  dst' = take i src @ drop i dst) ** a(t = n-i)"
    and R = "measure id"
      ])
  apply vcg_monadify
  apply vcg'
  apply (auto simp: list_update_append upd_conv_take_nth_drop take_Suc_conv_app_nth) 
  done

subsubsection Array-Set
    
definition arrayset :: "'b::llvm_rep ptr  'b  'a::len2 word  unit llM" where
  [llvm_code]: "arrayset dst c n  doM {
    llc_while
      (λi. ll_cmp (i<n))
      (λi. doM {
        array_upd dst i c;
        let i=i+(signed_nat 1);
        Mreturn i
      }) (signed_nat 0);
    Mreturn ()
  }"  
  

(*declare arrayset_def[llvm_code]*)

export_llvm (debug) "arrayset :: 32 word ptr  32 word  64 word  _"

lemma arrayset_rule_snat[vcg_rules]: "llvm_htriple 
  (array_assn dst dsti ** snat.assn n ni ** d(nlength dst))
  (arrayset dsti c ni)
  (λ_. array_assn (replicate n c @ drop n dst) dsti)"  
  unfolding arrayset_def
  apply (rewrite annotate_llc_while[where 
    I="λii t. EXS i dst'. snat.assn i ii ** array_assn dst' dsti 
      ** d(0i  in  dst' = replicate i c @ drop i dst) ** a(t = n-i)"
    and R="measure id"  
  ])
  apply vcg_monadify
  apply vcg'
  apply (auto simp: nth_append list_update_append replicate_Suc_conv_snoc simp del: replicate_Suc)
  by (metis Cons_nth_drop_Suc less_le_trans list_update_code(2))
  
text Array-Set also works for zero-size, and any pointer, including null›  
lemma arrayset_zerosize_rule: "llvm_htriple (snat.assn 0 ni) (arrayset p c ni) (λ_. )"  
  unfolding arrayset_def
  apply (rewrite annotate_llc_while[where I="λii _. EXS i. snat.assn i ii" and R="{}"])
  apply vcg_monadify
  apply vcg
  done
  

subsubsection Array-New-Init
  
definition "array_new_init n (c::'a::llvm_rep)  doM { 
  r  array_new TYPE('a) n; 
  arrayset r c n;
  Mreturn r
}"

lemma array_new_init_rule[vcg_rules]: "llvm_htriple   
  (snat.assn n ni ** d(n>0)) 
  (array_new_init ni c) 
  (λr. array_assn (replicate n c) r)"
  unfolding array_new_init_def
  by vcg


end