aboutsummaryrefslogtreecommitdiffstats
path: root/backend/Linearizeaux.ml
diff options
context:
space:
mode:
Diffstat (limited to 'backend/Linearizeaux.ml')
-rw-r--r--backend/Linearizeaux.ml337
1 files changed, 316 insertions, 21 deletions
diff --git a/backend/Linearizeaux.ml b/backend/Linearizeaux.ml
index a6964233..a813ac96 100644
--- a/backend/Linearizeaux.ml
+++ b/backend/Linearizeaux.ml
@@ -122,7 +122,11 @@ let enumerate_aux_flat f reach =
* rather than a branch (ifso).
*
* The enumeration below takes advantage of this - preferring to layout nodes
- * following the fallthroughs of the Lcond branches
+ * following the fallthroughs of the Lcond branches.
+ *
+ * It is slightly adapted from the work of Petris and Hansen 90 on intraprocedural
+ * code positioning - only we do it on a broader grain, since we don't have the exact
+ * frequencies (we only know which branch is the preferred one)
*)
let get_some = function
@@ -136,29 +140,320 @@ let rec last_element = function
| e :: [] -> e
| e' :: e :: l -> last_element (e::l)
-let dfs code entrypoint =
+let print_plist l =
+ let rec f = function
+ | [] -> ()
+ | n :: l -> Printf.printf "%d, " (P.to_int n); f l
+ in begin
+ Printf.printf "[";
+ f l;
+ Printf.printf "]"
+ end
+
+let forward_sequences code entry =
let visited = ref (PTree.map (fun n i -> false) code) in
- let rec dfs_list code = function
+ (* returns the list of traversed nodes, and a list of nodes to start traversing next *)
+ let rec traverse_fallthrough code node =
+ (* Printf.printf "Traversing %d..\n" (P.to_int node); *)
+ if not (get_some @@ PTree.get node !visited) then begin
+ visited := PTree.set node true !visited;
+ match PTree.get node code with
+ | None -> failwith "No such node"
+ | Some bb ->
+ let ln, rem = match (last_element bb) with
+ | Lop _ | Lload _ | Lgetstack _ | Lsetstack _ | Lstore _ | Lcall _
+ | Lbuiltin _ -> assert false
+ | Ltailcall _ | Lreturn -> ([], [])
+ | Lbranch n -> let ln, rem = traverse_fallthrough code n in (ln, rem)
+ | Lcond (_, _, ifso, ifnot) -> let ln, rem = traverse_fallthrough code ifnot in (ln, [ifso] @ rem)
+ | Ljumptable(_, ln) -> match ln with
+ | [] -> ([], [])
+ | n :: ln -> let lln, rem = traverse_fallthrough code n in (lln, ln @ rem)
+ in ([node] @ ln, rem)
+ end
+ else ([], [])
+ in let rec f code = function
| [] -> []
| node :: ln ->
- let node_dfs =
- if not (get_some @@ PTree.get node !visited) then begin
- visited := PTree.set node true !visited;
- match PTree.get node code with
- | None -> failwith "No such node"
- | Some bb -> [node] @ match (last_element bb) with
- | Lop _ | Lload _ | Lgetstack _ | Lsetstack _ | Lstore _ | Lcall _
- | Lbuiltin _ -> assert false
- | Ltailcall _ | Lreturn -> []
- | Lbranch n -> dfs_list code [n]
- | Lcond (_, _, ifso, ifnot) -> dfs_list code [ifnot; ifso]
- | Ljumptable(_, ln) -> dfs_list code ln
- end
- else []
- in node_dfs @ (dfs_list code ln)
- in dfs_list code [entrypoint]
-
-let enumerate_aux_trace f reach = dfs f.fn_code f.fn_entrypoint
+ let fs, rem_from_node = traverse_fallthrough code node
+ in [fs] @ ((f code rem_from_node) @ (f code ln))
+ in (f code [entry])
+
+module PInt = struct
+ type t = P.t
+ let compare x y = compare (P.to_int x) (P.to_int y)
+end
+
+module PSet = Set.Make(PInt)
+
+module LPInt = struct
+ type t = P.t list
+ let rec compare x y =
+ match x with
+ | [] -> ( match y with
+ | [] -> 0
+ | _ -> 1 )
+ | e :: l -> match y with
+ | [] -> -1
+ | e' :: l' ->
+ let e_cmp = PInt.compare e e' in
+ if e_cmp == 0 then compare l l' else e_cmp
+end
+
+module LPSet = Set.Make(LPInt)
+
+let iter_lpset f s = Seq.iter f (LPSet.to_seq s)
+
+let first_of = function
+ | [] -> None
+ | e :: l -> Some e
+
+let rec last_of = function
+ | [] -> None
+ | e :: l -> (match l with [] -> Some e | e :: l -> last_of l)
+
+let can_be_merged code s s' =
+ let last_s = get_some @@ last_of s in
+ let first_s' = get_some @@ first_of s' in
+ match get_some @@ PTree.get last_s code with
+ | Lop _ | Lload _ | Lgetstack _ | Lsetstack _ | Lstore _ | Lcall _
+ | Lbuiltin _ | Ltailcall _ | Lreturn -> false
+ | Lbranch n -> n == first_s'
+ | Lcond (_, _, ifso, ifnot) -> ifnot == first_s'
+ | Ljumptable (_, ln) ->
+ match ln with
+ | [] -> false
+ | n :: ln -> n == first_s'
+
+let merge s s' = Some s
+
+let try_merge code (fs: (BinNums.positive list) list) =
+ let seqs = ref (LPSet.of_list fs) in
+ let oldLength = ref (LPSet.cardinal !seqs) in
+ let continue = ref true in
+ let found = ref false in
+ while !continue do
+ begin
+ found := false;
+ iter_lpset (fun s ->
+ if !found then ()
+ else iter_lpset (fun s' ->
+ if (!found || s == s') then ()
+ else if (can_be_merged code s s') then
+ begin
+ seqs := LPSet.remove s !seqs;
+ seqs := LPSet.remove s' !seqs;
+ seqs := LPSet.add (get_some (merge s s')) !seqs;
+ found := true;
+ end
+ else ()
+ ) !seqs
+ ) !seqs;
+ if !oldLength == LPSet.cardinal !seqs then
+ continue := false
+ else
+ oldLength := LPSet.cardinal !seqs
+ end
+ done;
+ !seqs
+
+(** Code adapted from Duplicateaux.get_loop_headers
+ *
+ * Getting loop branches with a DFS visit :
+ * Each node is either Unvisited, Visited, or Processed
+ * pre-order: node becomes Processed
+ * post-order: node becomes Visited
+ *
+ * If we come accross an edge to a Processed node, it's a loop!
+ *)
+type pos = BinNums.positive
+
+module PP = struct
+ type t = pos * pos
+ let compare a b =
+ let ax, ay = a in
+ let bx, by = b in
+ let dx = compare ax bx in
+ if (dx == 0) then compare ay by
+ else dx
+end
+
+module PPMap = Map.Make(PP)
+
+type vstate = Unvisited | Processed | Visited
+
+let get_loop_edges code entry =
+ let visited = ref (PTree.map (fun n i -> Unvisited) code) in
+ let is_loop_edge = ref PPMap.empty
+ in let rec dfs_visit code from = function
+ | [] -> ()
+ | node :: ln ->
+ match (get_some @@ PTree.get node !visited) with
+ | Visited -> ()
+ | Processed -> begin
+ let from_node = get_some from in
+ is_loop_edge := PPMap.add (from_node, node) true !is_loop_edge;
+ visited := PTree.set node Visited !visited
+ end
+ | Unvisited -> begin
+ visited := PTree.set node Processed !visited;
+ let bb = get_some @@ PTree.get node code in
+ let next_visits = (match (last_element bb) with
+ | Lop _ | Lload _ | Lgetstack _ | Lsetstack _ | Lstore _ | Lcall _
+ | Lbuiltin _ -> assert false
+ | Ltailcall _ | Lreturn -> []
+ | Lbranch n -> [n]
+ | Lcond (_, _, ifso, ifnot) -> [ifso; ifnot]
+ | Ljumptable(_, ln) -> ln
+ ) in dfs_visit code (Some node) next_visits;
+ visited := PTree.set node Visited !visited;
+ dfs_visit code from ln
+ end
+ in begin
+ dfs_visit code None [entry];
+ !is_loop_edge
+ end
+
+let ppmap_is_true pp ppmap = PPMap.mem pp ppmap && PPMap.find pp ppmap
+
+module Int = struct
+ type t = int
+ let compare x y = compare x y
+end
+
+module ISet = Set.Make(Int)
+
+let print_iset s = begin
+ Printf.printf "{";
+ ISet.iter (fun e -> Printf.printf "%d, " e) s;
+ Printf.printf "}"
+end
+
+let print_depmap dm = begin
+ Printf.printf "[|";
+ Array.iter (fun s -> print_iset s; Printf.printf ", ") dm;
+ Printf.printf "|]\n"
+end
+
+let construct_depmap code entry fs =
+ let is_loop_edge = get_loop_edges code entry in
+ let visited = ref (PTree.map (fun n i -> false) code) in
+ let depmap = Array.map (fun e -> ISet.empty) fs in
+ let find_index_of_node n =
+ let index = ref 0 in
+ begin
+ Array.iteri (fun i s ->
+ match List.find_opt (fun e -> e == n) s with
+ | Some _ -> index := i
+ | None -> ()
+ ) fs;
+ !index
+ end
+ in let check_and_update_depmap from target =
+ (* Printf.printf "From %d to %d\n" (P.to_int from) (P.to_int target); *)
+ if not (ppmap_is_true (from, target) is_loop_edge) then
+ let in_index_fs = find_index_of_node from in
+ let out_index_fs = find_index_of_node target in
+ if out_index_fs != in_index_fs then
+ depmap.(out_index_fs) <- ISet.add in_index_fs depmap.(out_index_fs)
+ else ()
+ else ()
+ in let rec dfs_visit code = function
+ | [] -> ()
+ | node :: ln ->
+ begin
+ match (get_some @@ PTree.get node !visited) with
+ | true -> ()
+ | false -> begin
+ visited := PTree.set node true !visited;
+ let bb = get_some @@ PTree.get node code in
+ let next_visits =
+ match (last_element bb) with
+ | Ltailcall _ | Lreturn -> []
+ | Lbranch n -> (check_and_update_depmap node n; [n])
+ | Lcond (_, _, ifso, ifnot) -> begin
+ check_and_update_depmap node ifso;
+ check_and_update_depmap node ifnot;
+ [ifso; ifnot]
+ end
+ | Ljumptable(_, ln) -> begin
+ List.iter (fun n -> check_and_update_depmap node n) ln;
+ ln
+ end
+ (* end of bblocks should not be another value than one of the above *)
+ | _ -> failwith "last_element gave an invalid output"
+ in dfs_visit code next_visits
+ end;
+ dfs_visit code ln
+ end
+ in begin
+ dfs_visit code [entry];
+ depmap
+ end
+
+let print_sequence s =
+ Printf.printf "[";
+ List.iter (fun n -> Printf.printf "%d, " (P.to_int n)) s;
+ Printf.printf "]\n"
+
+let print_ssequence ofs =
+ Printf.printf "[";
+ List.iter (fun s -> print_sequence s) ofs;
+ Printf.printf "]\n"
+
+let order_sequences code entry fs =
+ let fs_a = Array.of_list fs in
+ let depmap = construct_depmap code entry fs_a in
+ let fs_evaluated = Array.map (fun e -> false) fs_a in
+ let ordered_fs = ref [] in
+ let evaluate s_id =
+ begin
+ assert (not fs_evaluated.(s_id));
+ ordered_fs := fs_a.(s_id) :: !ordered_fs;
+ fs_evaluated.(s_id) <- true;
+ Array.iteri (fun i deps ->
+ depmap.(i) <- ISet.remove s_id deps
+ ) depmap
+ end
+ in let select_next () =
+ let selected_id = ref None in
+ begin
+ Array.iteri (fun i deps ->
+ begin
+ (* Printf.printf "Deps: "; print_iset deps; Printf.printf "\n"; *)
+ match !selected_id with
+ | None -> if (deps == ISet.empty && not fs_evaluated.(i)) then selected_id := Some i
+ | Some id -> ()
+ end
+ ) depmap;
+ match !selected_id with
+ | Some id -> id
+ | None -> begin
+ Array.iteri (fun i deps ->
+ match !selected_id with
+ | None -> if not fs_evaluated.(i) then selected_id := Some i
+ | Some id -> ()
+ ) depmap;
+ get_some !selected_id
+ end
+ end
+ in begin
+ (* Printf.printf "depmap: "; print_depmap depmap; *)
+ (* Printf.printf "forward sequences identified: "; print_ssequence fs; *)
+ while List.length !ordered_fs != List.length fs do
+ let next_id = select_next () in
+ evaluate next_id
+ done;
+ (* Printf.printf "forward sequences ordered: "; print_ssequence (List.rev (!ordered_fs)); *)
+ List.rev (!ordered_fs)
+ end
+
+let enumerate_aux_trace f reach =
+ let code = f.fn_code in
+ let entry = f.fn_entrypoint in
+ let fs = forward_sequences code entry in
+ let ofs = order_sequences code entry fs in
+ List.flatten ofs
let enumerate_aux f reach =
if !Clflags.option_ftracelinearize then enumerate_aux_trace f reach