Theory Refine_Add_Fofu

theory Refine_Add_Fofu
imports Fofu_Impl_Base
theory Refine_Add_Fofu
imports Fofu_Impl_Base Refine_Monadic_Syntax_Sugar
begin

  notation Heap_Monad.return ("return")



  (* TODO: Integrate into Refinement Framework! *)

  lemma LFO_pre_refine: (* TODO: Move and generalize! *)
    assumes "(li,l)∈⟨A⟩list_set_rel"
    assumes "(ci,c)∈R -> bool_rel"
    assumes "(fi,f)∈A->R->⟨R⟩nres_rel"
    assumes "(s0i,s0)∈R"
    shows "LIST_FOREACH' (RETURN li) ci fi s0i ≤ \<Down>R (FOREACHci I l c f s0)"
  proof -
    from assms(1) have [simp]: "finite l" by (auto simp: list_set_rel_def br_def)
    show ?thesis
      unfolding FOREACHc_def FOREACHci_def FOREACHoci_by_LIST_FOREACH
      apply simp
      apply (rule LIST_FOREACH_autoref[param_fo, THEN nres_relD])
      using assms
      apply auto
      apply (auto simp: it_to_sorted_list_def nres_rel_def pw_le_iff refine_pw_simps
        list_set_rel_def br_def)
      done
  qed    
      

  lemma LFOci_refine: (* TODO: Move and generalize! *)
    assumes "(li,l)∈⟨A⟩list_set_rel"
    assumes "!!s si. (si,s)∈R ==> ci si <-> c s"
    assumes "!!x xi s si. [|(xi,x)∈A; (si,s)∈R|] ==> fi xi si ≤ \<Down>R (f x s)"
    assumes "(s0i,s0)∈R"
    shows "nfoldli li ci fi s0i ≤ \<Down>R (FOREACHci I l c f s0)"
  proof -
    from assms LFO_pre_refine[of li l A ci c R fi f s0i s0] show ?thesis
      unfolding fun_rel_def nres_rel_def LIST_FOREACH'_def
      apply (simp add: pw_le_iff refine_pw_simps)
      apply blast+
      done
  qed    

  lemma LFOc_refine: (* TODO: Move and generalize! *)
    assumes "(li,l)∈⟨A⟩list_set_rel"
    assumes "!!s si. (si,s)∈R ==> ci si <-> c s"
    assumes "!!x xi s si. [|(xi,x)∈A; (si,s)∈R|] ==> fi xi si ≤ \<Down>R (f x s)"
    assumes "(s0i,s0)∈R"
    shows "nfoldli li ci fi s0i ≤ \<Down>R (FOREACHc l c f s0)"
    unfolding FOREACHc_def
    by (rule LFOci_refine[OF assms])

    
  lemma LFO_refine: (* TODO: Move and generalize! *)
    assumes "(li,l)∈⟨A⟩list_set_rel"
    assumes "!!x xi s si. [|(xi,x)∈A; (si,s)∈R|] ==> fi xi si ≤ \<Down>R (f x s)"
    assumes "(s0i,s0)∈R"
    shows "nfoldli li (λ_. True) fi s0i ≤ \<Down>R (FOREACH l f s0)"
    unfolding FOREACH_def
    apply (rule LFOc_refine)
    apply (rule assms | simp)+
    done

  lemma LFOi_refine: (* TODO: Move and generalize! *)
    assumes "(li,l)∈⟨A⟩list_set_rel"
    assumes "!!x xi s si. [|(xi,x)∈A; (si,s)∈R|] ==> fi xi si ≤ \<Down>R (f x s)"
    assumes "(s0i,s0)∈R"
    shows "nfoldli li (λ_. True) fi s0i ≤ \<Down>R (FOREACHi I l f s0)"
    unfolding FOREACHi_def
    apply (rule LFOci_refine)
    apply (rule assms | simp)+
    done

  (* TODO: Move to refinement framework. Combine with select from CAVA-Base. *)
  definition "SELECTp ≡ select o Collect"

  lemma selectp_rule[refine_vcg]: 
    assumes "∀x. ¬P x ==> RETURN None ≤ SPEC Φ"
    assumes "!!x. P x ==> RETURN (Some x) ≤ SPEC Φ"
    shows "SELECTp P ≤ SPEC Φ"
    using assms unfolding SELECTp_def select_def[abs_def]
    by (auto)

  lemma selectp_refine_eq:
    "SELECTp P ≤ \<Down>(⟨R⟩option_rel) (SELECTp Q) <-> 
    (∀x. P x --> (∃y. (x,y)∈R ∧ Q y)) ∧ ((∀x. ¬P x) --> (∀y. ¬Q y))"
    by (auto simp: SELECTp_def select_def option_rel_def
      simp: pw_le_iff refine_pw_simps)

  lemma selectp_refine[refine]:
    assumes "SPEC P ≤\<Down>R (SPEC Q)"  
    assumes "!!y. ∀x. ¬P x ==> ¬Q y"
    shows "SELECTp P ≤ \<Down>(⟨R⟩option_rel) (SELECTp Q)"
    unfolding selectp_refine_eq
    using assms by (auto simp: pw_le_iff refine_pw_simps)

  lemma selectp_refine_Id[refine]:  
    assumes "!!x. P x ==> Q x"
    assumes "!!y. ∀x. ¬P x ==> ¬Q y"
    shows "SELECTp P ≤ \<Down>Id (SELECTp Q)"
    using selectp_refine[where R=Id, of P Q] assms by auto
    
  lemma selectp_pw[refine_pw_simps]:
    "nofail (SELECTp P)"  
    "inres (SELECTp P) r <-> (r=None --> (∀x. ¬P x)) ∧ (∀x. r=Some x --> P x)"
    unfolding SELECTp_def select_def[abs_def]
    by auto

  lemma selectp_pw_simps[simp]:
    "nofail (SELECTp P)"
    "inres (SELECTp P) None <-> (∀x. ¬P x)"
    "inres (SELECTp P) (Some x) <-> P x"
    by (auto simp: refine_pw_simps)

  context Refine_Monadic_Syntax begin 
    notation SELECTp (binder "selectp " 10)

    term "selectp x. P x"
  end


