Theory ELenses

theory ELenses
imports Lenses
begin

  (* TODO: We could use (λ_::unit. undefined) to realize a 'default' lifting operation
    from plain lenses to elenses.

    For lifting optionals to elenses, we could use a default type class.

    Then, we could define coercions to do the lifting automatically!
  *)

  definition "lift f m  case m of None  Inl f | Some x  Inr x"
  definition "lower m  case m of Inl _  None | Inr x  Some x"

  definition "pre m  case m of Inl e  Some e | _  None"

  lemma pre_simps[simp]: "pre (Inr x) = None" "pre (Inl l) = Some l"
    by (auto simp: pre_def split: sum.splits)

  lemma pre_eq_conv[simp]:
    "pre m = None  (r. m = Inr r)"
    "pre m = Some l  (m = Inl l)"
    by (auto simp: pre_def split: sum.splits)



  lemma lift_simps[simp]:
    "lift f (None) = Inl f"
    "lift f (Some x) = Inr x"
    by (auto simp: lift_def)

  lemma lift_invert[simp]:
    "lift f m = Inr x  m=Some x"
    "lift f m = Inl f'  f'=f  m=None"
    unfolding lift_def
    by (auto split: option.splits)

  lemma lower_simps[simp]:
    "lower (Inl f) = None"
    "lower (Inr x) = Some x"
    by (auto simp: lower_def)

  lemma lower_invert[simp]:
    "lower m = None  (f. m = Inl f)"
    "lower m = Some x  m = Inr x"
    by (auto simp: lower_def split: sum.splits)

  lemma lower_lift_id[simp]: "lower (lift f m) = m"
    by (auto simp: lower_def lift_def split: option.split)




  datatype ('a,'b,'e) elens = ELENS (eget: "'a  'e + 'b") (eput: "'b  'a  'e+'a")


  lemma sum_eqI[intro?]:
    assumes "l. s = Inl l  s'=Inl l"
    assumes "r. s = Inr r  s'=Inr r"
    shows "s=s'"
    apply (cases s; cases s'; simp)
    using assms by auto

  lemma elens_eqI[intro?]:
    assumes "s x. eget L s = Inr x  eget L' s = Inr x"
    assumes "s x. eget L s = Inl x  eget L' s = Inl x"
    assumes "x s s'. eput L x s = Inr s'  eput L' x s = Inr s'"
    assumes "x s f. eput L x s = Inl f  eput L' x s = Inl f"
    shows "L = L'"
    apply (cases L; cases L'; simp)
    apply (intro conjI ext sum_eqI)
    using assms by simp_all


  definition "lift_get f g s  lift f (g s)"
  definition "lift_put f p x s  lift f (p x s)"
  definition "lift_lens f L  ELENS (lift_get f (get L)) (lift_put f (put L))"

  definition "lower_get g s  lower (g s)"
  definition "lower_put p x s  lower (p x s)"
  definition "lower_lens L  LENS (lower_get (eget L)) (lower_put (eput L))"

  lemma lower_lift_gp_id[simp]:
    "lower_get (lift_get f g) = g"
    "lower_put (lift_put f p) = p"
    by (auto
      del: ext intro!: ext
      simp: lower_get_def lift_get_def lower_put_def lift_put_def)

  lemma lower_lift_lens_id[simp]:
    "lower_lens (lift_lens f L) = L"
    by (cases L) (auto simp: lower_lens_def lift_lens_def)


  abbreviation "eget' L  get' (lower_lens L)"
  abbreviation "eput' L  put' (lower_lens L)"

  definition "epre_get L s  pre (eget L s)"
  definition "epre_put L x s  pre (eput L x s)"

  (*
  abbreviation "epre_get L ≡ pre_get (lower_lens L)"
  abbreviation "epre_put L ≡ pre_put (lower_lens L)"
  abbreviation "epre_put_single_point L ≡ pre_put_single_point (lower_lens L)"
  *)

  abbreviation "elens L  lens (lower_lens L)"
  definition "ehlens L  hlens (lower_lens L)  (x s f. eput L x s = Inl f  eget L s = Inl f)"

  lemma ehlens_imp_hlens[simp]: "ehlens L  hlens (lower_lens L)"
    by (auto simp: ehlens_def)

  lemma ehlens_put_Inl_conv[simp]: "ehlens L  eput L x s = Inl f  eget L s = Inl f"
    by (auto simp: ehlens_def)

  (* TODO: Is reducing to optionals actually the best way? *)
  lemma lower_epre_get[simp]: "epre_get L s = None  pre_get (lower_lens L) s"
    by (auto simp: epre_get_def lower_lens_def lower_get_def)

  lemma lower_epre_put[simp]: "epre_put L x s = None  pre_put_single_point (lower_lens L) x s"
    by (auto simp: epre_put_def lower_lens_def lower_put_def pre_put_single_point_def)

  lemma lower_epre_put'[simp]: "elens L  epre_put L x s = None  pre_put (lower_lens L) s"
    by (meson lower_epre_put pre_put_single_point)

  lemma eget_rewrite1[simp]: "epre_get L s = None  eget L s = Inr (eget' L s)"
    apply (auto simp: epre_get_def)
    by (simp add: lower_get_def lower_lens_def)

  lemma eget_rewrite2[simp]: "eget L s = Inr x  epre_get L s = None  x = eget' L s"
    by (auto simp: epre_get_def lower_get_def lower_lens_def)

  lemma eget_nopre_conv[simp]: "eget L s = Inl e  epre_get L s = Some e"
    unfolding epre_get_def
    by (auto)

  lemma eput_rewrite1[simp]: "epre_put L x s = None  eput L x s = Inr (eput' L x s)"
    apply (auto simp: epre_put_def)
    by (simp add: lower_put_def lower_lens_def)

  lemma eput_rewrite2[simp]: "eput L x s = Inr s'  epre_put L x s = None  s' = eput' L x s"
    by (auto simp: epre_put_def lower_put_def lower_lens_def)

  lemma eput_nopre_conv[simp]: "eput L x s = Inl e  epre_put L x s = Some e"
    unfolding epre_put_def
    by (auto simp: )

  lemma epre_get_imp_pre_put[simp]:
    "elens L  epre_get L s = None  epre_put L x s = None"
    unfolding epre_get_def epre_put_def
    apply (auto)
    by (smt LENS_downstage(2) epre_get_def epre_put_def lens.pre_get_imp_putI lift_simps(1) lower_epre_get lower_lens_def lower_lift_id lower_put_def option.exhaust pre_eq_conv(2))

  lemma epre_get_lift_conv[simp]:
    "epre_get (lift_lens e L) s = None  pre_get L s"
    "epre_get (lift_lens e L) s = Some e'  e'=e  ¬pre_get L s"
    by (auto simp: epre_get_def lift_lens_def lift_get_def pre_get_def)

  lemma epre_put_lift_conv[simp]:
    "epre_put (lift_lens e L) x s = None  pre_put_single_point L x s"
    "epre_put (lift_lens e L) x s = Some e'  e'=e  ¬pre_put_single_point L x s"
    by (auto simp: epre_put_def lift_lens_def lift_put_def pre_put_single_point_def)

  lemma ehlens_lift[simp]: "ehlens (lift_lens f L)  hlens L"
    unfolding ehlens_def
    apply auto
    done

  lemma ehlens_pre_put_conv[simp]: "ehlens L  epre_put L x s = epre_get L s"
    unfolding ehlens_def
    apply auto
    by (metis option.exhaust)



  lemmas [simp] = epre_get_def[symmetric]
  lemma epre_get_ELENS[simp]: "epre_get (ELENS g p) s = pre (g s)"
    by (auto simp: epre_get_def)

  lemmas [simp] = epre_put_def[symmetric]
  lemma epre_put_ELENS[simp]: "epre_put (ELENS g p) x s = pre (p x s)"
    by (auto simp: epre_put_def)






  definition ebcomp :: "('a, 'b, 'f) elens  ('b, 'c, 'f) elens  ('a, 'c, 'f) elens"
    (infixl "" 80)
    where "ebcomp L1 L2  ELENS
      (λs. case eget L1 s of Inl r  Inl r | Inr s  eget L2 s)
      (λx s. case eget L1 s of Inl r  Inl r | Inr s' 
             (case eput L2 x s' of Inl r  Inl r | Inr s'  eput L1 s' s))
    "

  lemma ebcomp_assoc[simp]: "L1(L2L3) = L1L2L3"
    apply (cases L1; cases L2; cases L3; simp)
    unfolding ebcomp_def
    apply (auto split: sum.splits del: ext intro!: ext)
    done

  lemma ebcomp_lower[simp]: "lower_lens (L1  L2) = lower_lens L1 L lower_lens L2"
    apply (cases L1; cases L2; simp)
    unfolding ebcomp_def lower_lens_def compL_def lower_get_def lower_put_def
    apply (auto split: sum.splits Option.bind_splits del: ext intro!: ext)
    done

  lemma ebcomp_elens[simp]: "elens L1  elens L2  elens (L1  L2)"
    by (simp)



  lemma ebcomp_pre_get[simp]: "epre_get (L1  L2) s = (case epre_get L1 s of
      None  epre_get L2 (eget' L1 s)
    | Some e  Some e)"
    by (auto simp: ebcomp_def split: option.splits sum.splits)

  lemma ebcomp_pre_put[simp]: "elens L1  epre_put (L1  L2) x s = (case epre_get L1 s of
      None  epre_put L2 x (eget' L1 s)
    | Some e  Some e)"
    by (auto simp: ebcomp_def split: option.splits sum.splits)

  lemma ebcomp_get'[simp]: "elens L1; elens L2; epre_get L1 s = None; epre_get L2 (eget' L1 s) = None
     eget' (L1  L2) s = eget' L2 (eget' L1 s)"
    by (auto)

  lemma ebcomp_put'[simp]: "elens L1; elens L2; epre_get L1 s = None; epre_put L2 x (eget' L1 s) = None
     eput' (L1  L2) x s =  eput' L1 (eput' L2 x (eget' L1 s)) s"
    by (auto)

  definition "eidL  ELENS (Inr) (λx _. Inr x)"

  lemma lift_idL[simp]: "lift_lens f idL = eidL"
    unfolding lift_lens_def eidL_def lift_get_def lift_put_def
    by (auto del: ext intro!: ext)

  lemma lower_eidL[simp]: "lower_lens eidL = idL"
    unfolding lower_lens_def idL_def eidL_def lower_get_def lower_put_def
    by (auto del: ext intro!: ext)

  lemma eget_eidL_Inl_conv[simp]: "eget eidL s  Inl f"
    by (auto simp: eidL_def)

  lemma eput_eidL_Inl_conv[simp]: "eput eidL x s  Inl f"
    by (auto simp: eidL_def)

  lemma eidL_pre[simp]: "epre_get eidL s = None" "epre_put eidL x s = None"
    by (auto simp: eidL_def)

  lemma eid_left_neutral[simp]:
    assumes [simp]: "elens L"
    shows
    "eidL  L = L"
    by (rule elens_eqI; auto split: option.splits)

  lemma eid_right_neutral[simp]:
    assumes [simp]: "ehlens L"
    shows
    "L  eidL = L"
    by (rule elens_eqI; auto split: option.splits)



hide_const (open) lift lower pre

end