aboutsummaryrefslogtreecommitdiffstats
path: root/aarch64/PrepassSchedulingOracle.ml
blob: d7e80cd93b401b40e64fcdc4039c32854ae78025 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
(*             The Compcert verified compiler                  *)
(*                                                             *)
(*           Sylvain Boulmé     Grenoble-INP, VERIMAG          *)
(*           David Monniaux     CNRS, VERIMAG                  *)
(*           Cyril Six          Kalray                         *)
(*           Léo Gourdin        UGA, VERIMAG                   *)
(*           Nicolas Nardino    ENS-Lyon, VERIMAG              *)
(*                                                             *)
(*                                                             *)
(* *************************************************************)

open AST
open BTL
open Maps
open InstructionScheduler
open Registers
open PrepassSchedulingOracleDeps
open PrintBTL
open DebugPrint

let use_alias_analysis () = false

let build_constraints_and_resources (opweights : opweights) seqa btl =
  let last_reg_reads : int list PTree.t ref = ref PTree.empty
  and last_reg_write : (int * int) PTree.t ref = ref PTree.empty
  and last_mem_reads : int list ref = ref []
  and last_mem_write : int option ref = ref None
  and last_branch : int option ref = ref None
  and last_non_pipelined_op : int array =
    Array.make opweights.nr_non_pipelined_units (-1)
  and latency_constraints : latency_constraint list ref = ref [] in
  let add_constraint instr_from instr_to latency =
    assert (instr_from <= instr_to);
    assert (latency >= 0);
    if instr_from = instr_to then
      if latency = 0 then ()
      else
        failwith "PrepassSchedulingOracle.get_dependencies: negative self-loop"
    else
      latency_constraints :=
        { instr_from; instr_to; latency } :: !latency_constraints
  and get_last_reads reg =
    match PTree.get reg !last_reg_reads with Some l -> l | None -> []
  in
  let add_input_mem i =
    if not (use_alias_analysis ()) then (
      (* Read after write *)
      (match !last_mem_write with None -> () | Some j -> add_constraint j i 1);
      last_mem_reads := i :: !last_mem_reads)
  and add_output_mem i =
    if not (use_alias_analysis ()) then (
      (* Write after write *)
      (match !last_mem_write with None -> () | Some j -> add_constraint j i 1);
      (* Write after read *)
      List.iter (fun j -> add_constraint j i 0) !last_mem_reads;
      last_mem_write := Some i;
      last_mem_reads := [])
  and add_input_reg i reg =
    (* Read after write *)
    (match PTree.get reg !last_reg_write with
    | None -> ()
    | Some (j, latency) -> add_constraint j i latency);
    last_reg_reads := PTree.set reg (i :: get_last_reads reg) !last_reg_reads
  and add_output_reg i latency reg =
    (* Write after write *)
    (match PTree.get reg !last_reg_write with
    | None -> ()
    | Some (j, _) -> add_constraint j i 1);
    (* Write after read *)
    List.iter (fun j -> add_constraint j i 0) (get_last_reads reg);
    last_reg_write := PTree.set reg (i, latency) !last_reg_write;
    last_reg_reads := PTree.remove reg !last_reg_reads
  in
  let add_input_regs i regs = List.iter (add_input_reg i) regs in
  let rec add_builtin_res i (res : reg builtin_res) =
    match res with
    | BR r -> add_output_reg i 10 r
    | BR_none -> ()
    | BR_splitlong (hi, lo) ->
        add_builtin_res i hi;
        add_builtin_res i lo
  in
  let rec add_builtin_arg i (ba : reg builtin_arg) =
    match ba with
    | BA r -> add_input_reg i r
    | BA_int _ | BA_long _ | BA_float _ | BA_single _ -> ()
    | BA_loadstack (_, _) -> add_input_mem i
    | BA_addrstack _ -> ()
    | BA_loadglobal (_, _, _) -> add_input_mem i
    | BA_addrglobal _ -> ()
    | BA_splitlong (hi, lo) ->
        add_builtin_arg i hi;
        add_builtin_arg i lo
    | BA_addptr (a1, a2) ->
        add_builtin_arg i a1;
        add_builtin_arg i a2
  and irreversible_action i =
    match !last_branch with None -> () | Some j -> add_constraint j i 1
  in
  let set_branch i =
    irreversible_action i;
    last_branch := Some i
  and add_non_pipelined_resources i resources =
    Array.iter2
      (fun latency last ->
        if latency >= 0 && last >= 0 then add_constraint last i latency)
      resources last_non_pipelined_op;
    Array.iteri
      (fun rsc latency -> if latency >= 0 then last_non_pipelined_op.(rsc) <- i)
      resources
  in
  Array.iteri
    (fun i (inst, other_uses) ->
      List.iter (fun use -> add_input_reg i use) (Regset.elements other_uses);
      match inst with
      | Bnop _ -> ()
      | Bop (op, lr, rd, _) ->
          add_non_pipelined_resources i
            (opweights.non_pipelined_resources_of_op op (List.length lr));
          if Op.is_trapping_op op then irreversible_action i;
          add_input_regs i lr;
          add_output_reg i (opweights.latency_of_op op (List.length lr)) rd
      | Bload (trap, chk, addr, lr, rd, _) ->
          if trap = TRAP then irreversible_action i;
          add_input_mem i;
          add_input_regs i lr;
          add_output_reg i
            (opweights.latency_of_load trap chk addr (List.length lr))
            rd
      | Bstore (chk, addr, lr, src, _) ->
          irreversible_action i;
          add_input_regs i lr;
          add_input_reg i src;
          add_output_mem i
      | Bcond (cond, lr, BF (Bgoto s, _), ibnot, _) ->
          set_branch i;
          add_input_mem i;
          add_input_regs i lr
      | Bcond (_, _, _, _, _) ->
          failwith "build_constraints_and_resources: invalid Bcond"
      | BF (Bcall (signature, ef, lr, rd, _), _) ->
          set_branch i;
          (match ef with
          | Datatypes.Coq_inl r -> add_input_reg i r
          | Datatypes.Coq_inr symbol -> ());
          add_input_mem i;
          add_input_regs i lr;
          add_output_reg i (opweights.latency_of_call signature ef) rd;
          add_output_mem i;
          failwith "build_constraints_and_resources: invalid Bcall"
      | BF (Btailcall (signature, ef, lr), _) ->
          set_branch i;
          (match ef with
          | Datatypes.Coq_inl r -> add_input_reg i r
          | Datatypes.Coq_inr symbol -> ());
          add_input_mem i;
          add_input_regs i lr;
          failwith "build_constraints_and_resources: invalid Btailcall"
      | BF (Bbuiltin (ef, lr, rd, _), _) ->
          set_branch i;
          add_input_mem i;
          List.iter (add_builtin_arg i) lr;
          add_builtin_res i rd;
          add_output_mem i;
          failwith "build_constraints_and_resources: invalid Bbuiltin"
      | BF (Bjumptable (lr, _), _) ->
          set_branch i;
          add_input_reg i lr;
          failwith "build_constraints_and_resources: invalid Bjumptable"
      | BF (Breturn (Some r), _) ->
          set_branch i;
          add_input_reg i r;
          failwith "build_constraints_and_resources: invalid Breturn Some"
      | BF (Breturn None, _) ->
          set_branch i;
          failwith "build_constraints_and_resources: invalid Breturn None"
      | BF (Bgoto _, _) ->
          failwith "build_constraints_and_resources: invalid Bgoto"
      | Bseq (_, _) -> failwith "build_constraints_and_resources: Bseq")
    seqa;
  !latency_constraints

