Skip to content

Commit 258e65a

Browse files
authored
Merge pull request #97 from ChEB-AI/fix/save_out_dim_to_checkpoint
Fix `input_dim` and `out_dim` Not getting saved in model checkpoint
2 parents 2633810 + d18dd7a commit 258e65a

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

chebai/models/base.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def __init__(
4545
if exclude_hyperparameter_logging is None:
4646
exclude_hyperparameter_logging = tuple()
4747
self.criterion = criterion
48+
assert out_dim is not None, "out_dim must be specified"
49+
assert input_dim is not None, "input_dim must be specified"
50+
self.out_dim = out_dim
51+
self.input_dim = input_dim
52+
4853
self.save_hyperparameters(
4954
ignore=[
5055
"criterion",
@@ -55,10 +60,8 @@ def __init__(
5560
]
5661
)
5762

58-
self.out_dim = out_dim
59-
self.input_dim = input_dim
60-
assert out_dim is not None, "out_dim must be specified"
61-
assert input_dim is not None, "input_dim must be specified"
63+
self.hparams["out_dim"] = out_dim
64+
self.hparams["input_dim"] = input_dim
6265

6366
if optimizer_kwargs:
6467
self.optimizer_kwargs = optimizer_kwargs

0 commit comments

Comments
 (0)