Theory Simple_Memory

section Simple Memory Model
theory Simple_Memory
imports "../../lib/LLVM_Integer" "../../lib/LLVM_Float_Types" "../../lib/MM/MMonad" 
begin

  text Here, we combine a model of LLVM values, with our generic block-based memory model

  datatype llvm_ptr = is_null: PTR_NULL | is_addr: PTR_ADDR (the_addr: addr)
  hide_const (open) llvm_ptr.is_null llvm_ptr.is_addr llvm_ptr.the_addr

    
  lifting_update memory.lifting
  lifting_forget memory.lifting

          
  subsection LLVM Values
  
  datatype llvm_struct = 
    is_struct: VS_STRUCT (the_fields: "llvm_struct list") 
  | is_union: VS_UNION (the_variants: "llvm_struct list") 
  | is_int: VS_INT (the_width: nat)
  | is_single: VS_SINGLE
  | is_double: VS_DOUBLE
  | is_ptr: VS_PTR 
  hide_const (open) 
    llvm_struct.is_struct llvm_struct.the_fields
    llvm_struct.is_union llvm_struct.the_variants
    llvm_struct.is_ptr
    llvm_struct.is_single
    llvm_struct.is_double
    llvm_struct.is_int llvm_struct.the_width
  

  datatype 'v llvm_union = 
      is_zero_init: UN_ZERO_INIT (structs: "llvm_struct list")
    | is_sel: UN_SEL (lefts: "llvm_struct list") (the_val: 'v) (rights: "llvm_struct list")
  
  hide_const (open) 
    llvm_union.is_zero_init
    llvm_union.structs
    llvm_union.is_sel
    llvm_union.lefts
    llvm_union.the_val
    llvm_union.rights
  
    
      
  
  datatype llvm_val = 
    is_struct: LL_STRUCT (the_fields: "llvm_val list") 
  | is_union: LL_UNION (the_union: "llvm_val llvm_union") 
  | is_int: LL_INT (the_int: lint) 
  | is_single: LL_SINGLE (the_single: single)
  | is_double: LL_DOUBLE (the_double: double) (* 
      TODO: Similar to lint, we could encode different floating-point layouts here, 
        and restrict the code-generator to only accept the ones supported by LLVM.
    *)
  | is_ptr: LL_PTR (the_ptr: llvm_ptr)
  hide_const (open) 
    llvm_val.is_struct llvm_val.the_fields
    llvm_val.is_union llvm_val.the_union
    llvm_val.is_int llvm_val.the_int
    llvm_val.is_single llvm_val.the_single
    llvm_val.is_double llvm_val.the_double
    llvm_val.is_ptr llvm_val.the_ptr

  fun llvm_struct_of_union where
    "llvm_struct_of_union _ (UN_ZERO_INIT ss) = ss"
  | "llvm_struct_of_union sov (UN_SEL ls v rs) = (ls@sov v#rs)"

  fun vals_of_union where
    "vals_of_union (UN_ZERO_INIT _) = {}"
  | "vals_of_union (UN_SEL _ v _) = {v}"  
        
  lemma llvm_struct_of_union_cong[fundef_cong]: 
    "u=u'  (x. xvals_of_union u'  sov x = sov' x)  llvm_struct_of_union sov u = llvm_struct_of_union sov' u'"
    by (cases u; auto)
  
  lemma vals_of_union_smaller[termination_simp]: "x  vals_of_union un  size x < Suc (size_llvm_union size un)" apply (cases un) by auto 
    
  fun llvm_struct_of_val where
    "llvm_struct_of_val (LL_STRUCT vs) = VS_STRUCT (map llvm_struct_of_val vs)"
  | "llvm_struct_of_val (LL_UNION un) = VS_UNION (llvm_struct_of_union llvm_struct_of_val un)"
  | "llvm_struct_of_val (LL_INT i) = VS_INT (width i)"
  | "llvm_struct_of_val (LL_SINGLE _) = VS_SINGLE"
  | "llvm_struct_of_val (LL_DOUBLE _) = VS_DOUBLE"
  | "llvm_struct_of_val (LL_PTR _) = VS_PTR"

  
  
  fun llvm_zero_initializer where
    "llvm_zero_initializer (VS_STRUCT vss) = LL_STRUCT (map llvm_zero_initializer vss)"
  | "llvm_zero_initializer (VS_UNION ss) = LL_UNION (UN_ZERO_INIT ss)"
  | "llvm_zero_initializer (VS_INT w) = LL_INT (lconst w 0)"
  | "llvm_zero_initializer (VS_SINGLE) = LL_SINGLE (single_of_word 0)"
  | "llvm_zero_initializer (VS_DOUBLE) = LL_DOUBLE (double_of_word 0)"
  | "llvm_zero_initializer VS_PTR = LL_PTR PTR_NULL"
  
  lemma struct_of_llvm_zero_initializer[simp]: "llvm_struct_of_val (llvm_zero_initializer s) = s"
    apply (induction s) 
    apply (simp_all add: map_idI)
    done

  (*type_synonym llvm_memory = "llvm_val memory"
  translations (type) "llvm_memory" ↽ (type) "llvm_val memory"
  *)

  type_synonym 'a llM = "('a,llvm_val) M"
  translations
    (type) "'a llM"  (type) "('a, llvm_val) M"

    
  subsection Raw operations on values  
  
  
  fun llvm_union_len where
    "llvm_union_len (UN_ZERO_INIT ss) = length ss"
  | "llvm_union_len (UN_SEL ls v rs) = Suc (length ls + length rs)"
  
  fun llvm_union_can_dest where
    "llvm_union_can_dest (UN_ZERO_INIT ss) i  i < length ss"
  | "llvm_union_can_dest (UN_SEL ls v rs) i  i = length ls"

  fun llvm_union_dest where
    "llvm_union_dest (UN_ZERO_INIT ss) i = llvm_zero_initializer (ss!i)"
  | "llvm_union_dest (UN_SEL ls v rs) i = (if i=length ls then v else undefined ls v rs)"

  definition "llvm_union_can_make ss v i  i<length ss  llvm_struct_of_val v = ss!i"
  definition "llvm_union_make ss v i  UN_SEL (take i ss) v (drop (Suc i) ss)"

  context
    fixes ss v i un
    assumes can_make: "llvm_union_can_make ss v i"
    defines [simp]: "un  llvm_union_make ss v i"
  begin
    lemma un_make_simps[simp]:
      "llvm_union_len un = length ss"
      "llvm_union_can_dest un j  j=i"
      "j=i  llvm_union_dest un j = v"
      "llvm_struct_of_union llvm_struct_of_val un = ss"
      using can_make
      by (auto simp: llvm_union_make_def llvm_union_can_make_def Cons_nth_drop_Suc)
      
  end  
  
    
    
  
  context
    includes monad_syntax_M
  begin
    
  definition llvm_extract_addr :: "llvm_val  addr llM" where
    "llvm_extract_addr v  case v of LL_PTR (PTR_ADDR a)  return a | _  fail"

  definition llvm_extract_ptr :: "llvm_val  llvm_ptr llM" where
    "llvm_extract_ptr v  case v of LL_PTR p  return p | _  fail"
    
  definition llvm_extract_sint :: "llvm_val  int llM" where
    "llvm_extract_sint v  case v of LL_INT i  return (lint_to_sint i) | _  fail" 
        
  definition llvm_extract_unat :: "llvm_val  nat llM" where
    "llvm_extract_unat v  case v of LL_INT i  return (nat (lint_to_uint i)) | _  fail" 

  definition llvm_extract_value :: "llvm_val  nat  llvm_val llM" where 
  "llvm_extract_value v i  case v of 
    LL_STRUCT vs  doM {
      assert (i<length vs);
      return (vs!i)
    }
  | _  fail"
      
  definition llvm_insert_value :: "llvm_val  llvm_val  nat  llvm_val llM" where 
  "llvm_insert_value v x i  case v of 
    LL_STRUCT vs  doM {
      assert (i<length vs);
      assert (llvm_struct_of_val (vs!i) = llvm_struct_of_val x);
      return (LL_STRUCT (vs[i:=x]))
    }
  | _  fail"

  definition llvm_dest_union :: "llvm_val  nat  llvm_val llM" where
    "llvm_dest_union v i  case v of
      LL_UNION un  doM {
        assert llvm_union_can_dest un i;
        return llvm_union_dest un i
      }
    | _  fail"
  
  definition llvm_make_union :: "llvm_struct  llvm_val  nat  llvm_val llM" where
    "llvm_make_union s x i  case s of 
      VS_UNION ss  do {
        assert (llvm_union_can_make ss x i);
        return LL_UNION (llvm_union_make ss x i)
      }
    | _  fail"
  
  subsection Interface functions
  
  subsubsection Typed arguments
    
  (* TODO: redundancy with is_valid_addr! *)
  definition llvmt_check_addr :: "addr  unit llM" where "llvmt_check_addr a  doM { 
    Mvalid_addr a
  }"
    
  definition llvmt_load :: "addr  llvm_val llM" where "llvmt_load a  doM { 
    Mload a
  }"
  
  definition "llvmt_store x a  doM { 
    xorig  llvmt_load a; 
    assert llvm_struct_of_val x = llvm_struct_of_val xorig;
    Mstore a x
  }"
  
  definition "llvmt_alloc s n  doM {
    Mmalloc (replicate n (llvm_zero_initializer s))
  }"
  
  definition llvmt_free :: "nat  unit llM" where "llvmt_free b  doM {
    Mfree b
  }"  

  definition "llvmt_freep p  doM {
    assert llvm_ptr.is_addr p;
    let a = llvm_ptr.the_addr p;
  
    assert addr.index a=0;
    llvmt_free (addr.block a);
    return ()
  }"  

  definition "llvmt_allocp s n  doM {
    b  llvmt_alloc s n;
    return (PTR_ADDR (ADDR b 0))
  }"
    
  
  definition llvmt_check_ptr :: "llvm_ptr  unit llM" where "llvmt_check_ptr p  
    if llvm_ptr.is_null p then return ()
    else doM {
      let a = llvm_ptr.the_addr p;
      Mvalid_addr a ― ‹TODO: support 1-beyond-end pointers!
    }"
      
  definition "llvmt_ofs_ptr p ofs  doM {
    assert (llvm_ptr.is_addr p);
    let a = llvm_ptr.the_addr p;
    let b = addr.block a;
    let i = addr.index a;
    let i = i + ofs;
    let r = PTR_ADDR (ADDR b i);
    llvmt_check_ptr r;
    return r
  }"  
    
  definition "llvmt_check_ptrcmp p1 p2  
    if p1=PTR_NULL  p2=PTR_NULL then 
      return () 
    else doM {
      llvmt_check_ptr p1;
      llvmt_check_ptr p2
    }"
  
  definition "llvmt_ptr_eq p1 p2  doM {
    llvmt_check_ptrcmp p1 p2;
    return (p1 = p2)
  }"
  
  definition "llvmt_ptr_neq p1 p2  doM {
    llvmt_check_ptrcmp p1 p2;
    return (p1  p2)
  }"
  
  subsubsection Embedded arguments

  definition "llvm_load a  doM {
    a  llvm_extract_addr a;
    llvmt_load a
  }"
  
  definition "llvm_store x a  doM {
    a  llvm_extract_addr a;
    llvmt_store x a
  }"
  
  definition "llvm_alloc s n  doM {
    n  llvm_extract_unat n;
    p  llvmt_allocp s n;
    return (LL_PTR p)
  }"

  definition "llvm_extract_base_block a  case a of ADDR b i  if i=0 then return b else fail"
  
  definition "llvm_free p  doM {
    p  llvm_extract_ptr p;
    llvmt_freep p
  }"
  
  definition "llvm_ofs_ptr p ofs  doM {
    p  llvm_extract_ptr p;
    ofs  llvm_extract_sint ofs;
    r  llvmt_ofs_ptr p ofs;
    return (LL_PTR r)
  }"  
  
  
  definition "llvm_ptr_eq p1 p2  doM {
    p1  llvm_extract_ptr p1;
    p2  llvm_extract_ptr p2;
    r  llvmt_ptr_eq p1 p2;
    return (LL_INT (bool_to_lint r))
  }"  
  
  definition "llvm_ptr_neq p1 p2  doM {
    p1  llvm_extract_ptr p1;
    p2  llvm_extract_ptr p2;
    r  llvmt_ptr_neq p1 p2;
    return (LL_INT (bool_to_lint r))
  }"  
  
end
end