-
Notifications
You must be signed in to change notification settings - Fork 290
Description
What happened?
sympy2torch produces a module that fails when called if a function of a constant is present in the expression.
For example:
from sympy import symbols, exp
from pysr import sympy2torch
import torch
x, y = symbols("x y")
expression = exp(2)
module = sympy2torch(expression, [x, y])
X = torch.rand(100, 2).float() * 10
torch_out = module(X)
produces this error
TypeError: exp(): argument 'input' (position 1) must be Tensor, not float
I've tried other expressions like log(4), which produces the same problem.
The current mapping in export_torch.py is sympy.exp: torch.exp.
I believe that
def exp(x):
return torch.exp(torch.FloatTensor(x))
then using the mapping sympy.exp: exp might work, but I have been unable to test it (adding to extra_sympy_mappings doesn't work, I think because it is chained to the end of the existing mappings and doesn't override the original one).
Alternatively, perhaps simplifying all expressions to constants where possible might solve the problem for all expressions e.g. exp(2) becomes 7.38905609893.
Version
0.18.4
Operating System
Linux
Package Manager
pip
Interface
Script (i.e., python my_script.py)
Relevant log output
No response
Extra Info
No response