@@ -106,5 +106,64 @@ def add(params, a, *, b):
106106 np .testing .assert_allclose (np .float64 (y ), np .float64 (5 * x + 10 ))
107107
108108
109+ class DataShardingTest (parameterized .TestCase ):
110+ def setUp (self ):
111+ if jax .device_count () < 4 :
112+ self .skipTest ('At least 4 devices required' )
113+
114+ @parameterized .product (num_devices = ["all" , 2 ])
115+ def test_prefetch_to_device (self , num_devices ):
116+ devices = jax .local_devices ()
117+ if isinstance (num_devices , int ):
118+ devices = devices [:num_devices ]
119+ shape = (len (devices ), 4 , 16 , 16 , 3 )
120+ iterator = (jnp .ones (shape ) for _ in range (4 ))
121+
122+ data_iter = jax_utils .prefetch_to_device (iterator , size = 3 , devices = devices )
123+ for _ in range (4 ):
124+ data = next (data_iter )
125+ self .assertEqual (data .shape , shape )
126+ self .assertIsNotNone (data .sharding )
127+ sharding_slices_per_device = data .sharding .devices_indices_map (tuple (data .shape ))
128+ self .assertEqual (len (sharding_slices_per_device ), len (devices ))
129+ # Here we check that sharding_slices_per_device is like
130+ # Device(id=2): (slice(2, 3, None), slice(None, None, None), ..., slice(None, None, None))
131+ for i , dev in enumerate (devices ):
132+ sharding_slice = sharding_slices_per_device [dev ]
133+ self .assertEqual (sharding_slice [0 ], slice (i + 0 , i + 1 , None ))
134+ for sharding_slice_j in sharding_slice [1 :]:
135+ self .assertEqual (sharding_slice_j , slice (None , None , None ))
136+
137+ @parameterized .product (num_devices = ["all" , 2 ])
138+ def test_replicate (self , num_devices ):
139+ devices = jax .local_devices ()
140+ if isinstance (num_devices , int ):
141+ devices = devices [:num_devices ]
142+ num_batches = 5
143+ shape = (2 , 3 )
144+ data_tree = [
145+ i * jnp .ones ((2 , 3 )) for i in range (num_batches - 2 )
146+ ] + [4 , 5 * np .ones (shape )]
147+ out_tree = jax_utils .replicate (data_tree , devices = devices )
148+
149+ def check_sharding (p ):
150+ if p .ndim == 1 :
151+ self .assertEqual (p .shape , (len (devices ),))
152+ else :
153+ self .assertEqual (p .shape , (len (devices ), * shape ))
154+ self .assertIsNotNone (p .sharding )
155+ sharding_slices_per_device = p .sharding .devices_indices_map (tuple (p .shape ))
156+ self .assertEqual (len (sharding_slices_per_device ), len (devices ))
157+ # Here we check that sharding_slices_per_device is like
158+ # Device(id=2): (slice(2, 3, None), slice(None, None, None), slice(None, None, None))
159+ for i , dev in enumerate (devices ):
160+ sharding_slice = sharding_slices_per_device [dev ]
161+ self .assertEqual (sharding_slice [0 ], slice (i + 0 , i + 1 , None ))
162+ for sharding_slice_j in sharding_slice [1 :]:
163+ self .assertEqual (sharding_slice_j , slice (None , None , None ))
164+
165+ jax .tree .map (check_sharding , out_tree )
166+
167+
109168if __name__ == '__main__' :
110169 absltest .main ()
0 commit comments