Theory Param_HOL

theory Param_HOL
imports Param_Tool
header {* \isaheader{Parametricity Theorems for HOL} *}
theory Param_HOL
imports Param_Tool
begin

lemma param_if[param]: 
  assumes "(c,c')∈Id"
  assumes "[|c;c'|] ==> (t,t')∈R"
  assumes "[|¬c;¬c'|] ==> (e,e')∈R"
  shows "(If c t e, If c' t' e')∈R"
  using assms by auto

lemma param_Let[param]: 
  "(Let,Let)∈Ra -> (Ra->Rr) -> Rr"
  by (auto dest: fun_relD)

lemma param_id[param]: "(id,id)∈R->R" unfolding id_def by parametricity

lemma param_fun_comp[param]: "(op o, op o) ∈ (Ra->Rb) -> (Rc->Ra) -> Rc->Rb" 
  unfolding comp_def[abs_def] by parametricity

lemma param_fun_upd[param]: "
  (op =, op =) ∈ Ra->Ra->Id 
  ==> (fun_upd,fun_upd) ∈ (Ra->Rb) -> Ra -> Rb -> Ra -> Rb"
  unfolding fun_upd_def[abs_def]
  by (parametricity)

lemma param_unit[param]: "((),())∈unit_rel" by auto

lemma rec_bool_is_case: "old.rec_bool = case_bool"
  by (rule ext)+ (auto split: bool.split)

lemma param_bool[param]:
  "(True,True)∈Id"
  "(False,False)∈Id"
  "(conj,conj)∈Id->Id->Id"
  "(disj,disj)∈Id->Id->Id"
  "(Not,Not)∈Id->Id"
  "(case_bool,case_bool)∈R->R->Id->R"
  "(old.rec_bool,old.rec_bool)∈R->R->Id->R"
  "(op <->, op <->)∈Id->Id->Id"
  "(op -->, op -->)∈Id->Id->Id"
  by (auto split: bool.split simp: rec_bool_is_case)

lemma param_nat1[param]:
  "(0, 0::nat) ∈ Id"
  "(Suc, Suc) ∈ Id -> Id"
  "(1, 1::nat) ∈ Id"
  "(numeral n::nat,numeral n::nat) ∈ Id"
  "(op <, op <::nat => _) ∈ Id -> Id -> Id"
  "(op ≤, op ≤::nat => _) ∈ Id -> Id -> Id"
  "(op =, op =::nat => _) ∈ Id -> Id -> Id"
  "(op +::nat=>_,op +)∈Id->Id->Id"
  "(op -::nat=>_,op -)∈Id->Id->Id"
  "(op *::nat=>_,op *)∈Id->Id->Id"
  "(op div::nat=>_,op div)∈Id->Id->Id"
  "(op mod::nat=>_,op mod)∈Id->Id->Id"
  by auto

lemma param_case_nat[param]:
  "(case_nat,case_nat)∈Ra -> (Id -> Ra) -> Id -> Ra"
  apply (intro fun_relI)
  apply (auto split: nat.split dest: fun_relD)
  done

lemma param_rec_nat[param]: 
  "(rec_nat,rec_nat) ∈ R -> (Id -> R -> R) -> Id -> R"
  apply (intro fun_relI)
