-
Couldn't load subscription status.
- Fork 19.6k
Fix ModelParallel OOM issue during weight loading #21723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
eda5176
5da9108
9886e40
92bf1ed
250c19c
6e222c9
6197d21
cfc95da
0f02b80
2bb83c6
10d0d0f
5dfa590
e03bcee
10f04f7
ac7d4e8
db78464
0c56129
2e61b6e
a31c8e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,6 +33,14 @@ | |
| reason="Backend specific test", | ||
| ) | ||
| class JaxDistributionLibTest(testing.TestCase): | ||
| def _require_min_devices(self, min_devices): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How come this was not needed before? |
||
| """Skip test if fewer than min_devices are available.""" | ||
| if len(jax.devices()) < min_devices: | ||
| pytest.skip( | ||
| f"Test requires at least {min_devices} devices, " | ||
| f"but only {len(jax.devices())} available" | ||
| ) | ||
|
|
||
| def _create_jax_layout(self, sharding): | ||
| # Use jax_layout.Format or jax_layout.Layout if available. | ||
| if hasattr(jax_layout, "Format"): | ||
|
|
@@ -43,6 +51,7 @@ def _create_jax_layout(self, sharding): | |
| return sharding | ||
|
|
||
| def test_list_devices(self): | ||
| self._require_min_devices(8) | ||
| self.assertEqual(len(distribution_lib.list_devices()), 8) | ||
| self.assertEqual(len(distribution_lib.list_devices("cpu")), 8) | ||
| self.assertEqual(len(distribution_lib.list_devices("cpu")), 8) | ||
|
|
@@ -77,6 +86,7 @@ def test_initialize_with_coordinator_address(self, mock_jax_initialize): | |
| ) | ||
|
|
||
| def test_distribute_tensor(self): | ||
| self._require_min_devices(8) | ||
| jax_mesh = jax.sharding.Mesh( | ||
| np.array(jax.devices()).reshape(2, 4), ("batch", "model") | ||
| ) | ||
|
|
@@ -101,6 +111,7 @@ def test_function(inputs, target_layout): | |
| self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2)) | ||
|
|
||
| def test_distribute_variable(self): | ||
| self._require_min_devices(8) | ||
| # This test only verify the single worker/process behavior. | ||
| jax_mesh = jax.sharding.Mesh( | ||
| np.array(jax.devices()).reshape(2, 4), ("batch", "model") | ||
|
|
@@ -118,6 +129,7 @@ def test_distribute_variable(self): | |
| self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2)) | ||
|
|
||
| def test_distribute_input_data(self): | ||
| self._require_min_devices(8) | ||
| # This test only verify the single worker/process behavior. | ||
| # The multi-process test lives in g3. | ||
| jax_mesh = jax.sharding.Mesh( | ||
|
|
@@ -136,6 +148,7 @@ def test_distribute_input_data(self): | |
| self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2)) | ||
|
|
||
| def test_distribute_tensor_with_jax_layout(self): | ||
| self._require_min_devices(8) | ||
| jax_mesh = jax.sharding.Mesh( | ||
| np.array(jax.devices()).reshape(2, 4), ("batch", "model") | ||
| ) | ||
|
|
@@ -166,6 +179,7 @@ def test_function(inputs, target_layout): | |
| ) | ||
|
|
||
| def test_distribute_variable_with_jax_layout(self): | ||
| self._require_min_devices(8) | ||
| # This test only verify the single worker/process behavior. | ||
| jax_mesh = jax.sharding.Mesh( | ||
| np.array(jax.devices()).reshape(2, 4), ("batch", "model") | ||
|
|
@@ -187,6 +201,7 @@ def test_distribute_variable_with_jax_layout(self): | |
| ) | ||
|
|
||
| def test_distribute_input_data_with_jax_layout(self): | ||
| self._require_min_devices(8) | ||
| # This test only verify the single worker/process behavior. | ||
| jax_mesh = jax.sharding.Mesh( | ||
| np.array(jax.devices()).reshape(2, 4), ("batch", "model") | ||
|
|
@@ -212,6 +227,7 @@ def test_processes(self): | |
| self.assertEqual(backend_dlib.num_processes(), 1) | ||
|
|
||
| def test_to_backend_mesh(self): | ||
| self._require_min_devices(8) | ||
| devices = [f"cpu:{i}" for i in range(8)] | ||
| shape = (4, 2) | ||
| axis_names = ["batch", "model"] | ||
|
|
@@ -224,6 +240,7 @@ def test_to_backend_mesh(self): | |
| self.assertEqual(jax_mesh.axis_names, ("batch", "model")) | ||
|
|
||
| def test_to_backend_layout(self): | ||
| self._require_min_devices(8) | ||
| axes = ["data", None] | ||
| mesh = distribution_lib.DeviceMesh( | ||
| (4, 2), ["data", "model"], [f"cpu:{i}" for i in range(8)] | ||
|
|
@@ -248,6 +265,7 @@ def test_validation_for_device_mesh(self): | |
| backend_dlib._to_backend_layout(layout) | ||
|
|
||
| def test_variable_assignment_reuse_layout(self): | ||
| self._require_min_devices(8) | ||
| shape = (4, 2) | ||
| axis_names = ["batch", "model"] | ||
| device_mesh = distribution_lib.DeviceMesh( | ||
|
|
@@ -310,6 +328,7 @@ def test_e2e_data_parallel_model(self): | |
| model.fit(inputs, labels) | ||
|
|
||
| def test_e2e_model_parallel_model(self): | ||
| self._require_min_devices(8) | ||
| shape = (4, 2) | ||
| axis_names = ["batch", "model"] | ||
| device_mesh = distribution_lib.DeviceMesh( | ||
|
|
@@ -349,6 +368,7 @@ def test_e2e_model_parallel_model(self): | |
| model.fit(inputs, labels) | ||
|
|
||
| def test_e2e_model_parallel_with_output_sharding(self): | ||
| self._require_min_devices(8) | ||
| shape = (4, 2) | ||
| axis_names = ["batch", "model"] | ||
| device_mesh = distribution_lib.DeviceMesh( | ||
|
|
@@ -405,6 +425,7 @@ def test_e2e_model_parallel_with_output_sharding(self): | |
| ) | ||
|
|
||
| def test_distribute_data_input(self): | ||
| self._require_min_devices(4) | ||
| per_process_batch = jax.numpy.arange(24).reshape( | ||
| 6, 4 | ||
| ) # Example input array | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.