mypy_einsum is a Mypy plugin for type checking np.einsum, jax.numpy.einsum, and torch.einsum operations.
The Einstein summation convention can be used to compute many multi-dimensional, linear algebraic array operations. einsum provides a succinct way of representing these.
However, since einsum equations are passed as a string, it is very easy to overlook typos or other bugs as linters are unable to help.
mypy_einsum is a Mypy plugin that that is able to statically verify the correctness of einsum equations with needing to execute the code.
mypy_einsum can be installed with pip:
pip install mypy-einsumTo enable the plugin, add it to you projects Mypy configuration file.
Usually mypy.ini:
[mypy]
plugins = mypy_einsumor pyproject.toml:
[tool.mypy]
plugins = ["mypy_einsum"]Can you spot the 🐛 without running the code?
import numpy as np
a = np.arange(9).reshape(3, 3)
np.einsum("ik,kj->ij", a)mypy_einsum will catch it for you:
❯ mypy example.py --pretty
example.py:5: error: Number of einsum subscripts must be equal to the
number of operands. [einsum]
np.einsum("ik,kj->ij", a)
^~~~~~~~~~~
Found 1 error in 1 file (checked 1 source file)After fixing it mypy will succeed 🎉:
np.einsum("ik,kj->ij", a, a)❯ mypy example.py
Success: no issues found in 1 source filemypy_einsum aims to never raise warnings for valid einsum operations. If you encounter a warning that you believe is incorrect, or think mypy_einsum is not reporting an error please let us know. Contributions are very welcome!