let resources_of_instruction (opweights : opweights) = function
  | Bnop _ -> Array.map (fun _ -> 0) opweights.pipelined_resource_bounds
  | Bop (op, inputs, output, _) ->
      opweights.resources_of_op op (List.length inputs)
  | Bload (trap, chunk, addressing, addr_regs, output, _) ->
      opweights.resources_of_load trap chunk addressing (List.length addr_regs)
  | Bstore (chunk, addressing, addr_regs, input, _) ->
      opweights.resources_of_store chunk addressing (List.length addr_regs)
  | BF (Bcall (signature, ef, inputs, output, _), _) ->
      opweights.resources_of_call signature ef
  | BF (Bbuiltin (ef, builtin_inputs, builtin_output, _), _) ->
      opweights.resources_of_builtin ef
  | Bcond (cond, args, _, _, _) ->
      opweights.resources_of_cond cond (List.length args)
  | BF (Btailcall _, _) | BF (Bjumptable _, _) | BF (Breturn _, _) ->
      opweights.pipelined_resource_bounds
  | BF (Bgoto _, _) | Bseq (_, _) ->
      failwith "resources_of_instruction: invalid btl instruction"

let print_sequence pp seqa =
  Array.iteri
    (fun i (inst, other_uses) ->
      debug "i=%d\n inst = " i;
      print_btl_inst pp inst;
      debug "\n other_uses=";
      print_regset other_uses;
      debug "\n")
    seqa

