Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/aiida_trains_pot/mace/mace_train_plugin/calculations.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,23 @@ def define(cls, spec):
"ERROR_OUT_OF_WALLTIME",
message="The calculation stopped prematurely because it ran out of walltime.",
)
spec.exit_code(
410,
"ERROR_PREPROCESS_FAILED",
message="Preprocessing stage failed or did not produce expected outputs.",
)

spec.exit_code(
420,
"ERROR_TRAINING_FAILED",
message="Training stage failed or did not complete successfully.",
)

spec.exit_code(
430,
"ERROR_POSTPROCESS_FAILED",
message="Postprocessing stage failed or final model files are missing.",
)

def prepare_for_submission(self, folder):
"""Create input files.
Expand Down Expand Up @@ -387,6 +404,7 @@ def prepare_for_submission(self, folder):
"results",
"logs",
"_scheduler-std*",
"processed_data",
]

return calcinfo
88 changes: 76 additions & 12 deletions src/aiida_trains_pot/mace/mace_train_plugin/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,28 +211,79 @@ def parse(self, **kwargs):
# Check that folder content is as expected
files_retrieved = self.retrieved.list_object_names()
files_expected = ["aiida.model"]
exit_code = None

do_preprocess = getattr(self.node.inputs, "do_preprocess", None)

if do_preprocess and do_preprocess.value:
# Check processed_data folder exists
if "processed_data" not in files_retrieved:
self.logger.error("processed_data folder missing.")
exit_code = self.exit_codes.ERROR_PREPROCESS_FAILED

# List contents of processed_data
processed_contents = self.retrieved.list_object_names("processed_data")

required_entries = {"statistics.json", "train", "val", "test"}

if not required_entries.issubset(set(processed_contents)):
self.logger.error(
f"processed_data incomplete. Found: {processed_contents}, " f"expected at least: {required_entries}"
)
exit_code = self.exit_codes.ERROR_PREPROCESS_FAILED

# Ensure train/val/test are not empty
for subfolder in ["train", "val", "test"]:
subfolder_path = os.path.join("processed_data", subfolder)
subfolder_contents = self.retrieved.list_object_names(subfolder_path)

if not subfolder_contents:
self.logger.error(f"{subfolder_path} is empty.")
exit_code = self.exit_codes.ERROR_PREPROCESS_FAILED

if "_scheduler-stderr.txt" in files_retrieved:
with self.retrieved.open("_scheduler-stderr.txt", "r") as handle:
scheduler_err = handle.read()

if "DUE TO TIME LIMIT" in scheduler_err:
exit_code = self.exit_codes.ERROR_OUT_OF_WALLTIME
else:
self.logger.error(
f"Expected to find '_scheduler-stderr.txt' in retrieved folder, found '{files_retrieved}'"
)
exit_code = self.exit_codes.ERROR_MISSING_OUTPUT_FILES

# add error of out of walltime
if "mace.out" in self.retrieved.list_object_names():
with self.retrieved.open("mace.out", "rb") as handle:
output_node = SinglefileData(file=handle)
if len(parse_start_from_singlefiledata(output_node)) == 0:
return self.exit_codes.ERROR_TRAINING_FAILED
if (len(parse_start_from_singlefiledata(output_node)) > 0) and (
len(parse_complete_from_singlefiledata(output_node)) == 0
):
return self.exit_codes.ERROR_OUT_OF_WALLTIME
exit_code = self.exit_codes.ERROR_TRAINING_FAILED
else:
self.logger.error(f"Expected to find 'mace.out' in retrieved folder, found '{files_retrieved}'")
exit_code = self.exit_codes.ERROR_MISSING_OUTPUT_FILES

# Note: set(A) <= set(B) checks whether A is a subset of B
if not set(files_expected) <= set(files_retrieved):
self.logger.error(f"Found files '{files_retrieved}', expected to find '{files_expected}'")
return self.exit_codes.ERROR_MISSING_OUTPUT_FILES
exit_code = self.exit_codes.ERROR_POSTPROCESS_FAILED

if exit_code:
return exit_code

# add output file
for file in files_retrieved:
output_filename = file
for output_filename in files_retrieved:
self.logger.info(f"Parsing '{output_filename}'")
if "checkpoint" in output_filename or "logs" in output_filename or "results" in output_filename:
if "results" in output_filename:
if output_filename in ["results", "checkpoints", "logs", "processed_data"]:
try:
folder_contents = self.retrieved.list_object_names(output_filename)
except Exception as exc:
self.logger.error(f"Failed to access folder '{output_filename}': {exc}")
exit_code = self.exit_codes.ERROR_POSTPROCESS_FAILED

if "results" in output_filename:
for file_in_folder in folder_contents:
if not file_in_folder.endswith(".png"):
file_path = os.path.join(output_filename, file_in_folder)
Expand All @@ -244,15 +295,18 @@ def parse(self, **kwargs):
# output_node = FolderData(folder=handle)
# self.out(output_filename, output_node)
folder_data = FolderData()
folder_contents = self.retrieved.list_object_names(output_filename)
for file_in_folder in folder_contents:
file_path = os.path.join(output_filename, file_in_folder)
with self.retrieved.open(file_path, "rb") as handle:
folder_data.put_object_from_filelike(handle, file_in_folder)
self.out(output_filename, folder_data)
else:
with self.retrieved.open(output_filename, "rb") as handle:
output_node = SinglefileData(file=handle)
try:
with self.retrieved.open(output_filename, "rb") as handle:
output_node = SinglefileData(file=handle)
except Exception as exc:
self.logger.error(f"Failed opening file '{output_filename}': {exc}")
exit_code = self.exit_codes.ERROR_POSTPROCESS_FAILED

if "aiida_swa.model-lammps" in output_filename or "aiida_stagetwo.model-lammps" in output_filename:
self.out("model_stage2_lammps", output_node)
Expand All @@ -271,5 +325,15 @@ def parse(self, **kwargs):
self.out("model_stage1_pytorch", output_node)

elif "mace" in output_filename:
self.out("RMSE", List(parse_tables_from_singlefiledata(output_node)))
try:
rmse_data = parse_tables_from_singlefiledata(output_node)
if rmse_data:
self.out("RMSE", List(rmse_data))
except Exception as exc:
self.logger.error(f"Failed parsing RMSE from mace.out: {exc}")
exit_code = self.exit_codes.ERROR_POSTPROCESS_FAILED

if exit_code:
return exit_code

return ExitCode(0)