-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Description
Bug Report
Describe the bug
We want to convert TiRex model to the ONNX. We managed to do it, but when we try to run an inference on CPU (via ONNX Runtime on MacOS M4 Chip and Linux x86-64, see below) we got an error:
onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for MatMul(13) node with name '/blocks.0/slstm_layer/slstm_cell/MatMul'
This happens because of bfloat16 datatype in recurrent_kernel = self._recurrent_kernel_.to(dtype=torch.bfloat16)
line 55
This is a learnable parameter, which is defined in lines 35-37
As soon as we change it to float32
, everything works correctly. We would like to keep it in bfloat16
, because prediction accuracy is better with this datatype.
Questions: is bfloat16
is not supported for Matmul
operations? Or is there any other reasons why the inference doesn't work with bfloat16
datatype in recurrent_kernel
?
System information
- OS Platform and Distribution macOS 26.01.1 (M4 Chip), as well as Linux Ubuntu 20.04 (x86-64), as well as Oracle Linux Server 8.9 (Fedora / RHEL family, x86-64)
- ONNX version: 1.19.0
- ONNX Runtime version: 1.23.1
- Python version: 3.11.13
- Opset version: 18