aboutsummaryrefslogtreecommitdiffstats
path: root/src/hls/Schedule.ml
blob: b9ee74152a27e0162ef326d50b7d54ad7d7145c3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
(*
 * Vericert: Verified high-level synthesis.
 * Copyright (C) 2020 Yann Herklotz <yann@yannherklotz.com>
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 *)

open Printf
open Clflags
open Camlcoq
open Datatypes
open Coqlib
open Maps
open AST
open Kildall
open Op
open RTLBlockInstr
open RTLBlock
open HTL
open Verilog
open HTLgen
open HTLMonad
open HTLMonadExtra

module SS = Set.Make(P)

module IMap = Map.Make (struct
  type t = int

  let compare = compare
end)

type dfg = { nodes : instr list; edges : (int * int) list }
(** The DFG type defines a list of instructions with their data dependencies as [edges], which are
   the pairs of integers that represent the index of the instruction in the [nodes].  The edges
   always point from left to right. *)

let print_list f out_chan a =
  fprintf out_chan "[ ";
  List.iter (fprintf out_chan "%a " f) a;
  fprintf out_chan "]"

let print_tuple out_chan a =
  let l, r = a in
  fprintf out_chan "(%d,%d)" l r

let print_dfg out_chan dfg =
  fprintf out_chan "{ nodes = %a, edges = %a }"
    (print_list PrintRTLBlockInstr.print_bblock_body)
    dfg.nodes (print_list print_tuple) dfg.edges

let read_process command =
  let buffer_size = 2048 in
  let buffer = Buffer.create buffer_size in
  let string = Bytes.create buffer_size in
  let in_channel = Unix.open_process_in command in
  let chars_read = ref 1 in
  while !chars_read <> 0 do
    chars_read := input in_channel string 0 buffer_size;
    Buffer.add_substring buffer (Bytes.to_string string) 0 !chars_read
  done;
  ignore (Unix.close_process_in in_channel);
  Buffer.contents buffer

(** Add a dependency if it uses a register that was written to previously. *)
let add_dep i tree deps curr =
  match PTree.get curr tree with None -> deps | Some ip -> (ip, i) :: deps

(** This function calculates the dependencies of each instruction.  The nodes correspond to previous
   registers that were allocated and show which instruction caused it.

   This function only gathers the RAW constraints, and will therefore only be active for operations
   that modify registers, which is this case only affects loads and operations. *)
let accumulate_RAW_deps dfg curr =
  let i, dst_map, { edges; nodes } = dfg in
  let acc_dep_instruction rs dst =
    ( i + 1,
      PTree.set dst i dst_map,
      {
        nodes;
        edges = List.append (List.fold_left (add_dep i dst_map) [] rs) edges;
      } )
  in
  let acc_dep_instruction_nodst rs =
    ( i + 1,
      dst_map,
      {
        nodes;
        edges = List.append (List.fold_left (add_dep i dst_map) [] rs) edges;
      } )
  in
  match curr with
  | RBop (op, _, rs, dst) -> acc_dep_instruction rs dst
  | RBload (op, _mem, _addr, rs, dst) -> acc_dep_instruction rs dst
  | RBstore (op, _mem, _addr, rs, src) -> acc_dep_instruction_nodst (src :: rs)
  | _ -> (i + 1, dst_map, { edges; nodes })

(** Finds the next write to the [dst] register.  This is a small optimisation so that only one
   dependency is generated for a data dependency. *)
