Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions silnlp/nmt/clearml_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ def _load_config(self) -> None:
if self.task is None:
with (exp_dir / "config.yml").open("r", encoding="utf-8") as file:
config = yaml.safe_load(file)
config["use_default_model_dir"] = self.use_default_model_dir
if config is None or len(config.keys()) == 0:
raise RuntimeError("Config file has no contents.")
config["use_default_model_dir"] = self.use_default_model_dir
self.config = create_config(exp_dir, config)
return
# There is a ClearML task - lets' do more complex importing.
Expand All @@ -139,7 +139,6 @@ def _load_config(self) -> None:
# read in the project/experiment yaml file
with (exp_dir / "config.yml").open("r", encoding="utf-8") as file:
config = yaml.safe_load(file)
config["use_default_model_dir"] = self.use_default_model_dir
else:
config = {}
if config is None or len(config.keys()) == 0:
Expand All @@ -153,5 +152,5 @@ def _load_config(self) -> None:
exp_dir.mkdir(parents=True, exist_ok=True)
with (exp_dir / "config.yml").open("w+", encoding="utf-8") as file:
yaml.safe_dump(data=config, stream=file)

config["use_default_model_dir"] = self.use_default_model_dir
self.config = create_config(exp_dir, config)
2 changes: 1 addition & 1 deletion silnlp/nmt/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def translate(self):
translator = TranslationTask(
name=self.name,
checkpoint=config.get("checkpoint", "last"),
save_checkpoints=self.save_checkpoints,
use_default_model_dir=True if not (self.run_test or self.run_train) else self.save_checkpoints,
commit=self.commit,
)

Expand Down
9 changes: 6 additions & 3 deletions silnlp/nmt/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def translate(
class TranslationTask:
name: str
checkpoint: Union[str, int] = "last"
save_checkpoints: bool = False
use_default_model_dir: bool = True
clearml_queue: Optional[str] = None
commit: Optional[str] = None

Expand Down Expand Up @@ -267,7 +267,7 @@ def _init_translation_task(self, experiment_suffix: str) -> Tuple[Translator, Co
project_suffix="_infer",
experiment_suffix=experiment_suffix,
commit=self.commit,
use_default_model_dir=self.save_checkpoints,
use_default_model_dir=self.use_default_model_dir,
)
self.name = clearml.name

Expand Down Expand Up @@ -399,7 +399,10 @@ def main() -> None:
get_git_revision_hash()

translator = TranslationTask(
name=args.experiment, checkpoint=args.checkpoint, clearml_queue=args.clearml_queue, commit=args.commit
name=args.experiment,
checkpoint=args.checkpoint,
clearml_queue=args.clearml_queue,
commit=args.commit,
)

postprocess_handler = PostprocessHandler([PostprocessConfig(vars(args))])
Expand Down