Skip to content

Commit 2238474

Browse files
danielsuoGoogle-ML-Automation
authored andcommitted
[pmap] Deprecate jax.sharding.PmapSharding, jax.device_put_replicated and jax.device_put_sharded.
PiperOrigin-RevId: 831569381
1 parent ba6a5ff commit 2238474

16 files changed

+95
-13
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
4949
GPU (({jax-issue}`#33062`).
5050
5151
* Deprecations:
52+
* `jax.sharding.PmapSharding` is now deprecated. Please use
53+
`jax.NamedSharding` instead.
54+
* `jx.device_put_replicated` is now deprecated. Please use `jax.device_put`
55+
with the appropriate sharding instead.
56+
* `jax.device_put_sharded` is now deprecated. Please use `jax.device_put` with
57+
the appropriate sharding instead.
5258
* Default `axis_types` of `jax.make_mesh` will change in JAX v0.9.0 to return
5359
`jax.sharding.AxisType.Explicit`. Leaving axis_types unspecified will raise a
5460
`DeprecationWarning`.

jax/__init__.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@
9797
from jax._src.xla_bridge import device_count as device_count
9898
from jax._src.api import device_get as device_get
9999
from jax._src.api import device_put as device_put
100-
from jax._src.api import device_put_sharded as device_put_sharded
101-
from jax._src.api import device_put_replicated as device_put_replicated
100+
from jax._src.api import device_put_sharded as _deprecated_device_put_sharded
101+
from jax._src.api import device_put_replicated as _deprecated_device_put_replicated
102102
from jax._src.xla_bridge import devices as devices
103103
from jax._src.api import disable_jit as disable_jit
104104
from jax._src.api import eval_shape as eval_shape
@@ -197,6 +197,16 @@
197197
"jax.ArrayRef is deprecated; use jax.Ref instead.",
198198
Ref
199199
),
200+
# Added for v0.8.1
201+
"device_put_replicated": (
202+
"jax.device_put_replicated is deprecated; use jax.device_put instead.",
203+
_deprecated_device_put_replicated
204+
),
205+
# Added for v0.8.1
206+
"device_put_sharded": (
207+
"jax.device_put_sharded is deprecated; use jax.device_put instead.",
208+
_deprecated_device_put_sharded
209+
),
200210
# Finalized 2025-03-25; remove after 2025-06-25
201211
"treedef_is_leaf": (
202212
"jax.treedef_is_leaf was removed in JAX v0.6.0: use jax.tree_util.treedef_is_leaf.",
@@ -238,6 +248,8 @@
238248
if _typing.TYPE_CHECKING:
239249
array_ref = new_ref
240250
ArrayRef = Ref
251+
device_put_replicated = _deprecated_device_put_replicated
252+
device_put_sharded = _deprecated_device_put_sharded
241253
else:
242254
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
243255
__getattr__ = _deprecation_getattr(__name__, _deprecations)
@@ -247,3 +259,5 @@
247259
import jax.lib # TODO(phawkins): remove this export. # noqa: F401
248260

249261
# trailer
262+
del _deprecated_device_put_sharded
263+
del _deprecated_device_put_replicated

jax/_src/api.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2871,23 +2871,23 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
28712871
>>> import jax
28722872
>>> devices = jax.local_devices()
28732873
>>> x = [jax.numpy.ones(5) for device in devices]
2874-
>>> y = jax.device_put_sharded(x, devices)
2875-
>>> np.allclose(y, jax.numpy.stack(x))
2874+
>>> y = jax.device_put_sharded(x, devices) # doctest: +SKIP
2875+
>>> np.allclose(y, jax.numpy.stack(x)) # doctest: +SKIP
28762876
True
28772877
28782878
Passing a list of nested container objects with arrays at the leaves for
28792879
``shards`` corresponds to stacking the shards at each leaf. This requires
28802880
all entries in the list to have the same tree structure:
28812881
28822882
>>> x = [(i, jax.numpy.arange(i, i + 4)) for i in range(len(devices))]
2883-
>>> y = jax.device_put_sharded(x, devices)
2884-
>>> type(y)
2883+
>>> y = jax.device_put_sharded(x, devices) # doctest: +SKIP
2884+
>>> type(y) # doctest: +SKIP
28852885
<class 'tuple'>
2886-
>>> y0 = jax.device_put_sharded([a for a, b in x], devices)
2887-
>>> y1 = jax.device_put_sharded([b for a, b in x], devices)
2888-
>>> np.allclose(y[0], y0)
2886+
>>> y0 = jax.device_put_sharded([a for a, b in x], devices) # doctest: +SKIP
2887+
>>> y1 = jax.device_put_sharded([b for a, b in x], devices) # doctest: +SKIP
2888+
>>> np.allclose(y[0], y0) # doctest: +SKIP
28892889
True
2890-
>>> np.allclose(y[1], y1)
2890+
>>> np.allclose(y[1], y1) # doctest: +SKIP
28912891
True
28922892
28932893
See Also:
@@ -2953,8 +2953,8 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
29532953
>>> import jax
29542954
>>> devices = jax.local_devices()
29552955
>>> x = jax.numpy.array([1., 2., 3.])
2956-
>>> y = jax.device_put_replicated(x, devices)
2957-
>>> np.allclose(y, jax.numpy.stack([x for _ in devices]))
2956+
>>> y = jax.device_put_replicated(x, devices) # doctest: +SKIP
2957+
>>> np.allclose(y, jax.numpy.stack([x for _ in devices])) # doctest: +SKIP
29582958
True
29592959
29602960
See Also:

jax/sharding.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from jax._src.sharding_impls import (
2020
NamedSharding as NamedSharding,
2121
SingleDeviceSharding as SingleDeviceSharding,
22-
PmapSharding as PmapSharding,
22+
PmapSharding as _deprecated_PmapSharding,
2323
set_mesh as set_mesh,
2424
)
2525
from jax._src.partition_spec import (
@@ -39,3 +39,21 @@
3939
auto_axes as auto_axes,
4040
explicit_axes as explicit_axes,
4141
)
42+
43+
_deprecations = {
44+
# Added for v0.8.1
45+
"PmapSharding": (
46+
"jax.sharding.PmapSharding is deprecated; use jax.sharding.NamedSharding instead.",
47+
_deprecated_PmapSharding
48+
),
49+
}
50+
51+
import typing as _typing
52+
if _typing.TYPE_CHECKING:
53+
PmapSharding = _deprecated_PmapSharding
54+
else:
55+
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
56+
__getattr__ = _deprecation_getattr(__name__, _deprecations)
57+
del _deprecation_getattr
58+
del _typing
59+
del _deprecated_PmapSharding

tests/array_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,7 @@ def test_uneven_shard_error(self):
946946
r"factors: \[4, 2\] should evenly divide the shape\)"):
947947
mps.shard_shape((8, 3))
948948

