Skip to content
Merged
Show file tree
Hide file tree
Changes from 82 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
9446d80
mlx poc
williambdean Apr 11, 2025
8a38e2f
add test for dot
williambdean Apr 11, 2025
d6feeba
restore pytorch
williambdean Apr 11, 2025
c8f959e
wrap in mx.array
williambdean Apr 11, 2025
513ee3a
modify the pytorch jit
williambdean Apr 11, 2025
59f4a88
move file
williambdean Apr 11, 2025
07e21e4
dont wrap
williambdean Apr 11, 2025
0583bf7
attempt to fix github action
williambdean Apr 11, 2025
9022edd
change the rtol
williambdean Apr 11, 2025
ebe96e0
add init file
williambdean Apr 11, 2025
c859db0
skip if not installed
williambdean Apr 11, 2025
90321ba
remove torch related code / comments
williambdean Apr 11, 2025
5e51402
simplify the fgraph_convert
williambdean Apr 12, 2025
488ea5a
assert type
williambdean Apr 12, 2025
d714fbc
simplify the internal
williambdean Apr 18, 2025
081806f
remove the language
williambdean Apr 18, 2025
6e312e0
Adding operations in pytensor
cetagostini Apr 18, 2025
b8a95ea
add extension
williambdean Apr 18, 2025
4083fe1
make compare function
williambdean Apr 18, 2025
71ad63d
rename function
williambdean Apr 18, 2025
9f67c2c
correct the function name
williambdean Apr 18, 2025
fa47b1a
tests for elemwise
williambdean Apr 18, 2025
292c01b
Changes
cetagostini Apr 18, 2025
9133b3c
Toma tu tomate William
cetagostini Apr 18, 2025
1d68d5e
Pushing changes with the core shit.
cetagostini Apr 18, 2025
2014390
add more tests
williambdean Apr 18, 2025
89567aa
additional tests
williambdean Apr 18, 2025
dccee53
test for switch with mlx
williambdean Apr 18, 2025
ff871c2
Pushing code
cetagostini Apr 18, 2025
0c2fec1
Changes
cetagostini Apr 18, 2025
5275bf5
A lot of new code
cetagostini Apr 18, 2025
004ed73
almost there baby william
cetagostini Apr 18, 2025
5257c96
Another push small
cetagostini Apr 18, 2025
323045c
fix for all
williambdean Apr 18, 2025
0abac67
fix for carlos
williambdean Apr 18, 2025
199f17c
just return the compiled func
williambdean Apr 19, 2025
7b6a3d2
A change for willy may!
cetagostini Apr 19, 2025
710b563
FINALLY BABY LETS PARTY! (IF YOU ARE READING THIS MAKE MORE PRs)
cetagostini Apr 19, 2025
1e1d8f9
THE SUPER BLOCKWISEE YA YA YA YA JUUUUU
cetagostini Apr 19, 2025
b5c02a7
refactor to use getattr
williambdean Apr 19, 2025
8df5b09
bring argmax test
williambdean Apr 19, 2025
454fda9
use deepcopy
williambdean Apr 19, 2025
9299a28
move some tests
williambdean Apr 19, 2025
b4a9642
Guys, I'm getting sad. We need help yisus!!!!!
cetagostini Apr 19, 2025
30850f3
WILLIAM YOU NEED TO GO ANOTHER MILE! GO ON MY MATEEEEEEE, GO PHILLIES!
cetagostini Apr 19, 2025
84665e5
RETURN, WHAT A SHAME! Sad times are coming.
cetagostini Apr 19, 2025
3041340
AI COULD BE COOL? OR WE ARE JUST FUCKING AROUND?
cetagostini Apr 19, 2025
36f886b
AI RULES BABY MY MATE
cetagostini Apr 19, 2025
ca7c77f
I'm going for pizzas, it was an incredible day!
cetagostini Apr 19, 2025
9688407
test conv1d case
williambdean Apr 19, 2025
5d34fa6
SUUUUUUUUU!!!!!! LIFE IS GOING WELL. MLX FOR MEDIA MIX MODELS BAY
cetagostini Apr 19, 2025
b2fac8e
pre-commit
cetagostini Apr 19, 2025
10e1a40
Almost working
cetagostini Apr 19, 2025
c6841cb
Last PR sampling working
cetagostini Apr 23, 2025
e81ba94
Requested changes by Ricardo
cetagostini Jun 2, 2025
1da4530
Pre commit changes
cetagostini Jun 2, 2025
d6f6e2a
More changes from Ricardo
cetagostini Jun 8, 2025
3d144db
Pre Commit RUN
cetagostini Jun 8, 2025
8300fd4
Adding more operations for complex model
cetagostini Jun 8, 2025
d47de98
Working with simple model
cetagostini Jun 9, 2025
cfcb910
Change bad name
cetagostini Jun 9, 2025
481e3ad
Correcting test by Ricardo
cetagostini Jun 9, 2025
9527f6c
Changing synth test
cetagostini Jun 9, 2025
13a700a
Optimizing reshape
cetagostini Jun 9, 2025
fb46008
Comment
cetagostini Jun 9, 2025
70734c9
Small changes and adding small benchmark
cetagostini Jun 9, 2025
a43f1cf
Changes with Ricardo
cetagostini Jun 10, 2025
5e53537
improving benchmark
cetagostini Jun 10, 2025
127b896
pre commit
cetagostini Jun 10, 2025
26a6d14
benchs
cetagostini Jun 10, 2025
b2e924d
Changes on the branch
cetagostini Jul 11, 2025
a550919
Feedback from Ricardo
cetagostini Oct 1, 2025
e54c32f
update test based on llm recommendation
cetagostini Oct 5, 2025
d5a4bf8
Streamline Blockwise impl
jessegrabowski Oct 8, 2025
11faf7a
clean up imports
jessegrabowski Oct 8, 2025
2a86028
adjust github test.yml
jessegrabowski Oct 8, 2025
53cdf49
adjust github test.yml
jessegrabowski Oct 8, 2025
d22851a
skip mlx tests in benchmark ci
jessegrabowski Oct 8, 2025
ed0d687
Absolute imports
jessegrabowski Oct 8, 2025
2589de4
Use `importorskip` in mlx tests
jessegrabowski Oct 8, 2025
d16d245
address feedback
jessegrabowski Oct 8, 2025
9f41a4e
Add function names and remove wrappers
jessegrabowski Oct 9, 2025
bad7c90
Copy jax CARReduce test
jessegrabowski Oct 9, 2025
433a2cb
Move alloc tests to test_core.py
jessegrabowski Oct 9, 2025
2421a6f
Handle dynamic shapes to AllocEmpty in non-compiled mode
jessegrabowski Oct 9, 2025
d33cda0
Simplify mlx_funcify_CAReduce
jessegrabowski Oct 9, 2025
5940630
Delete AI cruft
jessegrabowski Oct 9, 2025
e484ba4
move all elemwise dispatches to elemwise.py
jessegrabowski Oct 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ jobs:
install-numba: [0]
install-jax: [0]
install-torch: [0]
install-mlx: [0]
install-xarray: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/xtensor"
Expand All @@ -106,6 +107,7 @@ jobs:
install-numba: 0
install-jax: 0
install-torch: 0
install-mlx: 0
install-xarray: 0
- install-numba: 1
os: "ubuntu-latest"
Expand Down Expand Up @@ -149,7 +151,16 @@ jobs:
fast-compile: 0
float32: 0
part: "tests/xtensor"
- os: macos-15
- os: "macos-15"
python-version: "3.11"
fast-compile: 0
float32: 0
install-mlx: 1
install-numba: 0
install-jax: 0
install-torch: 0
part: "tests/link/mlx"
- os: "macos-15"
python-version: "3.13"
fast-compile: 0
float32: 0
Expand Down Expand Up @@ -194,6 +205,7 @@ jobs:
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" mlx; fi
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi

pip install -e ./
Expand All @@ -210,6 +222,7 @@ jobs:
INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}}
INSTALL_XARRAY: ${{ matrix.install-xarray }}
INSTALL_MLX: ${{ matrix.install-mlx }}
OS: ${{ matrix.os}}

- name: Run tests
Expand Down
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ __pycache__
\#*\#
build
compiled/*.cpp
core.*
cutils_ext.cpp
dist
doc/.build/
Expand Down
436 changes: 436 additions & 0 deletions doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb

Large diffs are not rendered by default.

17 changes: 17 additions & 0 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytensor.link.basic import Linker, PerformLinker
from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.link.jax.linker import JAXLinker
from pytensor.link.mlx.linker import MLXLinker
from pytensor.link.numba.linker import NumbaLinker
from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.link.vm import VMLinker
Expand All @@ -50,6 +51,7 @@
"jax": JAXLinker(),
"pytorch": PytorchLinker(),
"numba": NumbaLinker(),
"mlx": MLXLinker(),
}


Expand Down Expand Up @@ -504,13 +506,28 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
),
)

MLX = Mode(
MLXLinker(),
RewriteDatabaseQuery(
include=["fast_run"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
],
),
)


predefined_modes = {
"FAST_COMPILE": FAST_COMPILE,
"FAST_RUN": FAST_RUN,
"JAX": JAX,
"NUMBA": NUMBA,
"PYTORCH": PYTORCH,
"MLX": MLX,
}

_CACHED_RUNTIME_MODES: dict[str, Mode] = {}
Expand Down
1 change: 1 addition & 0 deletions pytensor/link/mlx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pytensor.link.mlx.linker import MLXLinker
13 changes: 13 additions & 0 deletions pytensor/link/mlx/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# isort: off
from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify

import pytensor.link.mlx.dispatch.math
import pytensor.link.mlx.dispatch.basic
import pytensor.link.mlx.dispatch.elemwise
import pytensor.link.mlx.dispatch.shape
import pytensor.link.mlx.dispatch.subtensor
import pytensor.link.mlx.dispatch.core
import pytensor.link.mlx.dispatch.signal
import pytensor.link.mlx.dispatch.signal.conv
import pytensor.link.mlx.dispatch.blockwise
# isort: on
101 changes: 101 additions & 0 deletions pytensor/link/mlx/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import warnings
from copy import deepcopy
from functools import singledispatch
from types import NoneType

import mlx.core as mx
import numpy as np

from pytensor.compile.ops import DeepCopyOp
from pytensor.graph import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import Assert, CheckAndRaise


@singledispatch
def mlx_typify(data, **kwargs):
raise NotImplementedError(f"mlx_typify is not implemented for {type(data)}")


@mlx_typify.register(np.ndarray)
def mlx_typify_tensor(data, dtype=None, **kwargs):
return mx.array(data, dtype=dtype)


@mlx_typify.register(slice)
@mlx_typify.register(NoneType)
@mlx_typify.register(mx.array)
def mlx_typify_no_conversion_needed(data, **kwargs):
return data


@mlx_typify.register(int)
@mlx_typify.register(float)
def mlx_typify_python_scalar(data, **kwargs):
return mx.array(data)


@mlx_typify.register(bool)
@mlx_typify.register(np.bool_)
def mlx_typify_bool(data, **kwargs):
return bool(data)


@mlx_typify.register(np.integer)
@mlx_typify.register(np.floating)
@mlx_typify.register(np.complexfloating)
def mlx_typify_numpy_scalar(data, **kwargs):
return mx.array(data)


@singledispatch
def mlx_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a MLX compatible function from an PyTensor `Op`."""
raise NotImplementedError(
f"No MLX conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/1350` for progress or to request we prioritize this operation"
)


@mlx_funcify.register(FunctionGraph)
def mlx_funcify_FunctionGraph(
fgraph,
node=None,
fgraph_name="mlx_funcified_fgraph",
conversion_func=mlx_funcify,
**kwargs,
):
built_kwargs = {"conversion_func": conversion_func, **kwargs}
return fgraph_to_python(
fgraph,
conversion_func,
type_conversion_fn=mlx_typify,
fgraph_name=fgraph_name,
**built_kwargs,
)


@mlx_funcify.register(DeepCopyOp)
def mlx_funcify_DeepCopyOp(op, **kwargs):
def deepcopyop(x):
return deepcopy(x)

return deepcopyop


@mlx_funcify.register(Assert)
@mlx_funcify.register(CheckAndRaise)
def mlx_funcify_CheckAndRaise(op, node, **kwargs):
conds = node.inputs[1:]
if any(isinstance(cond, Constant) and not bool(cond.data) for cond in conds):
raise op.exc_type(op.msg)

warnings.warn(
f"""Skipping `{type(op).__name__}` Op (assertion: {op.msg}) as MLX tracing would remove it.""",
stacklevel=2,
)

def assert_fn(x, *inputs):
return x

return assert_fn
35 changes: 35 additions & 0 deletions pytensor/link/mlx/dispatch/blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import mlx.core as mx

from pytensor.link.mlx.dispatch import mlx_funcify
from pytensor.tensor.blockwise import Blockwise


@mlx_funcify.register(Blockwise)
def funcify_Blockwise(op: Blockwise, node, **kwargs):
# 2) Otherwise, get the core python function for this Blockwise
core_node = op._create_dummy_core_node(node.inputs)
core_f = mlx_funcify(op.core_op, core_node)

# 3) Determine how many inputs correspond to batch dimensions
n_batch = op.batch_ndim(node)

# 4) Handle case where no vectorization is needed
if n_batch == 0:
return core_f

# 5) Vectorize using mx.vmap over any batched inputs
in_axes: list[int | None] = []
for inp, sig in zip(node.inputs, op.inputs_sig):
batch_ndim = inp.type.ndim - len(sig)
if batch_ndim == 0:
in_axes.append(None)
continue

batch_bcast = inp.type.broadcastable[:batch_ndim]
# If all batch dims are broadcastable (size 1), treat input as static
in_axes.append(0 if not all(batch_bcast) else None)

if not any(axis == 0 for axis in in_axes):
return core_f

return mx.vmap(core_f, in_axes=tuple(in_axes))
Loading