Skip to content

Commit cb2c7d5

Browse files
committed
update
1 parent 6cd4d62 commit cb2c7d5

File tree

5 files changed

+25
-40
lines changed

5 files changed

+25
-40
lines changed

src/lightning/fabric/plugins/io/checkpoint_io.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,15 @@ def requires_cpu_collectives(self) -> bool:
8080
return False
8181

8282
@property
83-
def _requires_state_conversion(self) -> bool:
84-
"""Whether the Strategy must pre-convert stateful objects into ``state_dict`` form before calling this
85-
CheckpointIO.
86-
87-
CheckpointIO implementations that perform in-place loading may expect the provided
88-
``state`` to already contain plain dictionaries instead of high-level objects such
89-
as ``nn.Module`` or ``Optimizer``. When this returns ``True``, the Strategy should
90-
convert the state using its internal state-extraction logic prior to save/load.
83+
def _restore_after_setup(self) -> bool:
84+
"""Whether checkpoint restoration should be delayed until after the Strategy setup phase.
85+
86+
Some checkpoint implementations require the distributed environment, device placement,
87+
or wrapped modules to be fully initialized before loading state. When this returns
88+
``True``, the Trainer/Strategy will restore the checkpoint only after setup has completed.
89+
90+
This is primarily used by distributed checkpointing backends that depend on collective
91+
communication during load.
9192
9293
"""
9394
return False

src/lightning/fabric/plugins/io/distributed_async_io.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(
112112
async_type = state_dict_saver.AsyncCheckpointerType(self._checkpointer_type)
113113
default_save_options["async_checkpointer_type"] = async_type
114114
default_save_options["planner"] = DefaultSavePlanner(enable_plan_caching=enable_plan_caching)
115-
print(f"{default_save_options=}")
115+
116116
self.save_options = {**default_save_options, **(save_options or {})}
117117
self.load_options = dict(load_options or {})
118118
self._disable_safe_warnings()
@@ -137,7 +137,8 @@ def _wait(self) -> None:
137137

138138
@override
139139
@property
140-
def _requires_state_conversion(self) -> bool:
140+
def _restore_after_setup(self) -> bool:
141+
"""Requires delayed restoration until after Strategy setup."""
141142
return True
142143

143144
@property

src/lightning/fabric/strategies/strategy.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -336,25 +336,17 @@ def load_checkpoint(
336336
337337
"""
338338
torch.cuda.empty_cache()
339-
if self.checkpoint_io._requires_state_conversion and state is not None:
340-
if not isinstance(state, dict):
341-
raise ValueError(
342-
"When using a CheckpointIO that requires state conversion, the `state` argument must be a dict."
343-
)
344-
# update in_place so non-tensor objects get updated as well when using in-place loading
345-
state = self._convert_stateful_objects_in_state(state, filter={}, in_place=True)
346-
347-
# in-place loading requires state to be a dict
348-
_state = state if isinstance(state, dict) else None
349-
checkpoint = self.checkpoint_io.load_checkpoint(path, state=_state, weights_only=weights_only)
339+
converted_state = state
340+
if state is not None:
341+
converted_state = self._convert_stateful_objects_in_state(
342+
state,
343+
filter={},
344+
)
345+
346+
checkpoint = self.checkpoint_io.load_checkpoint(path, state=converted_state, weights_only=weights_only)
350347
if not state:
351348
return checkpoint
352349

353-
if checkpoint == {}:
354-
# In-place loaders (e.g., DCP) return {} to signal that the state
355-
# has already been fully restored by the CheckpointIO implementation.
356-
return {}
357-
358350
if isinstance(state, Module):
359351
self.load_module_state_dict(module=state, state_dict=checkpoint, strict=strict)
360352
return {}
@@ -422,13 +414,7 @@ def _convert_stateful_objects_in_state(
422414
self,
423415
state: dict[str, Union[Module, Optimizer, Any]],
424416
filter: dict[str, Callable[[str, Any], bool]],
425-
in_place: bool = False,
426417
) -> dict[str, Any]:
427-
if in_place and filter != {}:
428-
raise ValueError(
429-
"In-place conversion does not support filtering. Please set `in_place=False` to apply the filter."
430-
)
431-
432418
converted_state: dict[str, Any] = {}
433419
for key, obj in state.items():
434420
# convert the state
@@ -441,11 +427,8 @@ def _convert_stateful_objects_in_state(
441427
else:
442428
converted = obj
443429

444-
if in_place:
445-
state[key] = converted
446-
else:
447-
_apply_filter(key, filter, converted, converted_state)
448-
return converted_state if not in_place else state
430+
_apply_filter(key, filter, converted, converted_state)
431+
return converted_state
449432

450433

451434
class _BackwardSyncControl(ABC):

src/lightning/pytorch/strategies/strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def restore_checkpoint_after_setup(self) -> bool:
457457
If ``True``, restore checkpoint after strategy setup.
458458
459459
"""
460-
return self.checkpoint_io._requires_state_conversion
460+
return self.checkpoint_io._restore_after_setup
461461

462462
@property
463463
def lightning_restore_optimizer(self) -> bool:

tests/tests_fabric/plugins/io/test_distributed_async_io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ def test_async_checkpointio_requires_cpu_collectives():
111111

112112

113113
@RunIf(min_torch="2.4")
114-
def test_async_checkpointio_requires_state_conversion():
115-
assert DistributedAsyncCheckpointIO()._requires_state_conversion is True
114+
def test_async_checkpointio_requires_restore_after_setup():
115+
assert DistributedAsyncCheckpointIO()._restore_after_setup is True
116116

117117

118118
@RunIf(min_torch="2.4")

0 commit comments

Comments
 (0)