Theory IEEE_Float_Extend

section Converting Float to Double
theory IEEE_Float_Extend
imports "../More_Eisbach_Tools" IEEE_Fp_Add_Basic "../LLVM_More_Word_Lemmas" IEEE_Float_To_Word
begin
  text 
    LLVM code only accepts double (64 bit) constants, 
    even for float (32 bit) type. In that case, 
    LLVM requires that the specified double is exactly 
    representable as a float. 
    
    In this theory, we formalize and prove correct this conversion, and reflect it to Isabelle-ML, 
    such that we can use the verified conversion in our code generator.
  

  subsection Find Highest Bit
  (*
    given n>0, find h such that
    
      2^h≤n<2^(h+1)
  *)  
  fun highest_bit :: "nat  nat"  where
    "highest_bit n = (if n<2 then 0 else 1 + highest_bit (n div 2) )"  
  declare highest_bit.simps[simp del]  
      
  lemma highest_bit_lower: "0<n  2^highest_bit n  n"  
    apply (induction n rule: highest_bit.induct)
    apply (subst highest_bit.simps)
    by auto
    
  lemma highest_bit_upper: "n < 2^(highest_bit n+1)"  
    apply (induction n rule: highest_bit.induct)
    apply (subst highest_bit.simps)
    by auto
    
  lemmas highest_bit_bounds = highest_bit_lower highest_bit_upper
    
  definition "lower_bits n  n - 2^highest_bit n"
  
  lemma lower_bits_upper: "lower_bits n < 2^highest_bit n"
    unfolding lower_bits_def
    using highest_bit_bounds[of n]
    by (auto)
  
  lemma highest_lower_char: "0<n  2^highest_bit n + lower_bits n = n"
    by (simp add: highest_bit_lower lower_bits_def)
  
  text Can also be expressed with floorlog! 
  lemma highest_bit_is_floorlog: "highest_bit n = floorlog 2 n - 1"  
    by (smt (verit, ccfv_SIG) Suc_pred' add_diff_cancel_right' floorlog_ge_SucI floorlog_leI highest_bit.simps highest_bit_lower highest_bit_upper le_less_trans nat_less_le one_less_numeral_iff semiring_norm(76) zero_less_diff zero_order(1))

  lemma compute_highest_bit:
    "highest_bit 0 = 0"  
    "highest_bit 1 = 0"
    "highest_bit (Suc 0) = 0"
    "highest_bit (numeral n) = (if numeral n < (2::nat) then 0 else 1 + highest_bit (numeral n div 2))"
    by (subst highest_bit.simps; simp; fail)+

  lemma highest_bit_mono: "ab  highest_bit a  highest_bit b"  
    unfolding highest_bit_is_floorlog
    using diff_le_mono floorlog_mono by presburger

  lemma highest_bit_mono': "a<b  highest_bit a  highest_bit b"  
    using highest_bit_mono by simp
    
    
    
  
  subsection Extend Floating Point Number

  text Locale to summarize correctness criterion:
    all (meaningful) properties are preserved. In case of NaN, 
    the payload is bit-shifted.
  
  locale correct_extension =
    fixes f1 :: "('e1,'f1) float"
      and f2 :: "('e2,'f2) float"
    fixes ΔF
    defines "ΔF  LENGTH('f2) - LENGTH('f1)"  
    assumes preserves_sign[simp]: "sign f2 = sign f1"  
    assumes preserves_finiteness[simp]: 
      "is_finite f2  is_finite f1"
      "is_infinity f2  is_infinity f1"
    assumes preserves_nan[simp]: "is_nan f2  is_nan f1"
    assumes preserves_valof[simp]: "is_finite f1  valof f2 = valof f1"
    assumes preserves_nan_payload[simp]: "is_nan f1  fraction f2 = fraction f1 * 2^ΔF"
  begin
  
  
  end
  

  text We only cover the case where f1 ≤ 2^(e2-1) - 2^(e1-1)›.
    In this case, a denormal number is always extended to a normal number.
  
  locale float_extend_conv =
    fixes F1 :: "('e1,'f1) float itself"
    fixes F2 :: "('e2,'f2) float itself"
    assumes ELEN': "LENGTH('e1)  LENGTH('e2)"
    assumes FLEN': "LENGTH('f1)  LENGTH('f2)"
    assumes E2_cond': "LENGTH('f1)  2^(LENGTH('e2)-1)-2^(LENGTH('e1)-1)"
  begin
  
    context
      fixes E1 E2 F1 F2 ΔE ΔF ΔB B1 B2
      fixes f :: "('e1,'f1) float"
      defines "E1  LENGTH('e1)"
      defines "E2  LENGTH('e2)"
      defines "F1  LENGTH('f1)"
      defines "F2  LENGTH('f2)"
      
      defines B1_def': "B1  bias TYPE(('e1,'f1)float)"
      defines B2_def': "B2  bias TYPE(('e2,'f2)float)"
      
      defines "ΔE  E2 - E1"
      defines "ΔF  F2 - F1"
      defines "ΔB  B2 - B1"
      
    begin
  
      definition conv_norm :: "('e2,'f2) float" where                 
        "conv_norm  Abs_float' (sign f) (exponent f + ΔB) (2^ΔF*fraction f)"
    
      definition conv_denorm :: "('e2,'f2) float" where 
        "conv_denorm  let
          h = highest_bit (fraction f);
          r = lower_bits  (fraction f)
        in
          Abs_float' (sign f) (ΔB+1 - LENGTH('f1) + h) (2^(LENGTH('f2)-h) * r)"
    
      definition conv_zero :: "('e2,'f2) float" where
        "conv_zero  if sign f = 0 then 0 else minus_zero"

      definition conv_inf :: "('e2,'f2) float" where
        "conv_inf  if sign f = 0 then plus_infinity else minus_infinity"

      definition conv_nan :: "('e2,'f2) float" where
        "conv_nan  Abs_float' (sign f) (emax TYPE(('e2, 'f2)float)) (fraction f * 2^ΔF)"
                
      definition conv :: "('e2,'f2) float" where
        "conv  
             if is_normal f then conv_norm
        else if is_denormal f then conv_denorm
        else if is_infinity f then conv_inf
        else if is_zero f then conv_zero
        else conv_nan" 
  
      schematic_goal conv_eq_unfolded: 
        defines "TAG  λx. x"
        shows "conv = TAG ?foo" 
        unfolding conv_def conv_norm_def conv_denorm_def conv_inf_def conv_zero_def conv_nan_def
        unfolding TAG_def ..
        
                  
      lemma ELEN: "E1  E2"    
        unfolding E1_def E2_def using ELEN' by auto
                    
      lemma FLEN: "F1  F2"    
        unfolding F1_def F2_def using FLEN' by auto
          
      lemma LEN_ne_Z[simp]: 
        "E10" "E20" "F10" "F20" 
        "E1>0" "E2>0" "F1>0" "F2>0" 
        "E1Suc 0" "E2Suc 0" "F1Suc 0" "F2Suc 0" 
        unfolding E1_def E2_def F1_def F2_def 
        by (auto simp: Suc_leI)

      lemma B1_def: "B1 = 2^(E1-1)-1"  
        unfolding B1_def' bias_def E1_def ..
        
      lemma B2_def: "B2 = 2^(E2-1)-1"  
        unfolding B2_def' bias_def E2_def ..
        
      lemma fr_bound[simp]: "fraction f < 2 ^ F1"  
        unfolding F1_def by (simp add: fraction_upper_bound)
        
      lemma exp_bound[simp]: "exponent f < 2 ^ E1"  
        unfolding E1_def by (simp add: exponent_upper_bound)
        
      lemmas shortcut_defs[no_atp] = E1_def E2_def F1_def F2_def B1_def' B2_def'
      lemmas shortcut_folds[simp, no_atp] = shortcut_defs[symmetric]

      lemma LEN2_conv:
        "E2 = E1 + ΔE"  
        "F2 = F1 + ΔF"  
        unfolding ΔE_def ΔF_def using ELEN FLEN by auto

      lemma BLEN: "B1  B2"  
        unfolding B1_def B2_def using ELEN 
        by (meson diff_le_mono one_le_numeral power_increasing)
        
      lemma bias2_conv: "B2 = B1 + ΔB"  
        unfolding ΔB_def using BLEN by simp
      
                      
      lemma ΔB_alt: "ΔB = 2^(E1-1)*(2^ΔE - 1)"
        unfolding ΔB_def ΔE_def B1_def B2_def
        using ELEN
        by (auto simp: algebra_simps Suc_le_eq simp flip: power_add)
      
  
      lemma aux_simp1[simp]: "E2 - E1 + (E1 - Suc 0) = E2 - 1"  
        by (simp add: ELEN Suc_le_eq)
        
        
      lemma ΔB_lt[simp]: "ΔB < 2 ^ E2" 
        unfolding ΔB_alt ΔE_def
        using ELEN 
        by (auto simp: algebra_simps LEN2_conv less_imp_diff_less simp flip: power_add)
        
  
      lemma ΔB_lt'[simp]: "Suc ΔB < 2 ^ E2" 
        unfolding ΔB_alt ΔE_def
        using ELEN 
        apply (auto simp: algebra_simps LEN2_conv simp flip: power_add)
        by (smt (verit, best) LEN2_conv(1) LEN_ne_Z(6) Suc_pred diff_less less_imp_diff_less linorder_neqE_nat not_less_eq one_less_numeral_iff pos2 power_Suc power_less_power_Suc semiring_norm(76) zero_less_power)
        
      lemma sign_lt[simp]: "sign f < 2"  
        by (cases f rule: sign_cases) auto
        

        
      abbreviation "H  highest_bit (fraction f)"   
      abbreviation "L  lower_bits (fraction f)"   
        
      lemma H_upper: "H < F1"
        by (metis LEN_ne_Z(3) floorlog_leI fr_bound highest_bit.simps highest_bit_is_floorlog highest_bit_lower leD less_imp_diff_less less_nat_zero_code nat_less_le nat_neq_iff one_less_numeral_iff pos2 semiring_norm(76))
                      
      lemma [simp]: "Suc (ΔB + H) - F1 < 2 ^ E2"
        using H_upper ΔB_lt' by linarith
        
      lemma [simp]: "2 ^ (F2 - H) * L < 2 ^ F2"   
        using FLEN H_upper diff_le_self lower_bits_upper nat_less_power_trans by auto

      lemma E2_cond: "F1  2^(E2-1)-2^(E1-1)" using E2_cond' by simp
        
      lemma ΔB_bound[simp]: "Suc ΔB > F1"
        by (metis B1_def B2_def E2_cond ΔB_def diff_diff_left le_add_diff_inverse le_imp_less_Suc one_le_numeral one_le_power)
          
      lemma [simp]: "Suc ΔB  F1" 
        using ΔB_bound by linarith
             
      lemma [simp]: "¬(Suc (ΔB+H)  F1)"  
        using ΔB_bound by linarith

        
      lemma [simp]: "fraction f * 2 ^ ΔF < 2 ^ F2"
        by (simp add: LEN2_conv(2) nat_mult_power_less_eq)
        
      lemma [simp]: "emax TYPE(('e2, 'f2) IEEE.float) = 2 ^ E2 - 1"  
        unfolding emax_def E2_def
        using unat_minus_one_word by blast
        
      (*lemma [simp]: "emax TYPE(('e2, 'f2) IEEE.float) < 2 ^ E2"  
        unfolding emax_def E2_def
        by blast
      *)
        
                
      lemma exp_ΔB_bound[simp]: "(IEEE.exponent f + ΔB) < 2 ^ E2"  
        unfolding ΔB_alt ΔE_def
      proof (simp add: algebra_simps flip: power_add )
        show "exponent f + (2 ^ (E2 - Suc 0) - 2 ^ (E1 - Suc 0)) < 2 ^ E2" (is "?lhs<_")
        proof -
          have "?lhs = exponent f + (2^(E2-1) - 2^(E1-1))" using ELEN by auto
          also have " < 2^E1 + (2^(E2-1) - 2^(E1-1))" using exponent_upper_bound[of f] by (simp)
          also have " = 2^(E1-1) + 2^(E1-1) + (2^(E2-1) - 2^(E1-1))" using ELEN 
            by (metis LEN_ne_Z(5) Suc_diff_1 mult.commute mult_2_right power_Suc)
          also have " = 2^(E1-1) + 2^(E2-1)" using ELEN by simp
          also have "  2^(E2-1) + 2^(E2-1)" using ELEN by simp
          also have "  2^E2"
            by (metis LEN_ne_Z(6) Suc_pred' le_refl mult.commute mult_2_right power_Suc)
          finally show ?thesis .
        qed        
      qed
  
      theorem conv_norm_correct: 
        "is_normal f  valof (conv_norm) = valof f"
        "is_normal f  sign conv_norm = sign f"
        unfolding is_normal_def valof_eq conv_norm_def
        apply simp_all
        apply (simp add: bias2_conv LEN2_conv)
        apply (simp add: field_simps power_add of_nat_diff fraction_upper_bound)
        done

        
        
      lemma exponent_bound_finite:
        assumes "exponent f < 2 ^ E1 - Suc 0"  
        shows "exponent f + ΔB < 2 ^ E2 - Suc 0"  
      proof -
      
        have "(2::nat)^E1 = (2::nat)^(E1-1) + (2::nat)^(E1-1)"
          by (metis LEN_ne_Z(5) Suc_diff_1 mult.commute mult_2_right power_Suc)
        then have "(2::nat) ^ E1 + (2 ^ (E2 - Suc 0) - 2 ^ (E1 - Suc 0)) = 2 ^ (E1 - Suc 0) + 2 ^ (E2 - Suc 0)"
          using ELEN 
          by (simp add: algebra_simps)
        also have "(2::nat) ^ (E1 - Suc 0)  2 ^ (E2 - Suc 0)"  
          using ELEN by simp
        finally  
        have "(2::nat) ^ E1 + (2 ^ (E2 - Suc 0) - 2 ^ (E1 - Suc 0))  2 ^ E2"
          apply (simp)
          by (metis LEN_ne_Z(6) One_nat_def Suc_diff_1 mult.commute mult_2_right power_Suc)
        then have 1: "2 ^ E1 - Suc 0 + (2 ^ (E2 - Suc 0) - 2 ^ (E1 - Suc 0))  2 ^ E2 - Suc 0" 
          by auto
        
        show ?thesis using assms
          unfolding ΔB_def B1_def B2_def 
          apply (auto simp: algebra_split_simps)
          using 1 by linarith
      qed  
        
      lemma conv_norm_normal: "is_normal f  is_normal conv_norm"
        unfolding is_normal_def valof_eq conv_norm_def emax_eq
        by (simp_all add: exponent_bound_finite)
        
                 
      theorem conv_denorm_correct: 
        "is_denormal f  valof (conv_denorm) = valof f"
        "is_denormal f  sign (conv_denorm) = sign f"
        unfolding is_denormal_def
        apply (rewrite in "_ = " valof_eq)
        apply simp_all
        apply (subst highest_lower_char[of "fraction f", symmetric], simp)
        unfolding valof_eq conv_denorm_def Let_def
        apply simp_all
        apply (clarsimp simp: )
        apply (simp add: bias2_conv LEN2_conv)
      proof goal_cases
        case 1
        assume "exponent f = 0" and "0 < fraction f"
  
        have 1: "2 ^ (F1 + ΔF - H) * real L / 2 ^ (F1 + ΔF) = real L / 2^H"
          using H_upper 
          by (simp add: power_diff)
        
        have [simp]: "F1 + (H + (Suc (ΔB + H) - F1)) = 2*H + ΔB + 1"
          apply (simp)
          using ΔB_bound by linarith
          
        have [simp]: "F1 + (Suc (ΔB + H) - F1) = ΔB + H + 1"  
          apply (simp)
          using ΔB_bound by linarith
          
        have "2 ^ (Suc ΔB - F1 + H) * (1 + 2 ^ (F1 + ΔF - H) * real L / 2 ^ (F1 + ΔF)) / 2 ^ (B1 + ΔB) =
          2 * (2 ^ H + real L) / (2 ^ B1 * 2 ^ F1)" 
          unfolding 1
          apply simp
          apply (auto simp: field_split_simps simp flip: power_add)
          done
            
        then show ?case 
          apply (cases f rule: sign_cases)
          apply (auto simp: field_simps)
          done
          
      qed

      lemma conv_denorm_normal: "is_denormal f  is_normal conv_denorm"
        unfolding conv_denorm_def is_denormal_def is_normal_def Let_def
        apply auto
        subgoal using ΔB_bound by linarith
        subgoal using H_upper ΔB_lt' by linarith
        done
      
      theorem conv_zero_correct: "is_zero f  is_zero conv_zero  sign conv_zero = sign f"
        unfolding conv_zero_def
        apply (cases f rule: sign_cases)
        by auto
      
      theorem conv_inf_correct: "is_infinity f  is_infinity conv_inf  sign conv_inf = sign f"
        unfolding conv_inf_def
        apply (cases f rule: sign_cases)
        by auto
        
      theorem conv_nan_correct: "is_nan f  is_nan conv_nan  sign conv_nan = sign f  fraction conv_nan = fraction f * 2^ΔF"
        unfolding conv_nan_def is_nan_def
        by (simp)
        
      lemma normal_imp_finite: "is_normal ff  is_finite ff"
        by (simp add: is_finite_def)

      lemma denormal_imp_finite: "is_denormal ff  is_finite ff"
        by (simp add: is_finite_def)

      lemma zero_imp_finite: "is_zero ff  is_finite ff"  
        by (simp add: is_finite_def)
                        
      lemma infinity_finite: "is_infinity f  ¬is_finite f"
        using finite_infinity by auto
        
      lemma [simp, intro!]: "¬is_finite conv_inf"  
        unfolding conv_inf_def by auto
        
      lemma nan_imp_not_infinity: "is_nan ff  ¬is_infinity ff"  
        using float_distinct(1) by blast
        
      lemma conv_norm_nan: "is_normal f  ¬ is_nan conv_norm"  
        by (meson conv_norm_normal float_distinct(2))

      lemma conv_denorm_nan: "is_denormal f  ¬ is_nan conv_denorm"  
        by (meson conv_denorm_normal float_distinct(2))
        
      lemma [simp, intro!]: "¬is_nan conv_zero"  
        unfolding conv_zero_def
        by (simp add: is_nan_def)
        
      lemma [simp, intro!]: "¬is_nan conv_inf"  
        unfolding conv_inf_def
        by (simp add: is_nan_def)
                
      lemma conv_correct_aux:
        "sign conv = sign f"  
        "is_finite conv  is_finite f"
        "is_infinity conv  is_infinity f"
        "is_nan conv  is_nan f"
        "is_finite f  valof conv = valof f"
        "is_nan f  fraction conv = fraction f * 2^ΔF"
        supply [simp, intro] = conv_denorm_normal conv_norm_normal normal_imp_finite denormal_imp_finite zero_imp_finite infinity_finite finite_infinity
          conv_nan_correct conv_norm_correct conv_denorm_correct conv_zero_correct conv_inf_correct
          nan_imp_not_infinity conv_norm_nan conv_denorm_nan
        
        apply (cases f rule: float_cases'; simp add: conv_def)
        apply (cases f rule: float_cases'; simp add: conv_def) 
        apply (cases f rule: float_cases'; simp add: conv_def) 
        apply (cases f rule: float_cases'; simp add: conv_def) 
        apply (cases f rule: float_cases'; simp add: conv_def val_zero) 
        apply (cases f rule: float_cases'; simp add: conv_def) 
        done
              
      lemma conv_correct: "correct_extension f conv"  
        apply unfold_locales
        apply (simp_all add: conv_correct_aux ΔF_def)
        done
        
    end
  end    

subsection Standard Float Sizes  
  
  
  context begin
    interpretation float_extend_conv "TYPE((8,23) float)" "TYPE((11,52)float)"
      apply unfold_locales
      by simp_all
  
    definition [code del]: "float_extend_32_64  conv"
      
    lemmas float_extend_32_64_code[code] = conv_eq_unfolded[folded float_extend_32_64_def]  

    lemmas float_extend_32_64_correct = conv_correct[folded float_extend_32_64_def]
          
  end

  definition "fext_word_32_64 = fp64_of_float o float_extend_32_64 o float_of_fp32"
  


end