diff options
Diffstat (limited to 'backend/Linearizeaux.ml')
-rw-r--r-- | backend/Linearizeaux.ml | 409 |
1 files changed, 405 insertions, 4 deletions
diff --git a/backend/Linearizeaux.ml b/backend/Linearizeaux.ml index 902724e0..bfa056ca 100644 --- a/backend/Linearizeaux.ml +++ b/backend/Linearizeaux.ml @@ -1,4 +1,4 @@ -(* *********************************************************************) + (* *) (* The Compcert verified compiler *) (* *) @@ -12,7 +12,6 @@ open LTL open Maps -open Camlcoq (* Trivial enumeration, in decreasing order of PC *) @@ -29,6 +28,8 @@ let enumerate_aux f reach = (* More clever enumeration that flattens basic blocks *) +open Camlcoq + module IntSet = Set.Make(struct type t = int let compare = compare end) (* Determine join points: reachable nodes that have > 1 predecessor *) @@ -80,7 +81,7 @@ let basic_blocks f joins = | [] -> assert false | Lbranch s :: _ -> next_in_block blk minpc s | Ltailcall (sig0, ros) :: _ -> end_block blk minpc - | Lcond (cond, args, ifso, ifnot) :: _ -> + | Lcond (cond, args, ifso, ifnot, _) :: _ -> end_block blk minpc; start_block ifso; start_block ifnot | Ljumptable(arg, tbl) :: _ -> end_block blk minpc; List.iter start_block tbl @@ -110,5 +111,405 @@ let flatten_blocks blks = (* Build the enumeration *) -let enumerate_aux f reach = +let enumerate_aux_flat f reach = flatten_blocks (basic_blocks f (join_points f)) + +(** + * Enumeration based on traces as identified by Duplicate.v + * + * The Duplicate phase heuristically identifies the most frequented paths. Each + * Icond is modified so that the preferred condition is a fallthrough (ifnot) + * rather than a branch (ifso). + * + * The enumeration below takes advantage of this - preferring to layout nodes + * following the fallthroughs of the Lcond branches. + * + * It is slightly adapted from the work of Petris and Hansen 90 on intraprocedural + * code positioning - only we do it on a broader grain, since we don't have the exact + * frequencies (we only know which branch is the preferred one) + *) + +let get_some = function +| None -> failwith "Did not get some" +| Some thing -> thing + +exception EmptyList + +let rec last_element = function + | [] -> raise EmptyList + | e :: [] -> e + | e' :: e :: l -> last_element (e::l) + +let print_plist l = + let rec f = function + | [] -> () + | n :: l -> Printf.printf "%d, " (P.to_int n); f l + in begin + Printf.printf "["; + f l; + Printf.printf "]" + end + +(* adapted from the above join_points function, but with PTree *) +let get_join_points code entry = + let reached = ref (PTree.map (fun n i -> false) code) in + let reached_twice = ref (PTree.map (fun n i -> false) code) in + let rec traverse pc = + if get_some @@ PTree.get pc !reached then begin + if not (get_some @@ PTree.get pc !reached_twice) then + reached_twice := PTree.set pc true !reached_twice + end else begin + reached := PTree.set pc true !reached; + traverse_succs (successors_block @@ get_some @@ PTree.get pc code) + end + and traverse_succs = function + | [] -> () + | [pc] -> traverse pc + | pc :: l -> traverse pc; traverse_succs l + in traverse entry; !reached_twice + +let forward_sequences code entry = + let visited = ref (PTree.map (fun n i -> false) code) in + let join_points = get_join_points code entry in + (* returns the list of traversed nodes, and a list of nodes to start traversing next *) + let rec traverse_fallthrough code node = + (* Printf.printf "Traversing %d..\n" (P.to_int node); *) + if not (get_some @@ PTree.get node !visited) then begin + visited := PTree.set node true !visited; + match PTree.get node code with + | None -> failwith "No such node" + | Some bb -> + let ln, rem = match (last_element bb) with + | Lop _ | Lload _ | Lgetstack _ | Lsetstack _ | Lstore _ | Lcall _ + | Lbuiltin _ -> assert false + | Ltailcall _ | Lreturn -> begin (* Printf.printf "STOP tailcall/return\n"; *) ([], []) end + | Lbranch n -> + if get_some @@ PTree.get n join_points then ([], [n]) + else let ln, rem = traverse_fallthrough code n in (ln, rem) + | Lcond (_, _, ifso, ifnot, info) -> (match info with + | None -> begin (* Printf.printf "STOP Lcond None\n"; *) ([], [ifso; ifnot]) end + | Some false -> + if get_some @@ PTree.get ifnot join_points then ([], [ifso; ifnot]) + else let ln, rem = traverse_fallthrough code ifnot in (ln, [ifso] @ rem) + | Some true -> + let errstr = Printf.sprintf ("Inconsistency detected in node %d: ifnot is not the preferred branch") (P.to_int node) in + failwith errstr) + | Ljumptable(_, ln) -> begin (* Printf.printf "STOP Ljumptable\n"; *) ([], ln) end + in ([node] @ ln, rem) + end + else ([], []) + in let rec f code = function + | [] -> [] + | node :: ln -> + let fs, rem_from_node = traverse_fallthrough code node + in [fs] @ ((f code rem_from_node) @ (f code ln)) + in (f code [entry]) + +(** Unused code +module PInt = struct + type t = P.t + let compare x y = compare (P.to_int x) (P.to_int y) +end + +module PSet = Set.Make(PInt) + +module LPInt = struct + type t = P.t list + let rec compare x y = + match x with + | [] -> ( match y with + | [] -> 0 + | _ -> 1 ) + | e :: l -> match y with + | [] -> -1 + | e' :: l' -> + let e_cmp = PInt.compare e e' in + if e_cmp == 0 then compare l l' else e_cmp +end + +module LPSet = Set.Make(LPInt) + +let iter_lpset f s = Seq.iter f (LPSet.to_seq s) + +let first_of = function + | [] -> None + | e :: l -> Some e + +let rec last_of = function + | [] -> None + | e :: l -> (match l with [] -> Some e | e :: l -> last_of l) + +let can_be_merged code s s' = + let last_s = get_some @@ last_of s in + let first_s' = get_some @@ first_of s' in + match get_some @@ PTree.get last_s code with + | Lop _ | Lload _ | Lgetstack _ | Lsetstack _ | Lstore _ | Lcall _ + | Lbuiltin _ | Ltailcall _ | Lreturn -> false + | Lbranch n -> n == first_s' + | Lcond (_, _, ifso, ifnot, info) -> (match info with + | None -> false + | Some false -> ifnot == first_s' + | Some true -> failwith "Inconsistency detected - ifnot is not the preferred branch") + | Ljumptable (_, ln) -> + match ln with + | [] -> false + | n :: ln -> n == first_s' + +let merge s s' = Some s + +let try_merge code (fs: (BinNums.positive list) list) = + let seqs = ref (LPSet.of_list fs) in + let oldLength = ref (LPSet.cardinal !seqs) in + let continue = ref true in + let found = ref false in + while !continue do + begin + found := false; + iter_lpset (fun s -> + if !found then () + else iter_lpset (fun s' -> + if (!found || s == s') then () + else if (can_be_merged code s s') then + begin + seqs := LPSet.remove s !seqs; + seqs := LPSet.remove s' !seqs; + seqs := LPSet.add (get_some (merge s s')) !seqs; + found := true; + end + else () + ) !seqs + ) !seqs; + if !oldLength == LPSet.cardinal !seqs then + continue := false + else + oldLength := LPSet.cardinal !seqs + end + done; + !seqs +*) + +(** Code adapted from Duplicateaux.get_loop_headers + * + * Getting loop branches with a DFS visit : + * Each node is either Unvisited, Visited, or Processed + * pre-order: node becomes Processed + * post-order: node becomes Visited + * + * If we come accross an edge to a Processed node, it's a loop! + *) +type pos = BinNums.positive + +module PP = struct + type t = pos * pos + let compare a b = + let ax, ay = a in + let bx, by = b in + let dx = compare ax bx in + if (dx == 0) then compare ay by + else dx +end + +module PPMap = Map.Make(PP) + +type vstate = Unvisited | Processed | Visited + +let get_loop_edges code entry = + let visited = ref (PTree.map (fun n i -> Unvisited) code) in + let is_loop_edge = ref PPMap.empty + in let rec dfs_visit code from = function + | [] -> () + | node :: ln -> + match (get_some @@ PTree.get node !visited) with + | Visited -> () + | Processed -> begin + let from_node = get_some from in + is_loop_edge := PPMap.add (from_node, node) true !is_loop_edge; + visited := PTree.set node Visited !visited + end + | Unvisited -> begin + visited := PTree.set node Processed !visited; + let bb = get_some @@ PTree.get node code in + let next_visits = (match (last_element bb) with + | Lop _ | Lload _ | Lgetstack _ | Lsetstack _ | Lstore _ | Lcall _ + | Lbuiltin _ -> assert false + | Ltailcall _ | Lreturn -> [] + | Lbranch n -> [n] + | Lcond (_, _, ifso, ifnot, _) -> [ifso; ifnot] + | Ljumptable(_, ln) -> ln + ) in dfs_visit code (Some node) next_visits; + visited := PTree.set node Visited !visited; + dfs_visit code from ln + end + in begin + dfs_visit code None [entry]; + !is_loop_edge + end + +let ppmap_is_true pp ppmap = PPMap.mem pp ppmap && PPMap.find pp ppmap + +module Int = struct + type t = int + let compare x y = compare x y +end + +module ISet = Set.Make(Int) + +let print_iset s = begin + Printf.printf "{"; + ISet.iter (fun e -> Printf.printf "%d, " e) s; + Printf.printf "}" +end + +let print_depmap dm = begin + Printf.printf "[|"; + Array.iter (fun s -> print_iset s; Printf.printf ", ") dm; + Printf.printf "|]\n" +end + +let construct_depmap code entry fs = + let is_loop_edge = get_loop_edges code entry in + let visited = ref (PTree.map (fun n i -> false) code) in + let depmap = Array.map (fun e -> ISet.empty) fs in + let find_index_of_node n = + let index = ref 0 in + begin + Array.iteri (fun i s -> + match List.find_opt (fun e -> e == n) s with + | Some _ -> index := i + | None -> () + ) fs; + !index + end + in let check_and_update_depmap from target = + (* Printf.printf "From %d to %d\n" (P.to_int from) (P.to_int target); *) + if not (ppmap_is_true (from, target) is_loop_edge) then + let in_index_fs = find_index_of_node from in + let out_index_fs = find_index_of_node target in + if out_index_fs != in_index_fs then + depmap.(out_index_fs) <- ISet.add in_index_fs depmap.(out_index_fs) + else () + else () + in let rec dfs_visit code = function + | [] -> () + | node :: ln -> + begin + match (get_some @@ PTree.get node !visited) with + | true -> () + | false -> begin + visited := PTree.set node true !visited; + let bb = get_some @@ PTree.get node code in + let next_visits = + match (last_element bb) with + | Ltailcall _ | Lreturn -> [] + | Lbranch n -> (check_and_update_depmap node n; [n]) + | Lcond (_, _, ifso, ifnot, _) -> begin + check_and_update_depmap node ifso; + check_and_update_depmap node ifnot; + [ifso; ifnot] + end + | Ljumptable(_, ln) -> begin + List.iter (fun n -> check_and_update_depmap node n) ln; + ln + end + (* end of bblocks should not be another value than one of the above *) + | _ -> failwith "last_element gave an invalid output" + in dfs_visit code next_visits + end; + dfs_visit code ln + end + in begin + dfs_visit code [entry]; + depmap + end + +let print_sequence s = + Printf.printf "["; + List.iter (fun n -> Printf.printf "%d, " (P.to_int n)) s; + Printf.printf "]\n" + +let print_ssequence ofs = + Printf.printf "["; + List.iter (fun s -> print_sequence s) ofs; + Printf.printf "]\n" + +let order_sequences code entry fs = + let fs_a = Array.of_list fs in + let depmap = construct_depmap code entry fs_a in + let fs_evaluated = Array.map (fun e -> false) fs_a in + let ordered_fs = ref [] in + let evaluate s_id = + begin + assert (not fs_evaluated.(s_id)); + ordered_fs := fs_a.(s_id) :: !ordered_fs; + fs_evaluated.(s_id) <- true; + (* Printf.printf "++++++\n"; + Printf.printf "Scheduling %d\n" s_id; + Printf.printf "Initial depmap: "; print_depmap depmap; *) + Array.iteri (fun i deps -> + depmap.(i) <- ISet.remove s_id deps + ) depmap; + (* Printf.printf "Final depmap: "; print_depmap depmap; *) + end + in let choose_best_of candidates = + let current_best_id = ref None in + let current_best_score = ref None in + begin + List.iter (fun id -> + match !current_best_id with + | None -> begin + current_best_id := Some id; + match fs_a.(id) with + | [] -> current_best_score := None + | n::l -> current_best_score := Some (P.to_int n) + end + | Some b -> begin + match fs_a.(id) with + | [] -> () + | n::l -> let nscore = P.to_int n in + match !current_best_score with + | None -> (current_best_id := Some id; current_best_score := Some nscore) + | Some bs -> if nscore > bs then (current_best_id := Some id; current_best_score := Some nscore) + end + ) candidates; + !current_best_id + end + in let select_next () = + let candidates = ref [] in + begin + Array.iteri (fun i deps -> + begin + (* Printf.printf "Deps of %d: " i; print_iset deps; Printf.printf "\n"; *) + (* FIXME - if we keep it that way (no dependency check), remove all the unneeded stuff *) + if ((* deps == ISet.empty && *) not fs_evaluated.(i)) then + candidates := i :: !candidates + end + ) depmap; + if not (List.length !candidates > 0) then begin + Array.iteri (fun i deps -> + if (not fs_evaluated.(i)) then candidates := i :: !candidates + ) depmap; + end; + get_some (choose_best_of !candidates) + end + in begin + Printf.printf "-------------------------------\n"; + Printf.printf "depmap: "; print_depmap depmap; + Printf.printf "forward sequences identified: "; print_ssequence fs; + while List.length !ordered_fs != List.length fs do + let next_id = select_next () in + evaluate next_id + done; + Printf.printf "forward sequences ordered: "; print_ssequence (List.rev (!ordered_fs)); + List.rev (!ordered_fs) + end + +let enumerate_aux_trace f reach = + let code = f.fn_code in + let entry = f.fn_entrypoint in + let fs = forward_sequences code entry in + let ofs = order_sequences code entry fs in + List.flatten ofs + +let enumerate_aux f reach = + if !Clflags.option_ftracelinearize then enumerate_aux_trace f reach + else enumerate_aux_flat f reach |