Skip to content

Commit

Permalink
Merge pull request sympy#24009 from kunalsheth/master
Browse files Browse the repository at this point in the history
Make SMT Printer warning logs more pythonic.
  • Loading branch information
oscarbenjamin authored Dec 20, 2022
2 parents 659a924 + 3bc7834 commit c87cd36
Show file tree
Hide file tree
Showing 2 changed files with 388 additions and 290 deletions.
31 changes: 15 additions & 16 deletions sympy/printing/smtlib.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import builtins
import typing

import sympy
Expand Down Expand Up @@ -215,7 +214,7 @@ def smtlib_code(
symbol_table=None,
known_types=None, known_constants=None, known_functions=None,
prefix_expressions=None, suffix_expressions=None,
log_warn=builtins.print
log_warn=None
):
r"""Converts ``expr`` to a string of smtlib code.
Expand Down Expand Up @@ -252,20 +251,21 @@ def smtlib_code(
log_warn: lambda function, optional
A function to record all warnings during potentially risky operations.
Soundness is a core value in SMT solving, so it is good to log all assumptions made.
If not given, builtins ``print`` will be used.
Examples
========
>>> noop = (lambda _: None)
>>> from sympy import smtlib_code, symbols, sin, Eq
>>> x = symbols('x')
>>> smtlib_code(sin(x).series(x).removeO(), log_warn=noop)
>>> smtlib_code(sin(x).series(x).removeO(), log_warn=print)
Could not infer type of `x`. Defaulting to float.
Non-Boolean expression `x**5/120 - x**3/6 + x` will not be asserted. Converting to SMTLib verbatim.
'(declare-const x Real)\n(+ x (* (/ -1 6) (pow x 3)) (* (/ 1 120) (pow x 5)))'
>>> from sympy import Rational
>>> x, y, tau = symbols("x, y, tau")
>>> smtlib_code((2*tau)**Rational(7, 2), log_warn=noop)
>>> smtlib_code((2*tau)**Rational(7, 2), log_warn=print)
Could not infer type of `tau`. Defaulting to float.
Non-Boolean expression `8*sqrt(2)*tau**(7/2)` will not be asserted. Converting to SMTLib verbatim.
'(declare-const tau Real)\n(* 8 (pow 2 (/ 1 2)) (pow tau (/ 7 2)))'
``Piecewise`` expressions are implemented with ``ite`` expressions by default.
Expand All @@ -275,8 +275,7 @@ def smtlib_code(
>>> from sympy import Piecewise
>>> pw = Piecewise((x + 1, x > 0), (x, True))
>>> smtlib_code(Eq(pw, 3))
Could not infer type of `x`. Defaulting to float.
>>> smtlib_code(Eq(pw, 3), symbol_table={x: float}, log_warn=print)
'(declare-const x Real)\n(assert (= (ite (> x 0) (+ 1 x) x) 3))'
Custom printing can be defined for certain types by passing a dictionary of
Expand All @@ -293,11 +292,11 @@ def smtlib_code(
>>> user_def_funcs = { # functions defined by the user must have their types specified explicitly
... g: Callable[[int], float],
... }
>>> smtlib_code(f(x) + g(x), symbol_table=user_def_funcs, known_functions=smt_builtin_funcs)
>>> smtlib_code(f(x) + g(x), symbol_table=user_def_funcs, known_functions=smt_builtin_funcs, log_warn=print)
Non-Boolean expression `f(x) + g(x)` will not be asserted. Converting to SMTLib verbatim.
'(declare-const x Int)\n(declare-fun g (Int) Real)\n(sum (existing_smtlib_fcn x) (g x))'
"""
if not log_warn: log_warn = (lambda _: None)
log_warn = log_warn or (lambda _: None)

if not isinstance(expr, list): expr = [expr]
expr = [
Expand Down Expand Up @@ -359,16 +358,16 @@ def smtlib_code(
if type(fnc) not in p._known_functions and not fnc.is_Piecewise}
declarations = \
[
_auto_declare_smtlib(sym, p, log_warn=log_warn)
_auto_declare_smtlib(sym, p, log_warn)
for sym in constants.values()
] + [
_auto_declare_smtlib(fnc, p, log_warn=log_warn)
_auto_declare_smtlib(fnc, p, log_warn)
for fnc in functions.values()
]
declarations = [decl for decl in declarations if decl]

if auto_assert:
expr = [_auto_assert_smtlib(e, p, log_warn=log_warn) for e in expr]
expr = [_auto_assert_smtlib(e, p, log_warn) for e in expr]

# return SMTLibPrinter().doprint(expr)
return '\n'.join([
Expand All @@ -395,7 +394,7 @@ def smtlib_code(
])


def _auto_declare_smtlib(sym: typing.Union[Symbol, Function], p: SMTLibPrinter, log_warn=print):
def _auto_declare_smtlib(sym: typing.Union[Symbol, Function], p: SMTLibPrinter, log_warn: typing.Callable[[str], None]):
if sym.is_Symbol:
type_signature = p.symbol_table[sym]
assert isinstance(type_signature, type)
Expand All @@ -416,7 +415,7 @@ def _auto_declare_smtlib(sym: typing.Union[Symbol, Function], p: SMTLibPrinter,
return None


def _auto_assert_smtlib(e: Expr, p: SMTLibPrinter, log_warn=print):
def _auto_assert_smtlib(e: Expr, p: SMTLibPrinter, log_warn: typing.Callable[[str], None]):
if isinstance(e, Boolean) or (
e in p.symbol_table and p.symbol_table[e] == bool
) or (
Expand Down
Loading

0 comments on commit c87cd36

Please sign in to comment.