Theory Gabow_SCC_Code

section Code Generation for SCC-Computation
theory Gabow_SCC_Code
imports 
  Gabow_SCC
  Gabow_Skeleton_Code
  "../ds/Modest_MDP"
begin


subsection Refining to LLVM Data Structures

text This section is mainly Sepref boilerplate.

locale fr_graph_scc_impl_def_loc = fr_graph_impl_def_loc E V0 N D E_assn succi ni
  for E and V0 and N and D and E_assn and succi and ni
begin
  definition "node_state_am_assn = node_state.am_assn"
end



locale fr_graph_scc_impl_loc = fr_graph_impl_loc E V0 N D E_assn succi ni + 
  fr_graph_scc_impl_def_loc E V0 N D E_assn succi ni
  for E and V0 and N and D and E_assn and succi and ni +
  assumes D_VAL: "D = N"
begin


  lemma D_bound: "D < max_snat (LENGTH(size_T))"
    using D_N_BOUND
    by (simp add: D_VAL max_snat_def)

  lemma inc_data_ll_aux: "x < D  Suc x < max_snat 64" using D_bound by simp

  definition "build_scc_impl' = (λ s i.
    do {
      ASSERT (i < card (E* `` V0));
      spop_impl_fr s i;
      RETURN (op_bound_val (λx. x<Suc D) (Suc i), s)
    })"


  lemma build_scc_impl_fr_alt_def_aux: "a < card (E* `` V0)  a < D"
    by (auto simp: card_reachable_bound D_VAL)
  lemma build_scc_impl_fr_alt_def: "build_scc_impl_fr s i = build_scc_impl' s i"
    unfolding build_scc_impl_fr_def GSS_defs.build_scc_impl_def build_scc_impl'_def
    apply(simp add: pop_impl_fr_def build_scc_impl_fr_alt_def_aux)
    done

  sepref_register build_scc_impl_fr

  sepref_definition build_scc_impl_ll is "uncurry (PR_CONST build_scc_impl_fr)" :: "GS_assnd *a data_assnd a data_assn ×a GS_assn"
    unfolding build_scc_impl_fr_alt_def build_scc_impl'_def build_scc_impl_fr_alt_def_aux PR_CONST_def
    supply [simp] = inc_data_ll_aux 
    apply sepref
    done

  concrete_definition (in -) build_scc_impl_ll' [llvm_code] is fr_graph_scc_impl_loc.build_scc_impl_ll_def
  lemmas [sepref_fr_rules] = build_scc_impl_ll.refine[unfolded build_scc_impl_ll'.refine[OF fr_graph_scc_impl_loc_axioms]]


  definition open_is :: "nat × nat GS  nat × nat GS" where "open_is is = (case is of (i,s)  (i,s))"
  sepref_register open_is :: "nat × nat list × nat list × (nat node_state option) × (nat × nat list) list
      nat × nat list × nat list × (nat  node_state option) × (nat × nat list) list"

  sepref_definition open_is_ll is "RETURN o open_is" :: "(data_assn ×a GS_assn)d a data_assn ×a GS_assn"
    unfolding open_is_def
    apply sepref
    done

  concrete_definition (in -) open_is_ll' [llvm_code] is fr_graph_scc_impl_loc.open_is_ll_def
  lemmas [sepref_fr_rules] = open_is_ll.refine[unfolded open_is_ll'.refine[OF fr_graph_scc_impl_loc_axioms]]


  definition close_is :: "nat  nat GS  nat × nat GS" where "close_is i s = (i,s)"
  sepref_register close_is :: "nat  (nat list × nat list × (nat  node_state option) × (nat × nat list) list)
      nat × nat list × nat list × (nat  node_state option) × (nat × nat list) list"

  sepref_definition close_is_ll is "uncurry (RETURN oo close_is)" :: "data_assnd *a GS_assnd a data_assn ×a GS_assn"
    unfolding close_is_def
    apply sepref
    done

  concrete_definition (in -) close_is_ll' [llvm_code] is fr_graph_scc_impl_loc.close_is_ll_def
  lemmas [sepref_fr_rules] = close_is_ll.refine[unfolded close_is_ll'.refine[OF fr_graph_scc_impl_loc_axioms]]

  definition "compute_SCC_inner_while_body2 E' = (λ (i,s). 
          do {
            ― ‹Select edge from end of path

            (vo,s)  select_edge_impl s;

            if (vo = None) then ⌦‹TODO: How to properly handle the case distinction? ›
                build_scc_impl_fr s i
            else do {
                let v = the(vo);
                ― ‹No more outgoing edges from current node on path
                ASSERT (v  E*``V0);
                if is_on_stack_impl v s then do {
                  s collapse_impl_fr v s;
                  RETURN (i, s)
                } else if ¬is_done_impl v s then do {
                  ― ‹Edge to new node. Append to path
                  s  push_impl' E' v s;
                  RETURN (i, s)
                } else do {
                  ― ‹Edge to done node. Skip
                  RETURN (i, s)
                }
             }
          })"
  sepref_register compute_SCC_inner_while_body2 :: "(nat × nat) set
      nat × nat list × nat list × (nat  node_state option) × (nat × nat list) list
         (nat × (nat list × nat list × (nat  node_state option) × (nat × nat list) list)) nres"


  lemma compute_SCC_inner_while_body_alt_def: "compute_SCC_inner_while_body is = compute_SCC_inner_while_body2 E is"
    unfolding compute_SCC_inner_while_body_def compute_SCC_inner_while_body2_def close_is_def open_is_def push_S_impl_def select_edge_S_impl_def
    apply (cases "is")
    apply clarsimp
    apply (fo_rule arg_cong)
    apply (rule ext)
    apply (auto simp: push_impl_alt_def )
    done


  sepref_definition compute_SCC_inner_while_body_ll is "uncurry (PR_CONST compute_SCC_inner_while_body2)" :: "E_assnk *a (data_assn ×a GS_assn)d a data_assn ×a GS_assn"
    unfolding compute_SCC_inner_while_body2_def PR_CONST_def
    apply sepref
    done


  concrete_definition (in -) compute_SCC_inner_while_body_ll' [llvm_code] is fr_graph_scc_impl_loc.compute_SCC_inner_while_body_ll_def
  lemmas [sepref_fr_rules] = compute_SCC_inner_while_body_ll.refine[unfolded compute_SCC_inner_while_body_ll'.refine[OF fr_graph_scc_impl_loc_axioms]]


  definition "output_assn_raw = data_assn ×a (node_state.am_assn N)"

  definition open_iI :: "nat × nat oGS  nat × nat oGS" where "open_iI iI = (case iI of (i,I)  (i,I))"
  sepref_register open_iI :: "nat × (nat  node_state option)  nat × ((nat, node_state) i_map)"

  sepref_definition open_iI_ll is "RETURN o open_iI" :: "output_assn_rawd a data_assn ×a (node_state.am_assn N)"
    unfolding open_iI_def output_assn_raw_def
    apply sepref
    done

  concrete_definition (in -) open_iI_ll' [llvm_code] is fr_graph_scc_impl_loc.open_iI_ll_def
  lemmas [sepref_fr_rules] = open_iI_ll.refine[unfolded open_iI_ll'.refine[OF fr_graph_scc_impl_loc_axioms]]



  definition close_iI :: "nat  nat oGS  nat × nat oGS" where "close_iI i I = (i,I)"
  sepref_register close_iI :: "nat  ((nat, node_state) i_map)  nat × (nat  node_state option)"

  sepref_definition close_iI_ll is "uncurry (RETURN oo close_iI)" :: "data_assnd *a (node_state.am_assn N)d a output_assn_raw"
    unfolding close_iI_def output_assn_raw_def
    apply sepref
    done

  concrete_definition (in -) close_iI_ll' [llvm_code] is fr_graph_scc_impl_loc.close_iI_ll_def
  lemmas [sepref_fr_rules] = close_iI_ll.refine[unfolded close_iI_ll'.refine[OF fr_graph_scc_impl_loc_axioms]]


  definition "compute_SCC_impl_nfoldli E'  do {
      let i = Map.empty;
      let so = close_iI 0 i;
      so  nfoldli [0..<N] (λ_. True)(λv0 (iI0 :: (nat × (nat  node_state option))). do {
        ASSERT (v0  E*``V0);
        let (i0, I0) = open_iI iI0;
        if ¬is_done_oimpl v0 I0 then do {
          s  initial_impl' E' v0 I0;

          (i,s) WHILEIT (λ (i,s). (λ (SCC,p,D,pE). vE. cscc_invar v0 (oGS_α I0) (SCC,p,D,pE,vE)) (GSS_defs.s_α s i))
            (λ (_, s). ¬path_is_empty_impl s) 
            (compute_SCC_inner_while_body2 E')
            (i0,s);
          let (S,B,I,P) = open_GS s;
          RETURN (close_iI i I)
        } else
          RETURN (close_iI i0 I0)
        }) so;
      RETURN so
    }"
  sepref_register skeleton_impl_nfoldli :: "(nat × nat) set  ((nat, node_state) i_map) nres"



  lemma bounded_list_set_b_rel: "([0..<N], V0)  node_rel' Nlist_set_rel"
  proof -
    have "([0..<N], set [0..<N])  node_rel' Nlist_set_rel"
      apply(rule list_to_set_b_rel'_setI)
      by (auto simp: list_all_length)
    thus ?thesis 
      by (simp add: V0_BOUND)
  qed 


  lemma skeleton_impl_nfoldli_refine: "compute_SCC_impl_nfoldli E   Id compute_SCC_impl"
    unfolding compute_SCC_impl_nfoldli_def compute_SCC_impl_def open_GS_def open_iI_def close_iI_def initial_S_impl_def
    apply (simp only: Refine_Basic.nres_monad_laws)
    apply (simp del: conc_Id)
    apply (refine_rcg LFOi_refine[where A="node_rel' N"])
    apply refine_dref_type
    apply (vc_solve (nopre) solve: asm_rl I_to_outer 
      simp: bounded_list_set_rel compute_SCC_inner_while_body_alt_def initial_impl_alt bounded_list_set_b_rel)
    apply(auto simp add: reachable_bound) 
    done

  
  lemma fold_nat_am_custom_empty': "(let i = close_iI 0 Map.empty in f i) = (let N'= (BCONST N N); i = close_iI 0 (op_am_custom_empty N') in f i)"
    by simp


  sepref_definition compute_SCC_impl_ll is "PR_CONST compute_SCC_impl_nfoldli" :: "E_assnk a output_assn_raw"
    unfolding compute_SCC_impl_nfoldli_def PR_CONST_def initial_S_impl_def
    unfolding fold_am_custom_empty' nfoldli_upt_by_while
    apply (annot_snat_const "TYPE(size_T)")
    supply [simp] = reachable_bound
    apply sepref
    done


  concrete_definition (in -) compute_SCC_impl_ll' [llvm_code] is fr_graph_scc_impl_loc.compute_SCC_impl_ll_def
  lemmas [sepref_fr_rules] = compute_SCC_impl_ll.refine[unfolded compute_SCC_impl_ll'.refine[OF fr_graph_scc_impl_loc_axioms]]

  definition "output_assn  hr_comp output_assn_raw SCC_rel"

end

subsection Exporting LLVM Code

text For the code export to work, we 'extract' the definition from its locale
concrete_definition Modest_compute_SCC_impl[llvm_code] is compute_SCC_impl_ll'_def[of modest_graph_succ_ll]


interpretation fr_graph_scc_impl_def: fr_graph_scc_impl_def_loc
  where N=N and D="N"
  and E_assn = "(modest_graph_assn N)"
  and succi = modest_graph_succ_ll
  and ni=ni
  for N D ni .


type_synonym scc_resulti = "size_t × size_t ptr"

definition Modest_compute_SCC_impl' :: "size_t  modest_graphi ptr  scc_resulti ptr  unit llM"
where [llvm_code, llvm_inline]: "Modest_compute_SCC_impl' ni Eip resp  doM {
  Ei  ll_load Eip;
  res  Modest_compute_SCC_impl ni Ei;
  ll_store res resp
}"


export_llvm Modest_compute_SCC_impl' is "void compute_SCC(my_size_t, modest_graph_t *, scc_result_t *)"  
defines 
  typedef uint64_t my_size_t;
  typedef my_size_t node_t;
  typedef uint64_t shared_nat_t;
  typedef uint64_t *bitset_t;

  typedef struct {
    shared_nat_t *states;
    struct {
      shared_nat_t *transitions;
      node_t *branches;
    };
  } modest_graph_t;

  typedef struct {
    my_size_t num_sccs;
    node_t *scc_map;
  } scc_result_t;

  file "modest_gabow.ll"





subsection Combining the Refinements    


locale Modest_graph_impl_loc = fr_graph_scc_impl_def_loc E "{0..<N}" N N "modest_graph_assn N" modest_graph_succ_ll ni
  for E :: "(nat × nat) set" and N :: nat and ni +
  assumes n_impl: "(ni, N)  size_rel"
  assumes N_BOUND: "2 * N < max_snat LENGTH(64)"
  assumes E_BOUND: "E  {0..<N} × {0..<N}"
begin


  sublocale fr_graph_scc_impl_loc E "{0..<N}" N "N" "modest_graph_assn N" modest_graph_succ_ll ni
    apply unfold_locales
    apply(rule rtrancl_image_subsetI[OF E_BOUND, THEN finite_subset] ) 
    apply blast
    apply (rule mop_Modest_graph_succ_hnr)
    apply (rule E_BOUND)
    apply simp
    using N_BOUND apply simp
    apply (rule n_impl)
    apply simp
    done

  lemma Modest_graph_skeleton_impl_refines_spec: "(Modest_compute_SCC_impl ni, λ_. compute_SCC_spec) 
     (hr_comp (modest_graph_assn N) {(E, E)})k a output_assn"
  proof -
  
    note r1 = Modest_compute_SCC_impl.refine[of ni, folded compute_SCC_impl_ll'.refine[OF fr_graph_scc_impl_loc_axioms]]

    note r2 = compute_SCC_impl_ll.refine[unfolded PR_CONST_def]  
    
    note skeleton_impl_nfoldli_refine
    also note compute_SCC_impl_refine
    also note compute_SCC_correct
    finally have "(compute_SCC_impl_nfoldli E, compute_SCC_spec)  SCC_rel nres_rel"
      by (auto simp: nres_rel_def)
    hence r3: "(compute_SCC_impl_nfoldli, (λ _. compute_SCC_spec))  {(E,E)}  SCC_rel nres_rel"
      by auto
    
    from r2[FCOMP r3, unfolded r1, folded output_assn_def] show ?thesis .
  qed

end

  lemma (in modest_graph_invar) α_in_bound: "α  {0..<Ns} × {0..<Ns}"
  proof -
    have "α  {0..<Ns} × UNIV"
      unfolding α_def NS_DEF by auto
    thus ?thesis 
    using succ_in_bound unfolding NS_DEF
    by auto
  qed

  lemma Modest_graph_assn_pure_partD: "pure_part (modest_graph_assn N E Ei)  E  {0..<N} × {0..<N}"
    unfolding modest_graph_assn_def
    apply(drule pure_part_hr_compD)
    unfolding modest_graph_rel_def
    using modest_graph_invar.α_in_bound
    by (fastforce simp: in_br_conv )

subsection Main Correctness Theorem    
    
  theorem Modest_graph_SCC_impl_correct_htriple: "llvm_htriple 
    (snat_assn N ni ** (2 * N < max_snat LENGTH(64)) ** modest_graph_assn N E Ei) 
    (Modest_compute_SCC_impl ni Ei) 
    (λri. EXS r. 
      snat_assn N ni 
      ** modest_graph_assn N E Ei
      ** fr_graph_scc_impl_loc.output_assn E {0..<N} N N r ri 
      ** (set r = fr_graph.scc_set E {0..<N})
      ** (i j. i < j  j < length r  r ! i × r ! j  E* = {}))"
  apply(rule htriple_pure_preI)
  apply(drule pure_part_split_conj, clarify)+
  apply(drule Modest_graph_assn_pure_partD)
  apply clarsimp
  proof (goal_cases)
    case 1
    interpret Modest_graph_impl_loc E N ni
      apply unfold_locales 
      by (fact|clarsimp)+
    note [simp] = hr_comp_b_rel_Id

    note [simp] = compute_SCC_spec_def

    note [vcg_rules] = Modest_graph_skeleton_impl_refines_spec[to_hnr, THEN hn_refineD, unfolded hn_ctxt_def, of E Ei, simplified]
    show ?case 
      by vcg' 

  qed
end