-
Notifications
You must be signed in to change notification settings - Fork 284
Description
What happened?
after fitting a pysr module with "greater" as a binary operator, exporting to torch failed with the following error:
KeyError: 'Function Piecewise was not found in Torch function mappings.Please add it to extra_torch_mappings in the format, e.g., {sympy.sqrt: torch.sqrt}.'
I've seen that in #433 Piecewise was added to the mappings, so I'm surprised to see this error.
I did attempt to fix myself, but it didn't work out:
I've tried adding mappings such as:
{sympy.Piecewise: lambda x, y: torch.where(x[1], x[0], y[0])}
but then the same error arises for sympy.functions.elementary.piecewise.ExprCondPair
and then sympy.logic.boolalg.BooleanTrue
in the end, I added
extra_torch_mappings = {
sympy.Piecewise: lambda x, y: torch.where(x[1], x[0], y[0]),
sympy.functions.elementary.piecewise.ExprCondPair: tuple,
sympy.logic.boolalg.BooleanTrue: torch.BoolTensor,
"greater": lambda x, y: torch.where(x > y, 1.0, 0.0),
}
But even this produced the following error:
KeyError: 'Function ITE was not found in Torch function mappings.Please add it to extra_torch_mappings in the format, e.g., {sympy.sqrt: torch.sqrt}.'
Hopefully, I am missing something obvious?
Version
0.18.4
Operating System
Linux
Package Manager
pip
Interface
Script (i.e., python my_script.py
)
Relevant log output
No response
Extra Info
No response