diff options
-rw-r--r-- | backend/Tunneling.v | 135 | ||||
-rw-r--r-- | backend/Tunnelingproof.v | 253 |
2 files changed, 328 insertions, 60 deletions
diff --git a/backend/Tunneling.v b/backend/Tunneling.v index da1ce45a..265e06ba 100644 --- a/backend/Tunneling.v +++ b/backend/Tunneling.v @@ -12,6 +12,7 @@ (** Branch tunneling (optimization of branches to branches). *) +Require Import FunInd. Require Import Coqlib Maps UnionFind. Require Import AST. Require Import LTL. @@ -21,8 +22,8 @@ Require Import LTL. so that they jump directly to the end of the branch sequence. For example: << - L1: nop L2; L1: nop L3; - L2; nop L3; becomes L2: nop L3; + L1: branch L2; L1: branch L3; + L2; branch L3; becomes L2: branch L3; L3: instr; L3: instr; L4: if (cond) goto L1; L4: if (cond) goto L3; >> @@ -33,70 +34,156 @@ Require Import LTL. computations or useless moves), therefore there are more opportunities for tunneling after allocation than before. Symmetrically, prior tunneling helps linearization to produce - better code, e.g. by revealing that some [nop] instructions are - dead code (as the "nop L3" in the example above). + better code, e.g. by revealing that some [branch] instructions are + dead code (as the "branch L3" in the example above). *) (** The naive implementation of branch tunneling would replace any branch to a node [pc] by a branch to the node [branch_target f pc], defined as follows: << - branch_target f pc = branch_target f pc' if f(pc) = nop pc' + branch_target f pc = branch_target f pc' if f(pc) = branch pc' = pc otherwise >> However, this definition can fail to terminate if the program can contain loops consisting only of branches, as in << - L1: nop L1; + L1: branch L1; >> or -<< L1: nop L2; - L2: nop L1; +<< + L1: branch L2; + L2: branch L1; >> Coq warns us of this fact by not accepting the definition of [branch_target] above. - To handle this problem, we proceed in two passes. The first pass - populates a union-find data structure, adding equalities [pc = pc'] - for every instruction [pc: nop pc'] in the function. *) + To handle this problem, we proceed in two passes: + +- The first pass populates a union-find data structure, adding equalities + between PCs of blocks that are connected by branches and no other + computation. + +- The second pass rewrites the code, replacing every branch to a node [pc] + by a branch to the canonical representative of the equivalence class of [pc]. +*) + +(** * Construction of the union-find data structure *) Module U := UnionFind.UF(PTree). -Definition record_goto (uf: U.t) (pc: node) (b: bblock) : U.t := +(** We start populating the union-find data structure by adding + equalities [pc = pc'] for every block [pc: branch pc'] in the function. *) + +Definition record_branch (uf: U.t) (pc: node) (b: bblock) : U.t := match b with | Lbranch s :: _ => U.union uf pc s | _ => uf end. +Definition record_branches (f: LTL.function) : U.t := + PTree.fold record_branch f.(fn_code) U.empty. + +(** An additional optimization opportunity comes from conditional branches. + Consider a block [pc: cond ifso ifnot]. If the [ifso] case + and the [ifnot] case jump to the same block [pc'] + (modulo intermediate branches), the block can be simplified into + [pc: branch pc'], and the equality [pc = pc'] can be added to the + union-find data structure. *) + +(** In rare cases, the extra equation [pc = pc'] introduced by the + simplification of a conditional branch can trigger further simplifications + of other conditional branches. We therefore iterate the analysis + until no optimizable conditional branch remains. *) + +(** The code [c] (first component of the [st] triple) starts identical + to the code [fn.(fn_code)] of the current function, but each time + conditional branch at [pc] is optimized, we remove the block at + [pc] from the code [c]. This guarantees termination of the + iteration. *) + +Definition record_cond (st: code * U.t * bool) (pc: node) (b: bblock) : code * U.t * bool := + match b with + | Lcond cond args s1 s2 :: _ => + let '(c, u, _) := st in + if peq (U.repr u s1) (U.repr u s2) + then (PTree.remove pc c, U.union u pc s1, true) + else st + | _ => + st + end. + +Definition record_conds_1 (cu: code * U.t) : code * U.t * bool := + let (c, u) := cu in PTree.fold record_cond c (c, u, false). + +Definition measure_state (cu: code * U.t) : nat := + PTree_Properties.cardinal (fst cu). + +Function record_conds (cu: code * U.t) {measure measure_state cu} : U.t := + let (cu', changed) := record_conds_1 cu in + if changed then record_conds cu' else snd cu. +Proof. + intros [c0 u0] [c1 u1]. + set (P := fun (c: code) (s: code * U.t * bool) => + (forall pc, c!pc = None -> (fst (fst s))!pc = c0!pc) /\ + (PTree_Properties.cardinal (fst (fst s)) + + (if snd s then 1 else 0) + <= PTree_Properties.cardinal c0)%nat). + assert (A: P c0 (PTree.fold record_cond c0 (c0, u0, false))). + { apply PTree_Properties.fold_rec; unfold P. + - intros. destruct H0; split; auto. intros. rewrite <- H in H2. auto. + - simpl; split; intros. auto. simpl; lia. + - intros cd [[c u] changed] pc b NONE SOME [HR1 HR2]. simpl. split. + + intros p EQ. rewrite PTree.gsspec in EQ. destruct (peq p pc); try discriminate. + unfold record_cond. destruct b as [ | [] b ]; auto. + destruct (peq (U.repr u s1) (U.repr u s2)); auto. + simpl. rewrite PTree.gro by auto. auto. + + unfold record_cond. destruct b as [ | [] b ]; auto. + destruct (peq (U.repr u s1) (U.repr u s2)); auto. + simpl in *. + assert (SOME': c!pc = Some (Lcond cond args s1 s2 :: b)). + { rewrite HR1 by auto. auto. } + generalize (PTree_Properties.cardinal_remove SOME'). + destruct changed; lia. + } + unfold record_conds_1, measure_state; intros. + destruct A as [_ A]. rewrite teq in A. simpl in *. + lia. +Qed. + Definition record_gotos (f: LTL.function) : U.t := - PTree.fold record_goto f.(fn_code) U.empty. + record_conds (f.(fn_code), record_branches f). + +(** * Code transformation *) -(** The second pass rewrites all LTL instructions, replacing every +(** The code transformation rewrites all LTL instruction, replacing every successor [s] of every instruction by the canonical representative - of its equivalence class in the union-find data structure. *) + of its equivalence class in the union-find data structure. + Additionally, [Lcond] conditional branches are turned into [Lbranch] + unconditional branches whenever possible. *) -Definition tunnel_instr (uf: U.t) (i: instruction) : instruction := +Definition tunnel_instr (u: U.t) (i: instruction) : instruction := match i with - | Lbranch s => Lbranch (U.repr uf s) + | Lbranch s => Lbranch (U.repr u s) | Lcond cond args s1 s2 => - let s1' := U.repr uf s1 in let s2' := U.repr uf s2 in + let s1' := U.repr u s1 in let s2' := U.repr u s2 in if peq s1' s2' then Lbranch s1' else Lcond cond args s1' s2' - | Ljumptable arg tbl => Ljumptable arg (List.map (U.repr uf) tbl) + | Ljumptable arg tbl => Ljumptable arg (List.map (U.repr u) tbl) | _ => i end. -Definition tunnel_block (uf: U.t) (b: bblock) : bblock := - List.map (tunnel_instr uf) b. +Definition tunnel_block (u: U.t) (b: bblock) : bblock := + List.map (tunnel_instr u) b. Definition tunnel_function (f: LTL.function) : LTL.function := - let uf := record_gotos f in + let u := record_gotos f in mkfunction (fn_sig f) (fn_stacksize f) - (PTree.map1 (tunnel_block uf) (fn_code f)) - (U.repr uf (fn_entrypoint f)). + (PTree.map1 (tunnel_block u) (fn_code f)) + (U.repr u (fn_entrypoint f)). Definition tunnel_fundef (f: LTL.fundef) : LTL.fundef := transf_fundef tunnel_function f. diff --git a/backend/Tunnelingproof.v b/backend/Tunnelingproof.v index d514c16f..68913fc9 100644 --- a/backend/Tunnelingproof.v +++ b/backend/Tunnelingproof.v @@ -12,6 +12,7 @@ (** Correctness proof for the branch tunneling optimization. *) +Require Import FunInd. Require Import Coqlib Maps UnionFind. Require Import AST Linking. Require Import Values Memory Events Globalenvs Smallstep. @@ -29,12 +30,21 @@ Qed. (** * Properties of the branch map computed using union-find. *) -Definition measure_edge (u: U.t) (pc s: node) (f: node -> nat) : node -> nat := +Section BRANCH_MAP_CORRECT. + +Variable fn: LTL.function. + +Definition measure_branch (u: U.t) (pc s: node) (f: node -> nat) : node -> nat := fun x => if peq (U.repr u s) pc then f x else if peq (U.repr u x) pc then (f x + f s + 1)%nat else f x. -Definition branch_map_correct (c: code) (u: U.t) (f: node -> nat): Prop := +Definition measure_cond (u: U.t) (pc s1 s2: node) (f: node -> nat) : node -> nat := + fun x => if peq (U.repr u s1) pc then f x + else if peq (U.repr u x) pc then (f x + Nat.max (f s1) (f s2) + 1)%nat + else f x. + +Definition branch_map_correct_1 (c: code) (u: U.t) (f: node -> nat): Prop := forall pc, match c!pc with | Some(Lbranch s :: b) => @@ -43,59 +53,209 @@ Definition branch_map_correct (c: code) (u: U.t) (f: node -> nat): Prop := U.repr u pc = pc end. -Lemma record_gotos_correct_1: - forall fn, { f | branch_map_correct fn.(fn_code) (record_gotos fn) f }. +Lemma record_branch_correct: + forall c u f pc b, + branch_map_correct_1 (PTree.remove pc c) u f -> + c!pc = Some b -> + { f' | branch_map_correct_1 c (record_branch u pc b) f' }. Proof. - intros. - unfold record_gotos. apply PTree_Properties.fold_ind. - -- (* base case *) - intros m EMPTY. exists (fun _ => O). - red; intros. rewrite EMPTY. apply U.repr_empty. - -- (* inductive case *) - intros m u pc bb GET1 GET2 [f BMC]. + intros c u f pc b BMC GET1. assert (PC: U.repr u pc = pc). { specialize (BMC pc). rewrite PTree.grs in BMC. auto. } - assert (DFL: { f | branch_map_correct m u f }). + assert (DFL: { f | branch_map_correct_1 c u f }). { exists f. intros p. destruct (peq p pc). - - subst p. rewrite GET1. destruct bb as [ | [] bb ]; auto. + - subst p. rewrite GET1. destruct b as [ | [] b ]; auto. - specialize (BMC p). rewrite PTree.gro in BMC by auto. exact BMC. } - unfold record_goto. destruct bb as [ | [] bb ]; auto. - exists (measure_edge u pc s f). intros p. destruct (peq p pc). -+ subst p. rewrite GET1. unfold measure_edge. + unfold record_branch. destruct b as [ | [] b ]; auto. + exists (measure_branch u pc s f). intros p. destruct (peq p pc). ++ subst p. rewrite GET1. unfold measure_branch. rewrite (U.repr_union_2 u pc s); auto. rewrite U.repr_union_3. destruct (peq (U.repr u s) pc); auto. rewrite PC, peq_true. right; split; auto. lia. + specialize (BMC p). rewrite PTree.gro in BMC by auto. assert (U.repr u p = p -> U.repr (U.union u pc s) p = p). { intro. rewrite <- H at 2. apply U.repr_union_1. congruence. } - destruct (m!p) as [ [ | [] b ] | ]; auto. + destruct (c!p) as [ [ | [] _ ] | ]; auto. destruct BMC as [A | [A B]]. auto. right; split. apply U.sameclass_union_2; auto. - unfold measure_edge. destruct (peq (U.repr u s) pc). auto. + unfold measure_branch. destruct (peq (U.repr u s) pc). auto. rewrite A. destruct (peq (U.repr u s0) pc); lia. Qed. -Definition branch_target (f: function) (pc: node) : node := - U.repr (record_gotos f) pc. +Lemma record_branches_correct: + { f | branch_map_correct_1 fn.(fn_code) (record_branches fn) f }. +Proof. + unfold record_branches. apply PTree_Properties.fold_ind. +- (* base case *) + intros m EMPTY. exists (fun _ => O). + red; intros. rewrite EMPTY. apply U.repr_empty. +- (* inductive case *) + intros m u pc bb GET1 GET2 [f BMC]. eapply record_branch_correct; eauto. +Qed. + +Definition branch_map_correct_2 (c: code) (u: U.t) (f: node -> nat): Prop := + forall pc, + match fn.(fn_code)!pc with + | Some(Lbranch s :: b) => + U.repr u pc = pc \/ (U.repr u pc = U.repr u s /\ f s < f pc)%nat + | Some(Lcond cond args s1 s2 :: b) => + U.repr u pc = pc \/ (c!pc = None /\ U.repr u pc = U.repr u s1 /\ U.repr u pc = U.repr u s2 /\ f s1 < f pc /\ f s2 < f pc)%nat + | _ => + U.repr u pc = pc + end. -Definition count_gotos (f: function) (pc: node) : nat := - proj1_sig (record_gotos_correct_1 f) pc. +Lemma record_cond_correct: + forall c u changed f pc b, + branch_map_correct_2 c u f -> + fn.(fn_code)!pc = Some b -> + c!pc <> None -> + let '(c1, u1, _) := record_cond (c, u, changed) pc b in + { f' | branch_map_correct_2 c1 u1 f' }. +Proof. + intros c u changed f pc b BMC GET1 GET2. + assert (DFL: { f' | branch_map_correct_2 c u f' }). + { exists f; auto. } + unfold record_cond. destruct b as [ | [] b ]; auto. + destruct (peq (U.repr u s1) (U.repr u s2)); auto. + exists (measure_cond u pc s1 s2 f). + assert (PC: U.repr u pc = pc). + { specialize (BMC pc). rewrite GET1 in BMC. intuition congruence. } + intro p. destruct (peq p pc). +- subst p. rewrite GET1. unfold measure_cond. + rewrite U.repr_union_2 by auto. rewrite <- e, PC, peq_true. + destruct (peq (U.repr u s1) pc); auto. + right; repeat split. + + apply PTree.grs. + + rewrite U.repr_union_3. auto. + + rewrite U.repr_union_1 by congruence. auto. + + lia. + + lia. +- assert (P: U.repr u p = p -> U.repr (U.union u pc s1) p = p). + { intros. rewrite U.repr_union_1 by congruence. auto. } + specialize (BMC p). destruct (fn_code fn)!p as [ [ | [] bb ] | ]; auto. + + destruct BMC as [A | (A & B)]; auto. right; split. + * apply U.sameclass_union_2; auto. + * unfold measure_cond. rewrite <- A. + destruct (peq (U.repr u s1) pc). auto. + destruct (peq (U.repr u p) pc); lia. + + destruct BMC as [A | (A & B & C & D & E)]; auto. right; split; [ | split; [ | split]]. + * rewrite PTree.gro by auto. auto. + * apply U.sameclass_union_2; auto. + * apply U.sameclass_union_2; auto. + * unfold measure_cond. rewrite <- B, <- C. + destruct (peq (U.repr u s1) pc). auto. + destruct (peq (U.repr u p) pc); lia. +Qed. + +Definition code_compat (c: code) : Prop := + forall pc b, c!pc = Some b -> fn.(fn_code)!pc = Some b. + +Definition code_invariant (c0 c1 c2: code) : Prop := + forall pc, c0!pc = None -> c1!pc = c2!pc. + +Lemma record_conds_1_correct: + forall c u f, + branch_map_correct_2 c u f -> + code_compat c -> + let '(c', u', _) := record_conds_1 (c, u) in + (code_compat c' * { f' | branch_map_correct_2 c' u' f' })%type. +Proof. + intros c0 u0 f0 BMC0 COMPAT0. + unfold record_conds_1. + set (x := PTree.fold record_cond c0 (c0, u0, false)). + set (P := fun (cd: code) (cuc: code * U.t * bool) => + (code_compat (fst (fst cuc)) * + code_invariant cd (fst (fst cuc)) c0 * + { f | branch_map_correct_2 (fst (fst cuc)) (snd (fst cuc)) f })%type). + assert (REC: P c0 x). + { unfold x; apply PTree_Properties.fold_ind. + - intros cd EMPTY. split; [split|]; simpl. + + auto. + + red; auto. + + exists f0; auto. + - intros cd [[c u] changed] pc b GET1 GET2 [[COMPAT INV] [f BMC]]. simpl in *. + split; [split|]. + + unfold record_cond; destruct b as [ | [] b]; simpl; auto. + destruct (peq (U.repr u s1) (U.repr u s2)); simpl; auto. + red; intros. rewrite PTree.grspec in H. destruct (PTree.elt_eq pc0 pc). discriminate. auto. + + assert (DFL: code_invariant cd c c0). + { intros p GET. apply INV. rewrite PTree.gro by congruence. auto. } + unfold record_cond; destruct b as [ | [] b]; simpl; auto. + destruct (peq (U.repr u s1) (U.repr u s2)); simpl; auto. + intros p GET. rewrite PTree.gro by congruence. apply INV. rewrite PTree.gro by congruence. auto. + + assert (GET3: c!pc = Some b). + { rewrite <- GET2. apply INV. apply PTree.grs. } + assert (X: fn.(fn_code)!pc = Some b) by auto. + assert (Y: c!pc <> None) by congruence. + generalize (record_cond_correct c u changed f pc b BMC X Y). + destruct (record_cond (c, u, changed) pc b) as [[c1 u1] changed1]; simpl. + auto. + } + destruct x as [[c1 u1] changed1]; destruct REC as [[COMPAT1 INV1] BMC1]; auto. +Qed. + +Definition branch_map_correct (u: U.t) (f: node -> nat): Prop := + forall pc, + match fn.(fn_code)!pc with + | Some(Lbranch s :: b) => + U.repr u pc = pc \/ (U.repr u pc = U.repr u s /\ f s < f pc)%nat + | Some(Lcond cond args s1 s2 :: b) => + U.repr u pc = pc \/ (U.repr u pc = U.repr u s1 /\ U.repr u pc = U.repr u s2 /\ f s1 < f pc /\ f s2 < f pc)%nat + | _ => + U.repr u pc = pc + end. + +Lemma record_conds_correct: + forall cu, + { f | branch_map_correct_2 (fst cu) (snd cu) f } -> + code_compat (fst cu) -> + { f | branch_map_correct (record_conds cu) f }. +Proof. + intros cu0. functional induction (record_conds cu0); intros. +- destruct cu as [c u], cu' as [c' u'], H as [f BMC]. + generalize (record_conds_1_correct c u f BMC H0). + rewrite e. intros [U V]. apply IHt; auto. +- destruct cu as [c u], H as [f BMC]. + exists f. intros pc. specialize (BMC pc); simpl in *. + destruct (fn_code fn)!pc as [ [ | [] b ] | ]; tauto. +Qed. + +Lemma record_gotos_correct_1: + { f | branch_map_correct (record_gotos fn) f }. +Proof. + apply record_conds_correct; simpl. +- destruct record_branches_correct as [f BMC]. + exists f. intros pc. specialize (BMC pc); simpl in *. + destruct (fn_code fn)!pc as [ [ | [] b ] | ]; auto. +- red; auto. +Qed. + +Definition branch_target (pc: node) : node := + U.repr (record_gotos fn) pc. + +Definition count_gotos (pc: node) : nat := + proj1_sig record_gotos_correct_1 pc. Theorem record_gotos_correct: - forall f pc, - match f.(fn_code)!pc with + forall pc, + match fn.(fn_code)!pc with | Some(Lbranch s :: b) => - branch_target f pc = pc \/ - (branch_target f pc = branch_target f s /\ count_gotos f s < count_gotos f pc)%nat - | _ => branch_target f pc = pc + branch_target pc = pc \/ + (branch_target pc = branch_target s /\ count_gotos s < count_gotos pc)%nat + | Some(Lcond cond args s1 s2 :: b) => + branch_target pc = pc \/ + (branch_target pc = branch_target s1 /\ branch_target pc = branch_target s2 + /\ count_gotos s1 < count_gotos pc /\ count_gotos s2 < count_gotos pc)%nat + | _ => + branch_target pc = pc end. Proof. - intros. unfold count_gotos. destruct (record_gotos_correct_1 f) as [m P]; simpl. + intros. unfold count_gotos. destruct record_gotos_correct_1 as [f P]; simpl. apply P. Qed. +End BRANCH_MAP_CORRECT. + (** * Preservation of semantics *) Section PRESERVATION. @@ -187,13 +347,21 @@ Inductive match_states: state -> state -> Prop := (MEM: Mem.extends m tm), match_states (Block s f sp bb ls m) (Block ts (tunnel_function f) sp (tunneled_block f bb) tls tm) - | match_states_interm: + | match_states_interm_branch: forall s f sp pc bb ls m ts tls tm (STK: list_forall2 match_stackframes s ts) (LS: locmap_lessdef ls tls) (MEM: Mem.extends m tm), match_states (Block s f sp (Lbranch pc :: bb) ls m) (State ts (tunnel_function f) sp (branch_target f pc) tls tm) + | match_states_interm_cond: + forall s f sp cond args pc1 pc2 bb ls m ts tls tm + (STK: list_forall2 match_stackframes s ts) + (LS: locmap_lessdef ls tls) + (MEM: Mem.extends m tm) + (SAME: branch_target f pc1 = branch_target f pc2), + match_states (Block s f sp (Lcond cond args pc1 pc2 :: bb) ls m) + (State ts (tunnel_function f) sp (branch_target f pc1) tls tm) | match_states_call: forall s f ls m ts tls tm (STK: list_forall2 match_stackframes s ts) @@ -346,6 +514,7 @@ Definition measure (st: state) : nat := match st with | State s f sp pc ls m => (count_gotos f pc * 2)%nat | Block s f sp (Lbranch pc :: _) ls m => (count_gotos f pc * 2 + 1)%nat + | Block s f sp (Lcond _ _ pc1 pc2 :: _) ls m => (Nat.max (count_gotos f pc1) (count_gotos f pc2) * 2 + 1)%nat | Block s f sp bb ls m => 0%nat | Callstate s f ls m => 0%nat | Returnstate s ls m => 0%nat @@ -380,10 +549,16 @@ Proof. generalize (record_gotos_correct f pc). rewrite H. destruct bb; auto. destruct i; auto. ++ (* Lbranch *) intros [A | [B C]]. auto. right. split. simpl. lia. split. auto. rewrite B. econstructor; eauto. ++ (* Lcond *) + intros [A | (B & C & D & E)]. auto. + right. split. simpl. lia. + split. auto. + rewrite B. econstructor; eauto. congruence. - (* Lop *) exploit eval_operation_lessdef. apply reglist_lessdef; eauto. eauto. eauto. @@ -450,18 +625,24 @@ Proof. - (* Lbranch (eliminated) *) right; split. simpl. lia. split. auto. constructor; auto. -- (* Lcond *) +- (* Lcond (preserved) *) simpl tunneled_block. set (s1 := U.repr (record_gotos f) pc1). set (s2 := U.repr (record_gotos f) pc2). destruct (peq s1 s2). + left; econstructor; split. - eapply exec_Lbranch. - destruct b. -* constructor; eauto using locmap_undef_regs_lessdef_1. -* rewrite e. constructor; eauto using locmap_undef_regs_lessdef_1. + eapply exec_Lbranch. + set (pc := if b then pc1 else pc2). + replace s1 with (branch_target f pc) by (unfold pc; destruct b; auto). + constructor; eauto using locmap_undef_regs_lessdef_1. + left; econstructor; split. eapply exec_Lcond; eauto. eapply eval_condition_lessdef; eauto using reglist_lessdef. destruct b; econstructor; eauto using locmap_undef_regs_lessdef. +- (* Lcond (eliminated) *) + right; split. simpl. destruct b; lia. + split. auto. + set (pc := if b then pc1 else pc2). + replace (branch_target f pc1) with (branch_target f pc) by (unfold pc; destruct b; auto). + econstructor; eauto. - (* Ljumptable *) assert (tls (R arg) = Vint n). |