diff --git a/symengine/lib/symengine.pxd b/symengine/lib/symengine.pxd index e6b963d3..ffb69310 100644 --- a/symengine/lib/symengine.pxd +++ b/symengine/lib/symengine.pxd @@ -133,6 +133,7 @@ cdef extern from "" namespace "SymEngine": RCP[const Rational] rcp_static_cast_Rational "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil RCP[const Complex] rcp_static_cast_Complex "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil RCP[const Number] rcp_static_cast_Number "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil + RCP[const Dummy] rcp_static_cast_Dummy "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil RCP[const Add] rcp_static_cast_Add "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil RCP[const Mul] rcp_static_cast_Mul "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil RCP[const Pow] rcp_static_cast_Pow "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil @@ -180,7 +181,7 @@ cdef extern from "" namespace "SymEngine": Symbol(string name) nogil string get_name() nogil cdef cppclass Dummy(Symbol): - pass + size_t get_index() cdef extern from "" namespace "SymEngine": cdef cppclass Number(Basic): @@ -322,6 +323,7 @@ cdef extern from "" namespace "SymEngine": rcp_const_basic make_rcp_Symbol "SymEngine::make_rcp"(string name) nogil rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp"() nogil rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp"(string name) nogil + rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp"(string &name, size_t index) nogil rcp_const_basic make_rcp_PySymbol "SymEngine::make_rcp"(string name, PyObject * pyobj, bool use_pickle) except + rcp_const_basic make_rcp_Constant "SymEngine::make_rcp"(string name) nogil rcp_const_basic make_rcp_Infty "SymEngine::make_rcp"(RCP[const Number] i) nogil diff --git a/symengine/lib/symengine_wrapper.in.pyx b/symengine/lib/symengine_wrapper.in.pyx index d5c68b23..7b574c6e 100644 --- a/symengine/lib/symengine_wrapper.in.pyx +++ b/symengine/lib/symengine_wrapper.in.pyx @@ -278,10 +278,10 @@ def sympy2symengine(a, raise_error=False): """ import sympy from sympy.core.function import AppliedUndef as sympy_AppliedUndef - if isinstance(a, sympy.Symbol): + if isinstance(a, sympy.Dummy): + return Dummy(a.name, a.dummy_index) + elif isinstance(a, sympy.Symbol): return Symbol(a.name) - elif isinstance(a, sympy.Dummy): - return Dummy(a.name) elif isinstance(a, sympy.Mul): return mul(*[sympy2symengine(x, raise_error) for x in a.args]) elif isinstance(a, sympy.Add): @@ -1304,10 +1304,10 @@ cdef class Symbol(Expr): return sympy.Symbol(str(self)) def __reduce__(self): - if type(self) == Symbol: + if type(self) in (Symbol, Dummy): return Basic.__reduce__(self) else: - raise NotImplementedError("pickling for Symbol subclass not implemented") + raise NotImplementedError("pickling for subclass of Symbol or Dummy not implemented") def _sage_(self): import sage.all as sage @@ -1340,15 +1340,20 @@ cdef class Symbol(Expr): cdef class Dummy(Symbol): - def __init__(Basic self, name=None, *args, **kwargs): - if name is None: - self.thisptr = symengine.make_rcp_Dummy() + def __init__(Basic self, name=None, dummy_index=None, *args, **kwargs): + cdef size_t index + if dummy_index is None: + if name is None: + self.thisptr = symengine.make_rcp_Dummy() + else: + self.thisptr = symengine.make_rcp_Dummy(name.encode("utf-8")) else: - self.thisptr = symengine.make_rcp_Dummy(name.encode("utf-8")) + index = dummy_index + self.thisptr = symengine.make_rcp_Dummy(name.encode("utf-8"), index) def _sympy_(self): import sympy - return sympy.Dummy(str(self)[1:]) + return sympy.Dummy(name=self.name, dummy_index=self.dummy_index) @property def is_Dummy(self): @@ -1358,6 +1363,12 @@ cdef class Dummy(Symbol): def func(self): return self.__class__ + @property + def dummy_index(self): + cdef RCP[const symengine.Dummy] this = \ + symengine.rcp_static_cast_Dummy(self.thisptr) + cdef size_t index = deref(this).get_index() + return index def symarray(prefix, shape, **kwargs): """ Creates an nd-array of symbols diff --git a/symengine/tests/test_pickling.py b/symengine/tests/test_pickling.py index 5ae64a75..a51a3251 100644 --- a/symengine/tests/test_pickling.py +++ b/symengine/tests/test_pickling.py @@ -1,4 +1,4 @@ -from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos, Symbol +from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos, Symbol, Dummy from symengine.test_utilities import raises import pickle import unittest @@ -57,3 +57,19 @@ def test_llvm_double(): ll = pickle.loads(ss) inp = [1, 2, 3] assert np.allclose(l(inp), ll(inp)) + + +def _check_pickling_roundtrip(arg): + s2 = pickle.dumps(arg) + arg2 = pickle.loads(s2) + assert arg == arg2 + s3 = pickle.dumps(arg2) + arg3 = pickle.loads(s3) + assert arg == arg3 + + +def test_pickling_roundtrip(): + x, y, z = symbols('x y z') + _check_pickling_roundtrip(x+y) + _check_pickling_roundtrip(Dummy('d')) + _check_pickling_roundtrip(Dummy('d') - z) diff --git a/symengine/tests/test_symbol.py b/symengine/tests/test_symbol.py index 49825914..dc22d8ab 100644 --- a/symengine/tests/test_symbol.py +++ b/symengine/tests/test_symbol.py @@ -156,6 +156,9 @@ def test_dummy(): x2 = Symbol('x') xdummy1 = Dummy('x') xdummy2 = Dummy('x') + assert xdummy1.dummy_index != xdummy2.dummy_index # maybe test using "less than"? + assert xdummy1.name == 'x' + assert xdummy2.name == 'x' assert x1 == x2 assert x1 != xdummy1 diff --git a/symengine/tests/test_sympy_conv.py b/symengine/tests/test_sympy_conv.py index 5d173dc4..1d9790a7 100644 --- a/symengine/tests/test_sympy_conv.py +++ b/symengine/tests/test_sympy_conv.py @@ -1,6 +1,6 @@ from symengine import (Symbol, Integer, sympify, SympifyError, log, function_symbol, I, E, pi, oo, zoo, nan, true, false, - exp, gamma, have_mpfr, have_mpc, DenseMatrix, sin, cos, tan, cot, + exp, gamma, have_mpfr, have_mpc, DenseMatrix, Dummy, sin, cos, tan, cot, csc, sec, asin, acos, atan, acot, acsc, asec, sinh, cosh, tanh, coth, asinh, acosh, atanh, acoth, atan2, Add, Mul, Pow, diff, GoldenRatio, Catalan, EulerGamma, UnevaluatedExpr, RealDouble) @@ -833,3 +833,24 @@ def test_conv_large_integers(): if have_sympy: c = a._sympy_() d = sympify(c) + + +def _check_sympy_roundtrip(arg): + arg_sy1 = sympy.sympify(arg) + arg_se2 = sympify(arg_sy1) + assert arg == arg_se2 + arg_sy2 = sympy.sympify(arg_se2) + assert arg_sy2 == arg_sy1 + arg_se3 = sympify(arg_sy2) + assert arg_se3 == arg + + +@unittest.skipIf(not have_sympy, "SymPy not installed") +def test_sympy_roundtrip(): + x = Symbol("x") + y = Symbol("y") + d = Dummy("d") + _check_sympy_roundtrip(x) + _check_sympy_roundtrip(x+y) + _check_sympy_roundtrip(x**y) + _check_sympy_roundtrip(d)