Theory Capacity_Matrix_Impl

theory Capacity_Matrix_Impl
imports Fofu_Impl_Base Graph
section ‹Capacity Matrix by Fixed-Size Array›
theory Capacity_Matrix_Impl
imports 
  Fofu_Impl_Base   
  Graph
begin
  (*
  *)




  type_synonym 'a amtx = "nat×nat => 'a"
  type_synonym 'a mtx = "'a Heap.array"

  typedecl 'a i_mtx

  definition amtx_new_op :: "'a amtx => 'a amtx" where [simp]: "amtx_new_op c ≡ c"
  definition amtx_get_op :: "'a amtx => nat×nat => 'a" where [simp]: "amtx_get_op c e ≡ (c e)"
  definition amtx_set_op :: "'a amtx => nat×nat => 'a => 'a amtx" where [simp]: "amtx_set_op c e v ≡ (c(e:=v))"

  sepref_register amtx_new_op "'a amtx => 'a i_mtx"
  sepref_register amtx_get_op "'a i_mtx => nat×nat => 'a"
  sepref_register amtx_set_op "'a i_mtx => nat×nat => 'a => 'a i_mtx"

  lemma pat_amtx_get: "c$e≡amtx_get_op$c$e" by simp
  lemma pat_amtx_set: "fun_upd$c$e$v≡amtx_set_op$c$e$v" by simp

  lemmas amtx_pats = pat_amtx_get pat_amtx_set

  definition "is_mtx N c mtx ≡ ∃Al. mtx \<mapsto>a l * \<up>( 
      length l = N*N 
    ∧ (∀i<N. ∀j<N. l!(i*N+j) = c (i,j))
    ∧ (∀i j. (i≥N ∨ j≥N) --> c (i,j) = 0))"

  lemma is_mtx_precise[constraint_rules]: "precise (is_mtx N)"
    apply rule
    unfolding is_mtx_def
    apply clarsimp
    apply prec_extract_eqs
    apply (rule ext)
    apply (rename_tac x)
    apply (case_tac x; simp)
    apply (rename_tac i j)
    apply (case_tac "i<N"; case_tac "j<N"; simp)
    done
    

  definition "mtx_new N c ≡ do {
    Array.make (N*N) (λi. c (i div N, i mod N))
  }"

  definition "mtx_get N mtx e ≡ Array.nth mtx (fst e * N + snd e)"
  definition "mtx_set N mtx e v ≡ Array.upd (fst e * N + snd e) v mtx"

  lemma mtx_idx_valid[simp]: "[|i < (N::nat); j < N|] ==> i * N + j < N * N"
  proof -
    assume a1: "i < N"
    assume a2: "j < N"
    have "∀n na. ∃nb. ¬ na < n ∨ Suc (na + nb) = n"
      using less_imp_Suc_add by blast
    hence "0 < N"
      using a2 by blast
    thus ?thesis
      using a2 a1 by (metis (no_types) ab_semigroup_add_class.add.commute ab_semigroup_mult_class.mult.commute add.left_neutral div_if mod_div_equality mod_lemma mult_0_right)
  qed

  lemma mtx_index_unique[simp]: "[|i<(N::nat); j<N; i'<N; j'<N|] ==> i*N+j = i'*N+j' <-> i=i' ∧ j=j'"
    by (metis ab_semigroup_add_class.add.commute add_diff_cancel_right' div_if div_mult_self3 gr0I not_less0)

  lemma mtx_new_rl[sep_heap_rules]: "Graph.V c ⊆ {0..<N} ==> <emp> mtx_new N c <is_mtx N c>"
    by (sep_auto simp: mtx_new_def is_mtx_def Graph.V_def Graph.E_def)

  lemma mtx_get_rl[sep_heap_rules]: "[|i<N; j<N |] ==> <is_mtx N c mtx> mtx_get N mtx (i,j) <λr. is_mtx N c mtx * \<up>(r = c (i,j))>"
    by (sep_auto simp: mtx_get_def is_mtx_def)
    
  lemma mtx_set_rl[sep_heap_rules]: "[|i<N; j<N |] 
    ==> <is_mtx N c mtx> mtx_set N mtx (i,j) v <λr. is_mtx N (c((i,j) := v)) r>"
    by (sep_auto simp: mtx_set_def is_mtx_def nth_list_update)
        
  lemma mtx_new_fr_rl[sepref_fr_rules]: 
    "(mtx_new N, RETURN o amtx_new_op) ∈ [λc. Graph.V c ⊆ {0..<N}]a (pure (nat_rel ×r nat_rel -> Id))k -> is_mtx N"  
    apply rule apply rule
    apply (sep_auto simp: pure_def)
    done

  lemma [sepref_fr_rules]: 
    "CONSTRAINT IS_PURE_ID R ==> (uncurry (mtx_get N), uncurry (RETURN oo amtx_get_op)) ∈ [λ(_,(i,j)). i<N ∧ j<N]a (is_mtx N)k *a (pure (nat_rel ×r nat_rel))k -> R"
    apply rule apply rule
    apply (sep_auto simp: pure_def IS_PURE_ID_def)
    done
    
  lemma [sepref_fr_rules]: "CONSTRAINT IS_PURE_ID R ==> (uncurry2 (mtx_set N), uncurry2 (RETURN ooo amtx_set_op)) 
    ∈ [λ((_,(i,j)),_). i<N ∧ j<N]a (is_mtx N)d *a (pure (nat_rel ×r nat_rel))k *a Rk -> is_mtx N"
    apply rule apply rule
    apply (sep_auto simp: pure_def hn_ctxt_def IS_PURE_ID_def)
    done

end