Theory FW_Code

theory FW_Code
imports Recursion_Combinators Floyd_Warshall IICF_Misc Asymptotics_1D
(* Authors: Lammich, Wimmer *)
theory FW_Code
  imports
    Recursion_Combinators
    Floyd_Warshall.Floyd_Warshall
  "../../Refine_Imperative_HOL/IICF/IICF_Misc"
  "SepLogicTime_RBTreeBasic.Asymptotics_1D" 
begin
 
section ‹Refinement to Efficient Imperative Code›

text ‹
  We will now refine the recursive version of the \fw to an efficient imperative version.
  To this end, we use the Imperative Refinement Framework, yielding an implementation in Imperative HOL.
›

definition fw_upd' :: "('a::linordered_ab_monoid_add) mtx ⇒ nat ⇒ nat ⇒ nat ⇒ 'a mtx nrest" where
  "fw_upd' m k i j =
    do {
        mij ← mop_matrix_get 1 m (i, j);
        mik ← mop_matrix_get 1 m (i, k);
        mkj ← mop_matrix_get 1 m (k, j);
        s ← mop_plus 1 mik mkj;
        ss ← mop_min 1 mij s;
        mop_matrix_set 1 m (i, j) ss
  }"

definition fwi' ::  "('a::linordered_ab_monoid_add) mtx ⇒ nat ⇒ nat ⇒ nat ⇒ nat ⇒ 'a mtx nrest"
where
  "fwi' m n k i j = RECT (λ fw (m, k, i, j).
      case (i, j) of
        (0, 0) ⇒ fw_upd' m k 0 0 |
        (Suc i, 0) ⇒ do {m' ← fw (m, k, i, n); fw_upd' m' k (Suc i) 0} |
        (i, Suc j) ⇒ do {m' ← fw (m, k, i, j); fw_upd' m' k i (Suc j)}
    ) (m, k, i, j)"

lemma fwi'_simps:
  "fwi' m n k 0       0        = fw_upd' m k 0 0"
  "fwi' m n k (Suc i) 0        = do {m' ← fwi' m n k i n; fw_upd' m' k (Suc i) 0}"
  "fwi' m n k i       (Suc j)  = do {m' ← fwi' m n k i j; fw_upd' m' k i (Suc j)}"
unfolding fwi'_def by (subst RECT_unfold, (refine_mono; fail), (auto split: nat.split; fail))+


lemma fw_upd'_spec: "fw_upd' m k i j ≤ SPEC (λr. r = uncurry (fw_upd (curry m) k i j)) (λ_. enat 6)"
  unfolding SPEC_def 
  unfolding fw_upd'_def fw_upd_def upd_def
  apply(rule T_specifies_I)
  apply(vcg' ‹-› rules: matrix_set matrix_get mop_min mop_plus)
  by(auto split: if_splits )  
  

lemma
  "fwi' m n k i j ≤ SPEC (λ r. r = uncurry (fwi (curry m) n k i j)) (λ_. 6*((j+1)+i*(n+1)))"
proof (induction "curry m" n k i j arbitrary: m rule: fwi.induct)
  case (1 n k)
  then show ?case apply (simp add:   fwi'_simps)  by (rule fw_upd'_spec)  
next
  case (2 n k i) 
  show ?case apply (simp add:   fwi'_simps)
    unfolding SPEC_def
    apply(rule T_specifies_I)
    apply(vcg' ‹simp› rules: 2(1)[unfolded SPEC_def, THEN T_specifies_rev, THEN T_conseq4]
               fw_upd'_spec[unfolded SPEC_def, THEN T_specifies_rev, THEN T_conseq4]       )
    by(auto split: if_splits)    
next
  case (3 n k i j)
  show ?case  apply (simp add:   fwi'_simps)
    unfolding SPEC_def
    apply(rule T_specifies_I)
    apply(vcg' ‹simp› rules: 3(1)[unfolded SPEC_def, THEN T_specifies_rev, THEN T_conseq4]
               fw_upd'_spec[unfolded SPEC_def, THEN T_specifies_rev, THEN T_conseq4]       )
    by(auto split: if_splits)    
qed 
 

definition "fw_time (n::nat) = 6*((n+1)*(n+1))*(n+1)"
definition "fwi_time (i::nat) (j::nat) (n::nat) = 6*((j+1)+i*(n+1))"

lemma fw_time_fwi_time: "fw_time n = (n+1)* fwi_time n n n" unfolding fw_time_def fwi_time_def
  by auto  
     

lemma for_rec2_fwi:
  "for_rec2 (λ M. fw_upd' M k) M n i j ≤ SPEC (λ M'. M' = uncurry (fwi (curry M) n k i j)) (λ_. fwi_time i j n)"
  unfolding fwi_time_def
proof (induction "λ M. fw_upd' (M :: (nat × nat ⇒ 'a)) k" M n i j rule: for_rec2.induct)
  case (1 a n)
  then show ?case apply simp  by (rule fw_upd'_spec)  
next
  case (2 a n i)
  show ?case apply simp 
    unfolding SPEC_def
    apply(rule T_specifies_I)
    apply(vcg' ‹simp› rules:  2[unfolded SPEC_def, THEN T_specifies_rev, THEN T_conseq4]
               fw_upd'_spec[unfolded SPEC_def, THEN T_specifies_rev, THEN T_conseq4]       )
    by(auto split: if_splits)    
next
  case (3 a n i j)
  show ?case apply simp 
    unfolding SPEC_def
    apply(rule T_specifies_I)
    apply(vcg' ‹simp› rules:  3[unfolded SPEC_def, THEN T_specifies_rev, THEN T_conseq4]
               fw_upd'_spec[unfolded SPEC_def, THEN T_specifies_rev, THEN T_conseq4]       )
    by(auto split: if_splits)    
qed 

definition fw' ::  "('a::linordered_ab_monoid_add) mtx ⇒ nat ⇒ nat ⇒ 'a mtx nrest" where
  "fw' m n k = nfoldli [0..<k + 1] (λ _. True) (λ k M. do { for_rec2 (λ M. fw_upd' M k) M n n n }) m"

lemma fw'_spec:
  "fw' m n k ≤ SPEC (λ M'. M' = uncurry (fw (curry m) n k))  (λ_. (k+1)*fwi_time n n n)"
  unfolding fw'_def
proof (induction k)
  case 0
  then show ?case apply simp unfolding SPEC_def 
    apply(rule T_specifies_I)
    apply(vcg' ‹simp› rules:   
               for_rec2_fwi[unfolded SPEC_def, THEN T_specifies_rev, THEN T_conseq4]       )
    apply simp apply (vcg' ‹-›) by(auto split: if_splits simp: curry_def)
next
  case (Suc k)
  have dec: "[0..<Suc k + 1] = [0..<k + 1] @ [k+1]" by auto
  show ?case apply(simp only: dec nfoldli_append)  unfolding SPEC_def
    apply(rule T_specifies_I)
    apply(vcg' ‹simp› rules:   
               Suc(1)[unfolded SPEC_def, THEN T_specifies_rev, THEN T_conseq4]
                     )
    apply simp apply (vcg' ‹-› rules: for_rec2_fwi[unfolded SPEC_def, THEN T_specifies_rev, THEN T_conseq4] )
    apply simp apply (vcg' ‹-›) by(auto split: if_splits simp: curry_def)
qed



context
  fixes n :: nat
  fixes dummy :: "'a::{linordered_ab_monoid_add,zero,heap}"
begin

(*lemma [sepref_import_param]: "((+),(+)::'a⇒_) ∈ Id → Id → Id" by simp
lemma [sepref_import_param]: "(min,min::'a⇒_) ∈ Id → Id → Id" by simp*)

abbreviation "node_assn ≡ nat_assn"
abbreviation "mtx_assn ≡ asmtx_assn (Suc n) id_assn::('a mtx ⇒_)"

lemma ff: "(bb ::iTYPE('a i_mtx)) ⟹ (bb ::iTYPE(nat × nat ⇒'a))" by auto

sepref_definition fw_upd_impl is
  "uncurry2 (uncurry fw_upd')" ::
  "[λ (((_,k),i),j). k ≤ n ∧ i ≤ n ∧ j ≤ n]a mtx_assnd *a node_assnk *a node_assnk *a node_assnk
  → mtx_assn"
  unfolding fw_upd'_def[abs_def] 
  apply sepref_dbg_preproc 
   apply sepref_dbg_cons_init
  apply(drule ff)       (* ----------------- TODO -------------- *)
     apply sepref_dbg_id   
     apply sepref_dbg_monadify
     apply sepref_dbg_opt_init                                       
      apply sepref_dbg_trans  
  apply sepref_dbg_opt
  apply sepref_dbg_cons_solve ― ‹Frame rule, recovering the invalidated list 
    or pure elements, propagating recovery over the list structure›
  apply sepref_dbg_cons_solve ― ‹Trivial frame rule›
  apply sepref_dbg_constraints
  done

declare fw_upd_impl.refine[sepref_fr_rules]

sepref_register fw_upd' :: "'a i_mtx ⇒ nat ⇒ nat ⇒ nat ⇒ 'a i_mtx nrest"

definition
  "fwi_impl' (M :: 'a mtx) k = for_rec2 (λ M. fw_upd' M k) M n n n"

definition
  "fw_impl' (M :: 'a mtx) = fw' M n n"

context
  notes [id_rules] = itypeI[of n "TYPE (nat)"]
    and [sepref_import_param] = IdI[of n]
begin

sepref_definition fw_impl is
  "fw_impl'" :: "mtx_assnda mtx_assn"      using [[id_debug, goals_limit = 1]]
  unfolding fw_impl'_def[abs_def] fw'_def
  unfolding for_rec2_eq apply(subst (2) nfoldli_assert'[symmetric])
      apply(subst (1) nfoldli_assert'[symmetric])
      apply(subst (0) nfoldli_assert'[symmetric])
    unfolding nfoldli_def
    by sepref
 

sepref_definition fwi_impl is
  "uncurry fwi_impl'" :: "[λ (_,k). k ≤ n]a mtx_assnd *a node_assnk → mtx_assn"
unfolding fwi_impl'_def[abs_def] for_rec2_eq  
      apply(subst (1) nfoldli_assert'[symmetric])
      apply(subst (0) nfoldli_assert'[symmetric])
    unfolding nfoldli_def  
    by sepref

end (* End of sepref setup *)

end (* End of n *)


export_code fw_impl checking SML_imp

text ‹
  A compact specification for the characteristic property of the \fw.
›
definition fw_spec where
  "fw_spec n M ≡ SPEC (λ M'.
    if (∃ i ≤ n. M' i i < 0)
    then ¬ cyc_free M n
    else ∀i ≤ n. ∀j ≤ n. M' i j = D M i j n ∧ cyc_free M n) (λ_. fw_time n)"

lemma D_diag_nonnegI:
  assumes "cycle_free M n" "i ≤ n"
  shows "D M i i n ≥ 0"
using assms D_dest''[OF refl, of M i i n] unfolding cycle_free_def by auto


lemma fw_fw_spec:
  "SPECT [FW M n ↦ enat (fw_time n)] ≤ fw_spec n M"
unfolding fw_spec_def cycle_free_diag_equiv SPEC_def apply (simp add: le_fun_def)
proof (safe, goal_cases)
  case prems: (1 i)
  with fw_shortest_path[unfolded cycle_free_diag_equiv, OF prems(3)] D_diag_nonnegI show ?case
    by fastforce
next
  case 2 then show ?case using FW_neg_cycle_detect[unfolded cycle_free_diag_equiv]
    by (force intro: fw_shortest_path[symmetric, unfolded cycle_free_diag_equiv])
next
  case 3 then show ?case using FW_neg_cycle_detect[unfolded cycle_free_diag_equiv] by blast
qed

definition
  "mat_curry_rel = {(Mu, Mc). curry Mu = Mc}"

definition
  "mtx_curry_assn n = hr_comp (mtx_assn n) (br curry (λ_. True))"

declare mtx_curry_assn_def[symmetric, fcomp_norm_unfold]

lemma fw_impl'_correct:
  "(fw_impl', fw_spec) ∈ Id → br curry (λ _. True) → ⟨br curry (λ _. True)⟩ nrest_rel"
 unfolding fw_impl'_def[abs_def]  using fw'_spec fw_fw_spec  
 by (fastforce simp: in_br_conv pw_le_iff refine_pw_simps fw_time_fwi_time intro!: nrest_relI ) 
 

subsection ‹Main Result›

text ‹This is one way to state that ‹fw_impl› fulfills the specification ‹fw_spec›.›
theorem fw_impl_correct:
  "(fw_impl n, fw_spec n) ∈ (mtx_curry_assn n)da mtx_curry_assn n"
using fw_impl.refine[FCOMP fw_impl'_correct[THEN fun_relD, OF IdI]] .

text ‹An alternative version: a Hoare triple for total correctness.›
corollary fw_correct:
  "<mtx_curry_assn n M Mi * $(fw_time n)> fw_impl n Mi <λ Mi'. ∃A M'. mtx_curry_assn n M' Mi' * ↑
    (if (∃ i ≤ n. M' i i < 0)
    then ¬ cyc_free M n
    else ∀i ≤ n. ∀j ≤ n. M' i j = D M i j n ∧ cyc_free M n)>t"
unfolding cycle_free_diag_equiv
  thm fw_impl_correct[THEN hfrefD] fw_spec_def 
  apply (rule ht_cons_rule[OF _ _ extract_cost_ub_SPEC[OF fw_impl_correct[THEN hfrefD, unfolded fw_spec_def[unfolded cycle_free_diag_equiv]] ]])
  by sep_auto+ 

lemma fw_time_n_cube: "fw_time ∈ Θ(λn. n*n*n)"
  unfolding fw_time_def by auto2

subsection ‹Alternative Versions for Uncurried Matrices.›

definition "FWI' = uncurry ooo FWI o curry"


lemma fwi_impl'_refine_FWI':
  "(fwi_impl' n, (λx. SPECT [ x ↦ fwi_time n n n])  oo PR_CONST (λ M. FWI' M n)) ∈ Id → Id → ⟨Id⟩ nrest_rel"
unfolding fwi_impl'_def[abs_def] FWI_def[abs_def] FWI'_def using for_rec2_fwi
by (force simp: pw_nrest_rel_iff pw_le_iff refine_pw_simps fw_time_fwi_time) 
 
lemmas fwi_impl_refine_FWI' = fwi_impl.refine[FCOMP fwi_impl'_refine_FWI']

definition "FW' = uncurry oo FW o curry"

definition "FW'' n M = FW' M n"

lemma fw_impl'_refine_FW'':
  "(fw_impl' n, (λx. SPECT [ x ↦ fw_time n]) o PR_CONST (FW'' n)) ∈ Id → ⟨Id⟩ nrest_rel"
unfolding fw_impl'_def[abs_def] FW''_def[abs_def] FW'_def using fw'_spec
by (force simp: pw_le_iff pw_nrest_rel_iff refine_pw_simps fw_time_fwi_time)

lemmas fw_impl_refine_FW'' = fw_impl.refine[FCOMP fw_impl'_refine_FW'']

end