Skip to content

[BUG]: torch export fails for expressions with constant inputs e.g. exp(2) #656

@tbuckworth

Description

@tbuckworth

What happened?

sympy2torch produces a module that fails when called if a function of a constant is present in the expression.

For example:

from sympy import symbols, exp
from pysr import sympy2torch
import torch

x, y = symbols("x y")

expression = exp(2)

module = sympy2torch(expression, [x, y])

X = torch.rand(100, 2).float() * 10

torch_out = module(X)

produces this error

TypeError: exp(): argument 'input' (position 1) must be Tensor, not float

I've tried other expressions like log(4), which produces the same problem.

The current mapping in export_torch.py is sympy.exp: torch.exp.

I believe that

def exp(x):
    return torch.exp(torch.FloatTensor(x))

then using the mapping sympy.exp: exp might work, but I have been unable to test it (adding to extra_sympy_mappings doesn't work, I think because it is chained to the end of the existing mappings and doesn't override the original one).

Alternatively, perhaps simplifying all expressions to constants where possible might solve the problem for all expressions e.g. exp(2) becomes 7.38905609893.

Version

0.18.4

Operating System

Linux

Package Manager

pip

Interface

Script (i.e., python my_script.py)

Relevant log output

No response

Extra Info

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions