Skip to content

Conversation

williambdean
Copy link
Contributor

@williambdean williambdean commented Apr 11, 2025

Description

Getting ball rolling started with #1350

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1365.org.readthedocs.build/en/1365/

@williambdean williambdean marked this pull request as draft April 11, 2025 15:43
@williambdean williambdean marked this pull request as ready for review April 11, 2025 17:54
Copy link

codecov bot commented Apr 11, 2025

Codecov Report

❌ Patch coverage is 76.23188% with 164 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.58%. Comparing base (5227759) to head (e484ba4).
⚠️ Report is 15 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/mlx/dispatch/core.py 54.76% 60 Missing and 16 partials ⚠️
pytensor/link/mlx/dispatch/elemwise.py 77.08% 55 Missing ⚠️
pytensor/link/mlx/dispatch/math.py 72.91% 11 Missing and 2 partials ⚠️
pytensor/link/mlx/dispatch/blockwise.py 66.66% 4 Missing and 3 partials ⚠️
pytensor/link/mlx/dispatch/basic.py 92.98% 3 Missing and 1 partial ⚠️
pytensor/link/mlx/dispatch/subtensor.py 93.84% 0 Missing and 4 partials ⚠️
pytensor/link/mlx/dispatch/signal/conv.py 87.50% 2 Missing and 1 partial ⚠️
pytensor/link/mlx/dispatch/shape.py 93.10% 1 Missing and 1 partial ⚠️