949+
@jtu.ignore_warning(category=DeprecationWarning)
949950
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
950951
def test_pmap_sharding_hash_eq(self):
951952
if jax.device_count() < 2:
@@ -1056,6 +1057,7 @@ def test_pspec_tuple(self):
10561057
('sharded_dim_2', (4, 2, 4), 2),
10571058
('sharded_dim_1_1', (2, 4), 1)
10581059
)
1060+
@jtu.ignore_warning(category=DeprecationWarning)
10591061
def test_default_pmap_sharding(self, shape, sharded_dim):
10601062
if jax.device_count() < 4:
10611063
self.skipTest('Test needs >= 4 devices.')
@@ -1081,6 +1083,7 @@ def test_default_pmap_sharding(self, shape, sharded_dim):
10811083
self.assertEqual(actual_sharding.sharding_spec, expected_sharding.sharding_spec)
10821084
self.assertEqual(actual_sharding._device_assignment, expected_sharding._device_assignment)
10831085

1086+
@jtu.ignore_warning(category=DeprecationWarning)
10841087
def test_default_pmap_sharding_with_devices(self):
10851088
if jax.device_count() < 4:
10861089
self.skipTest('Test needs >= 4 devices.')
@@ -1090,6 +1093,7 @@ def test_default_pmap_sharding_with_devices(self):
10901093
ps = jax.sharding.PmapSharding.default((4, 2), devices=new_order)
10911094
self.assertEqual(ps._device_assignment, new_order)
10921095

1096+
@jtu.ignore_warning(category=DeprecationWarning)
10931097
def test_default_pmap_sharding_replicated(self):
10941098
x = np.zeros((len(jax.local_devices()), 8), dtype=np.float32)
10951099
x = jax.pmap(lambda x: x, in_axes=0, out_axes=None, axis_name='x')(x)

