Parallel setup / Auto Sharding mode for general tensor operation? #32494
Replies: 1 comment
-
Update: Why JAX does not automatically let some devices idle if needed? (quite rigid out of ML training scope) I tried coding a jax.device_put macro for automating device count with tensor dimension: import jax, jax.numpy as jnp, jax.lax as lax, numpy as np; from jax.tree_util import Partial
from typing import Sequence
uint = lambda el, b = 1: el // b + b * bool(el % b)
lmax = lambda *l : int(.5*(sum(l) + abs( l[0]-l[1])) )
def tileme( info : Sequence[int] | int = (8,), devices = np.array(jax.devices()) ):
""" automatic dispatch with convenient parameters
devices = None or list of devices
"""
axsn = tuple( str(el) for el in range(len(info)) )#('cathair',)
# if not devices:
# devices = jax.devices()
# maxDev = (len(jax.devices()),)
# else:
devices = np.asarray(devices)
maxDev = np.array(devices).shape
mesh = Mesh( devices[tuple( slice( lmax( el, el//uint(el/al) ) ) for el,al in zip(info,maxDev) )],
axis_names = axsn,
axis_types =(jax.sharding.AxisType.Auto,),
)
spec = jax.P(*axsn)
sharding = jax.sharding.NamedSharding(mesh, spec)
return Partial(jax.device_put , device = sharding, may_alias = True)
precision = jax.lax.Precision('high')
@jax.jit
def f(x: jax.Array):
A = jnp.ones((x.shape[-1],3), dtype = jnp.float32)
return jax.lax.dot( tileme(info=x.shape[-1:])(x), tileme(info=x.shape[-1:])(A), dimension_numbers = (((x.ndim-1,),(0,)),((),())),
precision = precision) This works on z = jnp.ones((4,3,2,7))
f(z)
It seems I must explicitly state every bit of the tensor shape as axis info of the jax.device_put call, however it should not be needed (since jax.Array same as numpy.array, tensor indexing is virtually calling a very lengthy contiguous vector with regular stepping) Same exercise for case 1: Y = jnp.ones( (32,64,7) )
O = lax.psum( tileme(info = Y.shape[1:2])(Y), tuple(range(1,2)))
O.sharding # NamedSharding(mesh=Mesh('0': 16, axis_types=(Auto,)), spec=PartitionSpec('0',), memory_kind=device) Now the sharding works , case 1 solved ! Just a gentle nudge for @jakevdp, @emilyfertig , @superbobry, @hawkinsp, @cgarciae on the topic. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
This is an attempt to get a simple working guide on parallel operations:

I searched for a lot of examples, but none did highlight the Auto sharding mode (
jax.sharding.AxisType.Auto
) in a way I can work with.I have a lot of applications in mind where stating how my mesh dispatch an Array is not advantageous (I rather have JAX decide mesh configuration arbitrarely before the computation)
cf https://docs.jax.dev/en/latest/sharded-computation.html
Here are 3 applications I'd like to tackle:
how to parallelize case 3? given : #11394
This question also remains in a case where I want to allocate a contiguous chunk (or chunks) of memory to "fill in parallel" with case 3 solution (if we can get one). Case 3 is also useful if I want to do mix sized mini batches (images of different sizes)
Is the SingleDeviceSharding state due to jax.device_put working in eager mode ? (dispatching on the mesh only for computation)
Thanks in advance!
NB : setting :
Pmode =1
vsPmode =0
have a slight negative impact on performance for jacobian reverse applications (measuring with
timeit.timeit
and.block_until_ready()
method). I don't see substantial speedup like in a truly parallelized case..Beta Was this translation helpful? Give feedback.
All reactions