❌ Your patch status has failed because the patch coverage (76.23%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1365      +/-   ##
==========================================
- Coverage   81.65%   81.58%   -0.07%     
==========================================
  Files         232      242      +10     
  Lines       53081    53771     +690     
  Branches     9403     9468      +65     
==========================================
+ Hits        43342    43870     +528     
- Misses       7286     7421     +135     
- Partials     2453     2480      +27     
Files with missing lines Coverage Δ
pytensor/compile/mode.py 85.00% <100.00%> (+0.13%) ⬆️
pytensor/link/mlx/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/link/mlx/linker.py 100.00% <100.00%> (ø)
pytensor/link/pytorch/linker.py 100.00% <ø> (ø)
pytensor/link/mlx/dispatch/shape.py 93.10% <93.10%> (ø)
pytensor/link/mlx/dispatch/signal/conv.py 87.50% <87.50%> (ø)
pytensor/link/mlx/dispatch/basic.py 92.98% <92.98%> (ø)
pytensor/link/mlx/dispatch/subtensor.py 93.84% <93.84%> (ø)
pytensor/link/mlx/dispatch/blockwise.py 66.66% <66.66%> (ø)
pytensor/link/mlx/dispatch/math.py 72.91% <72.91%> (ø)
... and 2 more

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94
Copy link
Member

I suggest basing yourself on the numba linker, torch has a lot of hacks we hopefully don't need here

@williambdean
Copy link
Contributor Author

williambdean commented Apr 12, 2025

Thanks for the pointer. I simplified the one method. Do you think that gen_functors can be removed as well? The only commonality with pytorch then is that no input can be numpy array.

@ricardoV94
Copy link
Member

ricardoV94 commented Apr 13, 2025

Yeah you shouldn't need that you just need a call to tipify on the runtime inputs as well

@williambdean
Copy link
Contributor Author

Still need to get this to run:

import pytensor

pytensor.config.mode = "MLX"

@cetagostini
Copy link
Contributor

Hey big thanks to @jessegrabowski and @ricardoV94 to help with this PR!

I feel the PR is huge enough. Should we make a first merge and start to iterate on next versions? Cleaning and making all more consistent with other backends.

Thanks to @williambdean to open the PR!

Comment on lines 67 to 101
@mlx_funcify.register(Assert)
@mlx_funcify.register(CheckAndRaise)
def mlx_funcify_CheckAndRaise(op, **kwargs):
warnings.warn(
f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as MLX tracing would remove it.""",
stacklevel=2,
)

def assert_fn(x, *inputs):
return x

return assert_fn
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true, or just copy/pasta from JAX?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to check more here!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this has been changed in JAX to raise when the condition is known to be False. We should do the same (or just implement an assert, if mlx allows that, I suspect it does not)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this was true. I'm moving to a similar JAX way.

from pytensor.tensor.signal.conv import Conv1d


def blockwise_conv1d(op, node, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed anymore since they fixed upstream right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we needed still. Where do you see its fixed? We are using this blockwise conv1d.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure but blockwise will call vmap on the core op, so we only need to dispatch core Conv1D to MLX Conv1D, then the blockwise variant will work automatically

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you confirm we don't need this specialized implementation anymore?

Also note that Convolve1d changed in main. Now mode is not a property of the Op but a runtime value of 1 when full and 0 when valid. Check how JAX is implemented now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay MLX path now matches the upstream design: blockwise simply vmap’s the core op, so no special MLX-only wrapper is needed anymore. We are effectively using the same contract JAX uses, except JAX has a static mode, while MLX now works for dynamic modes as well :)

Comment on lines +21 to +29
# Convert scalar to array if needed
if isinstance(x, int | float) or (
isinstance(x, np.number) and not isinstance(x, np.ndarray)
):
x = mx.array(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should not be needed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLX’s mx.transpose rejects plain Python or NumPy scalars (TypeError: transpose(): incompatible function arguments), so without the conversion we would crash whenever a 0-d value reaches a dimshuffle.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytensor vill not send python or numpy scalars to dimshuffle, than means you did something wrong, probably when converting constants



@mlx_funcify.register(Elemwise)
def mlx_funcify_Elemwise(op, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like CAReduce it should have a second level dispatch. Also we need to enforce the runtime_broadcastable checks (same in Alloc). And we shoud have a default implementation for that second level dispatch that tries to use getattr(MLX, "func_name") similar to how JAX does it already.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

return softmax_grad


@mlx_funcify.register(Softplus)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete this? You have one in elemwise already

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where? Can you point?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have mlx_funcify.register(Softplus) in link/mlx/dispatch/elemwise.py, but you also have it in .../math.py (where it should be). Same for Cast below

@williambdean
Copy link
Contributor Author

Were we going to split this PR up into core functionality and op implementations? What are the next steps here?

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 3, 2025

This PR seems okay state that we can merge as is when the comments are addressed

@cetagostini
Copy link
Contributor

@ricardoV94 I think all it's applied.



@pytest.mark.xfail(reason="Reshape Op is not supported yet")
def test_mlx_Reshape_various_shapes():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this similar to test_mlx_Reshape_concrete below? Combine with that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, doesn't seem appropriate in this case, as they're testing different scenarios of the reshape operation.

  1. test_mlx_Reshape_various_shapes focuses on testing different dimensional transformations with static/constant shapes.
  2. test_mlx_Reshape_concrete_shape focuses on testing computed/dynamic shapes where the shape is derived from the input tensor's properties.

Maybe they can be rename? But I feel two different things!

@cetagostini
Copy link
Contributor

cetagostini commented Jun 9, 2025

Current implementation allow to sample a simple pymc-marketing model, both gpu and cpu with MLX backend. Nevertheless complex model got issues still.

image

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

compare_mlx_and_py([], [out_pt], [])


@pytest.mark.xfail(reason="Inplace operations not yet supported in MLX mode")
Copy link
Member

@ricardoV94 ricardoV94 Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is inplace optimization something mlx does by itself (like JAX)? In that case we don't need to worry nor test them

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not find any explicit statement in the MLX documentation or codebase that MLX performs in-place optimization / mutation merging automatically in the same way that JAX sometimes does.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you're supposed to write x[idx] = 0 for inplacing? That's valid MLX syntax?

Copy link
Contributor

@cetagostini cetagostini Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that should be possible (I'll write code and double check). I'll start to make issues around MLX.

mx = pytest.importorskip("mlx.core")


def test_mlx_Subtensor_basic():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have any test with symbolic indices? Is that supported by MLX?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, MLX does not currently support general symbolic indexing or boolean-masked indexing (i.e. indexing where the selection is determined by a mask) in its graph mode. It does support index arrays.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's even more restrictive than JAX. Jax indices can be symbolic as long as they determine the output shape from the input shape. So if input indices have static shape(3,) that's fine as it knows the output will also have shape=(3,), regardless of the values

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's my understanding.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did another pass, this is looking like 99% there, left another round of comments and some change requests before

@jessegrabowski
Copy link
Member

One thing I don't like about this PR is that it's basically impossible to predict where function dispatches are. Elemwise was in math.py, but CARReduce was in elemwise.py. basic.py seems to be for shared tools, but it also has FunctionGraph, CheckAndRaise, and DeepCopyOp. core.py seems to have tensor constructors?

IMO we should aim for each linker directory to follow 1:1 the pytensor.tensor module, with the dispatch for each Op in a file/module of the same name

@cetagostini
Copy link
Contributor

cetagostini commented Oct 9, 2025

@jessegrabowski resolve all your comments @ricardoV94

I reply to the other missing comments, my opinion could be if we are 99% with the PR, the other 1% can be address on another clean PR? If you tell me things we would like to modify, I can make issues and we kick-off work to improve MLX and apply those changes.

@cetagostini
Copy link
Contributor

One thing I don't like about this PR is that it's basically impossible to predict where function dispatches are. Elemwise was in math.py, but CARReduce was in elemwise.py. basic.py seems to be for shared tools, but it also has FunctionGraph, CheckAndRaise, and DeepCopyOp. core.py seems to have tensor constructors?

IMO we should aim for each linker directory to follow 1:1 the pytensor.tensor module, with the dispatch for each Op in a file/module of the same name

Like this idea, by the way @jessegrabowski definitely for another PR? But I could do this.

@ricardoV94 ricardoV94 changed the title MLX backend POC Add MLX backend Oct 9, 2025
@ricardoV94 ricardoV94 merged commit 934306f into pymc-devs:main Oct 9, 2025
57 of 58 checks passed
@ricardoV94
Copy link
Member

Let the party begin :D

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants