Theory LLVM_Shallow

section ‹Shallow Embedding of LLVM Semantics›
theory LLVM_Shallow
imports Main  
  "LLVM_Memory"
  "../../cost/Abstract_Cost"
begin

  text ‹We define a type synonym for the LLVM monad›
  type_synonym 'a llM = "('a,unit,cost,llvm_memory,err) M"
  translations
    (type) "'a llM"  (type) "('a, unit, (char list, nat) acost, llvm_memory, err) M"
  
  subsection ‹Shallow Embedding of Values›  

  text ‹We use a type class to characterize types that can be injected into the value type.
    We will instantiate this type class to obtain injections from types of shape 
    T = T×T | _ word | _ ptr›
  
    Although, this type class can be instantiated by other types, those will not be accepted 
    by the code generator.
    
    We also define a class llvm_repv›, which additionally contains unit›. 
    This is required for void functions, and if-the-else statements that produce no result.
    
    Again, while this class might be instantiated for other types, those will be rejected
    by the code generator.
  ›
  
  class llvm_repv  
    
  class llvm_rep = llvm_repv +
    fixes to_val :: "'a  llvm_val"
      and from_val :: "llvm_val  'a"
      and struct_of :: "'a itself  llvm_vstruct"
      and init :: 'a
    assumes from_to_id[simp]: "from_val o to_val = id"
    assumes to_from_id[simp]: "llvm_vstruct v = struct_of TYPE('a)  to_val (from_val v) = v"
    assumes struct_of_matches[simp]: "llvm_vstruct (to_val x) = (struct_of TYPE('a))"
    assumes init_zero: "to_val init = llvm_zero_initializer (struct_of TYPE('a))"
    
  begin
  
    lemma from_to_id'[simp]: "from_val (to_val x) = x" 
      using pointfree_idE[OF from_to_id] .
  
    lemma "to_val x = to_val y  x=y"  
      by (metis from_to_id')
      
  end
  
  text ‹We use a phantom type to attach the type of the pointed to value to a pointer.›
  datatype 'a::llvm_rep ptr = PTR (the_raw_ptr: llvm_ptr)
  definition null :: "'a::llvm_rep ptr" where "null = PTR llvm_null"
  

  text ‹We instantiate the type classes for the supported types, 
    i.e., unit, word, ptr, and prod.›
  
  instance unit :: llvm_repv by standard
  
  instantiation word :: (len) llvm_rep begin
    definition "to_val w  llvm_int (lconst (len_of TYPE('a)) (uint w))"
    definition "from_val v  word_of_int (lint_to_uint (llvm_the_int v))"
    definition [simp]: "struct_of_word (_::'a word itself)  llvm_s_int (len_of TYPE('a))"
    definition [simp]: "init_word  0::'a word"
    
    
    lemma int_inv_aux: "width i = LENGTH('a)  lconst LENGTH('a) (uint (word_of_int (lint_to_uint i) :: 'a word)) = i"
      by (metis uint_const uint_eq uint_lower_bound uint_upper_bound width_lconst word_of_int_inverse word_ubin.norm_Rep)
    
    instance
      apply standard
      apply (rule ext)
      apply (auto simp: from_val_word_def to_val_word_def)
      apply (auto simp: llvm_s_int_def llvm_zero_initializer_def llvm_int_def)
      subgoal for v apply (cases v) 
        apply (auto simp: llvm_int_def llvm_the_int_def llvm_s_ptr_def llvm_s_pair_def)
        using int_inv_aux apply (simp add: llvm_vstruct_def) 
      done
      done
      
  end
  
  instantiation ptr :: (llvm_rep) llvm_rep begin
    definition "to_val  llvm_ptr o ptr.the_raw_ptr"
    definition "from_val v  PTR (llvm_the_ptr v)"
    definition [simp]: "struct_of_ptr (_::'a ptr itself)  llvm_s_ptr"
    definition [simp]: "init_ptr::'a ptr  null"
  
    instance
      apply standard
      apply (rule ext)
      apply (auto simp: from_val_ptr_def to_val_ptr_def)
      apply (auto simp: llvm_zero_initializer_def llvm_ptr_def llvm_s_ptr_def null_def llvm_null_def)
      subgoal for v apply (cases v)
        by (auto simp: llvm_s_int_def llvm_s_pair_def llvm_ptr_def llvm_the_ptr_def)
      done
      
  end
  
  instantiation prod :: (llvm_rep, llvm_rep) llvm_rep begin
    definition "to_val_prod  λ(a,b). llvm_pair (to_val a) (to_val b)"
    definition "from_val_prod p  case llvm_the_pair p of (a,b)  (from_val a, from_val b)"
    definition [simp]: "struct_of_prod (_::('a×'b) itself)  llvm_s_pair (struct_of TYPE('a)) (struct_of TYPE('b))"
    definition [simp]: "init_prod ::'a×'b  (init,init)"
    
    instance
      apply standard
      apply (rule ext)
      apply (auto simp: from_val_prod_def to_val_prod_def)
      apply (auto simp: llvm_pair_def llvm_s_pair_def init_zero llvm_zero_initializer_def)
      subgoal for v
        apply (cases v)
        apply (auto simp: llvm_s_int_def llvm_s_ptr_def llvm_pair_def llvm_the_pair_def 
          llvm_val.the_val_def llvm_vstruct_def split: prod.splits llvm_val.splits val.split)
        done
      done
      
  end

  lemma to_val_prod_conv[simp]: "to_val (a,b) = llvm_pair (to_val a) (to_val b)"
    unfolding to_val_prod_def by auto
  
  
  text ‹Checked conversion from value›  
  definition checked_from_val :: "llvm_val  'a::llvm_rep llM" where
    "checked_from_val v  doM {
      fcheck (STATIC_ERROR ''Type mismatch'') (llvm_vstruct v = struct_of TYPE('a));
      return (from_val v)
    }" 

      
  subsection ‹Instructions›  
  
  text ‹The instructions are arranged in the order as they are described in the 
    LLVM Language Reference Manual 🌐‹https://llvm.org/docs/LangRef.html›.›
    
  
  subsubsection ‹Binary Operations›  
  text ‹We define a generic lifter for binary arithmetic operations.
    It is parameterized by an error condition.
  › (* TODO: Use precondition instead of negated precondition! *)

  definition op_lift_arith2 :: "_  _  _  'a::len word  'a word  'a word llM"
    where "op_lift_arith2 n ovf f a b  doM {
    consume (cost n 1);
    let a = word_to_lint a;
    let b = word_to_lint b;
    fcheck (OVERFLOW_ERROR) (¬ovf a b);
    return (lint_to_word (f a b))
  }"
        
  definition "op_lift_arith2' n  op_lift_arith2 n (λ_ _. False)"

  definition udivrem_is_undef :: "lint  lint  bool" 
    where "udivrem_is_undef a b  lint_to_uint b=0"
  definition sdivrem_is_undef :: "lint  lint  bool" 
    where "sdivrem_is_undef a b  lint_to_sint b=0  sdivrem_ovf a b"
  
  definition "ll_add  op_lift_arith2' ''add'' (+)"
  definition "ll_sub  op_lift_arith2' ''sub'' (-)"
  definition "ll_mul  op_lift_arith2' ''mul'' (*)"
  definition "ll_udiv  op_lift_arith2 ''udiv'' udivrem_is_undef (div)"
  definition "ll_urem  op_lift_arith2 ''urem'' udivrem_is_undef (mod)"
  definition "ll_sdiv  op_lift_arith2 ''sdiv'' sdivrem_is_undef (sdiv)"
  definition "ll_srem  op_lift_arith2 ''srem'' sdivrem_is_undef (smod)"
  
  
  subsubsection ‹Compare Operations›
  definition op_lift_cmp :: "_  _  'a::len word  'a word  1 word llM"
    where "op_lift_cmp n f a b  doM {
    consume (cost n 1);
    let a = word_to_lint a;
    let b = word_to_lint b;
    return (lint_to_word (bool_to_lint (f a b)))
  }"
    
  definition op_lift_ptr_cmp :: "_  _  'a::llvm_rep ptr  'a ptr  1 word llM"
    where "op_lift_ptr_cmp n f a b  doM {
    consume (cost n 1);
    return (lint_to_word (bool_to_lint (f a b)))
  }"
  
  definition "ll_icmp_eq   op_lift_cmp ''icmp_eq'' (=)"
  definition "ll_icmp_ne   op_lift_cmp ''icmp_ne'' (≠)"
  definition "ll_icmp_sle  op_lift_cmp ''icmp_sle'' (≤s)"
  definition "ll_icmp_slt  op_lift_cmp ''icmp_slt'' (<s)"
  definition "ll_icmp_ule  op_lift_cmp ''icmp_ule'' (≤)"
  definition "ll_icmp_ult  op_lift_cmp ''icmp_ult'' (<)"

  
  (* For presentation in paper *)
  lemma "ll_add a b = doM {
      consume (cost ''add'' 1);
      return (a+b)
    }"
    unfolding ll_add_def op_lift_arith2'_def op_lift_arith2_def
    apply simp
    by (metis lint_word_inv word_to_lint_plus)
    
  
  text ‹Note: There are no pointer compare instructions in LLVM. 
    To compare pointers in LLVM, they have to be casted to integers first.
    However, our abstract memory model cannot assign a bit-width to pointers.
    
    Thus, we model pointer comparison instructions in our semantics, and let the 
    code generator translate them to integer comparisons. 
    
    Up to now, we only model pointer equality. 
    For less-than, suitable preconditions are required, which are consistent with the 
    actual memory layout of LLVM. We could, e.g., adopt the rules from the C standard here.
  ›
  definition "ll_ptrcmp_eq  op_lift_ptr_cmp ''ptrcmp_eq'' (=)"
  definition "ll_ptrcmp_ne  op_lift_ptr_cmp ''ptrcmp_ne'' (≠)"
  

  
  subsubsection ‹Bitwise Binary Operations›  
  definition "shift_ovf a n  nat (lint_to_uint n)  width a"
  definition "bitSHL' a n  bitSHL a (nat (lint_to_uint n))"
  definition "bitASHR' a n  bitASHR a (nat (lint_to_uint n))"
  definition "bitLSHR' a n  bitLSHR a (nat (lint_to_uint n))"
  
  definition "ll_shl  op_lift_arith2 ''shl'' shift_ovf bitSHL'"  
  definition "ll_lshr  op_lift_arith2 ''lshr'' shift_ovf bitLSHR'"  
  definition "ll_ashr  op_lift_arith2 ''ashr'' shift_ovf bitASHR'"
  
  definition "ll_and  op_lift_arith2' ''and'' (lliAND)"
  definition "ll_or  op_lift_arith2' ''or'' (lliOR)"
  definition "ll_xor  op_lift_arith2' ''xor'' (lliXOR)"
    

  subsubsection ‹Aggregate Operations›
  text ‹In LLVM, there is an extractvalue› and insertvalue› operation.
    In our shallow embedding, these get instantiated for fst› and snd›.›
    
  
  definition "checked_split_pair v  doM {
    fcheck (STATIC_ERROR ''Expected pair'') (llvm_is_pair v);
    return (llvm_the_pair v)
  }"

  (* TODO: reinsert costs for products and push it to the abstract level. *)
  definition ll_extract_fst :: "'t::llvm_rep  't1::llvm_rep llM" where "ll_extract_fst p = doM { ⌦‹consume (cost ''extract_fst'' 1);› (a,b)  checked_split_pair (to_val p); checked_from_val a }"
  definition ll_extract_snd :: "'t::llvm_rep  't2::llvm_rep llM" where "ll_extract_snd p = doM { ⌦‹consume (cost ''extract_snd'' 1);› (a,b)  checked_split_pair (to_val p); checked_from_val b }"
  definition ll_insert_fst :: "'t::llvm_rep  't1::llvm_rep  't llM" where "ll_insert_fst p x = doM { ⌦‹consume (cost ''insert_fst'' 1);› (a,b)  checked_split_pair (to_val p); checked_from_val (llvm_pair (to_val x) b) }" 
  definition ll_insert_snd :: "'t::llvm_rep  't2::llvm_rep  't llM" where "ll_insert_snd p x = doM { ⌦‹consume (cost ''insert_snd'' 1);› (a,b)  checked_split_pair (to_val p); checked_from_val (llvm_pair a (to_val x)) }" 
    
  (*  
  definition ll_extract_fst :: "('a::llvm_rep × 'b::llvm_rep) ⇒ 'a llM" where "ll_extract_fst ab ≡ return (fst ab)"
  definition ll_extract_snd :: "('a::llvm_rep × 'b::llvm_rep) ⇒ 'b llM" where "ll_extract_snd ab ≡ return (snd ab)"
  definition ll_insert_fst :: "('a::llvm_rep × 'b::llvm_rep) ⇒ 'a ⇒ ('a×'b) llM" where "ll_insert_fst ab a ≡ return (a,snd ab)"
  definition ll_insert_snd :: "('a::llvm_rep × 'b::llvm_rep) ⇒ 'b ⇒ ('a×'b) llM" where "ll_insert_snd ab b ≡ return (fst ab,b)"
  *)
    
  subsubsection ‹Memory Access and Addressing Operations›
    
  definition ll_load :: "'a::llvm_rep ptr  'a llM" where
    "ll_load p  doM {
      consume (cost ''load'' 1);
      r  llvm_load (the_raw_ptr p);
      checked_from_val r
    }"
    
  definition ll_store :: "'a::llvm_rep  'a ptr  unit llM" where
    "ll_store v p  doM {
      consume (cost ''store'' 1);
      llvm_store (to_val v) (the_raw_ptr p)
    }"

  text ‹Note that LLVM itself does not have malloc and free instructions.
    However, these are primitive instructions in our abstract memory model, 
    such that we have to model them in our semantics.
    
    The code generator will map them to the C standard library 
    functions calloc› and free›.
  ›
    
  definition ll_malloc :: "'a::llvm_rep itself  _::len word  'a ptr llM" where
    "ll_malloc TYPE('a) n = doM {
      consume (cost ''malloc'' (unat n)); ― ‹DESIGN CHOICE: malloc consumes n›
      fcheck MEM_ERROR (unat n > 0); ― ‹Disallow empty malloc›
      r  llvm_allocn (to_val (init::'a)) (unat n);
      return (PTR r)
    }"
        
  definition ll_free :: "'a::llvm_rep ptr  unit llM" 
    where "ll_free p  doM {
            consume (cost ''free'' 1); ― ‹DESIGN CHOICE: consume 1 ›
            llvm_free (the_raw_ptr p)
          }"


  text ‹As for the aggregate operations, the getelementptr› instruction is instantiated 
    for pointer indexing, fst, and snd. ›

  ― ‹pointer arithmetic, cost 1 each›

  definition ll_ofs_ptr :: "'a::llvm_rep ptr  _::len word  'a ptr llM" where "ll_ofs_ptr p ofs = doM {
    consume (cost ''ofs_ptr'' 1);
    r  llvm_checked_idx_ptr (the_raw_ptr p) (sint ofs);
    return (PTR r)
  }"  

  definition ll_gep_fst :: "'p::llvm_rep ptr  'a::llvm_rep ptr llM" where "ll_gep_fst p = doM {
    consume (cost ''gep_fst'' 1);
    fcheck (STATIC_ERROR ''gep_fst: Expected pair type'') (llvm_is_s_pair (struct_of TYPE('p)));
    r  llvm_checked_gep (the_raw_ptr p) PFST;
    return (PTR r)
  }"

  definition ll_gep_snd :: "'p::llvm_rep ptr  'b::llvm_rep ptr llM" where "ll_gep_snd p = doM {
    consume (cost ''gep_snd'' 1);
    fcheck (STATIC_ERROR ''gep_snd: Expected pair type'') (llvm_is_s_pair (struct_of TYPE('p)));
    r  llvm_checked_gep (the_raw_ptr p) PSND;
    return (PTR r)
  }"

  subsubsection ‹Conversion Operations›
  definition "llb_trunc i w  doM {
    fcheck (STATIC_ERROR ''Trunc must go to smaller type'') (width i > w);
    return (trunc w i)
  }"
  
  definition "llb_sext i w  doM {
    fcheck (STATIC_ERROR ''Sext must go to greater type'') (width i < w);
    return (sext w i)
  }"
  
  definition "llb_zext i w  doM {
    fcheck (STATIC_ERROR ''Zext must go to greater type'') (width i < w);
    return (zext w i)
  }"
  
  definition op_lift_iconv :: "_  _  'a::len word  'b::len word itself   'b word llM"
    where "op_lift_iconv n f a _  doM {
    consume (cost n 1);
    let a = word_to_lint a;
    let w = LENGTH('b);
    r  f a w;
    return (lint_to_word r)
  }"
  
  definition "ll_trunc  op_lift_iconv ''trunc'' llb_trunc"
  definition "ll_sext  op_lift_iconv ''sext'' llb_sext"
  definition "ll_zext  op_lift_iconv ''zext'' llb_zext"
  
    
        
        
  subsection ‹Control Flow›  

  text ‹Our shallow embedding uses a structured control flow, which allows
    only sequential composition, if-then-else, and function calls.
    
    The code generator then maps sequential composition to basic blocks, 
    and if-then-else to a control flow graph with conditional branching.
    Function calls are mapped to LLVM function calls.  
   ›
  
  text ‹We use the to Boolean conversion from word-lib. We re-state its semantics here.›
    
  lemma to_bool_as_lint_to_bool:
    "to_bool (w::1 word) = lint_to_bool (word_to_lint w)"
    unfolding to_bool_def word_to_lint_def
    apply (clarsimp simp: ltrue_def lfalse_def lint_to_bool_def)
    apply transfer
    apply auto
    done
  
  lemma to_bool_eq[simp]: "to_bool (w::1 word)  w0"
    by (rule to_bool_neq_0)
  
  definition llc_if :: "1 word  'a::llvm_repv llM  'a llM  'a llM" where
    "llc_if b t e  doM {
      consume (cost ''if'' 1);
      if to_bool b then t else e
    }"
  
  lemma llc_if_mono[partial_function_mono]:      
    "M_mono F; M_mono G  M_mono (λf. llc_if b (F f) (G f))"
    unfolding llc_if_def 
    by pf_mono_prover  

  subsubsection ‹Function Call›

  definition ll_call :: "'a llM  'a llM" where 
    "ll_call f  doM { consume (cost ''call'' 1) ; f  }"

  lemma ll_call_mono[partial_function_mono]: "M_mono f  M_mono (λx. ll_call (f x))"
    unfolding ll_call_def
    by pf_mono_prover
  
  subsubsection ‹Recursion with Time for Call›  
    
  definition "REC' F x = REC (λD x. F (λx. ll_call (D x)) x) x"
  
  lemma REC'_unfold:
    assumes DEF: "f  REC' F"
    assumes MONO: "x. M.mono_body (λfa. F (λx. ll_call (fa x)) x)"
    shows "f = F (λx. ll_call (f x))" 
    unfolding DEF REC'_def
    apply (rewrite REC_unfold[OF reflexive MONO])
    by rule
    
  lemma REC'_mono[partial_function_mono]:
    assumes MONO: "D x. M.mono_body (λE. B D (λx. ll_call (E x)) x)"
    assumes 1: "E x. M_mono (λD. B D (λx. ll_call (E x)) x)"
    shows "M_mono (λD. REC' (B D) x)"
    unfolding REC'_def
    using assms
    by pf_mono_prover
    
    
  
  notepad  (* TODO: cleanup *)
  begin
    ⌦‹
  
    assume MONO: "⋀x. M.mono_body (λfa. F (λx. ll_call (fa x)) x)"
    
      and DEF: "f ≡ REC' F"
      and F: " F  ≡ λD x. (if x>0 then D (x-1) else return 0 )"
    have "P (ll_call (f x))"
      apply(rewrite REC'_unfold[OF DEF MONO])
      apply(rewrite REC'_unfold[OF DEF MONO])
      apply(rewrite REC'_unfold[OF DEF MONO])
      apply(rewrite REC'_unfold[OF DEF MONO])
      unfolding F ll_call_def sorry
  ›
  end

  subsubsection ‹While-Combinator›
  text ‹
    Note that we also include the while combinator at this point, as we direct translation 
    of while to a control flow graph is the default translation mode of the code generator. 
    
    As an optional feature, while can be translated to 
    a tail recursive function call, which the preprocessor can do automatically.
  ›
    
  definition llc_while :: "('a::llvm_repv  1 word llM)  ('a  'a llM)  'a  'a llM" where
    "llc_while b f s0  ll_call (REC' (λmwhile σ. doM {
              ctd  b σ;
              llc_if ctd (f σ  mwhile) (return σ)
            }) s0)" 

  (*          
  lemma gen_code_thm_llc_while:
    assumes "f ≡ llc_while b body"
    shows "f s = ll_call (doM { ctd ← b s; llc_if ctd (doM { s←body s; f s}) (return s)})"
    unfolding assms
    unfolding llc_while_def llc_if_def
    apply (rewrite REC'_unfold[OF reflexive])
    apply pf_mono_prover
    by simp
  *)

  (* 'Definition' of llc_while for presentation in paper: *)  
  lemma "c. llc_while b c s  ll_call (doM {
     x  b s;
     llc_if x (doM {sc s; llc_while b c s}) (return s)
   })"
    unfolding llc_while_def llc_if_def
    apply (rewrite REC'_unfold[OF reflexive])
    apply pf_mono_prover
    by simp
    
   
  lemma llc_while_mono[partial_function_mono]:      
    assumes "x. M_mono (λf. b f x)"
    assumes "x. M_mono (λf. c f x)"
    shows "M_mono (λD. llc_while (b D) (c D) σ)"
    using assms unfolding llc_while_def by pf_mono_prover
      
    
       
end