section ‹Matrices by Array (Row-Major)› theory IICF_Array_Matrix imports "../Intf/IICF_Matrix" (* Separation_Logic_Imperative_HOL.Array_Blit *) begin definition is_amtx where [rewrite_ent]: "is_amtx N M c mtx = (∃⇩Al. mtx ↦⇩a l * ↑( length l = N*M ∧ (∀i<N. ∀j<M. l!(i*M+j) = c (i,j)) ∧ (∀i j. (i≥N ∨ j≥M) ⟶ c (i,j) = 0)))" (* lemma is_amtx_precise[safe_constraint_rules]: "precise (is_amtx N M)" apply rule unfolding is_amtx_def apply clarsimp (* apply prec_extract_eqs apply (rule ext) apply (rename_tac x) apply (case_tac x; simp) apply (rename_tac i j) apply (case_tac "i<N"; case_tac "j<M"; simp) done *) sorry *) lemma is_amtx_bounded: shows "rdomp (is_amtx N M) m ⟹ mtx_nonzero m ⊆ {0..<N}×{0..<M}" unfolding rdomp_def apply (clarsimp simp: mtx_nonzero_def is_amtx_def) by (meson not_less) (*definition "mtx_new N M c ≡ do { Array.make (N*M) (λi. c (i div M, i mod M)) }"*) partial_function (heap) imp_for' :: "nat ⇒ nat ⇒ (nat ⇒ 'a ⇒ 'a Heap) ⇒ 'a ⇒ 'a Heap" where "imp_for' i u f s = (if i ≥ u then return s else f i s ⤜ imp_for' (i + 1) u f)" declare imp_for'.simps[code] lemma imp_for'_simps[simp]: "i ≥ u ⟹ imp_for' i u f s = return s" "i < u ⟹ imp_for' i u f s = f i s ⤜ imp_for' (i + 1) u f" by (auto simp: imp_for'.simps) lemma "a=b ⟹ timeCredit_assn a ⟹⇩A timeCredit_assn b" by auto lemma imp_for'_rule: assumes LESS: "l≤u" assumes STEP: "⋀i s. ⟦ l≤i; i<u ⟧ ⟹ <I i s * timeCredit_assn t> f i s <I (i+1)>⇩t" shows "<I l s * timeCredit_assn (t*(u-l)+1)> imp_for' l u f s <I u>⇩t" using LESS proof (induction arbitrary: s rule: inc_induct) case base thus ?case by auto2 next case (step k) then have f: "(t * (u - k)) = t + (t * (u - (k+1)))" by (metis Suc_diff_Suc add.assoc add.commute mult_Suc_right plus_1_eq_Suc) have s: "Suc k = k + 1" by simp show ?case using step.hyps by (sep_auto heap: STEP step.IH[unfolded s] simp: f simp del: add_Suc One_nat_def) qed lemma imp_for'_rule': assumes LESS: "l≤u" assumes PRE: "P ⟹⇩A I l s * timeCredit_assn (t*(u-l)+1)" assumes STEP: "⋀i s. ⟦ l≤i; i<u ⟧ ⟹ <I i s * timeCredit_assn t> f i s <I (i+1)>⇩t" shows "<P> imp_for' l u f s <I u>⇩t" apply (rule pre_rule[OF PRE]) apply(rule imp_for'_rule) by fact+ definition "mtx_tabulate N M c ≡ do { m ← Array.new (N*M) 0; (_,_,m) ← imp_for' 0 (N*M) (λk (i,j,m). do { Array.upd k (c (i,j)) m; let j=j+1; if j<M then return (i,j,m) else return (i+1,0,m) }) (0,0,m); return m }" (* definition "amtx_copy ≡ array_copy" *) definition [rewrite]: "amtx_dflt N M v = Array.make (N*M) (λi. v)" definition [rewrite]: "mtx_get M mtx e = Array.nth mtx (fst e * M + snd e)" definition [rewrite]: "mtx_set M mtx e v = Array.upd (fst e * M + snd e) v mtx" lemma mtx_idx_unique_conv[simp]: fixes M :: nat assumes "j<M" "j'<M" shows "(i * M + j = i' * M + j') ⟷ (i=i' ∧ j=j')" using assms apply auto subgoal by (metis add_right_cancel div_if div_mult_self3 linorder_neqE_nat not_less0) subgoal using ‹⟦j < M; j' < M; i * M + j = i' * M + j'⟧ ⟹ i = i'› by auto done lemma mtx_index_unique[simp]: "⟦i<(N::nat); j<M; i'<N; j'<M⟧ ⟹ i*M+j = i'*M+j' ⟷ i=i' ∧ j=j'" by (metis ab_semigroup_add_class.add.commute add_diff_cancel_right' div_if div_mult_self3 gr0I not_less0) declare [[print_trace]] lemma mtx_tabulate_rl[sep_heap_rules]: assumes NONZ: "mtx_nonzero c ⊆ {0..<N}×{0..<M}" shows "<timeCredit_assn (N*M*3+3)> mtx_tabulate N M c <IICF_Array_Matrix.is_amtx N M c>⇩t" proof (cases "M=0") case True thus ?thesis unfolding mtx_tabulate_def using mtx_nonzeroD[OF _ NONZ] by (sep_auto simp: is_amtx_def zero_time simp del: add_Suc One_nat_def add_2_eq_Suc') next case False hence M_POS: "0<M" by auto show ?thesis unfolding mtx_tabulate_def apply (sep_auto decon: imp_for'_rule'[where I="λk (i,j,mi). ∃⇩Am. mi ↦⇩a m * ↑( k=i*M+j ∧ j<M ∧ k≤N*M ∧ length m = N*M ) * ↑( ∀i'<i. ∀j<M. m!(i'*M+j) = c (i',j) ) * ↑( ∀j'<j. m!(i*M+j') = c (i,j') )* timeCredit_assn 1 " and t="2" ] simp: nth_list_update M_POS dest: Suc_lessI simp del: One_nat_def add_2_eq_Suc' ) (* the setup is wrong here, somehow I need to ensure that within $ expressions, natural numbers are not rewritten, but the simplification rules should be applied when solving side-conditions *) (* apply (auto simp del: One_nat_def add_2_eq_Suc') [] *) apply auto[] apply auto[] apply (auto simp: nth_list_update M_POS dest: Suc_lessI)[] apply (sep_auto) unfolding is_amtx_def using mtx_nonzeroD[OF _ NONZ] apply sep_auto by (metis add.right_neutral M_POS mtx_idx_unique_conv) qed (* lemma mtx_copy_rl[sep_heap_rules]: "<is_amtx N M c mtx> amtx_copy mtx <λr. is_amtx N M c mtx * is_amtx N M c r>" by (sep_auto simp: amtx_copy_def is_amtx_def) *) definition "PRES_ZERO_UNIQUE A ≡ (A``{0}={0} ∧ A¯``{0} = {0})" lemma IS_ID_imp_PRES_ZERO_UNIQUE[constraint_rules]: "IS_ID A ⟹ PRES_ZERO_UNIQUE A" unfolding IS_ID_def PRES_ZERO_UNIQUE_def by auto definition op_amtx_dfltNxM :: "nat ⇒ nat ⇒ 'a::zero ⇒ nat×nat⇒'a" where [simp]: "op_amtx_dfltNxM N M v = (λ(i,j). if i<N ∧ j<M then v else 0)" lemma opt_amtx_dfltNxM[rewrite]: "op_amtx_dfltNxM N M v (i,j) = (if i<N ∧ j<M then v else 0)" unfolding op_amtx_dfltNxM_def by simp (* context fixes N M::nat begin sepref_decl_op (no_def) op_amtx_dfltNxM: "op_amtx_dfltNxM N M" :: "A → ⟨A⟩mtx_rel" where "CONSTRAINT PRES_ZERO_UNIQUE A" apply (rule fref_ncI) unfolding op_amtx_dfltNxM_def[abs_def] mtx_rel_def apply parametricity by (auto simp add: PRES_ZERO_UNIQUE_def) end *) declare [[print_trace]] lemma "length [0..<N * M] = N * M" by auto2 lemma "i< length[0..<N] ⟹ map (λi. k) [0..<N] ! i = k" by auto2 lemma "i< N ⟹ map (λi. k) [0..<N] ! i = k" @proof @have "i< length [0..<N]" @qed declare upt_zero_length [rewrite_arg] lemma "i< N*M ⟹ map (λi. k) [0..<N*M] ! i = k" by auto2 lemma mtx_idx_valid[simp,backward]: "⟦i < (N::nat); j < M⟧ ⟹ i * M + j < N * M" by (rule mlex_bound) lemma "i<N ⟹ j<M ⟹ map (λi. k) [0..<N * M] ! (i * M + j) = k" @proof @have "i * M + j < N*M" @qed lemma "i<N ⟹ j<M ⟹ map (λi. k) [0..<N * M] ! (i * M + j) = (if i < N ∧ j < M then k else 0)" @proof @have "i * M + j < N*M" oops lemma mtx_dflt_rl: "<timeCredit_assn (N*M+1)> amtx_dflt N M k <is_amtx N M (op_amtx_dfltNxM N M k)>" (* by (sep_auto simp: amtx_dflt_def is_amtx_def) *) by auto2 lemma ij[backward]: "i<N ⟹ j<M ⟹ i * M + j < N* (M::nat)" by auto2 lemma mtx_get_rl': "⟦i<N; j<M ⟧ ⟹ <timeCredit_assn 1 * is_amtx N M c mtx> mtx_get M mtx (i,j) <λr. is_amtx N M c mtx * ↑(r = c (i,j))>" by auto2 lemma mtx_get_rl: "⟦fst k<N; snd k<M ⟧ ⟹ <timeCredit_assn 1 * is_amtx N M c mtx> mtx_get M mtx k <λr. is_amtx N M c mtx * ↑(r = c k)>" by auto2 lemma "n<length l ⟹ l[n:=1] ! n = 1" by auto2 lemma "i<N ⟹ j<M ⟹ ia<N ⟹ ja<M ⟹ length l = N * M ⟹ l[i * M + j := v] ! (ia * M + ja) = (if i * M + j = ia * M + ja then v else l ! (ia * M + ja))" @proof @have "ia * M + ja < length l" @qed lemma "j<J ⟹ snd (i, j) < J " by auto2 lemma a[rewrite]: "length l = N * M ⟹ ia<N ⟹ ja<M ⟹ l[k := v] ! (ia * M + ja) = (if k = ia * M + ja then v else l ! (ia * M + ja))" @proof @have "ia * M + ja < length l" @qed thm nth_list_update thm sep_heap_rules lemma mtx_set_rl': "⟦i<N; j<M ⟧ ⟹ <timeCredit_assn 1 * is_amtx N M c mtx> mtx_set M mtx (i,j) v <λr. is_amtx N M (c((i,j) := v)) r>" unfolding mtx_set_def is_amtx_def by (sep_auto simp del: One_nat_def) lemma mtx_set_rl: "⟦fst k<N; snd k<M ⟧ ⟹ <timeCredit_assn 1 * is_amtx N M c mtx> mtx_set M mtx k v <λr. is_amtx N M (c(k := v)) r>" using mtx_set_rl' by force definition "amtx_assn N M A ≡ hr_comp (is_amtx N M) (⟨the_pure A⟩mtx_rel)" lemmas [fcomp_norm_unfold] = amtx_assn_def[symmetric] lemmas [safe_constraint_rules] = CN_FALSEI[of is_pure "amtx_assn N M A" for N M A] lemma [intf_of_assn]: "intf_of_assn A TYPE('a) ⟹ intf_of_assn (amtx_assn N M A) TYPE('a i_mtx)" by simp abbreviation "asmtx_assn N A ≡ amtx_assn N N A" lemma mtx_rel_pres_zero: assumes "PRES_ZERO_UNIQUE A" assumes "(m,m')∈⟨A⟩mtx_rel" shows "m ij = 0 ⟷ m' ij = 0" using assms apply1 (clarsimp simp: IS_PURE_def PRES_ZERO_UNIQUE_def is_pure_conv mtx_rel_def) apply (drule fun_relD) applyS (rule IdI[of ij]) applyS auto done lemma amtx_assn_bounded: assumes "CONSTRAINT (IS_PURE PRES_ZERO_UNIQUE) A" shows "rdomp (amtx_assn N M A) m ⟹ mtx_nonzero m ⊆ {0..<N}×{0..<M}" apply (clarsimp simp: mtx_nonzero_def amtx_assn_def rdomp_hrcomp_conv) apply (drule is_amtx_bounded) using assms by (fastforce simp: IS_PURE_def is_pure_conv mtx_rel_pres_zero[symmetric] mtx_nonzero_def) (* lemma mtx_tabulate_aref: "(mtx_tabulate N M, RETURN o op_mtx_new) ∈ [λc. mtx_nonzero c ⊆ {0..<N}×{0..<M}]⇩a id_assn⇧k → IICF_Array_Matrix.is_amtx N M" by sepref_to_hoare sep_auto lemma mtx_copy_aref: "(amtx_copy, RETURN o op_mtx_copy) ∈ (is_amtx N M)⇧k →⇩a is_amtx N M" apply rule apply rule apply (sep_auto simp: pure_def) done *) lemma mtx_nonzero_bid_eq: assumes "R⊆Id" assumes "(a, a') ∈ Id → R" shows "mtx_nonzero a = mtx_nonzero a'" using assms apply (clarsimp simp: mtx_nonzero_def) apply (metis fun_relE2 pair_in_Id_conv subsetCE) done lemma mtx_nonzero_zu_eq: assumes "PRES_ZERO_UNIQUE R" assumes "(a, a') ∈ Id → R" shows "mtx_nonzero a = mtx_nonzero a'" using assms apply (clarsimp simp: mtx_nonzero_def PRES_ZERO_UNIQUE_def) by (metis (no_types, hide_lams) IdI Image_singleton_iff converse_iff singletonD tagged_fun_relD_none) subsection "implementation of interface" thm mtx_set_rl lemma p: "the_pure id_assn = Id" by simp lemma extractpureD: "h ⊨ pure R a c * F ⟹ (c,a) ∈ R ∧ h ⊨ F" by (simp add: pure_def) lemma mop_matrix_update_rule[sepref_fr_rules]: "1 ≤ t ⟹ fst k' < M ⟹ snd k' < M ⟹ hn_refine (hn_val Id v' v * hn_val Id k' k * hn_ctxt (asmtx_assn M (pure Id)) m' m) (PR_CONST (mtx_set M) m k v) (hn_val Id v' v * hn_val Id k' k * hn_invalid (asmtx_assn M (pure Id)) m' m) (asmtx_assn M (pure Id)) ( PR_CONST (mop_matrix_set t) $ m' $ k' $ v')" apply(rule hn_refine_preI) unfolding mop_matrix_set_def autoref_tag_defs apply (rule extract_cost_otherway[OF _ mtx_set_rl, where F="hn_val Id v' v * hn_val Id k' k * hn_invalid (asmtx_assn M (pure Id)) m' m" ]) oops lemma mop_matrix_update_rule[sepref_fr_rules]: "1 ≤ t ⟹ fst k' < M ⟹ snd k' < M ⟹ hn_refine (hn_val Id v' v * hn_ctxt (prod_assn id_assn id_assn) k' k * hn_ctxt (asmtx_assn M (pure Id)) m' m) (PR_CONST (mtx_set M) m k v) (hn_val Id v' v * hn_ctxt (prod_assn id_assn id_assn) k' k * hn_invalid (asmtx_assn M (pure Id)) m' m) (asmtx_assn M (pure Id)) ( PR_CONST (mop_matrix_set t) $ m' $ k' $ v')" apply(rule hn_refine_preI) unfolding mop_matrix_set_def autoref_tag_defs apply (rule extract_cost_otherway[OF _ mtx_set_rl, where F="hn_val Id v' v * hn_val Id k' k * hn_invalid (asmtx_assn M (pure Id)) m' m" ]) unfolding mult.assoc apply(rotatel) apply(rotatel) apply rotater apply rotater apply rotater apply rotater apply swapr apply taker apply(rule isolate_first) apply (simp add: gr_def hn_ctxt_def) apply(rule ent_trans) apply(rule invalidate_clone[where R="asmtx_assn M id_assn"]) apply(rule match_first) unfolding amtx_assn_def apply simp apply (rule entails_triv) unfolding hn_ctxt_def apply(rotatel) apply(rule match_first) apply(rule isolate_first) subgoal by simp apply (rule entails_triv) subgoal by(auto dest: extractpureD) subgoal by(auto dest: extractpureD) subgoal apply rotatel apply rotatel apply rotatel apply rotater apply rotater apply (rule match_first) apply simp apply(simp only: ex_distrib_star' pure_def hr_comp_def) apply(auto simp add: ex_distrib_star') apply(rule ent_ex_postI[where x="m'(k' := v')"]) by simp subgoal by simp done lemma mop_matrix_get_rule[sepref_fr_rules]: "1 ≤ t ⟹ fst k' < M ⟹ snd k' < M ⟹ hn_refine (hn_ctxt (prod_assn id_assn id_assn) k' k * hn_ctxt (asmtx_assn M (pure Id)) m' m) (PR_CONST (mtx_get M) m k) (hn_ctxt (prod_assn id_assn id_assn) k' k* hn_ctxt (asmtx_assn M (pure Id)) m' m) id_assn ( PR_CONST (mop_matrix_get t) $ m' $ k')" apply(rule hn_refine_preI) unfolding autoref_tag_defs mop_matrix_get_def apply (rule extract_cost_otherway[OF _ mtx_get_rl]) unfolding mult.assoc unfolding hn_ctxt_def apply rotatel apply rotatel apply(rule match_first) apply rotater apply(rule match_first) apply(simp add: amtx_assn_def) apply (rule entails_triv) subgoal by(auto dest: extractpureD) subgoal by(auto dest: extractpureD) subgoal apply rotater unfolding amtx_assn_def apply (simp add: ) apply (simp add: pure_def ) apply safe apply(rule inst_ex_assn[where x="m' k'"]) by (auto simp: ) subgoal by auto done (* lemma op_mtx_new_fref': "CONSTRAINT PRES_ZERO_UNIQUE A ⟹ (RETURN ∘ op_mtx_new, RETURN ∘ op_mtx_new) ∈ (nat_rel ×⇩r nat_rel → A) →⇩f ⟨⟨A⟩mtx_rel⟩nrest_rel" by (rule op_mtx_new.fref) sepref_decl_impl (no_register) amtx_new_by_tab: mtx_tabulate_aref uses op_mtx_new_fref' by (auto simp: mtx_nonzero_zu_eq) sepref_decl_impl amtx_copy: mtx_copy_aref . *) definition [simp]: "op_amtx_new (N::nat) (M::nat) ≡ op_mtx_new" (* lemma amtx_fold_custom_new: "op_mtx_new ≡ op_amtx_new N M" apply simp done "mop_mtx_new ≡ λc. RETURN (op_amtx_new N M c)" by (auto simp: mop_mtx_new_alt[abs_def]) *) context fixes N M :: nat and t :: "nat ⇒ nat ⇒ nat" begin definition "mop_amtx_new c = SPECT [op_amtx_new N M c ↦ t N M] " sepref_register "mop_amtx_new" end lemma is_amtx_impl_amtx_assn: "(xi, x) ∈ Id → the_pure A ⟹ is_amtx N M xi r ⟹⇩A amtx_assn N M A x r" by (simp add: hr_compI mtx_rel_def amtx_assn_def) lemma amtx_new_hnr[sepref_fr_rules]: fixes A :: "'a::zero ⇒ 'b::{zero,heap} ⇒ assn" shows "CONSTRAINT (IS_PURE PRES_ZERO_UNIQUE) A ⟹ N*M*3+3 ≤ t N M ⟹ (mtx_tabulate N M, ( PR_CONST (mop_amtx_new N M t))) ∈ [λx. mtx_nonzero x ⊆ {0..<N} × {0..<M}]⇩a (pure (nat_rel ×⇩r nat_rel → the_pure A))⇧k → amtx_assn N M A" apply sepref_to_hoare apply(rule hn_refine_preI) unfolding autoref_tag_defs constraint_abbrevs mop_amtx_new_def subgoal for x xi apply (rule extract_cost_otherway[OF _ mtx_tabulate_rl, where F = "↑ ((xi, x) ∈ nat_rel ×⇩r nat_rel → the_pure A)"]) apply sep_auto subgoal by (auto dest: mtx_nonzero_zu_eq) subgoal apply sep_auto apply(rule isolate_first) apply(rule is_amtx_impl_amtx_assn) by auto subgoal by auto done done (* lemma amtx_fold_custom_new: "⋀c tt. SPECT [op_mtx_new c ↦ tt N] = mop_amtx_new N M (λN M. tt N N) c" by(auto simp: mop_amtx_new_def) *) (* lemma [def_pat_rules]: "op_amtx_new$N$M ≡ UNPROTECT (op_amtx_new N M)" by simp context fixes N M :: nat notes [param] = IdI[of N] IdI[of M] begin lemma mtx_dflt_aref: "(amtx_dflt N M, RETURN o PR_CONST (op_amtx_dfltNxM N M)) ∈ id_assn⇧k →⇩a is_amtx N M" apply rule apply rule apply (sep_auto simp: pure_def) done sepref_decl_impl amtx_dflt: mtx_dflt_aref . lemma amtx_get_aref: "(uncurry (mtx_get M), uncurry (RETURN oo op_mtx_get)) ∈ [λ(_,(i,j)). i<N ∧ j<M]⇩a (is_amtx N M)⇧k *⇩a (prod_assn nat_assn nat_assn)⇧k → id_assn" apply rule apply rule apply (sep_auto simp: pure_def) done sepref_decl_impl amtx_get: amtx_get_aref . lemma amtx_set_aref: "(uncurry2 (mtx_set M), uncurry2 (RETURN ooo op_mtx_set)) ∈ [λ((_,(i,j)),_). i<N ∧ j<M]⇩a (is_amtx N M)⇧d *⇩a (prod_assn nat_assn nat_assn)⇧k *⇩a id_assn⇧k → is_amtx N M" apply rule apply (rule hn_refine_preI) apply rule apply (sep_auto simp: pure_def hn_ctxt_def invalid_assn_def) done sepref_decl_impl amtx_set: amtx_set_aref . lemma amtx_get_aref': "(uncurry (mtx_get M), uncurry (RETURN oo op_mtx_get)) ∈ (is_amtx N M)⇧k *⇩a (prod_assn (pure (nbn_rel N)) (pure (nbn_rel M)))⇧k →⇩a id_assn" apply rule apply rule apply (sep_auto simp: pure_def IS_PURE_def IS_ID_def) done sepref_decl_impl amtx_get': amtx_get_aref' . lemma amtx_set_aref': "(uncurry2 (mtx_set M), uncurry2 (RETURN ooo op_mtx_set)) ∈ (is_amtx N M)⇧d *⇩a (prod_assn (pure (nbn_rel N)) (pure (nbn_rel M)))⇧k *⇩a id_assn⇧k →⇩a is_amtx N M" apply rule apply (rule hn_refine_preI) apply rule apply (sep_auto simp: pure_def hn_ctxt_def invalid_assn_def IS_PURE_def IS_ID_def) done sepref_decl_impl amtx_set': amtx_set_aref' . end subsection ‹Pointwise Operations› context fixes M N :: nat begin sepref_decl_op amtx_lin_get: "λf i. op_mtx_get f (i div M, i mod M)" :: "⟨A⟩mtx_rel → nat_rel → A" unfolding op_mtx_get_def mtx_rel_def by (rule frefI) (parametricity; simp) sepref_decl_op amtx_lin_set: "λf i x. op_mtx_set f (i div M, i mod M) x" :: "⟨A⟩mtx_rel → nat_rel → A → ⟨A⟩mtx_rel" unfolding op_mtx_set_def mtx_rel_def apply (rule frefI) apply parametricity by simp_all lemma op_amtx_lin_get_aref: "(uncurry Array.nth, uncurry (RETURN oo PR_CONST op_amtx_lin_get)) ∈ [λ(_,i). i<N*M]⇩a (is_amtx N M)⇧k *⇩a nat_assn⇧k → id_assn" apply sepref_to_hoare unfolding is_amtx_def apply sep_auto apply (metis mult.commute div_eq_0_iff div_mult2_eq div_mult_mod_eq mod_less_divisor mult_is_0 not_less0) done sepref_decl_impl amtx_lin_get: op_amtx_lin_get_aref by auto lemma op_amtx_lin_set_aref: "(uncurry2 (λm i x. Array.upd i x m), uncurry2 (RETURN ooo PR_CONST op_amtx_lin_set)) ∈ [λ((_,i),_). i<N*M]⇩a (is_amtx N M)⇧d *⇩a nat_assn⇧k *⇩a id_assn⇧k → is_amtx N M" proof - have [simp]: "i < N * M ⟹ ¬(M ≤ i mod M)" for i by (cases "N = 0 ∨ M = 0") (auto simp add: not_le) have [simp]: "i < N * M ⟹ ¬(N ≤ i div M)" for i apply (cases "N = 0 ∨ M = 0") apply (auto simp add: not_le) apply (metis mult.commute div_eq_0_iff div_mult2_eq neq0_conv) done show ?thesis apply sepref_to_hoare unfolding is_amtx_def by (sep_auto simp: nth_list_update) qed sepref_decl_impl amtx_lin_set: op_amtx_lin_set_aref by auto end lemma amtx_fold_lin_get: "m (i div M, i mod M) = op_amtx_lin_get M m i" by simp lemma amtx_fold_lin_set: "m ((i div M, i mod M) := x) = op_amtx_lin_set M m i x" by simp locale amtx_pointwise_unop_impl = mtx_pointwise_unop_loc + fixes A :: "'a ⇒ 'ai::{zero,heap} ⇒ assn" fixes fi :: "nat×nat ⇒ 'ai ⇒ 'ai Heap" assumes fi_hnr: "(uncurry fi,uncurry (RETURN oo f)) ∈ (prod_assn nat_assn nat_assn)⇧k *⇩a A⇧k →⇩a A" begin lemma this_loc: "amtx_pointwise_unop_impl N M f A fi" by unfold_locales context assumes PURE: "CONSTRAINT (IS_PURE PRES_ZERO_UNIQUE) A" begin context notes [[sepref_register_adhoc f N M]] notes [sepref_import_param] = IdI[of N] IdI[of M] notes [sepref_fr_rules] = fi_hnr notes [safe_constraint_rules] = PURE notes [simp] = algebra_simps begin sepref_thm opr_fold_impl1 is "RETURN o opr_fold_impl" :: "(amtx_assn N M A)⇧d →⇩a amtx_assn N M A" unfolding opr_fold_impl_def fold_prod_divmod_conv' apply (rewrite amtx_fold_lin_set) apply (rewrite in "f _ ⌑" amtx_fold_lin_get) by sepref end end concrete_definition (in -) amtx_pointwise_unnop_fold_impl1 uses amtx_pointwise_unop_impl.opr_fold_impl1.refine_raw prepare_code_thms (in -) amtx_pointwise_unnop_fold_impl1_def lemma op_hnr[sepref_fr_rules]: assumes PURE: "CONSTRAINT (IS_PURE PRES_ZERO_UNIQUE) A" shows "(amtx_pointwise_unnop_fold_impl1 N M fi, RETURN ∘ PR_CONST (mtx_pointwise_unop f)) ∈ (amtx_assn N M A)⇧d →⇩a amtx_assn N M A" unfolding PR_CONST_def apply (rule hfref_weaken_pre'[OF _ amtx_pointwise_unnop_fold_impl1.refine[OF this_loc PURE,FCOMP opr_fold_impl_refine]]) by (simp add: amtx_assn_bounded[OF PURE]) end locale amtx_pointwise_binop_impl = mtx_pointwise_binop_loc + fixes A :: "'a ⇒ 'ai::{zero,heap} ⇒ assn" fixes fi :: "'ai ⇒ 'ai ⇒ 'ai Heap" assumes fi_hnr: "(uncurry fi,uncurry (RETURN oo f)) ∈ A⇧k *⇩a A⇧k →⇩a A" begin lemma this_loc: "amtx_pointwise_binop_impl f A fi" by unfold_locales context notes [[sepref_register_adhoc f N M]] notes [sepref_import_param] = IdI[of N] IdI[of M] notes [sepref_fr_rules] = fi_hnr assumes PURE[safe_constraint_rules]: "CONSTRAINT (IS_PURE PRES_ZERO_UNIQUE) A" notes [simp] = algebra_simps begin sepref_thm opr_fold_impl1 is "uncurry (RETURN oo opr_fold_impl)" :: "(amtx_assn N M A)⇧d*⇩a(amtx_assn N M A)⇧k →⇩a amtx_assn N M A" unfolding opr_fold_impl_def[abs_def] fold_prod_divmod_conv' apply (rewrite amtx_fold_lin_set) apply (rewrite in "f ⌑ _" amtx_fold_lin_get) apply (rewrite in "f _ ⌑" amtx_fold_lin_get) by sepref end concrete_definition (in -) amtx_pointwise_binop_fold_impl1 for fi N M uses amtx_pointwise_binop_impl.opr_fold_impl1.refine_raw is "(uncurry ?f,_)∈_" prepare_code_thms (in -) amtx_pointwise_binop_fold_impl1_def lemma op_hnr[sepref_fr_rules]: assumes PURE: "CONSTRAINT (IS_PURE PRES_ZERO_UNIQUE) A" shows "(uncurry (amtx_pointwise_binop_fold_impl1 fi N M), uncurry (RETURN oo PR_CONST (mtx_pointwise_binop f))) ∈ (amtx_assn N M A)⇧d *⇩a (amtx_assn N M A)⇧k →⇩a amtx_assn N M A" unfolding PR_CONST_def apply (rule hfref_weaken_pre'[OF _ amtx_pointwise_binop_fold_impl1.refine[OF this_loc PURE,FCOMP opr_fold_impl_refine]]) apply (auto dest: amtx_assn_bounded[OF PURE]) done end locale amtx_pointwise_cmpop_impl = mtx_pointwise_cmpop_loc + fixes A :: "'a ⇒ 'ai::{zero,heap} ⇒ assn" fixes fi :: "'ai ⇒ 'ai ⇒ bool Heap" fixes gi :: "'ai ⇒ 'ai ⇒ bool Heap" assumes fi_hnr: "(uncurry fi,uncurry (RETURN oo f)) ∈ A⇧k *⇩a A⇧k →⇩a bool_assn" assumes gi_hnr: "(uncurry gi,uncurry (RETURN oo g)) ∈ A⇧k *⇩a A⇧k →⇩a bool_assn" begin lemma this_loc: "amtx_pointwise_cmpop_impl f g A fi gi" by unfold_locales context notes [[sepref_register_adhoc f g N M]] notes [sepref_import_param] = IdI[of N] IdI[of M] notes [sepref_fr_rules] = fi_hnr gi_hnr assumes PURE[safe_constraint_rules]: "CONSTRAINT (IS_PURE PRES_ZERO_UNIQUE) A" begin sepref_thm opr_fold_impl1 is "uncurry opr_fold_impl" :: "(amtx_assn N M A)⇧d*⇩a(amtx_assn N M A)⇧k →⇩a bool_assn" unfolding opr_fold_impl_def[abs_def] nfoldli_prod_divmod_conv apply (rewrite in "f ⌑ _" amtx_fold_lin_get) apply (rewrite in "f _ ⌑" amtx_fold_lin_get) apply (rewrite in "g ⌑ _" amtx_fold_lin_get) apply (rewrite in "g _ ⌑" amtx_fold_lin_get) by sepref end concrete_definition (in -) amtx_pointwise_cmpop_fold_impl1 for N M fi gi uses amtx_pointwise_cmpop_impl.opr_fold_impl1.refine_raw is "(uncurry ?f,_)∈_" prepare_code_thms (in -) amtx_pointwise_cmpop_fold_impl1_def lemma op_hnr[sepref_fr_rules]: assumes PURE: "CONSTRAINT (IS_PURE PRES_ZERO_UNIQUE) A" shows "(uncurry (amtx_pointwise_cmpop_fold_impl1 N M fi gi), uncurry (RETURN oo PR_CONST (mtx_pointwise_cmpop f g))) ∈ (amtx_assn N M A)⇧d *⇩a (amtx_assn N M A)⇧k →⇩a bool_assn" unfolding PR_CONST_def apply (rule hfref_weaken_pre'[OF _ amtx_pointwise_cmpop_fold_impl1.refine[OF this_loc PURE,FCOMP opr_fold_impl_refine]]) apply (auto dest: amtx_assn_bounded[OF PURE]) done end subsection ‹Regression Test and Usage Example› context begin text ‹To work with a matrix, the dimension should be fixed in a context› context fixes N M :: nat ― ‹We also register the dimension as an operation, such that we can use it like a constant› notes [[sepref_register_adhoc N M]] notes [sepref_import_param] = IdI[of N] IdI[of M] ― ‹Finally, we fix a type variable with the required type classes for matrix entries› fixes dummy:: "'a::{times,zero,heap}" begin text ‹First, we implement scalar multiplication with destructive update of the matrix:› private definition scmul :: "'a ⇒ 'a mtx ⇒ 'a mtx nres" where "scmul x m ≡ nfoldli [0..<N] (λ_. True) (λi m. nfoldli [0..<M] (λ_. True) (λj m. do { let mij = m(i,j); RETURN (m((i,j) := x * mij)) } ) m ) m" text ‹After declaration of an implementation for multiplication, refinement is straightforward. Note that we use the fixed @{term N} in the refinement assertions.› private lemma times_param: "(( * ),( * )::'a⇒_) ∈ Id → Id → Id" by simp context notes [sepref_import_param] = times_param begin sepref_definition scmul_impl is "uncurry scmul" :: "(id_assn⇧k *⇩a (amtx_assn N M id_assn)⇧d →⇩a amtx_assn N M id_assn)" unfolding scmul_def[abs_def] by sepref end text ‹Initialization with default value› private definition "init_test ≡ do { let m = op_amtx_dfltNxM 10 5 (0::nat); RETURN (m(1,2)) }" private sepref_definition init_test_impl is "uncurry0 init_test" :: "unit_assn⇧k→⇩anat_assn" unfolding init_test_def by sepref text ‹Initialization from function diagonal is more complicated: First, we have to define the function as a new constant› (* TODO: PR_CONST option for sepref-register! *) qualified definition "diagonalN k ≡ λ(i,j). if i=j ∧ j<N then k else 0" text ‹If it carries implicit parameters, we have to wrap it into a @{term PR_CONST} tag:› private sepref_register "PR_CONST diagonalN" private lemma [def_pat_rules]: "IICF_Array_Matrix.diagonalN$N ≡ UNPROTECT diagonalN" by simp text ‹Then, we have to implement the constant, where the result assertion must be for a pure function. Note that, due to technical reasons, we need the ‹the_pure› in the function type, and the refinement rule to be parameterized over an assertion variable (here ‹A›). Of course, you can constrain ‹A› further, e.g., @{term "CONSTRAINT (IS_PURE IS_ID) (A::int ⇒ int ⇒ assn)"} › private lemma diagonalN_hnr[sepref_fr_rules]: assumes "CONSTRAINT (IS_PURE PRES_ZERO_UNIQUE) A" (*assumes "CONSTRAINT (IS_PURE IS_ID) (A::int ⇒ int ⇒ assn)"*) shows "(return o diagonalN, RETURN o (PR_CONST diagonalN)) ∈ A⇧k →⇩a pure (nat_rel ×⇩r nat_rel → the_pure A)" using assms apply sepref_to_hoare apply (sep_auto simp: diagonalN_def is_pure_conv IS_PURE_def PRES_ZERO_UNIQUE_def (*IS_ID_def*)) done text ‹In order to discharge preconditions, we need to prove some auxiliary lemma that non-zero indexes are within range› lemma diagonal_nonzero_ltN[simp]: "(a,b)∈mtx_nonzero (diagonalN k) ⟹ a<N ∧ b<N" by (auto simp: mtx_nonzero_def diagonalN_def split: if_split_asm) private definition "init_test2 ≡ do { ASSERT (N>2); ― ‹Ensure that the coordinate ‹(1,2)› is valid› let m = op_mtx_new (diagonalN (1::int)); RETURN (m(1,2)) }" private sepref_definition init_test2_impl is "uncurry0 init_test2" :: "unit_assn⇧k→⇩aint_assn" unfolding init_test2_def amtx_fold_custom_new[of N N] by sepref end export_code scmul_impl in SML_imp end hide_const scmul_impl *) hide_const(open) is_amtx end