Theory SharedNat

section Bit-Packed pairs of Numbers
theory SharedNat
  imports "../lib/Base_MEC"
begin

subsection Miscellanneous Lemmas
subsubsection snat lemmas

lemma snat_less_refine: "snat_invar a  snat_invar b  snat a < snat b  a < b"
  apply(auto simp: snat_def snat_invar_def)
  apply (meson word_sless_alt word_sless_msb_less)
  apply (metis linorder_le_less_linear order_le_imp_less_or_eq signed_eq_0_iff word_gt_0 word_gt_a_gt_0 word_msb_sint)
  by (simp add: sint_eq_uint word_less_def)


lemma snat_mask_id: "snat_invar x  snat x < 2 ^ n  x && mask n = x"
  apply(simp add: snat_def snat_invar_def and_mask_mod_2p)
  apply (simp add: sint_eq_uint)
  done


lemma snat_take_bit_id: "snat_invar x  snat x < 2 ^ n  take_bit n x = x" 
  apply(simp add: snat_def snat_invar_def take_bit_word_eq_self_iff)
  apply(cases "LENGTH('a)  n")
  apply (auto simp: sint_eq_uint uint_power_lower word_less_def) 
  done

lemma rshift_suc_n_alt: "(w >> Suc n) = (w >> n) >> 1" for w::"'a::len word"
  apply(simp add: shiftr_shiftr)
  done 
  

lemma nmsb_rshiftn: "¬msb (w >> Suc n)"  for w::"'a::len2 word"
  apply(rewrite rshift_suc_n_alt)
  by auto


corollary snat_invar_rshift_Sucn: "snat_invar (w >> Suc n)"
  unfolding snat_invar_def
  using nmsb_rshiftn .

corollary n_gt_0_snat_invar_rshiftn: "n > 0  snat_invar (w >> n)"
  unfolding snat_invar_def
  using nmsb_rshiftn 
  using gr0_implies_Suc by blast


lemma nmsb_take_bit: "n < LENGTH('a::len2)  ¬msb (take_bit n (w::'a word))"
  apply(rule ccontr)
  apply (auto simp: word_msb_sint signed_take_bit_eq) 
  done

corollary snat_invar_take_bit: "n < LENGTH('a::len2)  snat_invar (take_bit n (w::'a word))"
  unfolding snat_invar_def
  using nmsb_take_bit .

lemma take_bit_lshift_is_0: "take_bit n (a << n) = 0" for a::"'a::len word"
  apply (auto simp: take_bit_eq_mask)
  done

lemma snat_invar_id_snat_rel: "snat_invar ii  (ii, snat ii)  snat_rel"
  apply(auto simp: snat_rel_def snat.rel_def in_br_conv snat_invar_def)
  done

lemma snat_invar_unat_snat_rel: "snat_invar ii  (ii, unat ii)  snat_rel"
  apply(auto simp: snat_eq_unat_aux2[symmetric] snat_invar_id_snat_rel)
  done

lemma p_sint_id_snat_rel: "sint ii  0  (ii, snat ii)  snat_rel"
  apply(auto simp: snat_rel_def snat.rel_def in_br_conv snat_invar_def word_msb_sint)
  done

lemma snat_invar_unat_bound:"snat_invar w  unat w < max_snat LENGTH('a)" for w::"'a::len2 word"
  apply auto
  apply (metis snat_eq_unat_aux2 snat_lt_max_snat)
  apply (rule ccontr)
  unfolding snat_invar_def
  apply (auto simp: msb_unat_big max_snat_def)
  done

subsubsection unat lemmas


lemma mod_word_2p_unat: "n < LENGTH ('a::len)  (unat (w::'a word)) mod (2 ^ n) = unat (w mod 2 ^ n)"
  by (simp add: unat_mod_distrib)

lemma div_word_2p_unat: "n < LENGTH ('a::len)  (unat (w::'a word)) div (2 ^ n) = unat (w div 2 ^ n)"
  by (simp add: unat_div_distrib)

