-
Notifications
You must be signed in to change notification settings - Fork 143
Add MLX backend #1365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add MLX backend #1365
Changes from 82 commits
Commits
Show all changes
88 commits
Select commit
Hold shift + click to select a range
9446d80
mlx poc
williambdean 8a38e2f
add test for dot
williambdean d6feeba
restore pytorch
williambdean c8f959e
wrap in mx.array
williambdean 513ee3a
modify the pytorch jit
williambdean 59f4a88
move file
williambdean 07e21e4
dont wrap
williambdean 0583bf7
attempt to fix github action
williambdean 9022edd
change the rtol
williambdean ebe96e0
add init file
williambdean c859db0
skip if not installed
williambdean 90321ba
remove torch related code / comments
williambdean 5e51402
simplify the fgraph_convert
williambdean 488ea5a
assert type
williambdean d714fbc
simplify the internal
williambdean 081806f
remove the language
williambdean 6e312e0
Adding operations in pytensor
cetagostini b8a95ea
add extension
williambdean 4083fe1
make compare function
williambdean 71ad63d
rename function
williambdean 9f67c2c
correct the function name
williambdean fa47b1a
tests for elemwise
williambdean 292c01b
Changes
cetagostini 9133b3c
Toma tu tomate William
cetagostini 1d68d5e
Pushing changes with the core shit.
cetagostini 2014390
add more tests
williambdean 89567aa
additional tests
williambdean dccee53
test for switch with mlx
williambdean ff871c2
Pushing code
cetagostini 0c2fec1
Changes
cetagostini 5275bf5
A lot of new code
cetagostini 004ed73
almost there baby william
cetagostini 5257c96
Another push small
cetagostini 323045c
fix for all
williambdean 0abac67
fix for carlos
williambdean 199f17c
just return the compiled func
williambdean 7b6a3d2
A change for willy may!
cetagostini 710b563
FINALLY BABY LETS PARTY! (IF YOU ARE READING THIS MAKE MORE PRs)
cetagostini 1e1d8f9
THE SUPER BLOCKWISEE YA YA YA YA JUUUUU
cetagostini b5c02a7
refactor to use getattr
williambdean 8df5b09
bring argmax test
williambdean 454fda9
use deepcopy
williambdean 9299a28
move some tests
williambdean b4a9642
Guys, I'm getting sad. We need help yisus!!!!!
cetagostini 30850f3
WILLIAM YOU NEED TO GO ANOTHER MILE! GO ON MY MATEEEEEEE, GO PHILLIES!
cetagostini 84665e5
RETURN, WHAT A SHAME! Sad times are coming.
cetagostini 3041340
AI COULD BE COOL? OR WE ARE JUST FUCKING AROUND?
cetagostini 36f886b
AI RULES BABY MY MATE
cetagostini ca7c77f
I'm going for pizzas, it was an incredible day!
cetagostini 9688407
test conv1d case
williambdean 5d34fa6
SUUUUUUUUU!!!!!! LIFE IS GOING WELL. MLX FOR MEDIA MIX MODELS BAY
cetagostini b2fac8e
pre-commit
cetagostini 10e1a40
Almost working
cetagostini c6841cb
Last PR sampling working
cetagostini e81ba94
Requested changes by Ricardo
cetagostini 1da4530
Pre commit changes
cetagostini d6f6e2a
More changes from Ricardo
cetagostini 3d144db
Pre Commit RUN
cetagostini 8300fd4
Adding more operations for complex model
cetagostini d47de98
Working with simple model
cetagostini cfcb910
Change bad name
cetagostini 481e3ad
Correcting test by Ricardo
cetagostini 9527f6c
Changing synth test
cetagostini 13a700a
Optimizing reshape
cetagostini fb46008
Comment
cetagostini 70734c9
Small changes and adding small benchmark
cetagostini a43f1cf
Changes with Ricardo
cetagostini 5e53537
improving benchmark
cetagostini 127b896
pre commit
cetagostini 26a6d14
benchs
cetagostini b2e924d
Changes on the branch
cetagostini a550919
Feedback from Ricardo
cetagostini e54c32f
update test based on llm recommendation
cetagostini d5a4bf8
Streamline Blockwise impl
jessegrabowski 11faf7a
clean up imports
jessegrabowski 2a86028
adjust github test.yml
jessegrabowski 53cdf49
adjust github test.yml
jessegrabowski d22851a
skip mlx tests in benchmark ci
jessegrabowski ed0d687
Absolute imports
jessegrabowski 2589de4
Use `importorskip` in mlx tests
jessegrabowski d16d245
address feedback
jessegrabowski 9f41a4e
Add function names and remove wrappers
jessegrabowski bad7c90
Copy jax CARReduce test
jessegrabowski 433a2cb
Move alloc tests to test_core.py
jessegrabowski 2421a6f
Handle dynamic shapes to AllocEmpty in non-compiled mode
jessegrabowski d33cda0
Simplify mlx_funcify_CAReduce
jessegrabowski 5940630
Delete AI cruft
jessegrabowski e484ba4
move all elemwise dispatches to elemwise.py
jessegrabowski File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,6 @@ __pycache__ | |
\#*\# | ||
build | ||
compiled/*.cpp | ||
core.* | ||
cutils_ext.cpp | ||
dist | ||
doc/.build/ | ||
|
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from pytensor.link.mlx.linker import MLXLinker |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# isort: off | ||
from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify | ||
|
||
import pytensor.link.mlx.dispatch.math | ||
import pytensor.link.mlx.dispatch.basic | ||
import pytensor.link.mlx.dispatch.elemwise | ||
import pytensor.link.mlx.dispatch.shape | ||
import pytensor.link.mlx.dispatch.subtensor | ||
import pytensor.link.mlx.dispatch.core | ||
import pytensor.link.mlx.dispatch.signal | ||
import pytensor.link.mlx.dispatch.signal.conv | ||
import pytensor.link.mlx.dispatch.blockwise | ||
# isort: on |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import warnings | ||
from copy import deepcopy | ||
from functools import singledispatch | ||
from types import NoneType | ||
|
||
import mlx.core as mx | ||
import numpy as np | ||
|
||
from pytensor.compile.ops import DeepCopyOp | ||
from pytensor.graph import Constant | ||
from pytensor.graph.fg import FunctionGraph | ||
from pytensor.link.utils import fgraph_to_python | ||
from pytensor.raise_op import Assert, CheckAndRaise | ||
|
||
|
||
@singledispatch | ||
def mlx_typify(data, **kwargs): | ||
raise NotImplementedError(f"mlx_typify is not implemented for {type(data)}") | ||
|
||
|
||
@mlx_typify.register(np.ndarray) | ||
def mlx_typify_tensor(data, dtype=None, **kwargs): | ||
return mx.array(data, dtype=dtype) | ||
|
||
|
||
@mlx_typify.register(slice) | ||
@mlx_typify.register(NoneType) | ||
@mlx_typify.register(mx.array) | ||
def mlx_typify_no_conversion_needed(data, **kwargs): | ||
return data | ||
|
||
|
||
@mlx_typify.register(int) | ||
@mlx_typify.register(float) | ||
def mlx_typify_python_scalar(data, **kwargs): | ||
return mx.array(data) | ||
|
||
|
||
@mlx_typify.register(bool) | ||
@mlx_typify.register(np.bool_) | ||
def mlx_typify_bool(data, **kwargs): | ||
return bool(data) | ||
|
||
|
||
@mlx_typify.register(np.integer) | ||
@mlx_typify.register(np.floating) | ||
@mlx_typify.register(np.complexfloating) | ||
def mlx_typify_numpy_scalar(data, **kwargs): | ||
return mx.array(data) | ||
|
||
|
||
@singledispatch | ||
def mlx_funcify(op, node=None, storage_map=None, **kwargs): | ||
"""Create a MLX compatible function from an PyTensor `Op`.""" | ||
raise NotImplementedError( | ||
f"No MLX conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/1350` for progress or to request we prioritize this operation" | ||
) | ||
|
||
|
||
@mlx_funcify.register(FunctionGraph) | ||
def mlx_funcify_FunctionGraph( | ||
fgraph, | ||
node=None, | ||
fgraph_name="mlx_funcified_fgraph", | ||
conversion_func=mlx_funcify, | ||
**kwargs, | ||
): | ||
built_kwargs = {"conversion_func": conversion_func, **kwargs} | ||
return fgraph_to_python( | ||
fgraph, | ||
conversion_func, | ||
type_conversion_fn=mlx_typify, | ||
fgraph_name=fgraph_name, | ||
**built_kwargs, | ||
) | ||
|
||
|
||
@mlx_funcify.register(DeepCopyOp) | ||
def mlx_funcify_DeepCopyOp(op, **kwargs): | ||
def deepcopyop(x): | ||
return deepcopy(x) | ||
|
||
return deepcopyop | ||
|
||
|
||
@mlx_funcify.register(Assert) | ||
@mlx_funcify.register(CheckAndRaise) | ||
def mlx_funcify_CheckAndRaise(op, node, **kwargs): | ||
conds = node.inputs[1:] | ||
if any(isinstance(cond, Constant) and not bool(cond.data) for cond in conds): | ||
raise op.exc_type(op.msg) | ||
|
||
warnings.warn( | ||
f"""Skipping `{type(op).__name__}` Op (assertion: {op.msg}) as MLX tracing would remove it.""", | ||
stacklevel=2, | ||
) | ||
|
||
def assert_fn(x, *inputs): | ||
return x | ||
|
||
return assert_fn |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import mlx.core as mx | ||
|
||
from pytensor.link.mlx.dispatch import mlx_funcify | ||
from pytensor.tensor.blockwise import Blockwise | ||
|
||
|
||
@mlx_funcify.register(Blockwise) | ||
def funcify_Blockwise(op: Blockwise, node, **kwargs): | ||
# 2) Otherwise, get the core python function for this Blockwise | ||
core_node = op._create_dummy_core_node(node.inputs) | ||
core_f = mlx_funcify(op.core_op, core_node) | ||
|
||
# 3) Determine how many inputs correspond to batch dimensions | ||
n_batch = op.batch_ndim(node) | ||
|
||
# 4) Handle case where no vectorization is needed | ||
if n_batch == 0: | ||
return core_f | ||
|
||
# 5) Vectorize using mx.vmap over any batched inputs | ||
in_axes: list[int | None] = [] | ||
for inp, sig in zip(node.inputs, op.inputs_sig): | ||
batch_ndim = inp.type.ndim - len(sig) | ||
if batch_ndim == 0: | ||
in_axes.append(None) | ||
continue | ||
|
||
batch_bcast = inp.type.broadcastable[:batch_ndim] | ||
# If all batch dims are broadcastable (size 1), treat input as static | ||
in_axes.append(0 if not all(batch_bcast) else None) | ||
|
||
if not any(axis == 0 for axis in in_axes): | ||
return core_f | ||
|
||
return mx.vmap(core_f, in_axes=tuple(in_axes)) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.