Skip to content

[BUG]: Piecewise not in torch_mappings #639

@tbuckworth

Description

@tbuckworth

What happened?

after fitting a pysr module with "greater" as a binary operator, exporting to torch failed with the following error:

KeyError: 'Function Piecewise was not found in Torch function mappings.Please add it to extra_torch_mappings in the format, e.g., {sympy.sqrt: torch.sqrt}.'

I've seen that in #433 Piecewise was added to the mappings, so I'm surprised to see this error.

I did attempt to fix myself, but it didn't work out:
I've tried adding mappings such as:

{sympy.Piecewise: lambda x, y: torch.where(x[1], x[0], y[0])}

but then the same error arises for sympy.functions.elementary.piecewise.ExprCondPair and then sympy.logic.boolalg.BooleanTrue

in the end, I added

extra_torch_mappings = {
        sympy.Piecewise: lambda x, y: torch.where(x[1], x[0], y[0]),
        sympy.functions.elementary.piecewise.ExprCondPair: tuple,
        sympy.logic.boolalg.BooleanTrue: torch.BoolTensor,
        "greater": lambda x, y: torch.where(x > y, 1.0, 0.0),
    }

But even this produced the following error:

KeyError: 'Function ITE was not found in Torch function mappings.Please add it to extra_torch_mappings in the format, e.g., {sympy.sqrt: torch.sqrt}.'

Hopefully, I am missing something obvious?

Version

0.18.4

Operating System

Linux

Package Manager

pip

Interface

Script (i.e., python my_script.py)

Relevant log output

No response

Extra Info

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions