aboutsummaryrefslogtreecommitdiffstats
path: root/src/hls/Schedule.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/hls/Schedule.ml')
-rw-r--r--src/hls/Schedule.ml549
1 files changed, 549 insertions, 0 deletions
diff --git a/src/hls/Schedule.ml b/src/hls/Schedule.ml
new file mode 100644
index 0000000..b9ee741
--- /dev/null
+++ b/src/hls/Schedule.ml
@@ -0,0 +1,549 @@
+(*
+ * Vericert: Verified high-level synthesis.
+ * Copyright (C) 2020 Yann Herklotz <yann@yannherklotz.com>
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ *)
+
+open Printf
+open Clflags
+open Camlcoq
+open Datatypes
+open Coqlib
+open Maps
+open AST
+open Kildall
+open Op
+open RTLBlockInstr
+open RTLBlock
+open HTL
+open Verilog
+open HTLgen
+open HTLMonad
+open HTLMonadExtra
+
+module SS = Set.Make(P)
+
+module IMap = Map.Make (struct
+ type t = int
+
+ let compare = compare
+end)
+
+type dfg = { nodes : instr list; edges : (int * int) list }
+(** The DFG type defines a list of instructions with their data dependencies as [edges], which are
+ the pairs of integers that represent the index of the instruction in the [nodes]. The edges
+ always point from left to right. *)
+
+let print_list f out_chan a =
+ fprintf out_chan "[ ";
+ List.iter (fprintf out_chan "%a " f) a;
+ fprintf out_chan "]"
+
+let print_tuple out_chan a =
+ let l, r = a in
+ fprintf out_chan "(%d,%d)" l r
+
+let print_dfg out_chan dfg =
+ fprintf out_chan "{ nodes = %a, edges = %a }"
+ (print_list PrintRTLBlockInstr.print_bblock_body)
+ dfg.nodes (print_list print_tuple) dfg.edges
+
+let read_process command =
+ let buffer_size = 2048 in
+ let buffer = Buffer.create buffer_size in
+ let string = Bytes.create buffer_size in
+ let in_channel = Unix.open_process_in command in
+ let chars_read = ref 1 in
+ while !chars_read <> 0 do
+ chars_read := input in_channel string 0 buffer_size;
+ Buffer.add_substring buffer (Bytes.to_string string) 0 !chars_read
+ done;
+ ignore (Unix.close_process_in in_channel);
+ Buffer.contents buffer
+
+(** Add a dependency if it uses a register that was written to previously. *)
+let add_dep i tree deps curr =
+ match PTree.get curr tree with None -> deps | Some ip -> (ip, i) :: deps
+
+(** This function calculates the dependencies of each instruction. The nodes correspond to previous
+ registers that were allocated and show which instruction caused it.
+
+ This function only gathers the RAW constraints, and will therefore only be active for operations
+ that modify registers, which is this case only affects loads and operations. *)
+let accumulate_RAW_deps dfg curr =
+ let i, dst_map, { edges; nodes } = dfg in
+ let acc_dep_instruction rs dst =
+ ( i + 1,
+ PTree.set dst i dst_map,
+ {
+ nodes;
+ edges = List.append (List.fold_left (add_dep i dst_map) [] rs) edges;
+ } )
+ in
+ let acc_dep_instruction_nodst rs =
+ ( i + 1,
+ dst_map,
+ {
+ nodes;
+ edges = List.append (List.fold_left (add_dep i dst_map) [] rs) edges;
+ } )
+ in
+ match curr with
+ | RBop (op, _, rs, dst) -> acc_dep_instruction rs dst
+ | RBload (op, _mem, _addr, rs, dst) -> acc_dep_instruction rs dst
+ | RBstore (op, _mem, _addr, rs, src) -> acc_dep_instruction_nodst (src :: rs)
+ | _ -> (i + 1, dst_map, { edges; nodes })
+
+(** Finds the next write to the [dst] register. This is a small optimisation so that only one
+ dependency is generated for a data dependency. *)
+let rec find_next_dst_write i dst i' curr =
+ let check_dst dst' curr' =
+ if dst = dst' then Some (i, i')
+ else find_next_dst_write i dst (i' + 1) curr'
+ in
+ match curr with
+ | [] -> None
+ | RBop (_, _, _, dst') :: curr' -> check_dst dst' curr'
+ | RBload (_, _, _, _, dst') :: curr' -> check_dst dst' curr'
+ | _ :: curr' -> find_next_dst_write i dst (i' + 1) curr'
+
+let rec find_all_next_dst_read i dst i' curr =
+ let check_dst rs curr' =
+ if List.exists (fun x -> x = dst) rs
+ then (i, i') :: find_all_next_dst_read i dst (i' + 1) curr'
+ else find_all_next_dst_read i dst (i' + 1) curr'
+ in
+ match curr with
+ | [] -> []
+ | RBop (_, _, rs, _) :: curr' -> check_dst rs curr'
+ | RBload (_, _, _, rs, _) :: curr' -> check_dst rs curr'
+ | RBstore (_, _, _, rs, src) :: curr' -> check_dst (src :: rs) curr'
+ | RBnop :: curr' -> find_all_next_dst_read i dst (i' + 1) curr'
+ | RBsetpred (_, rs, _) :: curr' -> check_dst rs curr'
+
+let drop i lst =
+ let rec drop' i' lst' =
+ match lst' with
+ | _ :: ls -> if i' = i then ls else drop' (i' + 1) ls
+ | [] -> []
+ in
+ if i = 0 then lst else drop' 1 lst
+
+let take i lst =
+ let rec take' i' lst' =
+ match lst' with
+ | l :: ls -> if i' = i then [ l ] else l :: take' (i' + 1) ls
+ | [] -> []
+ in
+ if i = 0 then [] else take' 1 lst
+
+let rec next_store i = function
+ | [] -> None
+ | RBstore (_, _, _, _, _) :: _ -> Some i
+ | _ :: rst -> next_store (i + 1) rst
+
+let rec next_load i = function
+ | [] -> None
+ | RBload (_, _, _, _, _) :: _ -> Some i
+ | _ :: rst -> next_load (i + 1) rst
+
+let accumulate_RAW_mem_deps dfg curr =
+ let i, { nodes; edges } = dfg in
+ match curr with
+ | RBload (_, _, _, _, _) -> (
+ match next_store 0 (take i nodes |> List.rev) with
+ | None -> (i + 1, { nodes; edges })
+ | Some d -> (i + 1, { nodes; edges = (i - d - 1, i) :: edges }) )
+ | _ -> (i + 1, { nodes; edges })
+
+let accumulate_WAR_mem_deps dfg curr =
+ let i, { nodes; edges } = dfg in
+ match curr with
+ | RBstore (_, _, _, _, _) -> (
+ match next_load 0 (take i nodes |> List.rev) with
+ | None -> (i + 1, { nodes; edges })
+ | Some d -> (i + 1, { nodes; edges = (i - d - 1, i) :: edges }) )
+ | _ -> (i + 1, { nodes; edges })
+
+let accumulate_WAW_mem_deps dfg curr =
+ let i, { nodes; edges } = dfg in
+ match curr with
+ | RBstore (_, _, _, _, _) -> (
+ match next_store 0 (take i nodes |> List.rev) with
+ | None -> (i + 1, { nodes; edges })
+ | Some d -> (i + 1, { nodes; edges = (i - d - 1, i) :: edges }) )
+ | _ -> (i + 1, { nodes; edges })
+
+(** Predicate dependencies. *)
+
+let rec in_predicate p p' =
+ match p' with
+ | Pvar p'' -> Nat.to_int p = Nat.to_int p''
+ | Pnot p'' -> in_predicate p p''
+ | Pand (p1, p2) -> in_predicate p p1 || in_predicate p p2
+ | Por (p1, p2) -> in_predicate p p1 || in_predicate p p2
+
+let rec get_predicate = function
+ | RBop (p, _, _, _) -> p
+ | RBload (p, _, _, _, _) -> p
+ | RBstore (p, _, _, _, _) -> p
+ | _ -> None
+
+let rec next_setpred p i = function
+ | [] -> None
+ | RBsetpred (_, _, p') :: rst ->
+ if in_predicate p' p then
+ Some i
+ else
+ next_setpred p (i + 1) rst
+ | _ :: rst -> next_setpred p (i + 1) rst
+
+let rec next_preduse p i instr=
+ let next p' rst =
+ if in_predicate p p' then
+ Some i
+ else
+ next_preduse p (i + 1) rst
+ in
+ match instr with
+ | [] -> None
+ | RBload (Some p', _, _, _, _) :: rst -> next p' rst
+ | RBstore (Some p', _, _, _, _) :: rst -> next p' rst
+ | RBop (Some p', _, _, _) :: rst -> next p' rst
+ | _ :: rst -> next_load (i + 1) rst
+
+let accumulate_RAW_pred_deps dfg curr =
+ let i, { nodes; edges } = dfg in
+ match get_predicate curr with
+ | Some p -> (
+ match next_setpred p 0 (take i nodes |> List.rev) with
+ | None -> (i + 1, { nodes; edges })
+ | Some d -> (i + 1, { nodes; edges = (i - d - 1, i) :: edges }) )
+ | _ -> (i + 1, { nodes; edges })
+
+let accumulate_WAR_pred_deps dfg curr =
+ let i, { nodes; edges } = dfg in
+ match curr with
+ | RBsetpred (_, _, p) -> (
+ match next_preduse p 0 (take i nodes |> List.rev) with
+ | None -> (i + 1, { nodes; edges })
+ | Some d -> (i + 1, { nodes; edges = (i - d - 1, i) :: edges }) )
+ | _ -> (i + 1, { nodes; edges })
+
+let accumulate_WAW_pred_deps dfg curr =
+ let i, { nodes; edges } = dfg in
+ match curr with
+ | RBsetpred (_, _, p) -> (
+ match next_setpred (Pvar p) 0 (take i nodes |> List.rev) with
+ | None -> (i + 1, { nodes; edges })
+ | Some d -> (i + 1, { nodes; edges = (i - d - 1, i) :: edges }) )
+ | _ -> (i + 1, { nodes; edges })
+
+(** This function calculates the WAW dependencies, which happen when two writes are ordered one
+ after another and therefore have to be kept in that order. This accumulation might be redundant
+ if register renaming is done before hand, because then these dependencies can be avoided. *)
+let accumulate_WAW_deps dfg curr =
+ let i, { edges; nodes } = dfg in
+ let dst_dep dst =
+ match find_next_dst_write i dst (i + 1) (drop (i + 1) nodes) with
+ | Some d -> (i + 1, { nodes; edges = d :: edges })
+ | _ -> (i + 1, { nodes; edges })
+ in
+ match curr with
+ | RBop (_, _, _, dst) -> dst_dep dst
+ | RBload (_, _, _, _, dst) -> dst_dep dst
+ | RBstore (_, _, _, _, _) -> (
+ match next_store (i + 1) (drop (i + 1) nodes) with
+ | None -> (i + 1, { nodes; edges })
+ | Some i' -> (i + 1, { nodes; edges = (i, i') :: edges }) )
+ | _ -> (i + 1, { nodes; edges })
+
+let accumulate_WAR_deps dfg curr =
+ let i, { edges; nodes } = dfg in
+ let dst_dep dst =
+ let dep_list = find_all_next_dst_read i dst 0 (take i nodes |> List.rev)
+ |> List.map (function (d, d') -> (i - d' - 1, d))
+ in
+ (i + 1, { nodes; edges = List.append dep_list edges })
+ in
+ match curr with
+ | RBop (_, _, _, dst) -> dst_dep dst
+ | RBload (_, _, _, _, dst) -> dst_dep dst
+ | _ -> (i + 1, { nodes; edges })
+
+let assigned_vars vars = function
+ | RBnop -> vars
+ | RBop (_, _, _, dst) -> dst :: vars
+ | RBload (_, _, _, _, dst) -> dst :: vars
+ | RBstore (_, _, _, _, _) -> vars
+ | RBsetpred (_, _, _) -> vars
+
+let get_pred = function
+ | RBnop -> None
+ | RBop (op, _, _, _) -> op
+ | RBload (op, _, _, _, _) -> op
+ | RBstore (op, _, _, _, _) -> op
+ | RBsetpred (_, _, _) -> None
+
+let independant_pred p p' =
+ match sat_pred_temp (Nat.of_int 100000) (Pand (p, p')) with
+ | Some None -> true
+ | _ -> false
+
+let check_dependent op1 op2 =
+ match op1, op2 with
+ | Some p, Some p' -> not (independant_pred p p')
+ | _, _ -> true
+
+let remove_unnecessary_deps dfg =
+ let { edges; nodes } = dfg in
+ let is_dependent = function (i1, i2) ->
+ let instr1 = List.nth nodes i1 in
+ let instr2 = List.nth nodes i2 in
+ check_dependent (get_pred instr1) (get_pred instr2)
+ in
+ { edges = List.filter is_dependent edges; nodes }
+
+(** All the nodes in the DFG have to come after the source of the basic block, and should terminate
+ before the sink of the basic block. After that, there should be constraints for data
+ dependencies between nodes. *)
+let gather_bb_constraints debug bb =
+ let _, _, dfg =
+ List.fold_left accumulate_RAW_deps
+ (0, PTree.empty, { nodes = bb.bb_body; edges = [] })
+ bb.bb_body
+ in
+ if debug then printf "DFG : %a\n" print_dfg dfg else ();
+ let _, dfg1 = List.fold_left accumulate_WAW_deps (0, dfg) bb.bb_body in
+ if debug then printf "DFG': %a\n" print_dfg dfg1 else ();
+ let _, dfg2 = List.fold_left accumulate_WAR_deps (0, dfg1) bb.bb_body in
+ if debug then printf "DFG'': %a\n" print_dfg dfg2 else ();
+ let _, dfg3 =
+ List.fold_left accumulate_RAW_mem_deps (0, dfg2) bb.bb_body
+ in
+ if debug then printf "DFG''': %a\n" print_dfg dfg3 else ();
+ let _, dfg4 =
+ List.fold_left accumulate_WAR_mem_deps (0, dfg3) bb.bb_body
+ in
+ if debug then printf "DFG'''': %a\n" print_dfg dfg4 else ();
+ let _, dfg5 =
+ List.fold_left accumulate_WAW_mem_deps (0, dfg4) bb.bb_body
+ in
+ let _, dfg6 =
+ List.fold_left accumulate_RAW_pred_deps (0, dfg5) bb.bb_body
+ in
+ let _, dfg7 =
+ List.fold_left accumulate_WAR_pred_deps (0, dfg6) bb.bb_body
+ in
+ let _, dfg8 =
+ List.fold_left accumulate_WAW_pred_deps (0, dfg7) bb.bb_body
+ in
+ let dfg9 = remove_unnecessary_deps dfg8 in
+ if debug then printf "DFG''''': %a\n" print_dfg dfg9 else ();
+ (List.length bb.bb_body, dfg9, successors_instr bb.bb_exit)
+
+let gen_bb_name s i = sprintf "bb%d%s" (P.to_int i) s
+
+let gen_bb_name_ssrc = gen_bb_name "ssrc"
+
+let gen_bb_name_ssnk = gen_bb_name "ssnk"
+
+let gen_var_name s c i = sprintf "v%d%s_%d" (P.to_int i) s c
+
+let gen_var_name_b = gen_var_name "b"
+
+let gen_var_name_e = gen_var_name "e"
+
+let print_lt0 = sprintf "%s - %s <= 0;\n"
+
+let print_bb_order i c = if P.to_int c < P.to_int i then
+ print_lt0 (gen_bb_name_ssnk i) (gen_bb_name_ssrc c) else
+ ""
+
+let print_src_order i c =
+ print_lt0 (gen_bb_name_ssrc i) (gen_var_name_b c i)
+ ^ print_lt0 (gen_var_name_e c i) (gen_bb_name_ssnk i)
+ ^ sprintf "%s - %s = 1;\n" (gen_var_name_e c i) (gen_var_name_b c i)
+
+let print_src_type i c =
+ sprintf "int %s;\n" (gen_var_name_e c i)
+ ^ sprintf "int %s;\n" (gen_var_name_b c i)
+
+let print_data_dep_order c (i, j) =
+ print_lt0 (gen_var_name_e i c) (gen_var_name_b j c)
+
+let gather_cfg_constraints (completed, (bvars, constraints, types)) c curr =
+ if List.exists (P.eq curr) completed then
+ (completed, (bvars, constraints, types))
+ else
+ match PTree.get curr c with
+ | None -> assert false
+ | Some (num_iters, dfg, next) ->
+ let constraints' =
+ constraints
+ ^ String.concat "" (List.map (print_bb_order curr) next)
+ ^ String.concat ""
+ (List.map (print_src_order curr)
+ (List.init num_iters (fun x -> x)))
+ ^ String.concat "" (List.map (print_data_dep_order curr) dfg.edges)
+ in
+ let types' =
+ types
+ ^ String.concat ""
+ (List.map (print_src_type curr)
+ (List.init num_iters (fun x -> x)))
+ ^ sprintf "int %s;\n" (gen_bb_name_ssrc curr)
+ ^ sprintf "int %s;\n" (gen_bb_name_ssnk curr)
+ in
+ let bvars' =
+ List.append
+ (List.map
+ (fun x -> gen_var_name_b x curr)
+ (List.init num_iters (fun x -> x)))
+ bvars
+ in
+ (curr :: completed, (bvars', constraints', types'))
+
+let rec intersperse s = function
+ | [] -> []
+ | [ a ] -> [ a ]
+ | x :: xs -> x :: s :: intersperse s xs
+
+let update_schedule v = function Some l -> Some (v :: l) | None -> Some [ v ]
+
+let parse_soln tree s =
+ let r = Str.regexp "v\\([0-9]+\\)b_\\([0-9]+\\)[ ]+\\([0-9]+\\)" in
+ if Str.string_match r s 0 then
+ IMap.update
+ (Str.matched_group 1 s |> int_of_string)
+ (update_schedule
+ ( Str.matched_group 2 s |> int_of_string,
+ Str.matched_group 3 s |> int_of_string ))
+ tree
+ else tree
+
+let solve_constraints vars constraints types =
+ let oc = open_out "lpsolve.txt" in
+ fprintf oc "min: ";
+ List.iter (fprintf oc "%s") (intersperse " + " vars);
+ fprintf oc ";\n";
+ fprintf oc "%s" constraints;
+ fprintf oc "%s" types;
+ close_out oc;
+ Str.split (Str.regexp_string "\n") (read_process "lp_solve lpsolve.txt")
+ |> drop 3
+ |> List.fold_left parse_soln IMap.empty
+
+let find_min = function
+ | [] -> assert false
+ | l :: ls ->
+ let rec find_min' current = function
+ | [] -> current
+ | l' :: ls' ->
+ if snd l' < current then find_min' (snd l') ls'
+ else find_min' current ls'
+ in
+ find_min' (snd l) ls
+
+let find_max = function
+ | [] -> assert false
+ | l :: ls ->
+ let rec find_max' current = function
+ | [] -> current
+ | l' :: ls' ->
+ if snd l' > current then find_max' (snd l') ls'
+ else find_max' current ls'
+ in
+ find_max' (snd l) ls
+
+let ( >>= ) = bind
+
+let combine_bb_schedule schedule s =
+ let i, st = s in
+ IMap.update st (update_schedule i) schedule
+
+let compare_tuple (a, _) (b, _) = compare a b
+
+(** Should generate the [RTLPar] code based on the input [RTLBlock] description. *)
+let transf_rtlpar c (schedule : (int * int) list IMap.t) =
+ let f i bb : RTLPar.bblock =
+ match bb with
+ | { bb_body = []; bb_exit = c } ->
+ { bb_body = [];
+ bb_exit = c
+ }
+ | { bb_body = bb_body'; bb_exit = ctrl_flow } ->
+ let i_sched =
+ try IMap.find (P.to_int i) schedule
+ with Not_found -> (
+ printf "Could not find %d\n" (P.to_int i);
+ IMap.iter
+ (fun d -> printf "%d: %a\n" d (print_list print_tuple))
+ schedule;
+ assert false
+ )
+ in
+ let min_state = find_min i_sched in
+ let max_state = find_max i_sched in
+ let i_sched_tree =
+ List.fold_left combine_bb_schedule IMap.empty i_sched
+ in
+ (*printf "--------------- curr: %d, max: %d, min: %d, next: %d\n" (P.to_int i) max_state min_state (P.to_int i - max_state + min_state - 1);
+ printf "HIIIII: %d orig: %d\n" (P.to_int i - max_state + min_state - 1) (P.to_int i);*)
+ { bb_body = (IMap.to_seq i_sched_tree |> List.of_seq |> List.sort compare_tuple |> List.map snd
+ |> List.map (List.map (fun x -> List.nth bb_body' x)));
+ bb_exit = ctrl_flow
+ }
+ in
+ PTree.map f c
+
+let second = function (_, a, _) -> a
+
+let schedule entry (c : RTLBlock.bb RTLBlockInstr.code) =
+ let debug = false in
+ let c' = PTree.map1 (gather_bb_constraints false) c in
+ (*let _ = if debug then PTree.map (fun r o -> printf "##### %d #####\n%a\n\n" (P.to_int r) print_dfg (second o)) c' else PTree.empty in*)
+ let _, (vars, constraints, types) =
+ List.map fst (PTree.elements c') |>
+ List.fold_left (fun compl ->
+ gather_cfg_constraints compl c') ([], ([], "", ""))
+ in
+ let schedule' = solve_constraints vars constraints types in
+ (*IMap.iter (fun a b -> printf "##### %d #####\n%a\n\n" a (print_list print_tuple) b) schedule';*)
+ (*printf "Schedule: %a\n" (fun a x -> IMap.iter (fun d -> fprintf a "%d: %a\n" d (print_list print_tuple)) x) schedule';*)
+ transf_rtlpar c schedule'
+
+let rec find_reachable_states c e =
+ match PTree.get e c with
+ | Some { bb_exit = ex; _ } ->
+ e :: List.fold_left (fun x a -> List.concat [x; find_reachable_states c a]) []
+ (successors_instr ex |> List.filter (fun x -> P.lt x e))
+ | None -> assert false
+
+let add_to_tree c nt i =
+ match PTree.get i c with
+ | Some p -> PTree.set i p nt
+ | None -> assert false
+
+let schedule_fn (f : RTLBlock.coq_function) : RTLPar.coq_function =
+ let scheduled = schedule f.fn_entrypoint f.fn_code in
+ let reachable = find_reachable_states scheduled f.fn_entrypoint
+ |> List.to_seq |> SS.of_seq |> SS.to_seq |> List.of_seq in
+ { fn_sig = f.fn_sig;
+ fn_params = f.fn_params;
+ fn_stacksize = f.fn_stacksize;
+ fn_code = List.fold_left (add_to_tree scheduled) PTree.empty reachable;
+ fn_entrypoint = f.fn_entrypoint
+ }