lemma unat_take_bit_2p: "unat (take_bit n w) = unat w mod 2 ^ n"
  by (metis take_bit_eq_mod unsigned_take_bit_eq)

lemma unat_and_mask_lt_2p: "unat (w && mask n) < 2 ^ n" 
  by (metis take_bit_eq_mask take_bit_nat_less_exp unsigned_take_bit_eq)

lemma unat_take_bit_lt_2p: "unat (take_bit n a) < 2 ^ n"
  by(auto simp: take_bit_eq_mask unat_and_mask_lt_2p)
  
lemma unat_unat_rel: "(ii, unat ii)  unat_rel"
  by(auto simp: unat_rel_def unat.rel_def in_br_conv)

subsection Mask and take_bit setup for sepref


lemma word_mask_less_length_no_msb: "n < LENGTH('a::len)  ¬(msb::'a word  bool) (mask n)"
  by (metis (no_types, lifting) min.absorb4 signed_take_bit_eq take_bit_minus_one_eq_mask take_bit_take_bit uint_lt_0 uint_sint word_msb_sint)

definition mask_impl :: "nat  ('a::len2) word" where "mask_impl n = (1 << n) - 1"

lemma mask_refine: "(mask_impl, mask)  nat_rel  word_rel"
  unfolding mask_impl_def
  apply(auto simp: mask_eq)
  done

sepref_def mask_ll is "RETURN o (mask_impl :: nat  ('a::len2) word)" :: "[λ n. n < LENGTH('a)]a (snat_assn' TYPE('a::len2))k  id_assn"
  unfolding mask_impl_def
  apply sepref
  done

lemmas [sepref_fr_rules] = mask_ll.refine[FCOMP mask_refine]

definition take_bit_impl :: "nat  ('a::len2) word  'a word" where "take_bit_impl n a = a AND ((1 << n) - 1)"

lemma take_bit_refine: "(take_bit_impl, take_bit)  nat_rel  word_rel  word_rel"
  unfolding take_bit_impl_def
  apply(auto simp: fun_eq_iff take_bit_eq_mask mask_eq)
  done

sepref_def take_bit_ll is "uncurry (RETURN oo (take_bit_impl :: nat  'a word  'a word))" :: "[λ (n,_). n < LENGTH('a::len2)]a (snat_assn' TYPE('a))k *a id_assnd  id_assn"
  unfolding take_bit_impl_def
  apply sepref
  done


lemmas [sepref_fr_rules] = take_bit_ll.refine[FCOMP take_bit_refine]


definition shared_left :: "nat × nat  nat" where "shared_left = fst"
definition shared_right :: "nat × nat  nat" where "shared_right = snd"

definition shared_sum :: "nat × nat  nat" where "shared_sum = (λ (l,r). l + r)"

definition make_tuple :: "nat  nat  nat × nat" where "make_tuple nl nr = (nl,nr)"

lemma shared_sum_as_left_plus_right: "shared_sum x = shared_left x + shared_right x"
  unfolding shared_left_def shared_right_def shared_sum_def
  apply (cases x) 
  by simp


definition "right_size = (44::nat)"
definition "left_size = len_size_T - right_size"

lemma size_bounds: "right_size < len_size_T" "left_size < len_size_T"
  unfolding right_size_def left_size_def by auto

lemma left_right_add: "left_size + right_size = len_size_T"
  unfolding left_size_def right_size_def
  by simp

definition "max_left = 2 ^ left_size"
definition "max_right = 2 ^ right_size"

lemma max_left_alt: "max_left = max_unat left_size"
  unfolding max_left_def max_unat_def by simp

lemma max_right_alt: "max_right = max_unat right_size"
  unfolding max_right_def max_unat_def by simp

lemma max_left_gt_0[simp]: "max_left > (0::nat)"
  unfolding max_left_def
  by simp

lemma max_right_gt_0[simp]: "max_right > (0::nat)"
  unfolding max_right_def
  by simp
  

lemma max_left_right_unat: "max_left * max_right = max_unat len_size_T"
  unfolding max_left_def max_right_def max_unat_def
  apply(fold monoid_mult_class.power_add)
  unfolding left_right_add
  by presburger

