Skip to content

Commit ed1ca45

Browse files
committed
update
1 parent f7bd53a commit ed1ca45

File tree

4 files changed

+7
-55
lines changed

4 files changed

+7
-55
lines changed

src/lightning/pytorch/strategies/strategy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,9 @@ def lightning_module(self) -> Optional["pl.LightningModule"]:
363363
"""Returns the pure LightningModule without potential wrappers."""
364364
return self._lightning_module
365365

366-
def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None, state: dict[str, Any] = None) -> dict[str, Any]:
366+
def load_checkpoint(
367+
self, checkpoint_path: _PATH, weights_only: Optional[bool] = None, state: dict[str, Any] = None
368+
) -> dict[str, Any]:
367369
torch.cuda.empty_cache()
368370
return self.checkpoint_io.load_checkpoint(checkpoint_path, weights_only=weights_only, state=state)
369371

src/lightning/pytorch/trainer/connectors/checkpoint_connector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None, weights_only: Op
8181
rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}")
8282
with pl_legacy_patch():
8383
_current_state = self.dump_checkpoint(weights_only=weights_only)
84-
loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path, weights_only=weights_only, state=_current_state)
84+
loaded_checkpoint = self.trainer.strategy.load_checkpoint(
85+
checkpoint_path, weights_only=weights_only, state=_current_state
86+
)
8587
self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path)
8688

8789
def _select_ckpt_path(

src/lightning/pytorch/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1063,7 +1063,7 @@ def _run(
10631063

10641064
if self.strategy.restore_checkpoint_after_setup:
10651065
log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
1066-
self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
1066+
self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path, weights_only)
10671067

10681068
# restore optimizers, etc.
10691069
log.debug(f"{self.__class__.__name__}: restoring training state")

tmp/main.py

Lines changed: 0 additions & 52 deletions
This file was deleted.

0 commit comments

Comments
 (0)