definition setsum_impl :: "('a => 'b::comm_monoid_add nres) => 'a set => 'b nres" where
  "setsum_impl g S ≡ foreach S (λx a. do { b \<leftarrow> g x; return (a+b)}) 0"

lemma setsum_imp_correct: 
  assumes [simp]: "finite S"
  assumes [THEN order_trans, refine_vcg]: "!!x. x∈S ==> gi x ≤ (spec r. r=g x)"
  shows "setsum_impl gi S ≤ (spec r. r=setsum g S)"
  unfolding setsum_impl_def
  apply (refine_vcg FOREACH_rule[where I="λit a. a = setsum g (S - it)"])
  apply (auto simp: it_step_insert_iff algebra_simps)
  done



    (* TODO: Move *)



    (* TODO: Move. Should this replace hn_refine_cons? *)
      
    (* TODO: Move *)  
    lemma param_prod_swap[param]: "(prod.swap, prod.swap)∈A×rB -> B×rA" by auto
    lemmas [sepref_import_param] = param_prod_swap
    

(* Refinement Setup for nfoldli -> move to Sepref-Foreach *)
lemma nfoldli_arities[sepref_monadify_arity]:
  "nfoldli ≡ λ2s c f σ. SP (nfoldli)$s$(λ2x. c$x)$(λ2x σ. f$x$σ)$σ"
  by (simp_all)

lemma nfoldli_comb[sepref_monadify_comb]:
  "!!s c f σ. (nfoldli)$s$(λ2x. c x)$f$σ ≡ 
    Refine_Basic.bind$(EVAL$s)$(λ2s. Refine_Basic.bind$(EVAL$σ)$(λ2σ. 
      SP (monadic_nfoldli)$s$(λ2x. (EVAL$(c x)))$f$σ
    ))"
  by (simp_all add: nfoldli_to_monadic)

text {* Setup for linearity analysis. *}
lemma monadic_nfoldli_skel[sepref_la_skel]:
  "!!s c f σ. SKEL (monadic_nfoldli$s$c$f$σ) = 
    la_seq 
      (la_op (s,σ)) 
      (la_rec (λD. la_seq (SKEL c) (la_seq (SKEL f) (la_rcall D)))
      )" by simp