lemma max_left_plus_right_lt_max_snat: "max_left + max_right < max_snat len_size_T"
  unfolding max_left_def max_right_def left_size_def right_size_def max_snat_def
  by simp
  

type_synonym shared_nat = "size_t"

definition shared_zero :: "shared_nat" where "shared_zero = 0"
definition make_shared :: "size_t  size_t  shared_nat" where "make_shared nl nr = (nl << right_size) OR nr"

definition get_shared_left :: "shared_nat  size_t" where "get_shared_left sn = (sn >> right_size)"
definition get_shared_right :: "shared_nat  size_t" where "get_shared_right sn = take_bit right_size sn"

definition get_shared_sum :: "shared_nat  size_t" where "get_shared_sum sn = (sn >> right_size) + take_bit right_size sn"


lemma get_shared_sum_alt_def: "get_shared_sum sn = get_shared_left sn + get_shared_right sn"
  unfolding get_shared_sum_def get_shared_left_def get_shared_right_def
  by simp

lemmas [llvm_inline] = make_shared_def


definition "shared_nat_α sn = (unat sn div max_right, unat sn mod max_right)"


lemma shared_zero_α[simp]: "shared_nat_α shared_zero = (0,0)"
  unfolding shared_nat_α_def shared_zero_def get_shared_left_def get_shared_right_def
  by simp


