aboutsummaryrefslogtreecommitdiffstats
path: root/src/hls/IfConversion.v
blob: f8d404c4dc1d60cc04437ef788ea79424399a895 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
(*
 * Vericert: Verified high-level synthesis.
 * Copyright (C) 2021 Yann Herklotz <yann@yannherklotz.com>
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 *)

Require Import compcert.common.AST.
Require Import compcert.common.Errors.
Require Import compcert.common.Globalenvs.
Require Import compcert.lib.Integers.
Require Import compcert.lib.Maps.

Require Import vericert.common.Vericertlib.
Require Import vericert.hls.RTLBlockInstr.
Require Import vericert.hls.RTLBlock.

(*|
=============
If conversion
=============

This conversion is a verified conversion from RTLBlock back to itself, which performs if-conversion
on basic blocks to make basic blocks larger.
|*)

Definition combine_pred (p: pred_op) (optp: option pred_op) :=
  match optp with
  | Some p' => Pand p p'
  | None => p
  end.

Definition map_if_convert (p: pred_op) (i: instr) :=
  match i with
  | RBop p' op args dst => RBop (Some (combine_pred p p')) op args dst
  | RBload p' chunk addr args dst =>
    RBload (Some (combine_pred p p')) chunk addr args dst
  | RBstore p' chunk addr args src =>
    RBstore (Some (combine_pred p p')) chunk addr args src
  | _ => i
  end.

Definition if_convert_block (c: code) (p: predicate) (bb: bblock) : bblock :=
  let cfi := bb_exit bb in
  match cfi with
  | RBcond cond args n1 n2 =>
    match PTree.get n1 c, PTree.get n2 c with
    | Some bb1, Some bb2 =>
      let bb1' := List.map (map_if_convert (Pvar p)) bb1.(bb_body) in
      let bb2' := List.map (map_if_convert (Pnot (Pvar p))) bb2.(bb_body) in
      mk_bblock (List.concat (bb.(bb_body) :: ((RBsetpred cond args p) :: bb1') :: bb2' :: nil))
                (RBpred_cf (Pvar p) bb1.(bb_exit) bb2.(bb_exit))
    | _, _ => bb
    end
  | _ => bb
  end.

Definition is_cond_cfi (cfi: cf_instr) :=
  match cfi with
  | RBcond _ _ _ _ => true
  | _ => false
  end.

Fixpoint any {A: Type} (f: A -> bool) (a: list A) :=
  match a with
  | x :: xs => f x || any f xs
  | nil => false
  end.

Fixpoint all {A: Type} (f: A -> bool) (a: list A) :=
  match a with
  | x :: xs => f x && all f xs
  | nil => true
  end.

Definition find_backedge (nb: node * bblock) :=
  let (n, b) := nb in
  let succs := successors_instr b.(bb_exit) in
  filter (fun x => Pos.ltb n x) succs.

Definition find_all_backedges (c: code) : list node :=
  List.concat (List.map find_backedge (PTree.elements c)).

Definition has_backedge (entry: node) (be: list node) :=
  any (fun x => Pos.eqb entry x) be.

Definition find_blocks_with_cond (c: code) : list (node * bblock) :=
  let backedges := find_all_backedges c in
  List.filter (fun x => is_cond_cfi (snd x).(bb_exit) &&
                        negb (has_backedge (fst x) backedges) &&
                        all (fun x' => negb (has_backedge x' backedges))
                            (successors_instr (snd x).(bb_exit))
              ) (PTree.elements c).

Definition if_convert_code (p: nat * code) (nb: node * bblock) :=
  let (n, bb) := nb in
  let (p', c) := p in
  let nbb := if_convert_block c (Pos.of_nat p') bb in
  (S p', PTree.set n nbb c).

Definition transf_function (f: function) : function :=
  let (_, c) := List.fold_left if_convert_code
                               (find_blocks_with_cond f.(fn_code))
                               (1%nat, f.(fn_code)) in
  mkfunction f.(fn_sig) f.(fn_params) f.(fn_stacksize) c f.(fn_entrypoint).

Definition transf_fundef (fd: fundef) : fundef :=
  transf_fundef transf_function fd.

Definition transf_program (p: program) : program :=
  transform_program transf_fundef p.