aboutsummaryrefslogtreecommitdiffstats
path: root/aarch64/PrepassSchedulingOracle.ml
diff options
context:
space:
mode:
Diffstat (limited to 'aarch64/PrepassSchedulingOracle.ml')
-rw-r--r--aarch64/PrepassSchedulingOracle.ml297
1 files changed, 297 insertions, 0 deletions
diff --git a/aarch64/PrepassSchedulingOracle.ml b/aarch64/PrepassSchedulingOracle.ml
new file mode 100644
index 00000000..d7e80cd9
--- /dev/null
+++ b/aarch64/PrepassSchedulingOracle.ml
@@ -0,0 +1,297 @@
+(* The Compcert verified compiler *)
+(* *)
+(* Sylvain Boulmé Grenoble-INP, VERIMAG *)
+(* David Monniaux CNRS, VERIMAG *)
+(* Cyril Six Kalray *)
+(* Léo Gourdin UGA, VERIMAG *)
+(* Nicolas Nardino ENS-Lyon, VERIMAG *)
+(* *)
+(* *)
+(* *************************************************************)
+
+open AST
+open BTL
+open Maps
+open InstructionScheduler
+open Registers
+open PrepassSchedulingOracleDeps
+open PrintBTL
+open DebugPrint
+
+let use_alias_analysis () = false
+
+let build_constraints_and_resources (opweights : opweights) seqa btl =
+ let last_reg_reads : int list PTree.t ref = ref PTree.empty
+ and last_reg_write : (int * int) PTree.t ref = ref PTree.empty
+ and last_mem_reads : int list ref = ref []
+ and last_mem_write : int option ref = ref None
+ and last_branch : int option ref = ref None
+ and last_non_pipelined_op : int array =
+ Array.make opweights.nr_non_pipelined_units (-1)
+ and latency_constraints : latency_constraint list ref = ref [] in
+ let add_constraint instr_from instr_to latency =
+ assert (instr_from <= instr_to);
+ assert (latency >= 0);
+ if instr_from = instr_to then
+ if latency = 0 then ()
+ else
+ failwith "PrepassSchedulingOracle.get_dependencies: negative self-loop"
+ else
+ latency_constraints :=
+ { instr_from; instr_to; latency } :: !latency_constraints
+ and get_last_reads reg =
+ match PTree.get reg !last_reg_reads with Some l -> l | None -> []
+ in
+ let add_input_mem i =
+ if not (use_alias_analysis ()) then (
+ (* Read after write *)
+ (match !last_mem_write with None -> () | Some j -> add_constraint j i 1);
+ last_mem_reads := i :: !last_mem_reads)
+ and add_output_mem i =
+ if not (use_alias_analysis ()) then (
+ (* Write after write *)
+ (match !last_mem_write with None -> () | Some j -> add_constraint j i 1);
+ (* Write after read *)
+ List.iter (fun j -> add_constraint j i 0) !last_mem_reads;
+ last_mem_write := Some i;
+ last_mem_reads := [])
+ and add_input_reg i reg =
+ (* Read after write *)
+ (match PTree.get reg !last_reg_write with
+ | None -> ()
+ | Some (j, latency) -> add_constraint j i latency);
+ last_reg_reads := PTree.set reg (i :: get_last_reads reg) !last_reg_reads
+ and add_output_reg i latency reg =
+ (* Write after write *)
+ (match PTree.get reg !last_reg_write with
+ | None -> ()
+ | Some (j, _) -> add_constraint j i 1);
+ (* Write after read *)
+ List.iter (fun j -> add_constraint j i 0) (get_last_reads reg);
+ last_reg_write := PTree.set reg (i, latency) !last_reg_write;
+ last_reg_reads := PTree.remove reg !last_reg_reads
+ in
+ let add_input_regs i regs = List.iter (add_input_reg i) regs in
+ let rec add_builtin_res i (res : reg builtin_res) =
+ match res with
+ | BR r -> add_output_reg i 10 r
+ | BR_none -> ()
+ | BR_splitlong (hi, lo) ->
+ add_builtin_res i hi;
+ add_builtin_res i lo
+ in
+ let rec add_builtin_arg i (ba : reg builtin_arg) =
+ match ba with
+ | BA r -> add_input_reg i r
+ | BA_int _ | BA_long _ | BA_float _ | BA_single _ -> ()
+ | BA_loadstack (_, _) -> add_input_mem i
+ | BA_addrstack _ -> ()
+ | BA_loadglobal (_, _, _) -> add_input_mem i
+ | BA_addrglobal _ -> ()
+ | BA_splitlong (hi, lo) ->
+ add_builtin_arg i hi;
+ add_builtin_arg i lo
+ | BA_addptr (a1, a2) ->
+ add_builtin_arg i a1;
+ add_builtin_arg i a2
+ and irreversible_action i =
+ match !last_branch with None -> () | Some j -> add_constraint j i 1
+ in
+ let set_branch i =
+ irreversible_action i;
+ last_branch := Some i
+ and add_non_pipelined_resources i resources =
+ Array.iter2
+ (fun latency last ->
+ if latency >= 0 && last >= 0 then add_constraint last i latency)
+ resources last_non_pipelined_op;
+ Array.iteri
+ (fun rsc latency -> if latency >= 0 then last_non_pipelined_op.(rsc) <- i)
+ resources
+ in
+ Array.iteri
+ (fun i (inst, other_uses) ->
+ List.iter (fun use -> add_input_reg i use) (Regset.elements other_uses);
+ match inst with
+ | Bnop _ -> ()
+ | Bop (op, lr, rd, _) ->
+ add_non_pipelined_resources i
+ (opweights.non_pipelined_resources_of_op op (List.length lr));
+ if Op.is_trapping_op op then irreversible_action i;
+ add_input_regs i lr;
+ add_output_reg i (opweights.latency_of_op op (List.length lr)) rd
+ | Bload (trap, chk, addr, lr, rd, _) ->
+ if trap = TRAP then irreversible_action i;
+ add_input_mem i;
+ add_input_regs i lr;
+ add_output_reg i
+ (opweights.latency_of_load trap chk addr (List.length lr))
+ rd
+ | Bstore (chk, addr, lr, src, _) ->
+ irreversible_action i;
+ add_input_regs i lr;
+ add_input_reg i src;
+ add_output_mem i
+ | Bcond (cond, lr, BF (Bgoto s, _), ibnot, _) ->
+ set_branch i;
+ add_input_mem i;
+ add_input_regs i lr
+ | Bcond (_, _, _, _, _) ->
+ failwith "build_constraints_and_resources: invalid Bcond"
+ | BF (Bcall (signature, ef, lr, rd, _), _) ->
+ set_branch i;
+ (match ef with
+ | Datatypes.Coq_inl r -> add_input_reg i r
+ | Datatypes.Coq_inr symbol -> ());
+ add_input_mem i;
+ add_input_regs i lr;
+ add_output_reg i (opweights.latency_of_call signature ef) rd;
+ add_output_mem i;
+ failwith "build_constraints_and_resources: invalid Bcall"
+ | BF (Btailcall (signature, ef, lr), _) ->
+ set_branch i;
+ (match ef with
+ | Datatypes.Coq_inl r -> add_input_reg i r
+ | Datatypes.Coq_inr symbol -> ());
+ add_input_mem i;
+ add_input_regs i lr;
+ failwith "build_constraints_and_resources: invalid Btailcall"
+ | BF (Bbuiltin (ef, lr, rd, _), _) ->
+ set_branch i;
+ add_input_mem i;
+ List.iter (add_builtin_arg i) lr;
+ add_builtin_res i rd;
+ add_output_mem i;
+ failwith "build_constraints_and_resources: invalid Bbuiltin"
+ | BF (Bjumptable (lr, _), _) ->
+ set_branch i;
+ add_input_reg i lr;
+ failwith "build_constraints_and_resources: invalid Bjumptable"
+ | BF (Breturn (Some r), _) ->
+ set_branch i;
+ add_input_reg i r;
+ failwith "build_constraints_and_resources: invalid Breturn Some"
+ | BF (Breturn None, _) ->
+ set_branch i;
+ failwith "build_constraints_and_resources: invalid Breturn None"
+ | BF (Bgoto _, _) ->
+ failwith "build_constraints_and_resources: invalid Bgoto"
+ | Bseq (_, _) -> failwith "build_constraints_and_resources: Bseq")
+ seqa;
+ !latency_constraints
+
+let resources_of_instruction (opweights : opweights) = function
+ | Bnop _ -> Array.map (fun _ -> 0) opweights.pipelined_resource_bounds
+ | Bop (op, inputs, output, _) ->
+ opweights.resources_of_op op (List.length inputs)
+ | Bload (trap, chunk, addressing, addr_regs, output, _) ->
+ opweights.resources_of_load trap chunk addressing (List.length addr_regs)
+ | Bstore (chunk, addressing, addr_regs, input, _) ->
+ opweights.resources_of_store chunk addressing (List.length addr_regs)
+ | BF (Bcall (signature, ef, inputs, output, _), _) ->
+ opweights.resources_of_call signature ef
+ | BF (Bbuiltin (ef, builtin_inputs, builtin_output, _), _) ->
+ opweights.resources_of_builtin ef
+ | Bcond (cond, args, _, _, _) ->
+ opweights.resources_of_cond cond (List.length args)
+ | BF (Btailcall _, _) | BF (Bjumptable _, _) | BF (Breturn _, _) ->
+ opweights.pipelined_resource_bounds
+ | BF (Bgoto _, _) | Bseq (_, _) ->
+ failwith "resources_of_instruction: invalid btl instruction"
+
+let print_sequence pp seqa =
+ Array.iteri
+ (fun i (inst, other_uses) ->
+ debug "i=%d\n inst = " i;
+ print_btl_inst pp inst;
+ debug "\n other_uses=";
+ print_regset other_uses;
+ debug "\n")
+ seqa
+
+let length_of_chunk = function
+ | Mint8signed | Mint8unsigned -> 1
+ | Mint16signed | Mint16unsigned -> 2
+ | Mint32 | Mfloat32 | Many32 -> 4
+ | Mint64 | Mfloat64 | Many64 -> 8
+
+let define_problem (opweights : opweights) (live_entry_regs : Regset.t)
+ (typing : RTLtyping.regenv) reference_counting seqa btl =
+ let simple_deps = build_constraints_and_resources opweights seqa btl in
+ {
+ max_latency = -1;
+ resource_bounds = opweights.pipelined_resource_bounds;
+ live_regs_entry = live_entry_regs;
+ typing;
+ reference_counting = Some reference_counting;
+ instruction_usages =
+ Array.map (resources_of_instruction opweights) (Array.map fst seqa);
+ latency_constraints = simple_deps;
+ }
+
+let zigzag_scheduler problem early_ones =
+ let nr_instructions = get_nr_instructions problem in
+ assert (nr_instructions = Array.length early_ones);
+ match list_scheduler problem with
+ | Some fwd_schedule ->
+ let fwd_makespan = fwd_schedule.(Array.length fwd_schedule - 1) in
+ let constraints' = ref problem.latency_constraints in
+ Array.iteri
+ (fun i is_early ->
+ if is_early then
+ constraints' :=
+ {
+ instr_from = i;
+ instr_to = nr_instructions;
+ latency = fwd_makespan - fwd_schedule.(i);
+ }
+ :: !constraints')
+ early_ones;
+ validated_scheduler reverse_list_scheduler
+ { problem with latency_constraints = !constraints' }
+ | None -> None
+
+let prepass_scheduler_by_name name problem seqa =
+ match name with
+ | "zigzag" ->
+ let early_ones =
+ Array.map
+ (fun (inst, _) ->
+ match inst with Bcond (_, _, _, _, _) -> true | _ -> false)
+ seqa
+ in
+ zigzag_scheduler problem early_ones
+ | _ -> scheduler_by_name name problem
+
+let schedule_sequence seqa btl (live_regs_entry : Registers.Regset.t)
+ (typing : RTLtyping.regenv) reference =
+ let opweights = OpWeights.get_opweights () in
+ try
+ if Array.length seqa <= 1 then None
+ else
+ let nr_instructions = Array.length seqa in
+ if !Clflags.option_debug_compcert > 6 then
+ Printf.printf "prepass scheduling length = %d\n" nr_instructions;
+ let problem =
+ define_problem opweights live_regs_entry typing reference seqa btl
+ in
+ if !Clflags.option_debug_compcert > 7 then (
+ print_sequence stdout seqa;
+ print_problem stdout problem);
+ match
+ prepass_scheduler_by_name !Clflags.option_fprepass_sched problem seqa
+ with
+ | None ->
+ Printf.printf "no solution in prepass scheduling\n";
+ None
+ | Some solution ->
+ let positions = Array.init nr_instructions (fun i -> i) in
+ Array.sort
+ (fun i j ->
+ let si = solution.(i) and sj = solution.(j) in
+ if si < sj then -1 else if si > sj then 1 else i - j)
+ positions;
+ Some positions
+ with Failure s ->
+ Printf.printf "failure in prepass scheduling: %s\n" s;
+ None