lemma make_shared_α[simp]: "snat_invar nl  snat_invar nr  snat nl < max_left  snat nr < max_right  
    shared_nat_α (make_shared nl nr) = (snat nl, snat nr)"
  unfolding shared_nat_α_def make_shared_def get_shared_left_def get_shared_right_def max_left_def max_right_def
  apply (clarsimp simp: shiftr_over_or_dist simp flip: shiftr_div_2n')
  apply(subst shiftl_shiftr_id)
  apply (auto simp: right_size_def left_size_def)[2]
  apply (subst(asm) snat_numeral[where 'a = size_T, symmetric])
  apply (simp add: max_snat_def)
  apply (subst(asm) snat_less_refine; (simp add: snat_invar_def; erule less_trans[of nl]; simp)|(simp add: snat_invar_def))
  apply(rewrite mod_word_2p_unat[of right_size "((nl << right_size) || nr)", simplified, OF size_bounds(1)])
  apply (auto simp add: shiftr_le_0 word_ao_dist snat_eq_unat_aux2 mod_2p_is_mask) 
  by (metis take_bit_eq_mask take_bit_nat_eq_self unsigned_take_bit_eq)


lemma get_left_α[simp]: "unat (get_shared_left sn) = shared_left (shared_nat_α sn)"
  unfolding shared_nat_α_def shared_left_def get_shared_left_def max_right_def
  apply (clarsimp simp add: shiftr_div_2n_w[symmetric] div_word_2p_unat[of right_size sn, simplified, OF size_bounds(1)]
    intro!: snat_eq_unat_aux2)
  done
  

lemma get_right_α[simp]: "unat (get_shared_right sn) = shared_right (shared_nat_α sn) "
  unfolding shared_nat_α_def shared_right_def get_shared_right_def max_right_def
  apply(clarsimp simp: mod_word_2p_unat[of right_size sn, simplified, OF size_bounds(1)]) 
  apply(rewrite word_eqI_folds(3))
  apply simp
  done


lemma shared_sum_unat_bounds: "unat (get_shared_left sn) + unat (get_shared_right sn) < 2 ^ (64 - right_size) + 2 ^ right_size"
proof -
  have "unat sn div 2 ^ right_size  2 ^ (64 - right_size)"
    apply(rewrite power_minus_is_div[OF less_imp_le_nat, OF size_bounds(1)])
    apply(rule div_le_mono)
    using unat_lt_max_unat[of sn] 
    unfolding max_unat_def by simp 
  moreover have "unat sn mod 2 ^ right_size < 2 ^ right_size" 
    by force
  ultimately have "unat sn div 2 ^ right_size + unat sn mod 2 ^ right_size < 2 ^ (64 - right_size) + 2 ^ right_size" 
    by linarith
  then show ?thesis 
    by (auto simp: shared_nat_α_def shared_left_def shared_right_def max_right_def)
qed


corollary shared_sum_unat_in_max_snat: "unat (get_shared_left sn) + unat (get_shared_right sn) < max_snat len_size_T"
  using shared_sum_unat_bounds[of sn] 
  unfolding right_size_def max_snat_def 
  by simp


lemma unat_shared_sum_bounds: "unat ((get_shared_left sn) + (get_shared_right sn)) < 2 ^ (64 - right_size) + 2 ^ right_size"
  apply(subst unat_add_lem')
  using shared_sum_unat_in_max_snat[unfolded max_snat_def, of sn]
  apply fastforce 
  using shared_sum_unat_bounds[of sn] .

corollary unat_shared_sum_in_max_snat_aux: "unat ((get_shared_left sn) + (get_shared_right sn)) < max_snat len_size_T"
  using unat_shared_sum_bounds[of sn] 
  unfolding right_size_def max_snat_def
  by simp

corollary unat_shared_sum_in_max_snat: "unat ((get_shared_sum sn)) < max_snat len_size_T"
  unfolding get_shared_sum_alt_def
  using unat_shared_sum_in_max_snat_aux[of sn] .

lemma snat_invar_get_shared_left: "snat_invar (get_shared_left sn)"
  unfolding get_shared_left_def right_size_def 
  by (simp add: n_gt_0_snat_invar_rshiftn)

lemma snat_invar_get_shared_right: "snat_invar (get_shared_right sn)"
  unfolding get_shared_right_def right_size_def 
  apply(rule snat_invar_take_bit)
  by force

lemma snat_invar_get_shared_sum: "snat_invar (get_shared_sum sn)"
  apply(auto simp: snat_invar_unat_bound unat_shared_sum_in_max_snat)
  done
  

lemma shared_sum_α[simp]: "unat (get_shared_sum sn) = shared_sum (shared_nat_α sn)"
  unfolding shared_sum_def shared_nat_α_def get_shared_sum_def
  apply(auto simp: shiftr_div_2n'[symmetric] unat_take_bit_2p[symmetric] max_right_def)
  apply(subst unat_add_lem'[symmetric]) 
  using shared_sum_unat_in_max_snat[of sn] 
  unfolding max_snat_def get_shared_left_def get_shared_right_def
  apply auto
  done


definition shared_nat_invar :: "shared_nat  bool" where "shared_nat_invar sn = ((get_shared_left sn < 2 ^ left_size)  (get_shared_right sn < 2 ^ right_size))"


lemma get_shared_right_ub: "get_shared_right sn < 2 ^ right_size"
  proof -
    have "sn && 2 ^ right_size - 1  2 ^ right_size - 1" 
      using Word.word_and_le1 by blast
    also have "... < 2 ^ right_size" 
      unfolding right_size_def by force
    finally show ?thesis
      apply (auto simp: get_shared_right_def mask_eq take_bit_eq_mask) 
      done
  qed


lemma get_shared_right_lb: "sint (get_shared_right sn)  0"
  apply(simp add: get_shared_right_def take_bit_eq_mod sint_eq_uint_2pl)
  apply(subgoal_tac "(sn mod 2 ^ right_size) < 2 ^ (LENGTH(64) - 1)" )
  apply(drule sint_eq_uint_2pl)
  apply simp
  apply(rule less_trans[of "sn mod 2 ^ right_size" "2 ^ right_size" "2 ^ (LENGTH(64) - 1)"])
  apply(rule word_mod_less_divisor)
  apply (auto simp add: right_size_def)
  done



lemma nmsb_get_shared_right: "¬msb (get_shared_right sn)"
  unfolding word_msb_sint
  using get_shared_right_lb[of sn] 
  by fastforce


lemma get_shared_right_snat_ub: "snat (get_shared_right sn) < 2 ^ right_size" 
  unfolding snat_eq_unat_aux2[unfolded snat_invar_def, OF nmsb_get_shared_right]
  unfolding get_shared_right_def 
  using unat_take_bit_lt_2p by blast


lemma get_shared_left_ub: "(get_shared_left sn < 2 ^ left_size)"
  proof -
    have "unat (sn >> right_size) < 2 ^ left_size" 
      using word_shiftr_lt[of sn right_size] 
      by (auto simp: right_size_def left_size_def)
    hence "unat (sn >> right_size) < unat ((2::64 word) ^ left_size)"
      using unat_p2[where 'a = 64, simplified, OF size_bounds(2)] 
      by presburger
    thus ?thesis 
      by(simp add: word_less_nat_alt get_shared_left_def)
  qed


lemma get_shared_left_lb: "sint (get_shared_left sn)  0"
  apply(simp add: get_shared_left_def right_size_def) 
  apply(rule ccontr)
  apply (clarsimp simp: not_le simp flip: word_msb_sint)
  done
  

lemma shared_nat_invar_true[simp]: "shared_nat_invar sn"
  apply (auto simp: get_shared_right_ub get_shared_left_ub get_shared_right_lb get_shared_left_lb shared_nat_invar_def)
  done


definition "shared_nat_rel = br shared_nat_α shared_nat_invar"
definition "shared_nat_assn_aux = word_assn"
abbreviation "shared_nat_assn  pure shared_nat_rel"




lemma shared_zero_refine: "(uncurry0 shared_zero, uncurry0 (0, 0))  unit_rel  shared_nat_rel"
  unfolding shared_nat_rel_def
  apply (auto simp: in_br_conv)
  done

                                
lemma make_shared_refine: "(uncurry make_shared, uncurry make_tuple)  [λ(nl,nr). nl < max_left  nr < max_right]f (size_rel ×r size_rel)  shared_nat_rel"
  unfolding make_tuple_def shared_nat_rel_def
  apply (auto intro!: frefI simp: in_br_conv snat_rel_def snat.rel_def)
  done


lemma get_left_refine: "(get_shared_left, shared_left)  shared_nat_rel  snat_rel"
  unfolding shared_nat_rel_def shared_left_def shared_nat_α_def max_right_def
  apply(clarsimp simp: in_br_conv shiftr_div_2n'[symmetric])
  apply(subst snat_eq_unat_aux2[symmetric])
  apply(auto simp: get_shared_left_def right_size_def intro!: n_gt_0_snat_invar_rshiftn snat_invar_id_snat_rel)
  done


lemma get_right_refine: "(get_shared_right, shared_right)  shared_nat_rel  snat_rel"
  unfolding shared_nat_rel_def shared_right_def shared_nat_α_def max_right_def
  apply(clarsimp simp: in_br_conv unat_take_bit_2p[symmetric])
  apply(subst snat_eq_unat_aux2[symmetric])
  apply(auto simp: get_shared_right_def right_size_def intro!: snat_invar_take_bit snat_invar_id_snat_rel)
  done


lemma get_shared_sum_refine: "(get_shared_sum, shared_sum)  shared_nat_rel  snat_rel"
  unfolding shared_nat_rel_def shared_sum_def shared_nat_α_def get_shared_sum_def max_right_def
  apply(clarsimp simp: in_br_conv shiftr_div_2n'[symmetric] unat_take_bit_2p[symmetric])
  apply(subst unat_add_lem'[symmetric])
  apply(rule less_trans[OF shared_sum_unat_in_max_snat[unfolded get_shared_left_def get_shared_right_def max_snat_def], of "2 ^ LENGTH(64)"])
  apply(auto intro!: snat_invar_unat_snat_rel snat_invar_get_shared_sum[unfolded get_shared_sum_alt_def get_shared_left_def get_shared_right_def])  
  done
  

sepref_def shared_zero_ll is "uncurry0 (RETURN (PR_CONST shared_zero))" :: "unit_assnk a shared_nat_assn_aux"
  unfolding shared_zero_def shared_nat_assn_aux_def PR_CONST_def
  apply sepref
  done


sepref_def make_shared_ll is "uncurry (RETURN oo make_shared)" :: "id_assnk *a id_assnk a shared_nat_assn_aux"
  unfolding make_shared_def right_size_def shared_nat_assn_aux_def
  apply (annot_snat_const "TYPE(size_T)")
  apply sepref
  done


sepref_def shared_left_ll is "RETURN o get_shared_left" :: "shared_nat_assn_auxk a id_assn"
  unfolding get_shared_left_def shared_nat_assn_aux_def right_size_def
  apply (annot_snat_const "TYPE(size_T)")
  apply sepref
  done


sepref_def shared_right_ll is "RETURN o get_shared_right" :: "shared_nat_assn_auxk a id_assn"
  unfolding get_shared_right_def shared_nat_assn_aux_def right_size_def
  apply (annot_snat_const "TYPE(size_T)")
  apply sepref
  done

sepref_def shared_sum_ll is "RETURN o get_shared_sum" :: "shared_nat_assn_auxk a id_assn"
  unfolding get_shared_sum_alt_def right_size_def
  apply sepref
  done


lemma shared_nat_assn_alt: "hr_comp shared_nat_assn_aux shared_nat_rel = shared_nat_assn"
  unfolding shared_nat_assn_aux_def
  apply auto
  done



context 
  notes[fcomp_norm_unfold] = shared_nat_assn_alt
begin

  lemmas [sepref_fr_rules] = shared_left_ll.refine[FCOMP get_left_refine]
  lemmas [sepref_fr_rules] = shared_right_ll.refine[FCOMP get_right_refine]
  lemmas [sepref_fr_rules] = make_shared_ll.refine[FCOMP make_shared_refine]
  lemmas [sepref_fr_rules] = shared_sum_ll.refine[FCOMP get_shared_sum_refine]

end

lemma div_less_mono_nat: "(A::nat) < B  B mod n = 0  A div n < B div n"
  using less_mult_imp_div_less by force

lemma rdomp_shared_left_upper:
  assumes "rdomp shared_nat_assn x"
  shows "shared_left x < max_left"
proof -
  {
    fix ni::size_t
    note unat_unat_rel[of ni]
    hence "unat ni < max_unat len_size_T"
      by(auto dest!: in_unat_rel_boundsD)
    moreover have "max_unat len_size_T mod max_right = 0"
      unfolding max_unat_def right_size_def max_right_def by simp
    ultimately have "unat ni div max_right < max_unat len_size_T div max_right"
      using div_less_mono_nat by presburger
    moreover have "max_unat len_size_T div max_right = max_left"
      using max_left_right_unat
      by (metis calculation div_by_0 less_nat_zero_code nonzero_mult_div_cancel_right)
    ultimately have "unat ni div max_right < max_left" by argo
  }
  then show ?thesis
    using assms
    unfolding shared_nat_rel_def shared_nat_α_def shared_left_def
    by(auto simp: in_br_conv max_unat_def)
qed

lemma rdomp_shared_right_upper:
  assumes "rdomp shared_nat_assn x"
  shows "shared_right x < max_right"
  using assms
  unfolding shared_nat_rel_def
  by (auto simp: in_br_conv shared_right_def shared_nat_α_def max_unat_def intro!: pos_mod_bound)


lemma nat_add_less_mono_rev: "(a::nat) < c  b < d  b + a < c + d"
  using add_less_mono by linarith

lemma rdomp_shared_sum_upper: "rdomp shared_nat_assn x  shared_sum x < max_right + max_left"
  apply(frule rdomp_shared_left_upper)
  apply(drule rdomp_shared_right_upper)
  unfolding shared_sum_as_left_plus_right
  apply(rule nat_add_less_mono_rev)
  by assumption+
  
corollary shared_nat_snat_boundD: "rdomp shared_nat_assn x  shared_sum x < max_snat len_size_T"
  apply(drule rdomp_shared_sum_upper)
  using max_left_plus_right_lt_max_snat 
  by simp

lemma shared_nat_rel_snat_boundD: "(ni,n)  shared_nat_rel  shared_sum n < max_snat len_size_T"
  by (auto simp: Range.intros shared_nat_snat_boundD)


end