lemma monadic_nfoldli_refine_aux':
  assumes c_ref: "!!s s'. hn_refine 
    (Γ * hn_ctxt Rs s' s) 
    (c s) 
    (Γ * hn_ctxt Rs s' s) 
    (pure bool_rel)
    (c' s')"
  assumes f_ref: "!!x x' s s'. hn_refine 
    (Γ * hn_ctxt Rl x' x * hn_ctxt Rs s' s)
    (f x s)
    (Γ * hn_ctxt Rl x' x * hn_invalid s' s) Rs
    (f' x' s')"

  shows "hn_refine 
    (Γ * hn_list Rl l' l * hn_ctxt Rs s' s) 
    (imp_nfoldli l c f s) 
    (Γ * hn_list Rl l' l * hn_invalid s' s) Rs
    (monadic_nfoldli l' c' f' s')"

  apply (induct pRl l' l 
    arbitrary: s s'
    rule: hn_list_aux.induct)

  apply simp
  apply (rule hn_refine_cons_post)
  apply (rule hn_refine_frame[OF hnr_RETURN_pass])
  apply (tactic {* Sepref_Frame.frame_tac @{context} 1 *})
  apply (simp add: hn_ctxt_def ent_true_drop)

  apply (simp only: imp_nfoldli_simps monadic_nfoldli_simp)
  apply (rule hnr_bind)
  apply (rule hn_refine_frame[OF c_ref])
  apply (tactic {* Sepref_Frame.frame_tac @{context} 1 *})

  apply (rule hnr_If)
  apply (tactic {* Sepref_Frame.frame_tac @{context} 1 *})
  apply (rule hnr_bind)
  apply (rule hn_refine_frame[OF f_ref])
  apply (simp add: assn_aci)
  apply (fr_rot_rhs 1)
  apply (fr_rot 2)
  apply (rule fr_refl)
  apply (rule fr_refl)
  apply (rule fr_refl)
  apply (rule ent_refl)

  apply (rule hn_refine_frame)
  apply rprems

  apply (simp add: assn_aci)
  apply (fr_rot_rhs 1)
  apply (rule ent_refl | rule fr_refl | fr_rot 1)
  apply (rule ent_refl | rule fr_refl | fr_rot 1)
  apply (rule ent_refl | rule fr_refl | fr_rot 1)
  apply (rule ent_refl | rule fr_refl | fr_rot 1)
  apply (rule ent_refl | rule fr_refl | fr_rot 1)
  apply (rule ent_refl | rule fr_refl | fr_rot 1)
  apply (rule ent_refl | rule fr_refl | fr_rot 1)
  apply (rule ent_refl | rule fr_refl | fr_rot 1)
 
  apply (tactic {* Sepref_Frame.frame_tac @{context} 1 *})

  apply (rule hn_refine_frame[OF hnr_RETURN_pass])
  apply (tactic {* Sepref_Frame.frame_tac @{context} 1 *})

  apply (simp add: assn_assoc)
  apply (tactic {* Sepref_Frame.merge_tac @{context} 1 *})
  apply (simp only: sup.idem, rule ent_refl)
  apply simp
  apply solve_entails
  apply (rule, sep_auto)
  apply (rule, sep_auto)
  done



lemma hn_monadic_nfoldli_rl'[sepref_comb_rules]:
  assumes "INDEP Rk" "INDEP Rσ"
  assumes FR: "P ==>A Γ * hn_list Rk s' s * hn_ctxt Rσ σ' σ"
  assumes c_ref: "!!σ σ'. hn_refine 
    (Γ * hn_ctxt Rσ σ' σ) 
    (c σ) 
    (Γc σ' σ) 
    (pure bool_rel) 
    (c' σ')"
  assumes C_FR: 
    "!!σ' σ. TERM monadic_nfoldli ==> 
      Γc σ' σ ==>A Γ * hn_ctxt Rσ σ' σ"

  assumes f_ref: "!!x' x σ' σ. hn_refine 
    (Γ * hn_ctxt Rk x' x * hn_ctxt Rσ σ' σ)
    (f x σ)
    (Γf x' x σ' σ) Rσ
    (f' x' σ')"
  assumes F_FR: "!!x' x σ' σ. TERM monadic_nfoldli ==> Γf x' x σ' σ ==>A 
    Γ * hn_ctxt Rk x' x * hn_ctxt Pfσ σ' σ"

  shows "hn_refine 
    P 
    (imp_nfoldli s c f σ) 
    (Γ * hn_list Rk s' s * hn_invalid σ' σ)
    Rσ
    ((monadic_nfoldli)
      $(LIN_ANNOT s' a)$(λ2σ'. c' σ')$(λ2x' σ'. f' x' σ')$(σ'L)
    )"
  unfolding APP_def PROTECT2_def LIN_ANNOT_def PR_CONST_def
  apply (rule hn_refine_cons_pre[OF FR])
  apply (rule hn_refine_cons[rotated])
  apply (rule monadic_nfoldli_refine_aux')
  apply (rule hn_refine_cons_post)
  apply (rule c_ref)
  apply (rule ent_trans[OF C_FR[OF TERMI]])
  apply (rule ent_refl)

  apply (rule hn_refine_cons_post)
  apply (rule f_ref)
  apply (rule ent_trans[OF F_FR[OF TERMI]])
  apply (tactic {* Sepref_Frame.frame_tac @{context} 1*})
  apply (rule ent_refl)
  apply (rule ent_refl)
  apply (rule ent_refl)
  done


  (* TODO: Move *)
  lemma lsr_finite[simp, intro]: "(l,s)∈⟨R⟩list_set_rel ==> finite s"
    by (auto simp: list_set_rel_def br_def)



  (* TODO: Move *)
  definition [simp]: "op_empty_ls ≡ {}"
  sepref_register op_empty_ls
  lemmas [sepref_import_param] = list_set_autoref_empty[folded op_empty_ls_def]

  thm list_set_autoref_insert[sepref_import_param, to_hfref, to_hnr]

  definition ls_ins_dj_imp :: "_=>_=>_ Heap" where [sepref_opt_simps]: "ls_ins_dj_imp x l ≡ return (x#l)"
  definition [simp]: "op_set_ins_dj ≡ Set.insert"

  lemma ls_ins_dj_rule[sepref_fr_rules]: 
    "(uncurry (ls_ins_dj_imp), uncurry (RETURN oo Set.insert)) 
      ∈ [λ(x,s). SIDE_PRECOND (x∉s)]a (pure R)k *a (pure (⟨R⟩list_set_rel))k -> pure (⟨R⟩list_set_rel)"
    apply rule
    apply rule
    (* TODO: Much too low-level reasoning *)
    apply (sep_auto simp: pure_def ls_ins_dj_imp_def intro: list_set_autoref_insert_dj[simplified])
    done

  lemma ls_op_ins_dj_rule[sepref_fr_rules]: 
    "(uncurry (ls_ins_dj_imp), uncurry (RETURN oo op_set_ins_dj)) 
      ∈ [λ(x,s). SIDE_PRECOND (x∉s)]a (pure R)k *a (pure (⟨R⟩list_set_rel))k -> pure (⟨R⟩list_set_rel)"
    using ls_ins_dj_rule
    by simp

  (* TODO: This messes up code generation with some odd error msg! Why?  
  (* TODO: Move to imperative-HOL. Or at least to imp-hol-add *)
  context begin
    setup_lifting type_definition_integer 
  
    lift_definition integer_encode :: "integer => nat" is int_encode .
  
    lemma integer_encode_eq: "integer_encode x = integer_encode y <-> x = y"
      apply transfer
      by (rule inj_int_encode [THEN inj_eq])

    lifting_update integer.lifting
    lifting_forget integer.lifting
  end  

  instance integer :: countable
    by (rule countable_classI [of integer_encode]) (simp add: integer_encode_eq)

  instance integer :: heap ..
  *)

  lemma int_of_integer_less_iff: "int_of_integer x < int_of_integer y <-> x<y"
    by (simp add: less_integer_def)

  lemma nat_of_integer_less_iff: "x≥0 ==> y≥0 ==> nat_of_integer x < nat_of_integer y <-> x<y"
    unfolding nat_of_integer.rep_eq
    by (auto simp: int_of_integer_less_iff nat_less_eq_zless int_of_integer_less_iff[of 0, simplified])
    
  (*(* TODO: Move *)
  lemma param_integer[param]:
    "(0, 0::integer) ∈ Id"
    "(1, 1::integer) ∈ Id"
    "(numeral n::integer,numeral n::integer) ∈ Id"
    "(op <, op <::integer => _) ∈ Id -> Id -> Id"
    "(op ≤, op ≤::integer => _) ∈ Id -> Id -> Id"
    "(op =, op =::integer => _) ∈ Id -> Id -> Id"
    "(op +::integer=>_,op +)∈Id->Id->Id"
    "(op -::integer=>_,op -)∈Id->Id->Id"
    "(op *::integer=>_,op * )∈Id->Id->Id"
    "(op div::integer=>_,op div)∈Id->Id->Id"
    "(op mod::integer=>_,op mod)∈Id->Id->Id"
    by auto
  
  lemmas [sepref_import_param] = param_integer  
  
  lemmas [id_rules] = 
    itypeI[Pure.of 0 "TYPE (integer)"]
  *)  

end