We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 707f3e3 commit 5d7841cCopy full SHA for 5d7841c
src/stamp/modeling/train.py
@@ -179,11 +179,11 @@ def train_model_(
179
logger=CSVLogger(save_dir=output_dir),
180
)
181
trainer.fit(
182
- model=cast(lightning.LightningModule, torch.compile(model)),
+ model=model,
183
train_dataloaders=train_dl,
184
val_dataloaders=valid_dl,
185
186
- shutil.move(model_checkpoint.best_model_path, output_dir / "model.ckpt")
+ shutil.copy(model_checkpoint.best_model_path, output_dir / "model.ckpt")
187
188
return LitVisionTransformer.load_from_checkpoint(model_checkpoint.best_model_path)
189
0 commit comments