@@ -27,6 +27,7 @@ def randint_sample(shape):
2727
2828class AllReduceTest (jt_multiprocess .MultiProcessTest ):
2929
30+ @jtu .ignore_warning (category = DeprecationWarning )
3031 def test_psum_simple (self ):
3132 f = jax .pmap (lambda x : lax .psum (x , "i" ), "i" , devices = jax .devices ())
3233 np .testing .assert_array_equal (
@@ -37,6 +38,7 @@ def test_psum_simple(self):
3738 @parameterized .parameters (
3839 (np .int32 ,), (jnp .float32 ,), (jnp .float16 ,), (jnp .bfloat16 ,)
3940 )
41+ @jtu .ignore_warning (category = DeprecationWarning )
4042 def test_psum (self , dtype ):
4143 f = jax .pmap (lambda x : lax .psum (x , "i" ), axis_name = "i" )
4244 xs = randint_sample (
@@ -47,6 +49,7 @@ def test_psum(self, dtype):
4749 for actual in out :
4850 jtu .check_close (actual , expected )
4951
52+ @jtu .ignore_warning (category = DeprecationWarning )
5053 def test_psum_subset_devices (self ):
5154 f = jax .pmap (
5255 lambda x : lax .psum (x , "i" ), axis_name = "i" , devices = jax .local_devices ()
@@ -57,6 +60,7 @@ def test_psum_subset_devices(self):
5760 for actual in out :
5861 np .testing .assert_array_equal (actual , expected )
5962
63+ @jtu .ignore_warning (category = DeprecationWarning )
6064 def test_psum_del (self ): # b/171945402
6165 f = jax .pmap (lambda x : lax .psum (x , "i" ), axis_name = "i" )
6266 g = jax .pmap (lambda x : lax .psum (x , "i" ), axis_name = "i" )
@@ -73,6 +77,7 @@ def test_psum_del(self): # b/171945402
7377 for actual in out :
7478 np .testing .assert_array_equal (actual , expected )
7579
80+ @jtu .ignore_warning (category = DeprecationWarning )
7681 def test_psum_multiple_operands (self ):
7782 f = jax .pmap (lambda x : lax .psum (x , "i" ), axis_name = "i" )
7883 xs = randint_sample ([jax .process_count (), jax .local_device_count (), 100 ])
@@ -85,6 +90,7 @@ def test_psum_multiple_operands(self):
8590 for actual in out_ys :
8691 np .testing .assert_array_equal (actual , expected_ys )
8792
93+ @jtu .ignore_warning (category = DeprecationWarning )
8894 def test_psum_axis_index_groups (self ):
8995 devices = list (range (jax .device_count ()))
9096 axis_index_groups = [devices [0 ::2 ], devices [1 ::2 ]]
0 commit comments