open AST open RTL open Maps open InstructionScheduler open Registers open PrepassSchedulingOracleDeps let use_alias_analysis () = false let length_of_chunk = function | Mint8signed | Mint8unsigned -> 1 | Mint16signed | Mint16unsigned -> 2 | Mint32 | Mfloat32 | Many32 -> 4 | Mint64 | Mfloat64 | Many64 -> 8;; let get_simple_dependencies (opweights : opweights) (seqa : (instruction*Regset.t) array) = let last_reg_reads : int list PTree.t ref = ref PTree.empty and last_reg_write : (int*int) PTree.t ref = ref PTree.empty and last_mem_reads : int list ref = ref [] and last_mem_write : int option ref = ref None and last_branch : int option ref = ref None and last_non_pipelined_op : int array = Array.make opweights.nr_non_pipelined_units ( -1 ) and latency_constraints : latency_constraint list ref = ref [] in let add_constraint instr_from instr_to latency = assert (instr_from <= instr_to); assert (latency >= 0); if instr_from = instr_to then (if latency = 0 then () else failwith "PrepassSchedulingOracle.get_dependencies: negative self-loop") else latency_constraints := { instr_from = instr_from; instr_to = instr_to; latency = latency }:: !latency_constraints and get_last_reads reg = match PTree.get reg !last_reg_reads with Some l -> l | None -> [] in let add_input_mem i = if not (use_alias_analysis ()) then begin begin (* Read after write *) match !last_mem_write with | None -> () | Some j -> add_constraint j i 1 end; last_mem_reads := i :: !last_mem_reads end and add_output_mem i = if not (use_alias_analysis ()) then begin begin (* Write after write *) match !last_mem_write with | None -> () | Some j -> add_constraint j i 1 end; (* Write after read *) List.iter (fun j -> add_constraint j i 0) !last_mem_reads; last_mem_write := Some i; last_mem_reads := [] end and add_input_reg i reg = begin (* Read after write *) match PTree.get reg !last_reg_write with | None -> () | Some (j, latency) -> add_constraint j i latency end; last_reg_reads := PTree.set reg (i :: get_last_reads reg) !last_reg_reads and add_output_reg i latency reg = begin (* Write after write *) match PTree.get reg !last_reg_write with | None -> () | Some (j, _) -> add_constraint j i 1 end; begin (* Write after read *) List.iter (fun j -> add_constraint j i 0) (get_last_reads reg) end; last_reg_write := PTree.set reg (i, latency) !last_reg_write; last_reg_reads := PTree.remove reg !last_reg_reads in let add_input_regs i regs = List.iter (add_input_reg i) regs in let rec add_builtin_res i (res : reg builtin_res) = match res with | BR r -> add_output_reg i 10 r | BR_none -> () | BR_splitlong (hi, lo) -> add_builtin_res i hi; add_builtin_res i lo in let rec add_builtin_arg i (ba : reg builtin_arg) = match ba with | BA r -> add_input_reg i r | BA_int _ | BA_long _ | BA_float _ | BA_single _ -> () | BA_loadstack(_,_) -> add_input_mem i | BA_addrstack _ -> () | BA_loadglobal(_, _, _) -> add_input_mem i | BA_addrglobal _ -> () | BA_splitlong(hi, lo) -> add_builtin_arg i hi; add_builtin_arg i lo | BA_addptr(a1, a2) -> add_builtin_arg i a1; add_builtin_arg i a2 in let irreversible_action i = match !last_branch with | None -> () | Some j -> add_constraint j i 1 in let set_branch i = irreversible_action i; last_branch := Some i in let add_non_pipelined_resources i resources = Array.iter2 (fun latency last -> if latency >= 0 && last >= 0 then add_constraint last i latency) resources last_non_pipelined_op; Array.iteri (fun rsc latency -> if latency >= 0 then last_non_pipelined_op.(rsc) <- i) resources in Array.iteri begin fun i (insn, other_uses) -> List.iter (fun use -> add_input_reg i use) (Regset.elements other_uses); match insn with | Inop _ -> () | Iop(op, inputs, output, _) -> add_non_pipelined_resources i (opweights.non_pipelined_resources_of_op op (List.length inputs)); (if Op.is_trapping_op op then irreversible_action i); add_input_regs i inputs; add_output_reg i (opweights.latency_of_op op (List.length inputs)) output | Iload(trap, chunk, addressing, addr_regs, output, _) -> (if trap=TRAP then irreversible_action i); add_input_mem i; add_input_regs i addr_regs; add_output_reg i (opweights.latency_of_load trap chunk addressing (List.length addr_regs)) output | Istore(chunk, addressing, addr_regs, input, _) -> irreversible_action i; add_input_regs i addr_regs; add_input_reg i input; add_output_mem i | Icall(signature, ef, inputs, output, _) -> set_branch i; (match ef with | Datatypes.Coq_inl r -> add_input_reg i r | Datatypes.Coq_inr symbol -> () ); add_input_mem i; add_input_regs i inputs; add_output_reg i (opweights.latency_of_call signature ef) output; add_output_mem i; failwith "Icall" | Itailcall(signature, ef, inputs) -> set_branch i; (match ef with | Datatypes.Coq_inl r -> add_input_reg i r | Datatypes.Coq_inr symbol -> () ); add_input_mem i; add_input_regs i inputs; failwith "Itailcall" | Ibuiltin(ef, builtin_inputs, builtin_output, _) -> set_branch i; add_input_mem i; List.iter (add_builtin_arg i) builtin_inputs; add_builtin_res i builtin_output; add_output_mem i; failwith "Ibuiltin" | Icond(cond, inputs, _, _, _) -> set_branch i; add_input_mem i; add_input_regs i inputs | Ijumptable(input, _) -> set_branch i; add_input_reg i input; failwith "Ijumptable" | Ireturn(Some input) -> set_branch i; add_input_reg i input; failwith "Ireturn" | Ireturn(None) -> set_branch i; failwith "Ireturn none" end seqa; !latency_constraints;; let resources_of_instruction (opweights : opweights) = function | Inop _ -> Array.map (fun _ -> 0) opweights.pipelined_resource_bounds | Iop(op, inputs, output, _) -> opweights.resources_of_op op (List.length inputs) | Iload(trap, chunk, addressing, addr_regs, output, _) -> opweights.resources_of_load trap chunk addressing (List.length addr_regs) | Istore(chunk, addressing, addr_regs, input, _) -> opweights.resources_of_store chunk addressing (List.length addr_regs) | Icall(signature, ef, inputs, output, _) -> opweights.resources_of_call signature ef | Ibuiltin(ef, builtin_inputs, builtin_output, _) -> opweights.resources_of_builtin ef | Icond(cond, args, _, _ , _) -> opweights.resources_of_cond cond (List.length args) | Itailcall _ | Ijumptable _ | Ireturn _ -> opweights.pipelined_resource_bounds let print_sequence pp (seqa : instruction array) = Array.iteri ( fun i (insn : instruction) -> PrintRTL.print_instruction pp (i, insn)) seqa;; type unique_id = int type 'a symbolic_term_node = | STop of Op.operation * 'a list | STinitial_reg of int | STother of int;; type symbolic_term = { hash_id : unique_id; hash_ct : symbolic_term symbolic_term_node };; let rec print_term channel term = match term.hash_ct with | STop(op, args) -> PrintOp.print_operation print_term channel (op, args) | STinitial_reg n -> Printf.fprintf channel "x%d" n | STother n -> Printf.fprintf channel "y%d" n;; type symbolic_term_table = { st_table : (unique_id symbolic_term_node, symbolic_term) Hashtbl.t; mutable st_next_id : unique_id };; let hash_init () = { st_table = Hashtbl.create 20; st_next_id = 0 };; let ground_to_id = function | STop(op, l) -> STop(op, List.map (fun t -> t.hash_id) l) | STinitial_reg r -> STinitial_reg r | STother i -> STother i;; let hash_node (table : symbolic_term_table) (term : symbolic_term symbolic_term_node) : symbolic_term = let grounded = ground_to_id term in match Hashtbl.find_opt table.st_table grounded with | Some x -> x | None -> let term' = { hash_id = table.st_next_id; hash_ct = term } in (if table.st_next_id = max_int then failwith "hash: max_int"); table.st_next_id <- table.st_next_id + 1; Hashtbl.add table.st_table grounded term'; term';; type access = { base : symbolic_term; offset : int64; length : int };; let term_equal a b = (a.hash_id = b.hash_id);; let access_of_addressing get_reg chunk addressing args = match addressing, args with | (Op.Aindexed ofs), [reg] -> Some { base = get_reg reg; offset = Camlcoq.camlint64_of_ptrofs ofs; length = length_of_chunk chunk } | _, _ -> None ;; (* TODO: global *) let symbolic_execution (seqa : instruction array) = let regs = ref PTree.empty and table = hash_init() in let assign reg term = regs := PTree.set reg term !regs and hash term = hash_node table term in let get_reg reg = match PTree.get reg !regs with | None -> hash (STinitial_reg (Camlcoq.P.to_int reg)) | Some x -> x in let targets = Array.make (Array.length seqa) None in Array.iteri begin fun i insn -> match insn with | Iop(Op.Omove, [input], output, _) -> assign output (get_reg input) | Iop(op, inputs, output, _) -> assign output (hash (STop(op, List.map get_reg inputs))) | Iload(trap, chunk, addressing, args, output, _) -> let access = access_of_addressing get_reg chunk addressing args in targets.(i) <- access; assign output (hash (STother(i))) | Icall(_, _, _, output, _) | Ibuiltin(_, _, BR output, _) -> assign output (hash (STother(i))) | Istore(chunk, addressing, args, va, _) -> let access = access_of_addressing get_reg chunk addressing args in targets.(i) <- access | Inop _ -> () | Ibuiltin(_, _, BR_none, _) -> () | Ibuiltin(_, _, BR_splitlong _, _) -> failwith "BR_splitlong" | Itailcall (_, _, _) |Icond (_, _, _, _, _) |Ijumptable (_, _) |Ireturn _ -> () end seqa; targets;; let print_access channel = function | None -> Printf.fprintf channel "any" | Some x -> Printf.fprintf channel "%a + %Ld" print_term x.base x.offset;; let print_targets channel seqa = let targets = symbolic_execution seqa in Array.iteri (fun i insn -> match insn with | Iload _ -> Printf.fprintf channel "%d: load %a\n" i print_access targets.(i) | Istore _ -> Printf.fprintf channel "%d: store %a\n" i print_access targets.(i) | _ -> () ) seqa;; let may_overlap a0 b0 = match a0, b0 with | (None, _) | (_ , None) -> true | (Some a), (Some b) -> if term_equal a.base b.base then (max a.offset b.offset) < (min (Int64.add (Int64.of_int a.length) a.offset) (Int64.add (Int64.of_int b.length) b.offset)) else match a.base.hash_ct, b.base.hash_ct with | STop(Op.Oaddrsymbol(ida, ofsa),[]), STop(Op.Oaddrsymbol(idb, ofsb),[]) -> (ida=idb) && let ao = Int64.add a.offset (Camlcoq.camlint64_of_ptrofs ofsa) and bo = Int64.add b.offset (Camlcoq.camlint64_of_ptrofs ofsb) in (max ao bo) < (min (Int64.add (Int64.of_int a.length) ao) (Int64.add (Int64.of_int b.length) bo)) | STop(Op.Oaddrstack _, []), STop(Op.Oaddrsymbol _, []) | STop(Op.Oaddrsymbol _, []), STop(Op.Oaddrstack _, []) -> false | STop(Op.Oaddrstack(ofsa),[]), STop(Op.Oaddrstack(ofsb),[]) -> let ao = Int64.add a.offset (Camlcoq.camlint64_of_ptrofs ofsa) and bo = Int64.add b.offset (Camlcoq.camlint64_of_ptrofs ofsb) in (max ao bo) < (min (Int64.add (Int64.of_int a.length) ao) (Int64.add (Int64.of_int b.length) bo)) | _ -> true;; (* (* TODO suboptimal quadratic algorithm *) let get_alias_dependencies seqa = let targets = symbolic_execution seqa and deps = ref [] in let add_constraint instr_from instr_to latency = deps := { instr_from = instr_from; instr_to = instr_to; latency = latency }:: !deps in for i=0 to (Array.length seqa)-1 do for j=0 to i-1 do match seqa.(j), seqa.(i) with | (Istore _), ((Iload _) | (Istore _)) -> if may_overlap targets.(j) targets.(i) then add_constraint j i 1 | (Iload _), (Istore _) -> if may_overlap targets.(j) targets.(i) then add_constraint j i 0 | (Istore _ | Iload _), (Icall _ | Ibuiltin _) | (Icall _ | Ibuiltin _), (Icall _ | Ibuiltin _ | Iload _ | Istore _) -> add_constraint j i 1 | (Inop _ | Iop _), _ | _, (Inop _ | Iop _) | (Iload _), (Iload _) -> () done done; !deps;; *) let define_problem (opweights : opweights) (live_entry_regs : Regset.t) (typing : RTLtyping.regenv) reference_counting seqa = let simple_deps = get_simple_dependencies opweights seqa in { max_latency = -1; resource_bounds = opweights.pipelined_resource_bounds; live_regs_entry = live_entry_regs; typing = typing; reference_counting = Some reference_counting; instruction_usages = Array.map (resources_of_instruction opweights) (Array.map fst seqa); latency_constraints = (* if (use_alias_analysis ()) then (get_alias_dependencies seqa) @ simple_deps else *) simple_deps };; let zigzag_scheduler problem early_ones = let nr_instructions = get_nr_instructions problem in assert(nr_instructions = (Array.length early_ones)); match list_scheduler problem with | Some fwd_schedule -> let fwd_makespan = fwd_schedule.((Array.length fwd_schedule) - 1) in let constraints' = ref problem.latency_constraints in Array.iteri (fun i is_early -> if is_early then constraints' := { instr_from = i; instr_to = nr_instructions ; latency = fwd_makespan - fwd_schedule.(i) } ::!constraints' ) early_ones; validated_scheduler reverse_list_scheduler { problem with latency_constraints = !constraints' } | None -> None;; let prepass_scheduler_by_name name problem early_ones = match name with | "zigzag" -> zigzag_scheduler problem early_ones | _ -> scheduler_by_name name problem let schedule_sequence (seqa : (instruction*Regset.t) array) (live_regs_entry : Registers.Regset.t) (typing : RTLtyping.regenv) reference = let opweights = OpWeights.get_opweights () in try if (Array.length seqa) <= 1 then None else begin let nr_instructions = Array.length seqa in (if !Clflags.option_debug_compcert > 6 then Printf.printf "prepass scheduling length = %d\n" (Array.length seqa)); let problem = define_problem opweights live_regs_entry typing reference seqa in (if !Clflags.option_debug_compcert > 7 then (print_sequence stdout (Array.map fst seqa); print_problem stdout problem)); match prepass_scheduler_by_name (!Clflags.option_fprepass_sched) problem (Array.map (fun (ins, _) -> match ins with | Icond _ -> true | _ -> false) seqa) with | None -> Printf.printf "no solution in prepass scheduling\n"; None | Some solution -> let positions = Array.init nr_instructions (fun i -> i) in Array.sort (fun i j -> let si = solution.(i) and sj = solution.(j) in if si < sj then -1 else if si > sj then 1 else i - j) positions; Some positions end with (Failure s) -> Printf.printf "failure in prepass scheduling: %s\n" s; None;;