let rec find_next_dst_write i dst i' curr =
  let check_dst dst' curr' =
    if dst = dst' then Some (i, i')
    else find_next_dst_write i dst (i' + 1) curr'
  in
  match curr with
  | [] -> None
  | RBop (_, _, _, dst') :: curr' -> check_dst dst' curr'
  | RBload (_, _, _, _, dst') :: curr' -> check_dst dst' curr'
  | _ :: curr' -> find_next_dst_write i dst (i' + 1) curr'

let rec find_all_next_dst_read i dst i' curr =
  let check_dst rs curr' =
    if List.exists (fun x -> x = dst) rs
    then (i, i') :: find_all_next_dst_read i dst (i' + 1) curr'
    else find_all_next_dst_read i dst (i' + 1) curr'
  in
  match curr with
  | [] -> []
  | RBop (_, _, rs, _) :: curr' -> check_dst rs curr'
  | RBload (_, _, _, rs, _) :: curr' -> check_dst rs curr'
  | RBstore (_, _, _, rs, src) :: curr' -> check_dst (src :: rs) curr'
  | RBnop :: curr' -> find_all_next_dst_read i dst (i' + 1) curr'
  | RBsetpred (_, rs, _) :: curr' -> check_dst rs curr'

let drop i lst =
  let rec drop' i' lst' =
    match lst' with
    | _ :: ls -> if i' = i then ls else drop' (i' + 1) ls
    | [] -> []
  in
  if i = 0 then lst else drop' 1 lst

let take i lst =
  let rec take' i' lst' =
    match lst' with
    | l :: ls -> if i' = i then [ l ] else l :: take' (i' + 1) ls
    | [] -> []
  in
  if i = 0 then [] else take' 1 lst

let rec next_store i = function
  | [] -> None
  | RBstore (_, _, _, _, _) :: _ -> Some i
  | _ :: rst -> next_store (i + 1) rst

let rec next_load i = function
  | [] -> None
  | RBload (_, _, _, _, _) :: _ -> Some i
  | _ :: rst -> next_load (i + 1) rst

let accumulate_RAW_mem_deps dfg curr =
  let i, { nodes; edges } = dfg in
  match curr with
  | RBload (_, _, _, _, _) -> (
      match next_store 0 (take i nodes |> List.rev) with
      | None -> (i + 1, { nodes; edges })
      | Some d -> (i + 1, { nodes; edges = (i - d - 1, i) :: edges }) )
  | _ -> (i + 1, { nodes; edges })

let accumulate_WAR_mem_deps dfg curr =
  let i, { nodes; edges } = dfg in
  match curr with
  | RBstore (_, _, _, _, _) -> (
      match next_load 0 (take i nodes |> List.rev) with
      | None -> (i + 1, { nodes; edges })
      | Some d -> (i + 1, { nodes; edges = (i - d - 1, i) :: edges }) )
  | _ -> (i + 1, { nodes; edges })

let accumulate_WAW_mem_deps dfg curr =
  let i, { nodes; edges } = dfg in
  match curr with
  | RBstore (_, _, _, _, _) -> (
      match next_store 0 (take i nodes |> List.rev) with
      | None -> (i + 1, { nodes; edges })
      | Some d -> (i + 1, { nodes; edges = (i - d - 1, i) :: edges }) )
  | _ -> (i + 1, { nodes; edges })

(** Predicate dependencies. *)

let rec in_predicate p p' =
  match p' with
  | Pvar p'' -> Nat.to_int p = Nat.to_int p''
  | Pnot p'' -> in_predicate p p''
  | Pand (p1, p2) -> in_predicate p p1 || in_predicate p p2
  | Por (p1, p2) -> in_predicate p p1 || in_predicate p p2

let rec get_predicate = function
  | RBop (p, _, _, _) -> p
  | RBload (p, _, _, _, _) -> p
  | RBstore (p, _, _, _, _) -> p
  | _ -> None

let rec next_setpred p i = function
  | [] -> None
  | RBsetpred (_, _, p') :: rst ->
    if in_predicate p' p then
      Some i
    else
      next_setpred p (i + 1) rst
  | _ :: rst -> next_setpred p (i + 1) rst

let rec next_preduse p i instr=
  let next p' rst =
    if in_predicate p p' then
      Some i
    else
      next_preduse p (i + 1) rst
  in
  match instr with
  | [] -> None
  | RBload (Some p', _, _, _, _) :: rst -> next p' rst
  | RBstore (Some p', _, _, _, _) :: rst -> next p' rst
  | RBop (Some p', _, _, _) :: rst -> next p' rst
  | _ :: rst -> next_load (i + 1) rst

let accumulate_RAW_pred_deps dfg curr =
  let i, { nodes; edges } = dfg in
  match get_predicate curr with
  | Some p -> (
      match next_setpred p 0 (take i nodes |> List.rev) with
      | None -> (i + 1, { nodes; edges })
      | Some d -> (i + 1, { nodes; edges = (i - d - 1, i) :: edges }) )
  | _ -> (i + 1, { nodes; edges })

let accumulate_WAR_pred_deps dfg curr =
  let i, { nodes; edges } = dfg in
  match curr with
  | RBsetpred (_, _, p) -> (
      match next_preduse p 0 (take i nodes |> List.rev) with
      | None -> (i + 1, { nodes; edges })
      | Some d -> (i + 1, { nodes; edges = (i - d - 1, i) :: edges }) )
  | _ -> (i + 1, { nodes; edges })

let accumulate_WAW_pred_deps dfg curr =
  let i, { nodes; edges } = dfg in
  match curr with
  | RBsetpred (_, _, p) -> (
      match next_setpred (Pvar p) 0 (take i nodes |> List.rev) with
      | None -> (i + 1, { nodes; edges })
      | Some d -> (i + 1, { nodes; edges = (i - d - 1, i) :: edges }) )
  | _ -> (i + 1, { nodes; edges })

(** This function calculates the WAW dependencies, which happen when two writes are ordered one
   after another and therefore have to be kept in that order.  This accumulation might be redundant
   if register renaming is done before hand, because then these dependencies can be avoided. *)
let accumulate_WAW_deps dfg curr =
  let i, { edges; nodes } = dfg in
  let dst_dep dst =
    match find_next_dst_write i dst (i + 1) (drop (i + 1) nodes) with
    | Some d -> (i + 1, { nodes; edges = d :: edges })
    | _ -> (i + 1, { nodes; edges })
  in
  match curr with
  | RBop (_, _, _, dst) -> dst_dep dst
  | RBload (_, _, _, _, dst) -> dst_dep dst
  | RBstore (_, _, _, _, _) -> (
      match next_store (i + 1) (drop (i + 1) nodes) with
      | None -> (i + 1, { nodes; edges })
      | Some i' -> (i + 1, { nodes; edges = (i, i') :: edges }) )
  | _ -> (i + 1, { nodes; edges })

let accumulate_WAR_deps dfg curr =
  let i, { edges; nodes } = dfg in
  let dst_dep dst =
    let dep_list = find_all_next_dst_read i dst 0 (take i nodes |> List.rev)
        |> List.map (function (d, d') -> (i - d' - 1, d))
    in
    (i + 1, { nodes; edges = List.append dep_list edges })
  in
  match curr with
  | RBop (_, _, _, dst) -> dst_dep dst
  | RBload (_, _, _, _, dst) -> dst_dep dst
  | _ -> (i + 1, { nodes; edges })

let assigned_vars vars = function
  | RBnop -> vars
  | RBop (_, _, _, dst) -> dst :: vars
  | RBload (_, _, _, _, dst) -> dst :: vars
  | RBstore (_, _, _, _, _) -> vars
  | RBsetpred (_, _, _) -> vars

let get_pred = function
  | RBnop -> None
  | RBop (op, _, _, _) -> op
  | RBload (op, _, _, _, _) -> op
  | RBstore (op, _, _, _, _) -> op
  | RBsetpred (_, _, _) -> None

let independant_pred p p' =
  match sat_pred_temp (Nat.of_int 100000) (Pand (p, p')) with
  | Some None -> true
  | _ -> false

let check_dependent op1 op2 =
  match op1, op2 with
  | Some p, Some p' -> not (independant_pred p p')
  | _, _ -> true

let remove_unnecessary_deps dfg =
  let { edges; nodes } = dfg in
  let is_dependent = function (i1, i2) ->
    let instr1 = List.nth nodes i1 in
    let instr2 = List.nth nodes i2 in
    check_dependent (get_pred instr1) (get_pred instr2)
  in
  { edges = List.filter is_dependent edges; nodes }

(** All the nodes in the DFG have to come after the source of the basic block, and should terminate
   before the sink of the basic block.  After that, there should be constraints for data
   dependencies between nodes. *)
let gather_bb_constraints debug bb =
  let _, _, dfg =
    List.fold_left accumulate_RAW_deps
      (0, PTree.empty, { nodes = bb.bb_body; edges = [] })
      bb.bb_body
  in
  if debug then printf "DFG : %a\n" print_dfg dfg else ();
  let _, dfg1 = List.fold_left accumulate_WAW_deps (0, dfg) bb.bb_body in
  if debug then printf "DFG': %a\n" print_dfg dfg1 else ();
  let _, dfg2 = List.fold_left accumulate_WAR_deps (0, dfg1) bb.bb_body in
  if debug then printf "DFG'': %a\n" print_dfg dfg2 else ();
  let _, dfg3 =
    List.fold_left accumulate_RAW_mem_deps (0, dfg2) bb.bb_body
  in
  if debug then printf "DFG''': %a\n" print_dfg dfg3 else ();
  let _, dfg4 =
    List.fold_left accumulate_WAR_mem_deps (0, dfg3) bb.bb_body
  in
  if debug then printf "DFG'''': %a\n" print_dfg dfg4 else ();
  let _, dfg5 =
    List.fold_left accumulate_WAW_mem_deps (0, dfg4) bb.bb_body
  in
  let _, dfg6 =
    List.fold_left accumulate_RAW_pred_deps (0, dfg5) bb.bb_body
  in
  let _, dfg7 =
    List.fold_left accumulate_WAR_pred_deps (0, dfg6) bb.bb_body
  in
  let _, dfg8 =
    List.fold_left accumulate_WAW_pred_deps (0, dfg7) bb.bb_body
  in
  let dfg9 = remove_unnecessary_deps dfg8 in
  if debug then printf "DFG''''': %a\n" print_dfg dfg9 else ();
  (List.length bb.bb_body, dfg9, successors_instr bb.bb_exit)

let gen_bb_name s i = sprintf "bb%d%s" (P.to_int i) s

let gen_bb_name_ssrc = gen_bb_name "ssrc"

let gen_bb_name_ssnk = gen_bb_name "ssnk"

let gen_var_name s c i = sprintf "v%d%s_%d" (P.to_int i) s c

let gen_var_name_b = gen_var_name "b"

let gen_var_name_e = gen_var_name "e"

let print_lt0 = sprintf "%s - %s <= 0;\n"

let print_bb_order i c = if P.to_int c < P.to_int i then
    print_lt0 (gen_bb_name_ssnk i) (gen_bb_name_ssrc c) else
    ""

let print_src_order i c =
  print_lt0 (gen_bb_name_ssrc i) (gen_var_name_b c i)
  ^ print_lt0 (gen_var_name_e c i) (gen_bb_name_ssnk i)
  ^ sprintf "%s - %s = 1;\n" (gen_var_name_e c i) (gen_var_name_b c i)

let print_src_type i c =
  sprintf "int %s;\n" (gen_var_name_e c i)
  ^ sprintf "int %s;\n" (gen_var_name_b c i)

let print_data_dep_order c (i, j) =
  print_lt0 (gen_var_name_e i c) (gen_var_name_b j c)

let gather_cfg_constraints (completed, (bvars, constraints, types)) c curr =
  if List.exists (P.eq curr) completed then
    (completed, (bvars, constraints, types))
  else
    match PTree.get curr c with
    | None -> assert false
    | Some (num_iters, dfg, next) ->
        let constraints' =
          constraints
          ^ String.concat "" (List.map (print_bb_order curr) next)
          ^ String.concat ""
              (List.map (print_src_order curr)
                 (List.init num_iters (fun x -> x)))
          ^ String.concat "" (List.map (print_data_dep_order curr) dfg.edges)
        in
        let types' =
          types
          ^ String.concat ""
              (List.map (print_src_type curr)
                 (List.init num_iters (fun x -> x)))
          ^ sprintf "int %s;\n" (gen_bb_name_ssrc curr)
          ^ sprintf "int %s;\n" (gen_bb_name_ssnk curr)
        in
        let bvars' =
          List.append
            (List.map
               (fun x -> gen_var_name_b x curr)
               (List.init num_iters (fun x -> x)))
            bvars
        in
        (curr :: completed, (bvars', constraints', types'))

let rec intersperse s = function
  | [] -> []
  | [ a ] -> [ a ]
  | x :: xs -> x :: s :: intersperse s xs

let update_schedule v = function Some l -> Some (v :: l) | None -> Some [ v ]

let parse_soln tree s =
  let r = Str.regexp "v\\([0-9]+\\)b_\\([0-9]+\\)[ ]+\\([0-9]+\\)" in
  if Str.string_match r s 0 then
    IMap.update
      (Str.matched_group 1 s |> int_of_string)
      (update_schedule
         ( Str.matched_group 2 s |> int_of_string,
           Str.matched_group 3 s |> int_of_string ))
      tree
  else tree

let solve_constraints vars constraints types =
  let oc = open_out "lpsolve.txt" in
  fprintf oc "min: ";
  List.iter (fprintf oc "%s") (intersperse " + " vars);
  fprintf oc ";\n";
  fprintf oc "%s" constraints;
  fprintf oc "%s" types;
  close_out oc;
  Str.split (Str.regexp_string "\n") (read_process "lp_solve lpsolve.txt")
  |> drop 3
  |> List.fold_left parse_soln IMap.empty

let find_min = function
  | [] -> assert false
  | l :: ls ->
      let rec find_min' current = function
        | [] -> current
        | l' :: ls' ->
            if snd l' < current then find_min' (snd l') ls'
            else find_min' current ls'
      in
      find_min' (snd l) ls

let find_max = function
  | [] -> assert false
  | l :: ls ->
      let rec find_max' current = function
        | [] -> current
        | l' :: ls' ->
            if snd l' > current then find_max' (snd l') ls'
            else find_max' current ls'
      in
      find_max' (snd l) ls

let ( >>= ) = bind

let combine_bb_schedule schedule s =
  let i, st = s in
  IMap.update st (update_schedule i) schedule

let compare_tuple (a, _) (b, _) = compare a b

(** Should generate the [RTLPar] code based on the input [RTLBlock] description. *)
let transf_rtlpar c (schedule : (int * int) list IMap.t) =
  let f i bb : RTLPar.bblock =
    match bb with
    | { bb_body = []; bb_exit = c } ->
      { bb_body = [];
        bb_exit = c
      }
    | { bb_body = bb_body'; bb_exit = ctrl_flow } ->
        let i_sched =
          try IMap.find (P.to_int i) schedule
          with Not_found -> (
            printf "Could not find %d\n" (P.to_int i);
            IMap.iter
              (fun d -> printf "%d: %a\n" d (print_list print_tuple))
              schedule;
            assert false
          )
        in
        let min_state = find_min i_sched in
        let max_state = find_max i_sched in
        let i_sched_tree =
          List.fold_left combine_bb_schedule IMap.empty i_sched
        in
        (*printf "--------------- curr: %d, max: %d, min: %d, next: %d\n" (P.to_int i) max_state min_state (P.to_int i - max_state + min_state - 1);
        printf "HIIIII: %d orig: %d\n" (P.to_int i - max_state + min_state - 1) (P.to_int i);*)
        { bb_body = (IMap.to_seq i_sched_tree |> List.of_seq |> List.sort compare_tuple |> List.map snd
                           |> List.map (List.map (fun x -> List.nth bb_body' x)));
          bb_exit = ctrl_flow
        }
  in
  PTree.map f c

let second = function (_, a, _) -> a

let schedule entry (c : RTLBlock.bb RTLBlockInstr.code) =
  let debug = false in
  let c' = PTree.map1 (gather_bb_constraints false) c in
  (*let _ = if debug then PTree.map (fun r o -> printf "##### %d #####\n%a\n\n" (P.to_int r) print_dfg (second o)) c' else PTree.empty in*)
  let _, (vars, constraints, types) =
    List.map fst (PTree.elements c') |>
    List.fold_left (fun compl ->
        gather_cfg_constraints compl c') ([], ([], "", ""))
  in
  let schedule' = solve_constraints vars constraints types in
  (*IMap.iter (fun a b -> printf "##### %d #####\n%a\n\n" a (print_list print_tuple) b) schedule';*)
  (*printf "Schedule: %a\n" (fun a x -> IMap.iter (fun d -> fprintf a "%d: %a\n" d (print_list print_tuple)) x) schedule';*)
  transf_rtlpar c schedule'

let rec find_reachable_states c e =
  match PTree.get e c with
  | Some { bb_exit = ex; _ } ->
    e :: List.fold_left (fun x a -> List.concat [x; find_reachable_states c a]) []
      (successors_instr ex |> List.filter (fun x -> P.lt x e))
  | None -> assert false

let add_to_tree c nt i =
  match PTree.get i c with
  | Some p -> PTree.set i p nt
  | None -> assert false

let schedule_fn (f : RTLBlock.coq_function) : RTLPar.coq_function =
  let scheduled = schedule f.fn_entrypoint f.fn_code in
  let reachable = find_reachable_states scheduled f.fn_entrypoint
                  |> List.to_seq |> SS.of_seq |> SS.to_seq |> List.of_seq in
  { fn_sig = f.fn_sig;
    fn_params = f.fn_params;
    fn_stacksize = f.fn_stacksize;
    fn_code = List.fold_left (add_to_tree scheduled) PTree.empty reachable;
    fn_entrypoint = f.fn_entrypoint
  }