Skip to content

Commit 2816d8c

Browse files
fix(strategies): correct device_mesh type hint in FSDP strategies (#21581)
1 parent 283ce77 commit 2816d8c

File tree

6 files changed

+28
-4
lines changed

6 files changed

+28
-4
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2424

2525
### Fixed
2626

27-
-
27+
- Fixed `device_mesh` type hint in `FSDPStrategy` to accept a 2-element tuple via the CLI ([#21581](https://github.com/Lightning-AI/pytorch-lightning/pull/21581))
2828

2929
---
3030

src/lightning/fabric/strategies/fsdp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __init__(
150150
activation_checkpointing_policy: Optional["_POLICY"] = None,
151151
sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD",
152152
state_dict_type: Literal["full", "sharded"] = "sharded",
153-
device_mesh: Optional[Union[tuple[int], "DeviceMesh"]] = None,
153+
device_mesh: Optional[Union[tuple[int, int], "DeviceMesh"]] = None,
154154
**kwargs: Any,
155155
) -> None:
156156
super().__init__(

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2727

2828
- Fixed `val_check_interval` raising `ValueError` when `limit_val_batches=0` and interval exceeds training batches ([#21560](https://github.com/Lightning-AI/pytorch-lightning/pull/21560))
2929

30-
-
30+
- Fixed `device_mesh` type hint in `FSDPStrategy` to accept a 2-element tuple via the CLI ([#21581](https://github.com/Lightning-AI/pytorch-lightning/pull/21581))
3131

3232
- Fixed ``RichModelSummary`` model size display formatting ([#21467](https://github.com/Lightning-AI/pytorch-lightning/pull/21467))
3333

src/lightning/pytorch/strategies/fsdp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def __init__(
160160
activation_checkpointing_policy: Optional["_POLICY"] = None,
161161
sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD",
162162
state_dict_type: Literal["full", "sharded"] = "full",
163-
device_mesh: Optional[Union[tuple[int], "DeviceMesh"]] = None,
163+
device_mesh: Optional[Union[tuple[int, int], "DeviceMesh"]] = None,
164164
**kwargs: Any,
165165
) -> None:
166166
super().__init__(

tests/tests_fabric/strategies/test_fsdp.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,3 +402,15 @@ def test_get_full_state_dict_context_offload(set_type_mock, monkeypatch):
402402
with _get_full_state_dict_context(module=Mock(spec=FullyShardedDataParallel), world_size=4):
403403
assert set_type_mock.call_args_list[0][0][2].offload_to_cpu # model config
404404
assert set_type_mock.call_args_list[0][0][3].offload_to_cpu # optim config
405+
406+
407+
def test_device_mesh_type_annotation():
408+
"""Test that ``device_mesh`` type hint accepts a 2-element tuple via jsonargparse (#21580)."""
409+
jsonargparse = pytest.importorskip("jsonargparse")
410+
from inspect import signature
411+
412+
annot = signature(FSDPStrategy).parameters["device_mesh"].annotation
413+
parser = jsonargparse.ArgumentParser()
414+
parser.add_argument("--device_mesh", type=annot)
415+
args = parser.parse_args(["--device_mesh=[1, 4]"])
416+
assert args.device_mesh == (1, 4)

tests/tests_pytorch/strategies/test_fsdp.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -966,3 +966,15 @@ def configure_optimizers(self):
966966
max_steps=4,
967967
)
968968
trainer.fit(model, ckpt_path=checkpoint_path_full)
969+
970+
971+
def test_device_mesh_type_annotation():
972+
"""Test that ``device_mesh`` type hint accepts a 2-element tuple via jsonargparse (#21580)."""
973+
jsonargparse = pytest.importorskip("jsonargparse")
974+
from inspect import signature
975+
976+
annot = signature(FSDPStrategy).parameters["device_mesh"].annotation
977+
parser = jsonargparse.ArgumentParser()
978+
parser.add_argument("--device_mesh", type=annot)
979+
args = parser.parse_args(["--device_mesh=[1, 4]"])
980+
assert args.device_mesh == (1, 4)

0 commit comments

Comments
 (0)