proof -
  case (goal1 s s' f f' n n') thus ?case
    apply (induct n' arbitrary: n s s')
    apply (fastforce simp: fun_rel_def)+
    done
qed

lemma param_int[param]:
  "(0, 0::int) ∈ Id"
  "(1, 1::int) ∈ Id"
  "(numeral n::int,numeral n::int) ∈ Id"
  "(op <, op <::int => _) ∈ Id -> Id -> Id"
  "(op ≤, op ≤::int => _) ∈ Id -> Id -> Id"
  "(op =, op =::int => _) ∈ Id -> Id -> Id"
  "(op +::int=>_,op +)∈Id->Id->Id"
  "(op -::int=>_,op -)∈Id->Id->Id"
  "(op *::int=>_,op *)∈Id->Id->Id"
  "(op div::int=>_,op div)∈Id->Id->Id"
  "(op mod::int=>_,op mod)∈Id->Id->Id"
  by auto

lemma rec_prod_is_case: "old.rec_prod = case_prod"
  by (rule ext)+ (auto split: bool.split)

lemma param_prod[param]:
  "(Pair,Pair)∈Ra -> Rb -> ⟨Ra,Rb⟩prod_rel"
  "(case_prod,case_prod) ∈ (Ra -> Rb -> Rr) -> ⟨Ra,Rb⟩prod_rel -> Rr"
  "(old.rec_prod,old.rec_prod) ∈ (Ra -> Rb -> Rr) -> ⟨Ra,Rb⟩prod_rel -> Rr"
  "(fst,fst)∈⟨Ra,Rb⟩prod_rel -> Ra"
  "(snd,snd)∈⟨Ra,Rb⟩prod_rel -> Rb"
  by (auto dest: fun_relD split: prod.split 
    simp: prod_rel_def rec_prod_is_case)

lemma param_case_prod':
  "[| (p,p')∈⟨Ra,Rb⟩prod_rel;
     !!a b a' b'. [| p=(a,b); p'=(a',b'); (a,a')∈Ra; (b,b')∈Rb |] 
      ==> (f a b, f' a' b')∈R
    |] ==> (case_prod f p, case_prod f' p') ∈ R"
  by (auto split: prod.split)

lemma param_map_prod[param]: 
  "(map_prod, map_prod) 
  ∈ (Ra->Rb) -> (Rc->Rd) -> ⟨Ra,Rc⟩prod_rel -> ⟨Rb,Rd⟩prod_rel"
  unfolding map_prod_def[abs_def]
  by parametricity

lemma param_apfst[param]: 
  "(apfst,apfst)∈(Ra->Rb)->⟨Ra,Rc⟩prod_rel->⟨Rb,Rc⟩prod_rel"
  unfolding apfst_def[abs_def] by parametricity

lemma param_apsnd[param]: 
  "(apsnd,apsnd)∈(Rb->Rc)->⟨Ra,Rb⟩prod_rel->⟨Ra,Rc⟩prod_rel"
  unfolding apsnd_def[abs_def] by parametricity

lemma param_curry[param]: 
  "(curry,curry) ∈ (⟨Ra,Rb⟩prod_rel -> Rc) -> Ra -> Rb -> Rc"
  unfolding curry_def by parametricity

context partial_function_definitions begin
  lemma 
    assumes M: "monotone le_fun le_fun F" 
    and M': "monotone le_fun le_fun F'"
    assumes ADM: 
      "admissible (λa. ∀x xa. (x, xa) ∈ Rb --> (a x, fixp_fun F' xa) ∈ Ra)"
    assumes bot: "!!x xa. (x, xa) ∈ Rb ==> (lub {}, fixp_fun F' xa) ∈ Ra"
    assumes F: "(F,F')∈(Rb->Ra)->Rb->Ra"
    assumes A: "(x,x')∈Rb"
    shows "(fixp_fun F x, fixp_fun F' x')∈Ra"
    using A
    apply (induct arbitrary: x x' rule: ccpo.fixp_induct[OF ccpo _ M])
    apply (rule ADM)
    apply(simp add: fun_lub_def bot)
    apply (subst ccpo.fixp_unfold[OF ccpo M'])
    apply (parametricity add: F)
    done
end


lemma param_option[param]:
  "(None,None)∈⟨R⟩option_rel"
  "(Some,Some)∈R -> ⟨R⟩option_rel"
  "(case_option,case_option)∈Rr->(R -> Rr)->⟨R⟩option_rel -> Rr"
  "(rec_option,rec_option)∈Rr->(R -> Rr)->⟨R⟩option_rel -> Rr"
  by (auto split: option.split 
    simp: option_rel_def case_option_def[symmetric]
    dest: fun_relD)

lemma param_case_option':
  "[| (x,x')∈⟨Rv⟩option_rel; 
     [|x=None; x'=None |] ==> (fn,fn')∈R;  
     !!v v'. [| x=Some v; x'=Some v'; (v,v')∈Rv |] ==> (fs v, fs' v')∈R
   |] ==> (case_option fn fs x, case_option fn' fs' x') ∈ R"
  by (auto split: option.split)

lemma the_paramL: "[|l≠None; (l,r)∈⟨R⟩option_rel|] ==> (the l, the r)∈R"
  apply (cases l)
  by (auto elim: option_relE)

lemma the_paramR: "[|r≠None; (l,r)∈⟨R⟩option_rel|] ==> (the l, the r)∈R"
  apply (cases l)
  by (auto elim: option_relE)

lemma the_default_param[param]: 
  "(the_default, the_default) ∈ R -> ⟨R⟩option_rel -> R"
  unfolding the_default_def
  by parametricity

lemma rec_sum_is_case: "old.rec_sum = case_sum"
  by (rule ext)+ (auto split: sum.split)

lemma param_sum[param]:
  "(Inl,Inl) ∈ Rl -> ⟨Rl,Rr⟩sum_rel"
  "(Inr,Inr) ∈ Rr -> ⟨Rl,Rr⟩sum_rel"
  "(case_sum,case_sum) ∈ (Rl -> R) -> (Rr -> R) -> ⟨Rl,Rr⟩sum_rel -> R"
  "(old.rec_sum,old.rec_sum) ∈ (Rl -> R) -> (Rr -> R) -> ⟨Rl,Rr⟩sum_rel -> R"
  by (fastforce split: sum.split dest: fun_relD 
    simp: rec_sum_is_case)+

lemma param_case_sum':
  "[| (s,s')∈⟨Rl,Rr⟩sum_rel;
     !!l l'. [| s=Inl l; s'=Inl l'; (l,l')∈Rl |] ==> (fl l, fl' l')∈R;
     !!r r'. [| s=Inr r; s'=Inr r'; (r,r')∈Rr |] ==> (fr r, fr' r')∈R
   |] ==> (case_sum fl fr s, case_sum fl' fr' s')∈R"
  by (auto split: sum.split)

primrec is_Inl where "is_Inl (Inl _) = True" | "is_Inl (Inr _) = False"
primrec is_Inr where "is_Inr (Inr _) = True" | "is_Inr (Inl _) = False"

lemma is_Inl_param[param]: "(is_Inl,is_Inl) ∈ ⟨Ra,Rb⟩sum_rel -> bool_rel"
  unfolding is_Inl_def by parametricity
lemma is_Inr_param[param]: "(is_Inr,is_Inr) ∈ ⟨Ra,Rb⟩sum_rel -> bool_rel"
  unfolding is_Inr_def by parametricity

lemma sum_projl_param[param]: 
  "[|is_Inl s; (s',s)∈⟨Ra,Rb⟩sum_rel|] 
  ==> (Sum_Type.sum.projl s',Sum_Type.sum.projl s) ∈ Ra"
  apply (cases s)
  apply (auto elim: sum_relE)
  done

lemma sum_projr_param[param]: 
  "[|is_Inr s; (s',s)∈⟨Ra,Rb⟩sum_rel|] 
  ==> (Sum_Type.sum.projr s',Sum_Type.sum.projr s) ∈ Rb"
  apply (cases s)
  apply (auto elim: sum_relE)
  done




lemma param_append[param]: 
  "(append, append)∈⟨R⟩list_rel -> ⟨R⟩list_rel -> ⟨R⟩list_rel"
  by (auto simp: list_rel_def list_all2_appendI)

lemma param_list1[param]:
  "(Nil,Nil)∈⟨R⟩list_rel"
  "(Cons,Cons)∈R -> ⟨R⟩list_rel -> ⟨R⟩list_rel"
  "(case_list,case_list)∈Rr->(R->⟨R⟩list_rel->Rr)->⟨R⟩list_rel->Rr"
  apply (force dest: fun_relD split: list.split)+
  done

lemma param_rec_list[param]: 
  "(rec_list,rec_list) 
  ∈ Ra -> (Rb -> ⟨Rb⟩list_rel -> Ra -> Ra) -> ⟨Rb⟩list_rel -> Ra"
proof (intro fun_relI)
  case (goal1 a a' f f' l l')
  from goal1(3) show ?case
    using goal1(1,2)
    apply (induct arbitrary: a a')
    apply simp
    apply (fastforce dest: fun_relD)
    done
qed

lemma param_case_list':
  "[| (l,l')∈⟨Rb⟩list_rel;
     [|l=[]; l'=[]|] ==> (n,n')∈Ra;  
     !!x xs x' xs'. [| l=x#xs; l'=x'#xs'; (x,x')∈Rb; (xs,xs')∈⟨Rb⟩list_rel |] 
     ==> (c x xs, c' x' xs')∈Ra
   |] ==> (case_list n c l, case_list n' c' l') ∈ Ra"
  by (auto split: list.split)
    
lemma param_map[param]: 
  "(map,map)∈(R1->R2) -> ⟨R1⟩list_rel -> ⟨R2⟩list_rel"
  unfolding map_rec[abs_def] by (parametricity)
    
lemma param_fold[param]: 
  "(fold,fold)∈(Re->Rs->Rs) -> ⟨Re⟩list_rel -> Rs -> Rs"
  "(foldl,foldl)∈(Rs->Re->Rs) -> Rs -> ⟨Re⟩list_rel -> Rs"
  "(foldr,foldr)∈(Re->Rs->Rs) -> ⟨Re⟩list_rel -> Rs -> Rs"
  unfolding List.fold_def List.foldr_def List.foldl_def
  by (parametricity)+

schematic_lemma param_take[param]: "(take,take)∈(?R::(_×_) set)"
  unfolding take_def 
  by (parametricity)

schematic_lemma param_drop[param]: "(drop,drop)∈(?R::(_×_) set)"
  unfolding drop_def 
  by (parametricity)

schematic_lemma param_length[param]: 
  "(length,length)∈(?R::(_×_) set)"
  unfolding size_list_overloaded_def size_list_def 
  by (parametricity)

fun list_eq :: "('a => 'a => bool) => 'a list => 'a list => bool" where
  "list_eq eq [] [] <-> True"
| "list_eq eq (a#l) (a'#l') 
     <-> (if eq a a' then list_eq eq l l' else False)"
| "list_eq _ _ _ <-> False"

lemma param_list_eq[param]: "
  (list_eq,list_eq) ∈ 
    (R -> R -> Id) -> ⟨R⟩list_rel -> ⟨R⟩list_rel -> Id"
proof (intro fun_relI)
  case (goal1 eq eq' l1 l1' l2 l2')
  thus ?case
    apply -
    apply (induct eq' l1' l2' arbitrary: l1 l2 rule: list_eq.induct)
    apply (simp_all only: list_eq.simps |
      elim list_relE |
      parametricity 
    )+
    done
qed

lemma id_list_eq_aux[simp]: "(list_eq op =) = (op =)"
proof (intro ext)
  fix l1 l2 :: "'a list"
  show "list_eq op = l1 l2 = (l1 = l2)"
    apply (induct "op = :: 'a => _" l1 l2 rule: list_eq.induct)
    apply simp_all
    done
qed

lemma param_list_equals[param]:
  "[| (op =, op =) ∈ R->R->Id |] 
  ==> (op =, op =) ∈ ⟨R⟩list_rel -> ⟨R⟩list_rel -> Id"
  unfolding id_list_eq_aux[symmetric]
  by (parametricity) 

lemma param_tl[param]:
  "(tl,tl) ∈ ⟨R⟩list_rel -> ⟨R⟩list_rel"
  unfolding tl_def[abs_def]
  by (parametricity)


primrec list_all_rec where
  "list_all_rec P [] <-> True"
| "list_all_rec P (a#l) <-> P a ∧ list_all_rec P l"

primrec list_ex_rec where
  "list_ex_rec P [] <-> False"
| "list_ex_rec P (a#l) <-> P a ∨ list_ex_rec P l"

lemma list_all_rec_eq: "(∀x∈set l. P x) = list_all_rec P l"
  by (induct l) auto

lemma list_ex_rec_eq: "(∃x∈set l. P x) = list_ex_rec P l"
  by (induct l) auto

lemma param_list_ball[param]:
  "[|(P,P')∈(Ra->Id); (l,l')∈⟨Ra⟩ list_rel|] 
    ==> (∀x∈set l. P x, ∀x∈set l'. P' x) ∈ Id"
  unfolding list_all_rec_eq
  unfolding list_all_rec_def
  by (parametricity)

lemma param_list_bex[param]:
  "[|(P,P')∈(Ra->Id); (l,l')∈⟨Ra⟩ list_rel|] 
    ==> (∃x∈set l. P x, ∃x∈set l'. P' x) ∈ Id"
  unfolding list_ex_rec_eq[abs_def]
  unfolding list_ex_rec_def
  by (parametricity)

lemma param_rev[param]: "(rev,rev) ∈ ⟨R⟩list_rel -> ⟨R⟩list_rel"
  unfolding rev_def
  by (parametricity)
  
lemma param_Ball[param]: "(Ball,Ball)∈⟨Ra⟩set_rel->(Ra->Id)->Id"
  by (auto simp: set_rel_def dest: fun_relD)
lemma param_Bex[param]: "(Bex,Bex)∈⟨Ra⟩set_rel->(Ra->Id)->Id"
  apply (auto simp: set_rel_def dest: fun_relD)
  apply (drule (1) set_mp)
  apply (erule DomainE)
  apply (auto dest: fun_relD)
  done

lemma param_foldli[param]: "(foldli, foldli) 
  ∈ ⟨Re⟩list_rel -> (Rs->Id) -> (Re->Rs->Rs) -> Rs -> Rs"
  unfolding foldli_def
  by parametricity

lemma param_foldri[param]: "(foldri, foldri) 
  ∈ ⟨Re⟩list_rel -> (Rs->Id) -> (Re->Rs->Rs) -> Rs -> Rs"
  unfolding foldri_def[abs_def]
  by parametricity

lemma param_nth[param]: 
  assumes I: "i'<length l'"
  assumes IR: "(i,i')∈nat_rel"
  assumes LR: "(l,l')∈⟨R⟩list_rel" 
  shows "(l!i,l'!i') ∈ R"
  using LR I IR
  by (induct arbitrary: i i' rule: list_rel_induct) 
     (auto simp: nth.simps split: nat.split)

lemma param_replicate[param]:
  "(replicate,replicate) ∈ nat_rel -> R -> ⟨R⟩list_rel"
  unfolding replicate_def by parametricity

term list_update
lemma param_list_update[param]: 
  "(list_update,list_update) ∈ ⟨Ra⟩list_rel -> nat_rel -> Ra -> ⟨Ra⟩list_rel"
  unfolding list_update_def[abs_def] by parametricity

lemma param_zip[param]:
  "(zip, zip) ∈ ⟨Ra⟩list_rel -> ⟨Rb⟩list_rel -> ⟨⟨Ra,Rb⟩prod_rel⟩list_rel"
    unfolding zip_def by parametricity

lemma param_upt[param]:
  "(upt, upt) ∈ nat_rel -> nat_rel -> ⟨nat_rel⟩list_rel"
   unfolding upt_def[abs_def] by parametricity

lemma param_concat[param]: "(concat, concat) ∈ 
    ⟨⟨R⟩list_rel⟩list_rel -> ⟨R⟩list_rel"
unfolding concat_def[abs_def] by parametricity

lemma param_all_interval_nat[param]: 
  "(List.all_interval_nat, List.all_interval_nat) 
  ∈ (nat_rel -> bool_rel) -> nat_rel -> nat_rel -> bool_rel"
  unfolding List.all_interval_nat_def[abs_def]
  apply parametricity
  apply simp
  done


subsection {*Sets*}

lemma param_empty[param]:
  "({},{})∈⟨R⟩set_rel" by (auto simp: set_rel_def)

lemma param_insert[param]:
  "single_valued R ==> (insert,insert)∈R->⟨R⟩set_rel->⟨R⟩set_rel"
  by (auto simp: set_rel_def dest: single_valuedD)

lemma param_union[param]:
  "(op ∪, op ∪) ∈ ⟨R⟩set_rel -> ⟨R⟩set_rel -> ⟨R⟩set_rel"
  by (auto simp: set_rel_def)

lemma param_inter[param]:
  assumes "single_valued (R¯)"
  shows "(op ∩, op ∩) ∈ ⟨R⟩set_rel -> ⟨R⟩set_rel -> ⟨R⟩set_rel"
  using assms by (auto dest: single_valuedD simp: set_rel_def)

lemma param_diff[param]:
  assumes "single_valued (R¯)"
  shows "(op -, op -) ∈ ⟨R⟩set_rel -> ⟨R⟩set_rel -> ⟨R⟩set_rel"
  using assms 
  by (auto dest: single_valuedD simp: set_rel_def)

lemma param_set[param]: 
  "single_valued Ra ==> (set,set)∈⟨Ra⟩list_rel -> ⟨Ra⟩set_rel"
proof 
  fix l l'
  assume A: "single_valued Ra"
  assume "(l,l')∈⟨Ra⟩list_rel"
  thus "(set l, set l')∈⟨Ra⟩set_rel"
    apply (induct)
    apply simp
    apply simp
    using A apply (parametricity)
    done
qed

end