diff options
Diffstat (limited to 'scripts/smt2-bmc/smtio.py')
-rw-r--r-- | scripts/smt2-bmc/smtio.py | 124 |
1 files changed, 124 insertions, 0 deletions
diff --git a/scripts/smt2-bmc/smtio.py b/scripts/smt2-bmc/smtio.py new file mode 100644 index 0000000..d9f165c --- /dev/null +++ b/scripts/smt2-bmc/smtio.py @@ -0,0 +1,124 @@ +#!/usr/bin/python3 + +import sys +import subprocess + +class smtio: + def __init__(self, solver="yices", debug_print=False, debug_file=None): + if solver == "yices": + popen_vargs = ['yices-smt2', '--incremental'] + + if solver == "z3": + popen_vargs = ['z3', '-smt2', '-in'] + + self.debug_print = debug_print + self.debug_file = debug_file + self.p = subprocess.Popen(popen_vargs, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + def write(self, stmt): + stmt = stmt.strip() + if self.debug_print: + print("> %s" % stmt) + if self.debug_file: + print(stmt, file=self.debug_file) + self.debug_file.flush() + self.p.stdin.write(bytes(stmt + "\n", "ascii")) + self.p.stdin.flush() + + def read(self): + stmt = [] + count_brackets = 0 + + while True: + line = self.p.stdout.readline().decode("ascii").strip() + count_brackets += line.count("(") + count_brackets -= line.count(")") + stmt.append(line) + if self.debug_print: + print("< %s" % line) + if count_brackets == 0: + break + if not self.p.poll(): + print("SMT Solver terminated unexpectedly: %s" % "".join(stmt)) + sys.exit(1) + + stmt = "".join(stmt) + if stmt.startswith("(error"): + print("SMT Solver Error: %s" % stmt, file=sys.stderr) + sys.exit(1) + + return stmt + + def parse(self, stmt): + def worker(stmt): + if stmt[0] == '(': + expr = [] + cursor = 1 + while stmt[cursor] != ')': + el, le = worker(stmt[cursor:]) + expr.append(el) + cursor += le + return expr, cursor+1 + + if stmt[0] == '|': + expr = "|" + cursor = 1 + while stmt[cursor] != '|': + expr += stmt[cursor] + cursor += 1 + expr += "|" + return expr, cursor+1 + + if stmt[0] in [" ", "\t", "\r", "\n"]: + el, le = worker(stmt[1:]) + return el, le+1 + + expr = "" + cursor = 0 + while stmt[cursor] not in ["(", ")", "|", " ", "\t", "\r", "\n"]: + expr += stmt[cursor] + cursor += 1 + return expr, cursor + return worker(stmt)[0] + + def bv2hex(self, v): + if v == "true": return "1" + if v == "false": return "0" + h = "" + while len(v) > 2: + d = 0 + if len(v) > 0 and v[-1] == "1": d += 1 + if len(v) > 1 and v[-2] == "1": d += 2 + if len(v) > 2 and v[-3] == "1": d += 4 + if len(v) > 3 and v[-4] == "1": d += 8 + h = hex(d)[2:] + h + if len(v) < 4: break + v = v[:-4] + return h + + def bv2bin(self, v): + if v == "true": return "1" + if v == "false": return "0" + return v[2:] + + def get(self, expr): + self.write("(get-value (%s))" % (expr)) + return self.parse(self.read())[0][1] + + def get_net(self, mod_name, net_name, state_name): + return self.get("(|%s_n %s| %s)" % (mod_name, net_name, state_name)) + + def get_net_bool(self, mod_name, net_name, state_name): + v = self.get_net(mod_name, net_name, state_name) + assert v in ["true", "false"] + return 1 if v == "true" else 0 + + def get_net_hex(self, mod_name, net_name, state_name): + return self.bv2hex(self.get_net(mod_name, net_name, state_name)) + + def get_net_bin(self, mod_name, net_name, state_name): + return self.bv2bin(self.get_net(mod_name, net_name, state_name)) + + def wait(self): + self.p.wait() + |