let length_of_chunk = function
  | Mint8signed | Mint8unsigned -> 1
  | Mint16signed | Mint16unsigned -> 2
  | Mint32 | Mfloat32 | Many32 -> 4
  | Mint64 | Mfloat64 | Many64 -> 8

let define_problem (opweights : opweights) (live_entry_regs : Regset.t)
    (typing : RTLtyping.regenv) reference_counting seqa btl =
  let simple_deps = build_constraints_and_resources opweights seqa btl in
  {
    max_latency = -1;
    resource_bounds = opweights.pipelined_resource_bounds;
    live_regs_entry = live_entry_regs;
    typing;
    reference_counting = Some reference_counting;
    instruction_usages =
      Array.map (resources_of_instruction opweights) (Array.map fst seqa);
    latency_constraints = simple_deps;
  }

let zigzag_scheduler problem early_ones =
  let nr_instructions = get_nr_instructions problem in
  assert (nr_instructions = Array.length early_ones);
  match list_scheduler problem with
  | Some fwd_schedule ->
      let fwd_makespan = fwd_schedule.(Array.length fwd_schedule - 1) in
      let constraints' = ref problem.latency_constraints in
      Array.iteri
        (fun i is_early ->
          if is_early then
            constraints' :=
              {
                instr_from = i;
                instr_to = nr_instructions;
                latency = fwd_makespan - fwd_schedule.(i);
              }
              :: !constraints')
        early_ones;
      validated_scheduler reverse_list_scheduler
        { problem with latency_constraints = !constraints' }
  | None -> None

let prepass_scheduler_by_name name problem seqa =
  match name with
  | "zigzag" ->
      let early_ones =
        Array.map
          (fun (inst, _) ->
            match inst with Bcond (_, _, _, _, _) -> true | _ -> false)
          seqa
      in
      zigzag_scheduler problem early_ones
  | _ -> scheduler_by_name name problem

let schedule_sequence seqa btl (live_regs_entry : Registers.Regset.t)
    (typing : RTLtyping.regenv) reference =
  let opweights = OpWeights.get_opweights () in
  try
    if Array.length seqa <= 1 then None
    else
      let nr_instructions = Array.length seqa in
      if !Clflags.option_debug_compcert > 6 then
        Printf.printf "prepass scheduling length = %d\n" nr_instructions;
      let problem =
        define_problem opweights live_regs_entry typing reference seqa btl
      in
      if !Clflags.option_debug_compcert > 7 then (
        print_sequence stdout seqa;
        print_problem stdout problem);
      match
        prepass_scheduler_by_name !Clflags.option_fprepass_sched problem seqa
      with
      | None ->
          Printf.printf "no solution in prepass scheduling\n";
          None
      | Some solution ->
          let positions = Array.init nr_instructions (fun i -> i) in
          Array.sort
            (fun i j ->
              let si = solution.(i) and sj = solution.(j) in
              if si < sj then -1 else if si > sj then 1 else i - j)
            positions;
          Some positions
  with Failure s ->
    Printf.printf "failure in prepass scheduling: %s\n" s;
    None