Theory SharedNat
section ‹Bit-Packed pairs of Numbers›
theory SharedNat
imports "../lib/Base_MEC"
begin
subsection ‹Miscellanneous Lemmas›
subsubsection ‹snat lemmas›
lemma snat_less_refine: "snat_invar a ⟹ snat_invar b ⟹ snat a < snat b ⟷ a < b"
apply(auto simp: snat_def snat_invar_def)
apply (meson word_sless_alt word_sless_msb_less)
apply (metis linorder_le_less_linear order_le_imp_less_or_eq signed_eq_0_iff word_gt_0 word_gt_a_gt_0 word_msb_sint)
by (simp add: sint_eq_uint word_less_def)
lemma snat_mask_id: "snat_invar x ⟹ snat x < 2 ^ n ⟹ x && mask n = x"
apply(simp add: snat_def snat_invar_def and_mask_mod_2p)
apply (simp add: sint_eq_uint)
done
lemma snat_take_bit_id: "snat_invar x ⟹ snat x < 2 ^ n ⟹ take_bit n x = x"
apply(simp add: snat_def snat_invar_def take_bit_word_eq_self_iff)
apply(cases "LENGTH('a) ≤ n")
apply (auto simp: sint_eq_uint uint_power_lower word_less_def)
done
lemma rshift_suc_n_alt: "(w >> Suc n) = (w >> n) >> 1" for w::"'a::len word"
apply(simp add: shiftr_shiftr)
done
lemma nmsb_rshiftn: "¬msb (w >> Suc n)" for w::"'a::len2 word"
apply(rewrite rshift_suc_n_alt)
by auto
corollary snat_invar_rshift_Sucn: "snat_invar (w >> Suc n)"
unfolding snat_invar_def
using nmsb_rshiftn .
corollary n_gt_0_snat_invar_rshiftn: "n > 0 ⟹ snat_invar (w >> n)"
unfolding snat_invar_def
using nmsb_rshiftn
using gr0_implies_Suc by blast
lemma nmsb_take_bit: "n < LENGTH('a::len2) ⟹ ¬msb (take_bit n (w::'a word))"
apply(rule ccontr)
apply (auto simp: word_msb_sint signed_take_bit_eq)
done
corollary snat_invar_take_bit: "n < LENGTH('a::len2) ⟹ snat_invar (take_bit n (w::'a word))"
unfolding snat_invar_def
using nmsb_take_bit .
lemma take_bit_lshift_is_0: "take_bit n (a << n) = 0" for a::"'a::len word"
apply (auto simp: take_bit_eq_mask)
done
lemma snat_invar_id_snat_rel: "snat_invar ii ⟹ (ii, snat ii) ∈ snat_rel"
apply(auto simp: snat_rel_def snat.rel_def in_br_conv snat_invar_def)
done
lemma snat_invar_unat_snat_rel: "snat_invar ii ⟹ (ii, unat ii) ∈ snat_rel"
apply(auto simp: snat_eq_unat_aux2[symmetric] snat_invar_id_snat_rel)
done
lemma p_sint_id_snat_rel: "sint ii ≥ 0 ⟹ (ii, snat ii) ∈ snat_rel"
apply(auto simp: snat_rel_def snat.rel_def in_br_conv snat_invar_def word_msb_sint)
done
lemma snat_invar_unat_bound:"snat_invar w ⟷ unat w < max_snat LENGTH('a)" for w::"'a::len2 word"
apply auto
apply (metis snat_eq_unat_aux2 snat_lt_max_snat)
apply (rule ccontr)
unfolding snat_invar_def
apply (auto simp: msb_unat_big max_snat_def)
done
subsubsection ‹unat lemmas›
lemma mod_word_2p_unat: "n < LENGTH ('a::len) ⟹ (unat (w::'a word)) mod (2 ^ n) = unat (w mod 2 ^ n)"
by (simp add: unat_mod_distrib)
lemma div_word_2p_unat: "n < LENGTH ('a::len) ⟹ (unat (w::'a word)) div (2 ^ n) = unat (w div 2 ^ n)"
by (simp add: unat_div_distrib)
lemma unat_take_bit_2p: "unat (take_bit n w) = unat w mod 2 ^ n"
by (metis take_bit_eq_mod unsigned_take_bit_eq)
lemma unat_and_mask_lt_2p: "unat (w && mask n) < 2 ^ n"
by (metis take_bit_eq_mask take_bit_nat_less_exp unsigned_take_bit_eq)
lemma unat_take_bit_lt_2p: "unat (take_bit n a) < 2 ^ n"
by(auto simp: take_bit_eq_mask unat_and_mask_lt_2p)
lemma unat_unat_rel: "(ii, unat ii) ∈ unat_rel"
by(auto simp: unat_rel_def unat.rel_def in_br_conv)
subsection ‹Mask and take_bit setup for sepref›
lemma word_mask_less_length_no_msb: "n < LENGTH('a::len) ⟹ ¬(msb::'a word ⇒ bool) (mask n)"
by (metis (no_types, lifting) min.absorb4 signed_take_bit_eq take_bit_minus_one_eq_mask take_bit_take_bit uint_lt_0 uint_sint word_msb_sint)
definition mask_impl :: "nat ⇒ ('a::len2) word" where "mask_impl n = (1 << n) - 1"
lemma mask_refine: "(mask_impl, mask) ∈ nat_rel → word_rel"
unfolding mask_impl_def
apply(auto simp: mask_eq)
done
sepref_def mask_ll is "RETURN o (mask_impl :: nat ⇒ ('a::len2) word)" :: "[λ n. n < LENGTH('a)]⇩a (snat_assn' TYPE('a::len2))⇧k → id_assn"
unfolding mask_impl_def
apply sepref
done
lemmas [sepref_fr_rules] = mask_ll.refine[FCOMP mask_refine]
definition take_bit_impl :: "nat ⇒ ('a::len2) word ⇒ 'a word" where "take_bit_impl n a = a AND ((1 << n) - 1)"
lemma take_bit_refine: "(take_bit_impl, take_bit) ∈ nat_rel → word_rel → word_rel"
unfolding take_bit_impl_def
apply(auto simp: fun_eq_iff take_bit_eq_mask mask_eq)
done
sepref_def take_bit_ll is "uncurry (RETURN oo (take_bit_impl :: nat ⇒ 'a word ⇒ 'a word))" :: "[λ (n,_). n < LENGTH('a::len2)]⇩a (snat_assn' TYPE('a))⇧k *⇩a id_assn⇧d → id_assn"
unfolding take_bit_impl_def
apply sepref
done
lemmas [sepref_fr_rules] = take_bit_ll.refine[FCOMP take_bit_refine]
definition shared_left :: "nat × nat ⇒ nat" where "shared_left = fst"
definition shared_right :: "nat × nat ⇒ nat" where "shared_right = snd"
definition shared_sum :: "nat × nat ⇒ nat" where "shared_sum = (λ (l,r). l + r)"
definition make_tuple :: "nat ⇒ nat ⇒ nat × nat" where "make_tuple nl nr = (nl,nr)"
lemma shared_sum_as_left_plus_right: "shared_sum x = shared_left x + shared_right x"
unfolding shared_left_def shared_right_def shared_sum_def
apply (cases x)
by simp
definition "right_size = (44::nat)"
definition "left_size = len_size_T - right_size"
lemma size_bounds: "right_size < len_size_T" "left_size < len_size_T"
unfolding right_size_def left_size_def by auto
lemma left_right_add: "left_size + right_size = len_size_T"
unfolding left_size_def right_size_def
by simp
definition "max_left = 2 ^ left_size"
definition "max_right = 2 ^ right_size"
lemma max_left_alt: "max_left = max_unat left_size"
unfolding max_left_def max_unat_def by simp
lemma max_right_alt: "max_right = max_unat right_size"
unfolding max_right_def max_unat_def by simp
lemma max_left_gt_0[simp]: "max_left > (0::nat)"
unfolding max_left_def
by simp
lemma max_right_gt_0[simp]: "max_right > (0::nat)"
unfolding max_right_def
by simp
lemma max_left_right_unat: "max_left * max_right = max_unat len_size_T"
unfolding max_left_def max_right_def max_unat_def
apply(fold monoid_mult_class.power_add)
unfolding left_right_add
by presburger
lemma max_left_plus_right_lt_max_snat: "max_left + max_right < max_snat len_size_T"
unfolding max_left_def max_right_def left_size_def right_size_def max_snat_def
by simp
type_synonym shared_nat = "size_t"
definition shared_zero :: "shared_nat" where "shared_zero = 0"
definition make_shared :: "size_t ⇒ size_t ⇒ shared_nat" where "make_shared nl nr = (nl << right_size) OR nr"
definition get_shared_left :: "shared_nat ⇒ size_t" where "get_shared_left sn = (sn >> right_size)"
definition get_shared_right :: "shared_nat ⇒ size_t" where "get_shared_right sn = take_bit right_size sn"
definition get_shared_sum :: "shared_nat ⇒ size_t" where "get_shared_sum sn = (sn >> right_size) + take_bit right_size sn"
lemma get_shared_sum_alt_def: "get_shared_sum sn = get_shared_left sn + get_shared_right sn"
unfolding get_shared_sum_def get_shared_left_def get_shared_right_def
by simp
lemmas [llvm_inline] = make_shared_def
definition "shared_nat_α sn = (unat sn div max_right, unat sn mod max_right)"
lemma shared_zero_α[simp]: "shared_nat_α shared_zero = (0,0)"
unfolding shared_nat_α_def shared_zero_def get_shared_left_def get_shared_right_def
by simp
lemma make_shared_α[simp]: "snat_invar nl ⟹ snat_invar nr ⟹ snat nl < max_left ⟹ snat nr < max_right ⟹
shared_nat_α (make_shared nl nr) = (snat nl, snat nr)"
unfolding shared_nat_α_def make_shared_def get_shared_left_def get_shared_right_def max_left_def max_right_def
apply (clarsimp simp: shiftr_over_or_dist simp flip: shiftr_div_2n')
apply(subst shiftl_shiftr_id)
apply (auto simp: right_size_def left_size_def)[2]
apply (subst(asm) snat_numeral[where 'a = size_T, symmetric])
apply (simp add: max_snat_def)
apply (subst(asm) snat_less_refine; (simp add: snat_invar_def; erule less_trans[of nl]; simp)|(simp add: snat_invar_def))
apply(rewrite mod_word_2p_unat[of right_size "((nl << right_size) || nr)", simplified, OF size_bounds(1)])
apply (auto simp add: shiftr_le_0 word_ao_dist snat_eq_unat_aux2 mod_2p_is_mask)
by (metis take_bit_eq_mask take_bit_nat_eq_self unsigned_take_bit_eq)
lemma get_left_α[simp]: "unat (get_shared_left sn) = shared_left (shared_nat_α sn)"
unfolding shared_nat_α_def shared_left_def get_shared_left_def max_right_def
apply (clarsimp simp add: shiftr_div_2n_w[symmetric] div_word_2p_unat[of right_size sn, simplified, OF size_bounds(1)]
intro!: snat_eq_unat_aux2)
done
lemma get_right_α[simp]: "unat (get_shared_right sn) = shared_right (shared_nat_α sn) "
unfolding shared_nat_α_def shared_right_def get_shared_right_def max_right_def
apply(clarsimp simp: mod_word_2p_unat[of right_size sn, simplified, OF size_bounds(1)])
apply(rewrite word_eqI_folds(3))
apply simp
done
lemma shared_sum_unat_bounds: "unat (get_shared_left sn) + unat (get_shared_right sn) < 2 ^ (64 - right_size) + 2 ^ right_size"
proof -
have "unat sn div 2 ^ right_size ≤ 2 ^ (64 - right_size)"
apply(rewrite power_minus_is_div[OF less_imp_le_nat, OF size_bounds(1)])
apply(rule div_le_mono)
using unat_lt_max_unat[of sn]
unfolding max_unat_def by simp
moreover have "unat sn mod 2 ^ right_size < 2 ^ right_size"
by force
ultimately have "unat sn div 2 ^ right_size + unat sn mod 2 ^ right_size < 2 ^ (64 - right_size) + 2 ^ right_size"
by linarith
then show ?thesis
by (auto simp: shared_nat_α_def shared_left_def shared_right_def max_right_def)
qed
corollary shared_sum_unat_in_max_snat: "unat (get_shared_left sn) + unat (get_shared_right sn) < max_snat len_size_T"
using shared_sum_unat_bounds[of sn]
unfolding right_size_def max_snat_def
by simp
lemma unat_shared_sum_bounds: "unat ((get_shared_left sn) + (get_shared_right sn)) < 2 ^ (64 - right_size) + 2 ^ right_size"
apply(subst unat_add_lem')
using shared_sum_unat_in_max_snat[unfolded max_snat_def, of sn]
apply fastforce
using shared_sum_unat_bounds[of sn] .
corollary unat_shared_sum_in_max_snat_aux: "unat ((get_shared_left sn) + (get_shared_right sn)) < max_snat len_size_T"
using unat_shared_sum_bounds[of sn]
unfolding right_size_def max_snat_def
by simp
corollary unat_shared_sum_in_max_snat: "unat ((get_shared_sum sn)) < max_snat len_size_T"
unfolding get_shared_sum_alt_def
using unat_shared_sum_in_max_snat_aux[of sn] .
lemma snat_invar_get_shared_left: "snat_invar (get_shared_left sn)"
unfolding get_shared_left_def right_size_def
by (simp add: n_gt_0_snat_invar_rshiftn)
lemma snat_invar_get_shared_right: "snat_invar (get_shared_right sn)"
unfolding get_shared_right_def right_size_def
apply(rule snat_invar_take_bit)
by force
lemma snat_invar_get_shared_sum: "snat_invar (get_shared_sum sn)"
apply(auto simp: snat_invar_unat_bound unat_shared_sum_in_max_snat)
done
lemma shared_sum_α[simp]: "unat (get_shared_sum sn) = shared_sum (shared_nat_α sn)"
unfolding shared_sum_def shared_nat_α_def get_shared_sum_def
apply(auto simp: shiftr_div_2n'[symmetric] unat_take_bit_2p[symmetric] max_right_def)
apply(subst unat_add_lem'[symmetric])
using shared_sum_unat_in_max_snat[of sn]
unfolding max_snat_def get_shared_left_def get_shared_right_def
apply auto
done
definition shared_nat_invar :: "shared_nat ⇒ bool" where "shared_nat_invar sn = ((get_shared_left sn < 2 ^ left_size) ∧ (get_shared_right sn < 2 ^ right_size))"
lemma get_shared_right_ub: "get_shared_right sn < 2 ^ right_size"
proof -
have "sn && 2 ^ right_size - 1 ≤ 2 ^ right_size - 1"
using Word.word_and_le1 by blast
also have "... < 2 ^ right_size"
unfolding right_size_def by force
finally show ?thesis
apply (auto simp: get_shared_right_def mask_eq take_bit_eq_mask)
done
qed
lemma get_shared_right_lb: "sint (get_shared_right sn) ≥ 0"
apply(simp add: get_shared_right_def take_bit_eq_mod sint_eq_uint_2pl)
apply(subgoal_tac "(sn mod 2 ^ right_size) < 2 ^ (LENGTH(64) - 1)" )
apply(drule sint_eq_uint_2pl)
apply simp
apply(rule less_trans[of "sn mod 2 ^ right_size" "2 ^ right_size" "2 ^ (LENGTH(64) - 1)"])
apply(rule word_mod_less_divisor)
apply (auto simp add: right_size_def)
done
lemma nmsb_get_shared_right: "¬msb (get_shared_right sn)"
unfolding word_msb_sint
using get_shared_right_lb[of sn]
by fastforce
lemma get_shared_right_snat_ub: "snat (get_shared_right sn) < 2 ^ right_size"
unfolding snat_eq_unat_aux2[unfolded snat_invar_def, OF nmsb_get_shared_right]
unfolding get_shared_right_def
using unat_take_bit_lt_2p by blast
lemma get_shared_left_ub: "(get_shared_left sn < 2 ^ left_size)"
proof -
have "unat (sn >> right_size) < 2 ^ left_size"
using word_shiftr_lt[of sn right_size]
by (auto simp: right_size_def left_size_def)
hence "unat (sn >> right_size) < unat ((2::64 word) ^ left_size)"
using unat_p2[where 'a = 64, simplified, OF size_bounds(2)]
by presburger
thus ?thesis
by(simp add: word_less_nat_alt get_shared_left_def)
qed
lemma get_shared_left_lb: "sint (get_shared_left sn) ≥ 0"
apply(simp add: get_shared_left_def right_size_def)
apply(rule ccontr)
apply (clarsimp simp: not_le simp flip: word_msb_sint)
done
lemma shared_nat_invar_true[simp]: "shared_nat_invar sn"
apply (auto simp: get_shared_right_ub get_shared_left_ub get_shared_right_lb get_shared_left_lb shared_nat_invar_def)
done
definition "shared_nat_rel = br shared_nat_α shared_nat_invar"
definition "shared_nat_assn_aux = word_assn"
abbreviation "shared_nat_assn ≡ pure shared_nat_rel"
lemma shared_zero_refine: "(uncurry0 shared_zero, uncurry0 (0, 0)) ∈ unit_rel → shared_nat_rel"
unfolding shared_nat_rel_def
apply (auto simp: in_br_conv)
done
lemma make_shared_refine: "(uncurry make_shared, uncurry make_tuple) ∈ [λ(nl,nr). nl < max_left ∧ nr < max_right]⇩f (size_rel ×⇩r size_rel) → shared_nat_rel"
unfolding make_tuple_def shared_nat_rel_def
apply (auto intro!: frefI simp: in_br_conv snat_rel_def snat.rel_def)
done
lemma get_left_refine: "(get_shared_left, shared_left) ∈ shared_nat_rel → snat_rel"
unfolding shared_nat_rel_def shared_left_def shared_nat_α_def max_right_def
apply(clarsimp simp: in_br_conv shiftr_div_2n'[symmetric])
apply(subst snat_eq_unat_aux2[symmetric])
apply(auto simp: get_shared_left_def right_size_def intro!: n_gt_0_snat_invar_rshiftn snat_invar_id_snat_rel)
done
lemma get_right_refine: "(get_shared_right, shared_right) ∈ shared_nat_rel → snat_rel"
unfolding shared_nat_rel_def shared_right_def shared_nat_α_def max_right_def
apply(clarsimp simp: in_br_conv unat_take_bit_2p[symmetric])
apply(subst snat_eq_unat_aux2[symmetric])
apply(auto simp: get_shared_right_def right_size_def intro!: snat_invar_take_bit snat_invar_id_snat_rel)
done
lemma get_shared_sum_refine: "(get_shared_sum, shared_sum) ∈ shared_nat_rel → snat_rel"
unfolding shared_nat_rel_def shared_sum_def shared_nat_α_def get_shared_sum_def max_right_def
apply(clarsimp simp: in_br_conv shiftr_div_2n'[symmetric] unat_take_bit_2p[symmetric])
apply(subst unat_add_lem'[symmetric])
apply(rule less_trans[OF shared_sum_unat_in_max_snat[unfolded get_shared_left_def get_shared_right_def max_snat_def], of "2 ^ LENGTH(64)"])
apply(auto intro!: snat_invar_unat_snat_rel snat_invar_get_shared_sum[unfolded get_shared_sum_alt_def get_shared_left_def get_shared_right_def])
done
sepref_def shared_zero_ll is "uncurry0 (RETURN (PR_CONST shared_zero))" :: "unit_assn⇧k →⇩a shared_nat_assn_aux"
unfolding shared_zero_def shared_nat_assn_aux_def PR_CONST_def
apply sepref
done
sepref_def make_shared_ll is "uncurry (RETURN oo make_shared)" :: "id_assn⇧k *⇩a id_assn⇧k →⇩a shared_nat_assn_aux"
unfolding make_shared_def right_size_def shared_nat_assn_aux_def
apply (annot_snat_const "TYPE(size_T)")
apply sepref
done
sepref_def shared_left_ll is "RETURN o get_shared_left" :: "shared_nat_assn_aux⇧k →⇩a id_assn"
unfolding get_shared_left_def shared_nat_assn_aux_def right_size_def
apply (annot_snat_const "TYPE(size_T)")
apply sepref
done
sepref_def shared_right_ll is "RETURN o get_shared_right" :: "shared_nat_assn_aux⇧k →⇩a id_assn"
unfolding get_shared_right_def shared_nat_assn_aux_def right_size_def
apply (annot_snat_const "TYPE(size_T)")
apply sepref
done
sepref_def shared_sum_ll is "RETURN o get_shared_sum" :: "shared_nat_assn_aux⇧k →⇩a id_assn"
unfolding get_shared_sum_alt_def right_size_def
apply sepref
done
lemma shared_nat_assn_alt: "hr_comp shared_nat_assn_aux shared_nat_rel = shared_nat_assn"
unfolding shared_nat_assn_aux_def
apply auto
done
context
notes[fcomp_norm_unfold] = shared_nat_assn_alt
begin
lemmas [sepref_fr_rules] = shared_left_ll.refine[FCOMP get_left_refine]
lemmas [sepref_fr_rules] = shared_right_ll.refine[FCOMP get_right_refine]
lemmas [sepref_fr_rules] = make_shared_ll.refine[FCOMP make_shared_refine]
lemmas [sepref_fr_rules] = shared_sum_ll.refine[FCOMP get_shared_sum_refine]
end
lemma div_less_mono_nat: "(A::nat) < B ⟹ B mod n = 0 ⟹ A div n < B div n"
using less_mult_imp_div_less by force
lemma rdomp_shared_left_upper:
assumes "rdomp shared_nat_assn x"
shows "shared_left x < max_left"
proof -
{
fix ni::size_t
note unat_unat_rel[of ni]
hence "unat ni < max_unat len_size_T"
by(auto dest!: in_unat_rel_boundsD)
moreover have "max_unat len_size_T mod max_right = 0"
unfolding max_unat_def right_size_def max_right_def by simp
ultimately have "unat ni div max_right < max_unat len_size_T div max_right"
using div_less_mono_nat by presburger
moreover have "max_unat len_size_T div max_right = max_left"
using max_left_right_unat
by (metis calculation div_by_0 less_nat_zero_code nonzero_mult_div_cancel_right)
ultimately have "unat ni div max_right < max_left" by argo
}
then show ?thesis
using assms
unfolding shared_nat_rel_def shared_nat_α_def shared_left_def
by(auto simp: in_br_conv max_unat_def)
qed
lemma rdomp_shared_right_upper:
assumes "rdomp shared_nat_assn x"
shows "shared_right x < max_right"
using assms
unfolding shared_nat_rel_def
by (auto simp: in_br_conv shared_right_def shared_nat_α_def max_unat_def intro!: pos_mod_bound)
lemma nat_add_less_mono_rev: "(a::nat) < c ⟹ b < d ⟹ b + a < c + d"
using add_less_mono by linarith
lemma rdomp_shared_sum_upper: "rdomp shared_nat_assn x ⟹ shared_sum x < max_right + max_left"
apply(frule rdomp_shared_left_upper)
apply(drule rdomp_shared_right_upper)
unfolding shared_sum_as_left_plus_right
apply(rule nat_add_less_mono_rev)
by assumption+
corollary shared_nat_snat_boundD: "rdomp shared_nat_assn x ⟹ shared_sum x < max_snat len_size_T"
apply(drule rdomp_shared_sum_upper)
using max_left_plus_right_lt_max_snat
by simp
lemma shared_nat_rel_snat_boundD: "(ni,n) ∈ shared_nat_rel ⟹ shared_sum n < max_snat len_size_T"
by (auto simp: Range.intros shared_nat_snat_boundD)
end