Skip to content

Commit 97839d9

Browse files
committed
add kwargs to trainer
1 parent 0925a41 commit 97839d9

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

chebai/trainer/CustomTrainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def predict_from_file(
8181
input_path: _PATH,
8282
save_to: _PATH = "predictions.csv",
8383
classes_path: Optional[_PATH] = None,
84+
**kwargs,
8485
) -> None:
8586
"""
8687
Loads a model from a checkpoint and makes predictions on input data from a file.
@@ -90,7 +91,7 @@ def predict_from_file(
9091
checkpoint_path: Path to the model checkpoint.
9192
input_path: Path to the input file containing SMILES strings.
9293
save_to: Path to save the predictions CSV file.
93-
classes_path: Optional path to a file containing class names.
94+
classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered).
9495
"""
9596
loaded_model = model.__class__.load_from_checkpoint(checkpoint_path)
9697
with open(input_path, "r") as input:

0 commit comments

Comments
 (0)