-
Notifications
You must be signed in to change notification settings - Fork 143
Add MLX backend #1365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MLX backend #1365
Conversation
Codecov Report❌ Patch coverage is ❌ Your patch status has failed because the patch coverage (76.23%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1365 +/- ##
==========================================
- Coverage 81.65% 81.58% -0.07%
==========================================
Files 232 242 +10
Lines 53081 53771 +690
Branches 9403 9468 +65
==========================================
+ Hits 43342 43870 +528
- Misses 7286 7421 +135
- Partials 2453 2480 +27
🚀 New features to boost your workflow:
|
I suggest basing yourself on the numba linker, torch has a lot of hacks we hopefully don't need here |
Thanks for the pointer. I simplified the one method. Do you think that |
Yeah you shouldn't need that you just need a call to tipify on the runtime inputs as well |
Still need to get this to run: import pytensor
pytensor.config.mode = "MLX" |
Hey big thanks to @jessegrabowski and @ricardoV94 to help with this PR! I feel the PR is huge enough. Should we make a first merge and start to iterate on next versions? Cleaning and making all more consistent with other backends. Thanks to @williambdean to open the PR! |
@mlx_funcify.register(Assert) | ||
@mlx_funcify.register(CheckAndRaise) | ||
def mlx_funcify_CheckAndRaise(op, **kwargs): | ||
warnings.warn( | ||
f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as MLX tracing would remove it.""", | ||
stacklevel=2, | ||
) | ||
|
||
def assert_fn(x, *inputs): | ||
return x | ||
|
||
return assert_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this true, or just copy/pasta from JAX?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to check more here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note this has been changed in JAX to raise when the condition is known to be False. We should do the same (or just implement an assert, if mlx allows that, I suspect it does not)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, this was true. I'm moving to a similar JAX way.
from pytensor.tensor.signal.conv import Conv1d | ||
|
||
|
||
def blockwise_conv1d(op, node, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed anymore since they fixed upstream right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we needed still. Where do you see its fixed? We are using this blockwise conv1d.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure but blockwise will call vmap on the core op, so we only need to dispatch core Conv1D to MLX Conv1D, then the blockwise variant will work automatically
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you confirm we don't need this specialized implementation anymore?
Also note that Convolve1d changed in main. Now mode is not a property of the Op but a runtime value of 1 when full and 0 when valid. Check how JAX is implemented now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay MLX path now matches the upstream design: blockwise simply vmap’s the core op, so no special MLX-only wrapper is needed anymore. We are effectively using the same contract JAX uses, except JAX has a static mode, while MLX now works for dynamic modes as well :)
# Convert scalar to array if needed | ||
if isinstance(x, int | float) or ( | ||
isinstance(x, np.number) and not isinstance(x, np.ndarray) | ||
): | ||
x = mx.array(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should not be needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MLX’s mx.transpose rejects plain Python or NumPy scalars (TypeError: transpose(): incompatible function arguments), so without the conversion we would crash whenever a 0-d value reaches a dimshuffle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytensor vill not send python or numpy scalars to dimshuffle, than means you did something wrong, probably when converting constants
pytensor/link/mlx/dispatch/math.py
Outdated
|
||
|
||
@mlx_funcify.register(Elemwise) | ||
def mlx_funcify_Elemwise(op, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like CAReduce it should have a second level dispatch. Also we need to enforce the runtime_broadcastable
checks (same in Alloc). And we shoud have a default implementation for that second level dispatch that tries to use getattr(MLX, "func_name")
similar to how JAX does it already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
return softmax_grad | ||
|
||
|
||
@mlx_funcify.register(Softplus) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Delete this? You have one in elemwise already
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where? Can you point?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have mlx_funcify.register(Softplus)
in link/mlx/dispatch/elemwise.py
, but you also have it in .../math.py
(where it should be). Same for Cast below
Were we going to split this PR up into core functionality and op implementations? What are the next steps here? |
This PR seems okay state that we can merge as is when the comments are addressed |
@ricardoV94 I think all it's applied. |
|
||
|
||
@pytest.mark.xfail(reason="Reshape Op is not supported yet") | ||
def test_mlx_Reshape_various_shapes(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this similar to test_mlx_Reshape_concrete
below? Combine with that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean, doesn't seem appropriate in this case, as they're testing different scenarios of the reshape operation.
test_mlx_Reshape_various_shapes
focuses on testing different dimensional transformations with static/constant shapes.test_mlx_Reshape_concrete_shape
focuses on testing computed/dynamic shapes where the shape is derived from the input tensor's properties.
Maybe they can be rename? But I feel two different things!
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
compare_mlx_and_py([], [out_pt], []) | ||
|
||
|
||
@pytest.mark.xfail(reason="Inplace operations not yet supported in MLX mode") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is inplace optimization something mlx does by itself (like JAX)? In that case we don't need to worry nor test them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could not find any explicit statement in the MLX documentation or codebase that MLX performs in-place optimization / mutation merging automatically in the same way that JAX sometimes does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So you're supposed to write x[idx] = 0
for inplacing? That's valid MLX syntax?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that should be possible (I'll write code and double check). I'll start to make issues around MLX.
mx = pytest.importorskip("mlx.core") | ||
|
||
|
||
def test_mlx_Subtensor_basic(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you have any test with symbolic indices? Is that supported by MLX?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, MLX does not currently support general symbolic indexing or boolean-masked indexing (i.e. indexing where the selection is determined by a mask) in its graph mode. It does support index arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's even more restrictive than JAX. Jax indices can be symbolic as long as they determine the output shape from the input shape. So if input indices have static shape(3,)
that's fine as it knows the output will also have shape=(3,)
, regardless of the values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's my understanding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did another pass, this is looking like 99% there, left another round of comments and some change requests before
4fc4248
to
d16d245
Compare
One thing I don't like about this PR is that it's basically impossible to predict where function dispatches are. Elemwise was in math.py, but CARReduce was in elemwise.py. basic.py seems to be for shared tools, but it also has FunctionGraph, CheckAndRaise, and DeepCopyOp. core.py seems to have tensor constructors? IMO we should aim for each linker directory to follow 1:1 the |
@jessegrabowski resolve all your comments @ricardoV94 I reply to the other missing comments, my opinion could be if we are 99% with the PR, the other 1% can be address on another clean PR? If you tell me things we would like to modify, I can make issues and we kick-off work to improve MLX and apply those changes. |
Like this idea, by the way @jessegrabowski definitely for another PR? But I could do this. |
Let the party begin :D |
Description
Getting ball rolling started with #1350
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1365.org.readthedocs.build/en/1365/