tests/debugging_primitives_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,7 @@ def test_visualize_wide_array(self):
11051105
""")
11061106
self.assertEqual(output(), expected)
11071107

1108+
@jtu.ignore_warning(category=DeprecationWarning)
11081109
def test_visualize_pmap_sharding(self):
11091110
ss = pxla.ShardingSpec(
11101111
sharding=(pxla.Unstacked(8),),

tests/multiprocess/all_gather_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class AllGatherTest(jt_multiprocess.MultiProcessTest):
3030
@parameterized.parameters(
3131
(np.int32,), (jnp.float32,), (jnp.float16,), (jnp.bfloat16,)
3232
)
33+
@jtu.ignore_warning(category=DeprecationWarning)
3334
def test_all_gather(self, dtype):
3435
f = jax.pmap(lambda x: lax.all_gather(x, "i"), axis_name="i")
3536
xs = randint_sample(

tests/multiprocess/all_reduce_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def randint_sample(shape):
2727

2828
class AllReduceTest(jt_multiprocess.MultiProcessTest):
2929

30+
@jtu.ignore_warning(category=DeprecationWarning)
3031
def test_psum_simple(self):
3132
f = jax.pmap(lambda x: lax.psum(x, "i"), "i", devices=jax.devices())
3233
np.testing.assert_array_equal(
@@ -37,6 +38,7 @@ def test_psum_simple(self):
3738
@parameterized.parameters(
3839
(np.int32,), (jnp.float32,), (jnp.float16,), (jnp.bfloat16,)
3940
)
41+
@jtu.ignore_warning(category=DeprecationWarning)
4042
def test_psum(self, dtype):
4143
f = jax.pmap(lambda x: lax.psum(x, "i"), axis_name="i")
4244
xs = randint_sample(
@@ -47,6 +49,7 @@ def test_psum(self, dtype):
4749
for actual in out:
4850
jtu.check_close(actual, expected)
4951

52+
@jtu.ignore_warning(category=DeprecationWarning)
5053
def test_psum_subset_devices(self):
5154
f = jax.pmap(
5255
lambda x: lax.psum(x, "i"), axis_name="i", devices=jax.local_devices()
@@ -57,6 +60,7 @@ def test_psum_subset_devices(self):
5760
for actual in out:
5861
np.testing.assert_array_equal(actual, expected)
5962

63+
@jtu.ignore_warning(category=DeprecationWarning)
6064
def test_psum_del(self): # b/171945402
6165
f = jax.pmap(lambda x: lax.psum(x, "i"), axis_name="i")
6266
g = jax.pmap(lambda x: lax.psum(x, "i"), axis_name="i")
@@ -73,6 +77,7 @@ def test_psum_del(self): # b/171945402
7377
for actual in out:
7478
np.testing.assert_array_equal(actual, expected)
7579

80+
@jtu.ignore_warning(category=DeprecationWarning)
7681
def test_psum_multiple_operands(self):
7782
f = jax.pmap(lambda x: lax.psum(x, "i"), axis_name="i")
7883
xs = randint_sample([jax.process_count(), jax.local_device_count(), 100])
@@ -85,6 +90,7 @@ def test_psum_multiple_operands(self):
8590
for actual in out_ys:
8691
np.testing.assert_array_equal(actual, expected_ys)
8792

93+
@jtu.ignore_warning(category=DeprecationWarning)
8894
def test_psum_axis_index_groups(self):
8995
devices = list(range(jax.device_count()))
9096
axis_index_groups = [devices[0::2], devices[1::2]]

tests/multiprocess/all_to_all_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class AllToAllTest(jt_multiprocess.MultiProcessTest):
3030

3131
@parameterized.parameters((np.int32,), (jnp.float32,), (jnp.float16,),
3232
(jnp.bfloat16,))
33+
@jtu.ignore_warning(category=DeprecationWarning)
3334
def test_all_to_all(self, dtype):
3435
f = jax.pmap(
3536
lambda x: lax.all_to_all(x, "i", split_axis=0, concat_axis=0),

tests/multiprocess/multihost_utils_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,7 @@ def test_host_local_to_global_reshard_committed_single_device_array(self):
419419
)
420420
np.testing.assert_array_equal(o.data, global_data[o.index])
421421

422+
@jtu.ignore_warning(category=DeprecationWarning)
422423
def test_host_local_to_global_replicated(self):
423424
num_local_devices = jax.local_device_count()
424425
global_mesh = jax.sharding.Mesh(jax.devices(), axis_names=['x'])
@@ -435,6 +436,7 @@ def test_host_local_to_global_replicated(self):
435436
# Array is accessible on every host.
436437
np.testing.assert_array_equal(out, local_input_data)
437438

439+
@jtu.ignore_warning(category=DeprecationWarning)
438440
def test_host_local_to_global_locally_replicated(self):
439441
# Make an array which is locally replicated but sharded across hosts.
440442
num_processes = jax.process_count()

0 commit comments

Comments
 (0)