Theory Bits_Natural

(* Author: Mathias Fleury
   Minor additions by Peter Lammich
*)
theory Bits_Natural
  imports "HOL-Library.Word" Word_Lib.Aligned "Word_Lib.Word_Lib_Sumo"
begin

(* TODO: Move *)
lemma bin_trunc_xor':
  "bintrunc n x XOR bintrunc n y = bintrunc n (x XOR y)"
  by (auto simp add: bin_eq_iff bin_nth_ops nth_bintr)

(*lemma uint_xor: "uint (x XOR y) = uint x XOR uint y"
  by (transfer, simp add: bin_trunc_xor')*)

(*instance nat :: semiring_bit_syntax ..*)

instantiation nat :: set_bit begin
  definition set_bit_nat :: "nat  nat  bool  nat" where
    "set_bit i n b = nat (bin_sc n b (int i))"

instance 
  apply standard
  apply (simp_all add: set_bit_nat_def)
  by (metis bin_nth_sc_gen bin_sign_sc bit_nat_iff bit_of_nat_iff_bit of_nat_0_le_iff sign_Pls_ge_0)
end  

instantiation nat :: msb
begin
  definition msb_nat :: "nat  bool" where
    "msb i = msb (int i)"

instance ..
end


(*
instantiation nat :: semiring_bit_shifts
begin

definition set_bits_nat :: "(nat ⇒ bool) ⇒ nat" where
  "set_bits f =
  (if ∃n. ∀n'≥n. ¬ f n' then
     let n = LEAST n. ∀n'≥n. ¬ f n'
     in nat (bl_to_bin (rev (map f [0..<n])))
   else if ∃n. ∀n'≥n. f n' then
     let n = LEAST n. ∀n'≥n. f n'
     in nat (sbintrunc n (bl_to_bin (True # rev (map f [0..<n]))))
   else 0 :: nat)"


definition not_nat :: "nat ⇒ nat" where
  "NOT i = nat (NOT (int i))"

(*
definition shiftl_nat where
  "shiftl x n = nat ((int x) * 2 ^ n)"

definition shiftr_nat where
  "shiftr x n = nat (int x div 2 ^ n)"

definition bitNOT_nat :: "nat ⇒ nat" where
  "bitNOT i = nat (bitNOT (int i))"

definition bitAND_nat :: "nat ⇒ nat ⇒ nat" where
  "bitAND i j = nat (bitAND (int i) (int j))"

definition bitOR_nat :: "nat ⇒ nat ⇒ nat" where
  "bitOR i j = nat (bitOR (int i) (int j))"

definition bitXOR_nat :: "nat ⇒ nat ⇒ nat" where
  "bitXOR i j = nat (bitXOR (int i) (int j))"

definition msb_nat :: "nat ⇒ bool" where
  "msb i = msb (int i)"
*)
instance . .

end
*)

lemma nat_shiftr[simp]:
  "m >> 0 = m"
  ((0::nat) >> m) = 0
  (m >> Suc n) = (m div 2 >> n) for m :: nat
  by (simp_all add: shiftr_def  drop_bit_Suc)

lemma nat_shifl_div: m >> n = m div (2^n) for m :: nat
  by (induction n arbitrary: m) (auto simp: div_mult2_eq)

lemma nat_shiftl[simp]:
  "m << 0 = m"
  ((0) << m) = 0
  (m << Suc n) = ((m * 2) << n) for m :: nat
  by (simp_all add: shiftl_def)

lemma nat_shiftr_div2: m >> 1 = m div 2 for m :: nat
  by auto

lemma nat_shiftr_div: m << n = m * (2^n) for m :: nat
  by (induction n arbitrary: m) (auto simp: div_mult2_eq)

definition shiftl1 :: nat  nat where
  shiftl1 n = n << 1

definition shiftr1 :: nat  nat where
  shiftr1 n = n >> 1

(*
instantiation natural :: bit_comprehension
begin

context includes natural.lifting begin

lift_definition test_bit_natural :: ‹natural ⇒ nat ⇒ bool› is test_bit .

lift_definition lsb_natural :: ‹natural ⇒ bool› is lsb .

lift_definition set_bit_natural :: "natural ⇒ nat ⇒ bool ⇒ natural" is
  set_bit .

lift_definition set_bits_natural :: ‹(nat ⇒ bool) ⇒ natural›
  is ‹set_bits :: (nat ⇒ bool) ⇒ nat› .

lift_definition shiftl_natural :: ‹natural ⇒ nat ⇒ natural›
  is ‹shiftl :: nat ⇒ nat ⇒ nat› .

lift_definition shiftr_natural :: ‹natural ⇒ nat ⇒ natural›
  is ‹shiftr :: nat ⇒ nat ⇒ nat› .

lift_definition bitNOT_natural :: ‹natural ⇒ natural›
  is ‹bitNOT :: nat ⇒ nat› .

lift_definition bitAND_natural :: ‹natural ⇒ natural ⇒ natural›
  is ‹bitAND :: nat ⇒ nat ⇒ nat› .

lift_definition bitOR_natural :: ‹natural ⇒ natural ⇒ natural›
  is ‹bitOR :: nat ⇒ nat ⇒ nat› .

lift_definition bitXOR_natural :: ‹natural ⇒ natural ⇒ natural›
  is ‹bitXOR :: nat ⇒ nat ⇒ nat› .

lift_definition msb_natural :: ‹natural ⇒ bool›
  is ‹msb :: nat ⇒ bool› .

end

instance ..
end
*)

lemma bitXOR_1_if_mod_2:  L XOR 1 = (if L mod 2 = 0 then L + 1 else L - 1) for L :: nat
  apply transfer
  apply (subst int_int_eq[symmetric])
  apply (rule bin_rl_eqI)
   apply (auto simp: xor_nat_def)
  unfolding bin_last_def xor_nat_def
       apply presburger+
  done

lemma bitAND_1_mod_2: L AND 1 = L mod 2 for L :: nat by auto

(*lemma nat_set_bit_0: ‹set_bit x 0 b = nat ((bin_rest (int x)) BIT b)› for x :: nat
  by (auto simp: set_bit_nat_def Bit_def) 
*)  

lemma nat_test_bit0_iff: n !! 0  n mod 2 = 1 for n :: nat
proof -
  have 2: 2 = int 2
    by auto
  have [simp]: int n mod 2 = 1  n mod 2 = Suc 0
    unfolding 2 zmod_int[symmetric]
    by auto

  show ?thesis
    by (auto simp: bin_last_def zmod_int bit_nat_def odd_iff_mod_2_eq_one)
    
qed

lemma test_bit_2: m > 0  (2*n) !! m  n !! (m - 1) for n :: nat
  by (cases m)
    (auto simp: bit_nat_def)

lemma test_bit_Suc_2: m > 0  Suc (2 * n) !! m  (2 * n) !! m for n :: nat
  apply (cases m)
  by (auto simp: bit_nat_def div_mult2_eq)

lemma bin_rest_prev_eq:
  assumes [simp]: m > 0
  shows  nat ((bin_rest (int w))) !! (m - Suc (0::nat)) = w !! m
proof -
  define m' where m' = w div 2
  have w: w = 2 * m'  w = Suc (2 * m')
    unfolding m'_def
    by auto
  moreover have bin_nth (int m') (m - Suc 0) = m' !! (m - Suc 0)
    by (simp add: bit_of_nat_iff_bit)
  ultimately show ?thesis
    by (auto simp: test_bit_2 test_bit_Suc_2)
qed

lemma bin_sc_ge0: w >= 0 ==> (0::int)  bin_sc n b w
  by (induction n arbitrary: w) auto

lemma bin_to_bl_eq_nat:
  bin_to_bl (size a) (int a) = bin_to_bl (size b) (int b) ==> a=b
  by (metis Nat.size_nat_def size_bin_to_bl)

lemma nat_bin_nth_bl: "n < m  w !! n = nth (rev (bin_to_bl m (int w))) n" for w :: nat
  by (metis bin_nth_bl bit_of_nat_iff_bit)

lemma bin_nth_ge_size: nat na  n  0  na  bin_nth na n = False
proof (induction n arbitrary: na)
  case 0
  then show ?case by auto
next
  case (Suc n na) note IH = this(1) and H = this(2-)
  have na = 1  0  na div 2
    using H by auto
  moreover have
    na = 0  na = 1  nat (na div 2)  n
    using H by auto
  ultimately show ?case
    using IH[rule_format,  of bin_rest na] H
    by (auto simp: bit_Suc)
qed

lemma test_bit_nat_outside: "n > size w  ¬w !! n" for w :: nat
  unfolding bit_nat_def
  by (metis Nat.size_nat_def div_less even_zero le_eq_less_or_eq le_less_trans n_less_equal_power_2)

lemma nat_bin_nth_bl':
  a !! n  (n < size a  (rev (bin_to_bl (size a) (int a)) ! n))
  by (metis Nat.size_nat_def bit_nat_def div_less even_zero n_less_equal_power_2 nat_bin_nth_bl not_less_iff_gr_or_eq test_bit_nat_outside)

lemma nat_set_bit_test_bit: set_bit w n x !! m = (if m = n then x else w !! m) for w n :: nat
  unfolding nat_bin_nth_bl'
  apply auto
        apply (metis bin_nth_bl bin_nth_sc bin_nth_simps(3) bin_to_bl_def int_nat_eq set_bit_nat_def)
       apply (metis bin_nth_ge_size bin_nth_sc bin_sc_ge0 leI of_nat_less_0_iff set_bit_nat_def)
      apply (metis bin_nth_bl bin_nth_ge_size bin_nth_sc bin_sc_ge0 bin_to_bl_def int_nat_eq leI
      of_nat_less_0_iff set_bit_nat_def)
      apply (metis Generic_set_bit.bit_set_bit_iff bin_to_bl_def nat_bin_nth_bl' size_nat)
    apply (metis Nat.size_nat_def bin_nth_bl bin_nth_sc_gen bin_to_bl_def int_nat_eq nat_bin_nth_bl
      nat_bin_nth_bl' of_nat_less_0_iff of_nat_less_iff set_bit_nat_def)
   apply (metis (full_types) bin_nth_bl bin_nth_ge_size bin_nth_sc_gen bin_sc_ge0 bin_to_bl_def leI of_nat_less_0_iff set_bit_nat_def)
  by (metis bin_nth_bl bin_nth_ge_size bin_nth_sc_gen bin_sc_ge0 bin_to_bl_def int_nat_eq leI of_nat_less_0_iff set_bit_nat_def)

  
  
lemma unat_or: "unat (x OR y) = unat x OR unat y" by (rule unsigned_or_eq)

lemma unat_and: "unat (x AND y) = unat x AND unat y" by (rule unsigned_and_eq)
  
lemma unat_xor: "unat (x XOR y) = unat x XOR unat y" by (rule unsigned_xor_eq)
  
  
(* TODO: Add OR-numerals, XOR-numerals! *)  
  
lemma nat_and_numerals [simp]:
  "(numeral (Num.Bit0 x) :: nat) AND (numeral (Num.Bit0 y) :: nat) = (2 :: nat) * (numeral x AND numeral y)"
  "numeral (Num.Bit0 x) AND numeral (Num.Bit1 y) = (2 :: nat) * (numeral x AND numeral y)"
  "numeral (Num.Bit1 x) AND numeral (Num.Bit0 y) = (2 :: nat) * (numeral x AND numeral y)"
  "numeral (Num.Bit1 x) AND numeral (Num.Bit1 y) = (2 :: nat) * (numeral x AND numeral y)+1"
  "0 AND n = 0"
  "n AND 0 = 0"
  "(1::nat) AND numeral (Num.Bit0 y) = 0"
  "(1::nat) AND numeral (Num.Bit1 y) = 1"
  "numeral (Num.Bit0 x) AND (1::nat) = 0"
  "numeral (Num.Bit1 x) AND (1::nat) = 1"
(*  "(Suc 0::nat) AND numeral (Num.Bit0 y) = 0"
  "(Suc 0::nat) AND numeral (Num.Bit1 y) = 1"
  "numeral (Num.Bit0 x) AND (Suc 0::nat) = 0"
  "numeral (Num.Bit1 x) AND (Suc 0::nat) = 1"*)
  "Suc 0 AND Suc 0 = 1"
  for n::nat
  by (auto)
  
  
  
lemma nat_and_comm: "a AND b = b AND a" for a b :: nat
  unfolding and_nat_def by (auto simp: int_and_comm)

lemma AND_upper_nat1: "a AND b  a" for a b :: nat
proof -
  have "int a AND int b  int a"
    by (rule AND_upper1) simp
  thus ?thesis unfolding and_nat_def by linarith
qed    

lemma AND_upper_nat2: "a AND b  b" for a b :: nat
  using AND_upper_nat1[of b a] by (simp add: nat_and_comm)
  
lemmas AND_upper_nat1' [simp] = order_trans [OF AND_upper_nat1]
lemmas AND_upper_nat1'' [simp] = order_le_less_trans [OF AND_upper_nat1]
  
lemmas AND_upper_nat2' [simp] = order_trans [OF AND_upper_nat2]
lemmas AND_upper_nat2'' [simp] = order_le_less_trans [OF AND_upper_nat2]


lemma msb_shiftr_nat [simp]: "msb ((x :: nat) >> r)  msb x"
  by (simp add: msb_int_def msb_nat_def)

lemma bintrunc_le: a  0  a < b  bintrunc n a < b
  by (smt bintr_lt2p bintrunc_mod2p mod_pos_pos_trivial)


lemma msb_shiftr_word [simp]: "r < LENGTH('a)  msb ((x :: 'a :: {len} word) >> r)  ((r = 0  msb x))"
  supply [[smt_trace]]
  apply (cases r)
  apply (auto simp: bl_shiftr word_size msb_word_def 
    simp flip: sint_uint[unfolded One_nat_def] hd_bl_sign_sint)
  done

lemma msb_shiftl_word [simp]: "r < LENGTH('a)   x << r < 2 ^ (LENGTH('a) - Suc 0) 
     msb ((x :: 'a :: {len} word) << r) = (r = 0  msb x)"
  using less_is_drop_replicate[of x << r (LENGTH('a) - Suc 0)]
  apply (cases r)
  apply (auto simp: bl_shiftl word_size msb_word_def 
    simp flip: sint_uint[unfolded One_nat_def] hd_bl_sign_sint)
  done

end