aboutsummaryrefslogtreecommitdiffstats
path: root/src/hls/Schedule.ml
diff options
context:
space:
mode:
authorYann Herklotz <git@yannherklotz.com>2020-10-23 13:24:49 +0100
committerYann Herklotz <git@yannherklotz.com>2020-10-23 13:24:49 +0100
commitb8e2a41b7023954f7de91c9d3835804a11908386 (patch)
tree01d89991c9710cc4b753fe75e99c4f1d2c9929b9 /src/hls/Schedule.ml
parent65d1b8ad3f5991561a7b9a633459d2b6950b4c8a (diff)
downloadvericert-kvx-b8e2a41b7023954f7de91c9d3835804a11908386.tar.gz
vericert-kvx-b8e2a41b7023954f7de91c9d3835804a11908386.zip
Fix scheduling for loads and stores with WAR dependencies
Diffstat (limited to 'src/hls/Schedule.ml')
-rw-r--r--src/hls/Schedule.ml136
1 files changed, 120 insertions, 16 deletions
diff --git a/src/hls/Schedule.ml b/src/hls/Schedule.ml
index e521214..1ebcd41 100644
--- a/src/hls/Schedule.ml
+++ b/src/hls/Schedule.ml
@@ -89,9 +89,18 @@ let accumulate_RAW_deps dfg curr =
edges = List.append (List.fold_left (add_dep i dst_map) [] rs) edges;
} )
in
+ let acc_dep_instruction_nodst rs =
+ ( i + 1,
+ dst_map,
+ {
+ nodes;
+ edges = List.append (List.fold_left (add_dep i dst_map) [] rs) edges;
+ } )
+ in
match curr with
| RBop (_, rs, dst) -> acc_dep_instruction rs dst
| RBload (_mem, _addr, rs, dst) -> acc_dep_instruction rs dst
+ | RBstore (_mem, _addr, rs, src) -> acc_dep_instruction_nodst (src :: rs)
| _ -> (i + 1, dst_map, { edges; nodes })
(** Finds the next write to the [dst] register. This is a small optimisation so that only one
@@ -107,6 +116,19 @@ let rec find_next_dst_write i dst i' curr =
| RBload (_, _, _, dst') :: curr' -> check_dst dst' curr'
| _ :: curr' -> find_next_dst_write i dst (i' + 1) curr'
+let rec find_all_next_dst_read i dst i' curr =
+ let check_dst rs curr' =
+ if List.exists (fun x -> x = dst) rs
+ then (i, i') :: find_all_next_dst_read i dst (i' + 1) curr'
+ else find_all_next_dst_read i dst (i' + 1) curr'
+ in
+ match curr with
+ | [] -> []
+ | RBop (_, rs, _) :: curr' -> check_dst rs curr'
+ | RBload (_, _, rs, _) :: curr' -> check_dst rs curr'
+ | RBstore (_, _, rs, src) :: curr' -> check_dst (src :: rs) curr'
+ | RBnop :: curr' -> find_all_next_dst_read i dst (i' + 1) curr'
+
let drop i lst =
let rec drop' i' lst' =
match lst' with
@@ -123,6 +145,43 @@ let take i lst =
in
if i = 0 then [] else take' 1 lst
+let rec next_store i = function
+ | [] -> None
+ | RBstore (_, _, _, _) :: _ -> Some i
+ | _ :: rst -> next_store (i + 1) rst
+
+let rec next_load i = function
+ | [] -> None
+ | RBload (_, _, _, _) :: _ -> Some i
+ | _ :: rst -> next_load (i + 1) rst
+
+let accumulate_RAW_mem_deps dfg curr =
+ let i, { nodes; edges } = dfg in
+ match curr with
+ | RBload (_, _, _, _) -> (
+ match next_store 0 (take i nodes |> List.rev) with
+ | None -> (i + 1, { nodes; edges })
+ | Some d -> (i + 1, { nodes; edges = (i - d - 1, i) :: edges }) )
+ | _ -> (i + 1, { nodes; edges })
+
+let accumulate_WAR_mem_deps dfg curr =
+ let i, { nodes; edges } = dfg in
+ match curr with
+ | RBstore (_, _, _, _) -> (
+ match next_load 0 (take i nodes |> List.rev) with
+ | None -> (i + 1, { nodes; edges })
+ | Some d -> (i + 1, { nodes; edges = (i - d - 1, i) :: edges }) )
+ | _ -> (i + 1, { nodes; edges })
+
+let accumulate_WAW_mem_deps dfg curr =
+ let i, { nodes; edges } = dfg in
+ match curr with
+ | RBstore (_, _, _, _) -> (
+ match next_store 0 (take i nodes |> List.rev) with
+ | None -> (i + 1, { nodes; edges })
+ | Some d -> (i + 1, { nodes; edges = (i - d - 1, i) :: edges }) )
+ | _ -> (i + 1, { nodes; edges })
+
(** This function calculates the WAW dependencies, which happen when two writes are ordered one
after another and therefore have to be kept in that order. This accumulation might be redundant
if register renaming is done before hand, because then these dependencies can be avoided. *)
@@ -136,18 +195,23 @@ let accumulate_WAW_deps dfg curr =
match curr with
| RBop (_, _, dst) -> dst_dep dst
| RBload (_, _, _, dst) -> dst_dep dst
+ | RBstore (_, _, _, _) -> (
+ match next_store (i + 1) (drop (i + 1) nodes) with
+ | None -> (i + 1, { nodes; edges })
+ | Some i' -> (i + 1, { nodes; edges = (i, i') :: edges }) )
| _ -> (i + 1, { nodes; edges })
let accumulate_WAR_deps dfg curr =
let i, { edges; nodes } = dfg in
let dst_dep dst =
- match find_next_dst_write i dst 0 (take i nodes |> List.rev) with
- | Some (d, d') -> (i + 1, { nodes; edges = (i - d', d) :: edges })
- | _ -> (i + 1, { nodes; edges })
+ let dep_list = find_all_next_dst_read i dst 0 (take i nodes |> List.rev)
+ |> List.map (function (d, d') -> (i - d' - 1, d))
+ in
+ (i + 1, { nodes; edges = List.append dep_list edges })
in
match curr with
- | RBop (_, rs, dst) -> dst_dep dst
- | RBload (_, _, rs, dst) -> dst_dep dst
+ | RBop (_, _, dst) -> dst_dep dst
+ | RBload (_, _, _, dst) -> dst_dep dst
| _ -> (i + 1, { nodes; edges })
let assigned_vars vars = function
@@ -165,11 +229,26 @@ let gather_bb_constraints bb =
(0, PTree.empty, { nodes = bb.bb_body; edges = [] })
bb.bb_body
in
+ printf "DFG : %a\n" print_dfg dfg;
let _, dfg' = List.fold_left accumulate_WAW_deps (0, dfg) bb.bb_body in
+ printf "DFG': %a\n" print_dfg dfg';
let _, dfg'' = List.fold_left accumulate_WAR_deps (0, dfg') bb.bb_body in
+ printf "DFG'': %a\n" print_dfg dfg'';
+ let _, dfg''' =
+ List.fold_left accumulate_RAW_mem_deps (0, dfg'') bb.bb_body
+ in
+ printf "DFG''': %a\n" print_dfg dfg''';
+ let _, dfg'''' =
+ List.fold_left accumulate_WAR_mem_deps (0, dfg''') bb.bb_body
+ in
+ printf "DFG'''': %a\n" print_dfg dfg'''';
+ let _, dfg''''' =
+ List.fold_left accumulate_WAW_mem_deps (0, dfg'''') bb.bb_body
+ in
+ printf "DFG''''': %a\n" print_dfg dfg''''';
match bb.bb_exit with
| None -> assert false
- | Some e -> (List.length bb.bb_body, dfg'', successors_instr e)
+ | Some e -> (List.length bb.bb_body, dfg''''', successors_instr e)
let gen_bb_name s i = sprintf "bb%d%s" (P.to_int i) s
@@ -315,6 +394,7 @@ let translate_control_flow r curr_st instr =
state.st_controllogic;
} )
| RBreturn ret -> (
+ let new_state state = state.st_freshstate in
match ret with
| None ->
let fin =
@@ -329,7 +409,13 @@ let translate_control_flow r curr_st instr =
( (),
{
state with
- st_datapath = PTree.set curr_st fin state.st_datapath;
+ st_datapath =
+ PTree.set (new_state state) fin state.st_datapath;
+ st_controllogic =
+ PTree.set curr_st
+ (state_goto state.st_st (new_state state))
+ state.st_controllogic;
+ st_freshstate = P.succ state.st_freshstate;
} )
| Some ret' ->
let fin =
@@ -343,7 +429,13 @@ let translate_control_flow r curr_st instr =
( (),
{
state with
- st_datapath = PTree.set curr_st fin state.st_datapath;
+ st_datapath =
+ PTree.set (new_state state) fin state.st_datapath;
+ st_controllogic =
+ PTree.set curr_st
+ (state_goto state.st_st (new_state state))
+ state.st_controllogic;
+ st_freshstate = P.succ state.st_freshstate;
} ) )
| _ ->
error
@@ -388,7 +480,9 @@ let translate_instr r curr_st instr =
{
state with
st_datapath =
- PTree.set curr_st (Vnonblock (dst, Vvar src)) state.st_datapath;
+ PTree.set curr_st
+ (Vseq (prev_instr state, Vnonblock (dst, Vvar src)))
+ state.st_datapath;
} )
let combine_bb_schedule schedule s =
@@ -410,7 +504,15 @@ let transf_htl r c (schedule : (int * int) list IMap.t) =
match bb with
| { bb_body = []; bb_exit = Some c } -> translate_control_flow r i c
| { bb_body = bb_body'; bb_exit = Some ctrl_flow } ->
- let i_sched = IMap.find (P.to_int i) schedule in
+ let i_sched =
+ try IMap.find (P.to_int i) schedule
+ with Not_found ->
+ printf "Could not find %d\n" (P.to_int i);
+ IMap.iter
+ (fun d -> printf "%d: %a\n" d (print_list print_tuple))
+ schedule;
+ assert false
+ in
let min_state = find_min i_sched in
let max_state = find_max i_sched in
let i_sched_tree =
@@ -420,7 +522,7 @@ let transf_htl r c (schedule : (int * int) list IMap.t) =
(add_schedules r bb_body' min_state (P.to_int i))
(IMap.to_seq i_sched_tree |> List.of_seq)
>>= fun _ ->
- translate_control_flow r (P.of_int (P.to_int i - max_state - 1)) ctrl_flow
+ translate_control_flow r (P.of_int (P.to_int i - max_state)) ctrl_flow
| _ ->
coqstring_of_camlstring "Illegal state reached in scheduler"
|> Errors.msg |> error
@@ -432,9 +534,9 @@ let schedule entry r (c : code) =
let _, (vars, constraints, types) =
gather_cfg_constraints ([], ([], "", "")) c' entry
in
- let schedule = solve_constraints vars constraints types in
- IMap.iter (fun d -> printf "%d: %a\n" d (print_list print_tuple)) schedule;
- transf_htl r (PTree.elements c) schedule
+ let schedule' = solve_constraints vars constraints types in
+ (*printf "Schedule: %a\n" (fun a x -> IMap.iter (fun d -> fprintf a "%d: %a\n" d (print_list print_tuple)) x) schedule';*)
+ transf_htl r (PTree.elements c) schedule'
let transl_module' (f : RTLBlock.coq_function) : HTL.coq_module mon =
create_reg (Some Voutput) (Nat.of_int 1) >>= fun fin ->
@@ -473,11 +575,13 @@ let transl_module' (f : RTLBlock.coq_function) : HTL.coq_module mon =
let max_state f =
let st = P.of_int 10000 in
- { (init_state st) with
+ {
+ (init_state st) with
st_st = st;
st_freshreg = P.succ st;
st_freshstate = P.of_int 10000;
- st_scldecls = AssocMap.AssocMap.set st (None, Nat.of_int 32) ((init_state st).st_scldecls);
+ st_scldecls =
+ AssocMap.AssocMap.set st (None, Nat.of_int 32) (init_state st).st_scldecls;
}
let transl_module (f : RTLBlock.coq_function) : HTL.coq_module Errors.res =