aboutsummaryrefslogtreecommitdiffstats
path: root/scripts/smt2-bmc/smtio.py
blob: d9f165c650bc90b9147fbeba7cf74ee66743dfdf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/python3

import sys
import subprocess

class smtio:
    def __init__(self, solver="yices", debug_print=False, debug_file=None):
        if solver == "yices":
            popen_vargs = ['yices-smt2', '--incremental']

        if solver == "z3":
            popen_vargs = ['z3', '-smt2', '-in']

        self.debug_print = debug_print
        self.debug_file = debug_file
        self.p = subprocess.Popen(popen_vargs, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

    def write(self, stmt):
        stmt = stmt.strip()
        if self.debug_print:
            print("> %s" % stmt)
        if self.debug_file:
            print(stmt, file=self.debug_file)
            self.debug_file.flush()
        self.p.stdin.write(bytes(stmt + "\n", "ascii"))
        self.p.stdin.flush()

    def read(self):
        stmt = []
        count_brackets = 0

        while True:
            line = self.p.stdout.readline().decode("ascii").strip()
            count_brackets += line.count("(")
            count_brackets -= line.count(")")
            stmt.append(line)
            if self.debug_print:
                print("< %s" % line)
            if count_brackets == 0:
                break
            if not self.p.poll():
                print("SMT Solver terminated unexpectedly: %s" % "".join(stmt))
                sys.exit(1)

        stmt = "".join(stmt)
        if stmt.startswith("(error"):
            print("SMT Solver Error: %s" % stmt, file=sys.stderr)
            sys.exit(1)

        return stmt

    def parse(self, stmt):
        def worker(stmt):
            if stmt[0] == '(':
                expr = []
                cursor = 1
                while stmt[cursor] != ')':
                    el, le = worker(stmt[cursor:])
                    expr.append(el)
                    cursor += le
                return expr, cursor+1

            if stmt[0] == '|':
                expr = "|"
                cursor = 1
                while stmt[cursor] != '|':
                    expr += stmt[cursor]
                    cursor += 1
                expr += "|"
                return expr, cursor+1

            if stmt[0] in [" ", "\t", "\r", "\n"]:
                el, le = worker(stmt[1:])
                return el, le+1

            expr = ""
            cursor = 0
            while stmt[cursor] not in ["(", ")", "|", " ", "\t", "\r", "\n"]:
                expr += stmt[cursor]
                cursor += 1
            return expr, cursor
        return worker(stmt)[0]

    def bv2hex(self, v):
        if v == "true": return "1"
        if v == "false": return "0"
        h = ""
        while len(v) > 2:
            d = 0
            if len(v) > 0 and v[-1] == "1": d += 1
            if len(v) > 1 and v[-2] == "1": d += 2
            if len(v) > 2 and v[-3] == "1": d += 4
            if len(v) > 3 and v[-4] == "1": d += 8
            h = hex(d)[2:] + h
            if len(v) < 4: break
            v = v[:-4]
        return h

    def bv2bin(self, v):
        if v == "true": return "1"
        if v == "false": return "0"
        return v[2:]

    def get(self, expr):
        self.write("(get-value (%s))" % (expr))
        return self.parse(self.read())[0][1]

    def get_net(self, mod_name, net_name, state_name):
        return self.get("(|%s_n %s| %s)" % (mod_name, net_name, state_name))

    def get_net_bool(self, mod_name, net_name, state_name):
        v = self.get_net(mod_name, net_name, state_name)
        assert v in ["true", "false"]
        return 1 if v == "true" else 0

    def get_net_hex(self, mod_name, net_name, state_name):
        return self.bv2hex(self.get_net(mod_name, net_name, state_name))

    def get_net_bin(self, mod_name, net_name, state_name):
        return self.bv2bin(self.get_net(mod_name, net_name, state_name))

    def wait(self):
        self.p.wait()