Theory BinarySearch

theory BinarySearch
imports IICF_Array_ListN RefineMonadicVCG Asymptotics_1D
theory BinarySearch
  imports "../Refine_Imperative_HOL/IICF/Impl/IICF_Array_ListN" 
     "../RefineMonadicVCG" "SepLogicTime_RBTreeBasic.Asymptotics_1D"
begin
section "Binary Search"

definition avg :: "nat ⇒ nat ⇒ nat" where
  "avg l r = (l + r) div 2"

definition "listlookup_time = 1"

function binarysearch_time :: "nat ⇒ nat" where
  "n < 2 ⟹ binarysearch_time n = 2 + listlookup_time"
| "n ≥ 2 ⟹ binarysearch_time n = 2 + listlookup_time + binarysearch_time (n div 2)"
by force simp_all
termination by (relation "Wellfounded.measure (λn. n)") auto

definition binarysearch_time' :: "nat ⇒ real" where
  "binarysearch_time' n = real (binarysearch_time n)"

lemma div_2_to_rounding:
  "n - n div 2 = nat ⌈n / 2⌉" "n div 2 = nat ⌊n / 2⌋" by linarith+

lemma binarysearch_time'_Theta: "(λn. binarysearch_time' n) ∈ Θ(λn. ln (real n))"
  apply (master_theorem2 2.3 recursion: binarysearch_time.simps(2) rew: binarysearch_time'_def div_2_to_rounding)
  unfolding listlookup_time_def
  prefer 2 apply auto2
  by (auto simp: binarysearch_time'_def)

lemma binarysearch_mono:
  "m ≤ n ⟹ binarysearch_time m ≤ binarysearch_time n" 
proof (induction n arbitrary: m rule: less_induct)
  case (less n)
  show ?case
  proof (cases "m<2")
    case True
    then show ?thesis apply (cases "n<2") by auto
  next
    case False
    then show ?thesis using less(2) by (auto intro: less(1))
  qed
qed

definition binarysearch_SPEC :: "nat ⇒ nat ⇒ 'a list ⇒ 'a ⇒ bool nrest" where
  "binarysearch_SPEC l r xs x
   = SPECT (emb (λs. s ⟷ (∃i. l ≤ i ∧ i < r ∧ xs ! i = x)) (binarysearch_time (r-l)) )"

definition "binarysearch l r x xs ≡
    RECT (λfw (l,r).
      if l ≥ r then RETURNT False
    else if l + 1 ≥ r then do {
              ASSERT (l < length xs);
             xsi ← mop_lookup_list listlookup_time xs l;
                                RETURNT (xsi = x) }
    else do {
        m ← RETURNT (avg l r);
        ASSERT (m < length xs);
        xm ← mop_lookup_list listlookup_time xs m;
      (if xm = x then RETURNT True
      else if xm < x then fw (m + 1, r)
      else fw (l, m))
      }
  ) (l,r)"

prepare_code_thms binarysearch_def
print_theorems
thm binarysearch.code(1,2) 

 
lemma avg_diff1: "(l::nat) ≤ r ⟹ r - (avg l r + 1) ≤ (r - l) div 2" by (simp add: avg_def)
lemma avg_diff2: "(l::nat) ≤ r ⟹ avg l r - l ≤ (r - l) div 2" by  (simp add: avg_def)

lemma avg_between [backward] :
  "l + 1 < r ⟹ r > avg l r"
  "l + 1 < r ⟹ avg l r > l" by (auto simp: avg_def)

lemma binarysearch_correct: "sorted xs ⟹ l ≤ r ⟹ r ≤ length xs ⟹
   binarysearch l r x xs ≤ binarysearch_SPEC l r xs x"
  unfolding binarysearch_SPEC_def 
  apply(rule T_specifies_I)
    apply(subst binarysearch.code(1))
proof(induct "r-l" arbitrary: l r rule: less_induct)
  case less
  from less(2-4) show ?case apply(subst binarysearch.code(2))  unfolding mop_lookup_list_def
     apply (vcg'‹simp› rules: less(1)[THEN T_conseq4] )   
    unfolding Some_le_emb'_conv Some_eq_emb'_conv
    subgoal by auto 
    subgoal using le_less_Suc_eq by fastforce
    subgoal apply (simp) by auto2 
    subgoal by(simp add: avg_def)  
    subgoal by(simp add: avg_def)  
    subgoal 
      apply (rule allI conjI) apply auto2
        using binarysearch_mono[OF avg_diff1] 
        by (simp add: le_SucI)
    subgoal by(simp add: avg_def)    
    subgoal by(simp add: avg_def)   
    subgoal by(simp add: avg_def)
    subgoal 
      apply (rule allI conjI) apply auto2  
        using binarysearch_mono[OF avg_diff2] 
        by (simp add: le_SucI) 
    subgoal by auto2
    done
  
qed
 
sepref_definition binarysearch_impl is 
  "uncurry3 binarysearch" :: "nat_assnk *a nat_assnk *a id_assnk *a array_assnka bool_assn"
  unfolding binarysearch_def avg_def  listlookup_time_def
  using [[goals_limit = 3]] 
  by sepref

thm binarysearch_impl.refine[to_hnr]
thm hnr_refine[OF binarysearch_correct ] binarysearch_impl.refine[to_hnr, unfolded autoref_tag_defs]
thm  hnr_refine[OF binarysearch_correct, OF _ _ _ binarysearch_impl.refine[to_hnr, unfolded autoref_tag_defs], no_vars] 

lemma binary_search_impl_correct: 
  assumes "sorted xs" "l ≤ r" "r ≤ length xs"
  shows "hn_refine (hn_ctxt array_assn xs bi * hn_val Id x bia * hn_val nat_rel r bib * hn_val nat_rel l ai)
            (binarysearch_impl ai bib bia bi)
            (hn_ctxt array_assn xs bi * hn_val Id x bia * hn_val nat_rel r bib * hn_val nat_rel l ai) 
            bool_assn (binarysearch_SPEC l r xs x)"
  using assms hnr_refine[OF binarysearch_correct, OF _ _ _ binarysearch_impl.refine[to_hnr, unfolded autoref_tag_defs]] by metis

thm extract_cost_ub'[OF binary_search_impl_correct[unfolded  binarysearch_SPEC_def], where Cost_ub="binarysearch_time (r - l)" ]
lemma binary_search_correct': "sorted xs ⟹ r ≤ length xs ⟹ l ≤ r ⟹ 
     <hn_ctxt array_assn xs p * hn_val Id x bia * hn_val nat_rel r bib * hn_val nat_rel l ai * timeCredit_assn (binarysearch_time (r - l))> 
        binarysearch_impl ai bib bia p
       <λra. hn_ctxt array_assn xs p * ↑ (ra ⟷ (∃i≥l. i < r ∧ xs ! i = x))>t"
  apply(rule extract_cost_ub'[OF binary_search_impl_correct[unfolded  binarysearch_SPEC_def], where Cost_ub="binarysearch_time (r - l)" ])
       apply auto
     apply(subst in_ran_emb_special_case) apply (simp_all add: pure_def) apply auto
   by (metis (no_types, lifting) ent_true_drop(1) entails_ex entt_refl') 


subsection ‹Final Hoare triple and run-time claim.›

lemma binary_search_correct: "sorted xs ⟹ r ≤ length xs ⟹ l ≤ r ⟹ 
     <array_assn xs p * timeCredit_assn (binarysearch_time (r - l))> 
        binarysearch_impl l r x p
       <λra.   array_assn xs p * ↑ (ra ⟷ (∃i≥l. i < r ∧ xs ! i = x))>t"
  apply(rule ht_cons_rule[OF _ _ binary_search_correct'[ unfolded hn_ctxt_def pure_def ]])
  by (sep_auto )+

lemma binary_search_time_ln: "binarysearch_time ∈ Θ(λn. ln (real n))"
  using binarysearch_time'_Theta unfolding binarysearch_time'_def by auto

end