Skip to content

Commit 5acf4e6

Browse files
committed
Rewritten flax.jax_utils.prefetch_to_device and flax.jax_utils.replicate using jax.device_put
1 parent bee81ed commit 5acf4e6

File tree

2 files changed

+92
-6
lines changed

2 files changed

+92
-6
lines changed

flax/jax_utils.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from jax import core, lax
2626
from jax.extend import linear_util as lu
2727
from jax.interpreters import partial_eval as pe
28+
from jax.sharding import NamedSharding, PartitionSpec as P, AxisType
2829

2930

3031
def _pmap_device_order():
@@ -42,7 +43,24 @@ def replicate(tree, devices=None):
4243
A new pytree containing the replicated arrays.
4344
"""
4445
devices = devices or _pmap_device_order()
45-
return jax.device_put_replicated(tree, devices)
46+
mesh = jax.make_mesh(
47+
(len(devices),),
48+
("_flax_jax_utils_replicate_data_axis",),
49+
(AxisType.Auto,),
50+
devices=devices,
51+
)
52+
data_sharding = NamedSharding(mesh, P("_flax_jax_utils_replicate_data_axis"))
53+
54+
def _device_put_replicated(x):
55+
if isinstance(x, (jax.Array, np.ndarray)):
56+
buf = x[None]
57+
else:
58+
buf = jnp.asarray(x)[None]
59+
buf = jnp.concat([buf] * len(devices))
60+
return jax.device_put(buf, data_sharding)
61+
62+
with jax.set_mesh(mesh):
63+
return jax.tree.map(_device_put_replicated, tree)
4664

4765

4866
def unreplicate(tree):
@@ -137,17 +155,26 @@ def prefetch_to_device(iterator, size, devices=None):
137155
queue = collections.deque()
138156
devices = _pmap_device_order() if devices is None else devices
139157

158+
mesh = jax.make_mesh(
159+
(len(devices),),
160+
("_flax_jax_utils_prefetch_to_device_data_axis",),
161+
(AxisType.Auto,),
162+
devices=devices,
163+
)
164+
data_sharding = NamedSharding(mesh, P("_flax_jax_utils_prefetch_to_device_data_axis"))
165+
140166
def _prefetch(xs):
141-
return jax.device_put_sharded(list(xs), devices)
167+
return jax.device_put(xs, data_sharding)
142168

143169
def enqueue(n): # Enqueues *up to* `n` elements from the iterator.
144170
for data in itertools.islice(iterator, n):
145171
queue.append(jax.tree_util.tree_map(_prefetch, data))
146172

147-
enqueue(size) # Fill up the buffer.
148-
while queue:
149-
yield queue.popleft()
150-
enqueue(1)
173+
with jax.set_mesh(mesh):
174+
enqueue(size) # Fill up the buffer.
175+
while queue:
176+
yield queue.popleft()
177+
enqueue(1)
151178

152179

153180
def _scan_nd(body_fn, init, xs, n=1, unroll=(1,)):

tests/jax_utils_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,5 +106,64 @@ def add(params, a, *, b):
106106
np.testing.assert_allclose(np.float64(y), np.float64(5 * x + 10))
107107

108108

109+
class DataShardingTest(parameterized.TestCase):
110+
def setUp(self):
111+
if jax.device_count() < 4:
112+
self.skipTest('At least 4 devices required')
113+
114+
@parameterized.product(num_devices= ["all", 2])
115+
def test_prefetch_to_device(self, num_devices):
116+
devices = jax.local_devices()
117+
if isinstance(num_devices, int):
118+
devices = devices[:num_devices]
119+
shape = (len(devices), 4, 16, 16, 3)
120+
iterator = (jnp.ones(shape) for _ in range(4))
121+
122+
data_iter = jax_utils.prefetch_to_device(iterator, size=3, devices=devices)
123+
for _ in range(4):
124+
data = next(data_iter)
125+
self.assertEqual(data.shape, shape)
126+
self.assertIsNotNone(data.sharding)
127+
sharding_slices_per_device = data.sharding.devices_indices_map(tuple(data.shape))
128+
self.assertEqual(len(sharding_slices_per_device), len(devices))
129+
# Here we check that sharding_slices_per_device is like
130+
# Device(id=2): (slice(2, 3, None), slice(None, None, None), ..., slice(None, None, None))
131+
for i, dev in enumerate(devices):
132+
sharding_slice = sharding_slices_per_device[dev]
133+
self.assertEqual(sharding_slice[0], slice(i + 0, i + 1, None))
134+
for sharding_slice_j in sharding_slice[1:]:
135+
self.assertEqual(sharding_slice_j, slice(None, None, None))
136+
137+
@parameterized.product(num_devices= ["all", 2])
138+
def test_replicate(self, num_devices):
139+
devices = jax.local_devices()
140+
if isinstance(num_devices, int):
141+
devices = devices[:num_devices]
142+
num_batches = 5
143+
shape = (2, 3)
144+
data_tree = [
145+
i * jnp.ones((2, 3)) for i in range(num_batches - 2)
146+
] + [4, 5 * np.ones(shape)]
147+
out_tree = jax_utils.replicate(data_tree, devices=devices)
148+
149+
def check_sharding(p):
150+
if p.ndim == 1:
151+
self.assertEqual(p.shape, (len(devices),))
152+
else:
153+
self.assertEqual(p.shape, (len(devices), *shape))
154+
self.assertIsNotNone(p.sharding)
155+
sharding_slices_per_device = p.sharding.devices_indices_map(tuple(p.shape))
156+
self.assertEqual(len(sharding_slices_per_device), len(devices))
157+
# Here we check that sharding_slices_per_device is like
158+
# Device(id=2): (slice(2, 3, None), slice(None, None, None), slice(None, None, None))
159+
for i, dev in enumerate(devices):
160+
sharding_slice = sharding_slices_per_device[dev]
161+
self.assertEqual(sharding_slice[0], slice(i + 0, i + 1, None))
162+
for sharding_slice_j in sharding_slice[1:]:
163+
self.assertEqual(sharding_slice_j, slice(None, None, None))
164+
165+
jax.tree.map(check_sharding, out_tree)
166+
167+
109168
if __name__ == '__main__':
110169
absltest.main()

0 commit comments

Comments
 (0)