Skip to content
Open
2 changes: 1 addition & 1 deletion docs/source/reference/command_line.rst
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ Below is a list with all available subcommands.
--broker-host HOSTNAME Hostname for the message broker. [default: 127.0.0.1]
--broker-port INTEGER Port for the message broker. [default: 5672]
--broker-virtual-host TEXT Name of the virtual host for the message broker without
leading forward slash. [default: ""]
leading forward slash.
--repository DIRECTORY Absolute path to the file repository.
--test-profile Designate the profile to be used for running the test
suite only.
Expand Down
14 changes: 13 additions & 1 deletion src/aiida/cmdline/commands/cmd_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ def inspect(ctx, archive, version, meta_data, database):
help='Determine entities to export, but do not create the archive. Deprecated, please use `--dry-run` instead.',
)
@options.DRY_RUN(help='Determine entities to export, but do not create the archive.')
@click.option(
'--tmp-dir',
type=click.Path(exists=True, file_okay=False, dir_okay=True, writable=True, path_type=Path),
help=(
'Location where the temporary directory will be written during archive creation.'
'The directory must exist and be writable, and defaults to the parent directory of the output file.'
'This parameter is useful when the output directory has limited space or when you want to use a specific'
'filesystem (e.g., faster storage) for temporary operations.'
),
)
@decorators.with_dbenv()
def create(
output_file,
Expand All @@ -160,6 +170,7 @@ def create(
batch_size,
test_run,
dry_run,
tmp_dir,
):
"""Create an archive from all or part of a profiles's data.

Expand Down Expand Up @@ -211,6 +222,7 @@ def create(
'compression': compress,
'batch_size': batch_size,
'test_run': dry_run,
'tmp_dir': tmp_dir,
}

if AIIDA_LOGGER.level <= logging.REPORT: # type: ignore[attr-defined]
Expand Down Expand Up @@ -327,7 +339,7 @@ class ExtrasImportCode(Enum):
'--extras-mode-new',
type=click.Choice(EXTRAS_MODE_NEW),
default='import',
help='Specify whether to import extras of new nodes: ' 'import: import extras. ' 'none: do not import extras.',
help='Specify whether to import extras of new nodes: import: import extras. none: do not import extras.',
)
@click.option(
'--comment-mode',
Expand Down
220 changes: 131 additions & 89 deletions src/aiida/tools/archive/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
stored in a single file.
"""

import os
import shutil
import tempfile
from datetime import datetime
Expand Down Expand Up @@ -59,6 +60,7 @@ def create_archive(
compression: int = 6,
test_run: bool = False,
backend: Optional[StorageBackend] = None,
tmp_dir: Optional[Union[str, Path]] = None,
**traversal_rules: bool,
) -> Path:
"""Export AiiDA data to an archive file.
Expand Down Expand Up @@ -139,6 +141,11 @@ def create_archive(

:param backend: the backend to export from. If not specified, the default backend is used.

:param tmp_dir: Location where the temporary directory will be written during archive creation.
The directory must exist and be writable, and defaults to the parent directory of the output file.
This parameter is useful when the output directory has limited space or when you want to use a specific
filesystem (e.g., faster storage) for temporary operations.

:param traversal_rules: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules`
what rule names are toggleable and what the defaults are.

Expand Down Expand Up @@ -239,7 +246,7 @@ def querybuilder():
entity_ids[EntityTypes.USER].add(entry.pk)
else:
raise ArchiveExportError(
f'I was given {entry} ({type(entry)}),' ' which is not a User, Node, Computer, or Group instance'
f'I was given {entry} ({type(entry)}), which is not a User, Node, Computer, or Group instance'
)
group_nodes, link_data = _collect_required_entities(
querybuilder,
Expand Down Expand Up @@ -280,94 +287,129 @@ def querybuilder():

EXPORT_LOGGER.report(f'Creating archive with:\n{tabulate(count_summary)}')

# Handle temporary directory configuration
if tmp_dir is not None:
tmp_dir = Path(tmp_dir)
if not tmp_dir.exists():
EXPORT_LOGGER.warning(f"Specified temporary directory '{tmp_dir}' doesn't exist. Creating it.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This behaviour is different from the CLI command, that does expect the temporary directory to exist. It also seems to conflict with the docstring of the function.

tmp_dir.mkdir(parents=True)
if not tmp_dir.is_dir():
msg = f"Specified temporary directory '{tmp_dir}' is not a directory"
raise ArchiveExportError(msg)
# Check if directory is writable
# Taken from: https://stackoverflow.com/a/2113511
if not os.access(tmp_dir, os.W_OK | os.X_OK):
msg = f"Specified temporary directory '{tmp_dir}' is not writable"
raise ArchiveExportError(msg)

else:
# Create temporary directory in the same folder as the output file
tmp_dir = filename.parent

# Create and open the archive for writing.
# We create in a temp dir then move to final place at end,
# so that the user cannot end up with a half written archive on errors
with tempfile.TemporaryDirectory() as tmpdir:
tmp_filename = Path(tmpdir) / 'export.zip'
with archive_format.open(tmp_filename, mode='x', compression=compression) as writer:
# add metadata
writer.update_metadata(
{
'ctime': datetime.now().isoformat(),
'creation_parameters': {
'entities_starting_set': None
if entities is None
else {etype.value: list(unique) for etype, unique in starting_uuids.items() if unique},
'include_authinfos': include_authinfos,
'include_comments': include_comments,
'include_logs': include_logs,
'graph_traversal_rules': full_traversal_rules,
},
}
)
# stream entity data to the archive
with get_progress_reporter()(desc='Archiving database: ', total=sum(entity_counts.values())) as progress:
for etype, ids in entity_ids.items():
if etype == EntityTypes.NODE and strip_checkpoints:

def transform(row):
data = row['entity']
if data.get('node_type', '').startswith('process.'):
data['attributes'].pop(orm.ProcessNode.CHECKPOINT_KEY, None)
return data
else:

def transform(row):
return row['entity']

progress.set_description_str(f'Archiving database: {etype.value}s')
if ids:
for nrows, rows in batch_iter(
querybuilder()
.append(
entity_type_to_orm[etype], filters={'id': {'in': ids}}, tag='entity', project=['**']
)
.iterdict(batch_size=batch_size),
batch_size,
transform,
):
writer.bulk_insert(etype, rows)
progress.update(nrows)

# stream links
progress.set_description_str(f'Archiving database: {EntityTypes.LINK.value}s')

def transform(d):
return {
'input_id': d.source_id,
'output_id': d.target_id,
'label': d.link_label,
'type': d.link_type,
try:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Diff is large here because of the additional try-except and the indent that goes with it (to capture disk-space errors) but the actual code inside should be the same!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

git diff --ignore-all-space main src/aiida/tools/archive/create.py gives instead only:

diff --git a/src/aiida/tools/archive/create.py b/src/aiida/tools/archive/create.py
index 94ca88cd4..c4f0671d5 100644
--- a/src/aiida/tools/archive/create.py
+++ b/src/aiida/tools/archive/create.py
@@ -12,6 +12,7 @@ The archive is a subset of the provenance graph,
 stored in a single file.
 """
 
+import os
 import shutil
 import tempfile
 from datetime import datetime
@@ -59,6 +60,7 @@ def create_archive(
     compression: int = 6,
     test_run: bool = False,
     backend: Optional[StorageBackend] = None,
+    tmp_dir: Optional[Union[str, Path]] = None,
     **traversal_rules: bool,
 ) -> Path:
     """Export AiiDA data to an archive file.
@@ -139,6 +141,12 @@ def create_archive(
 
     :param backend: the backend to export from. If not specified, the default backend is used.
 
+    :param tmp_dir: Directory to use for temporary files during archive creation.
+        If not specified, a temporary directory will be created in the same directory as the output file
+        with a '.aiida-export-' prefix. This parameter is useful when the output directory has limited
+        space or when you want to use a specific filesystem (e.g., faster storage) for temporary operations.
+        The directory must exist and be writable.
+
     :param traversal_rules: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules`
         what rule names are toggleable and what the defaults are.
 
@@ -280,10 +288,32 @@ def create_archive(
 
     EXPORT_LOGGER.report(f'Creating archive with:\n{tabulate(count_summary)}')
 
+    # Handle temporary directory configuration
+    tmp_prefix = '.aiida-export-'
+    if tmp_dir is not None:
+        tmp_dir = Path(tmp_dir)
+        if not tmp_dir.exists():
+            EXPORT_LOGGER.warning(f"Specified temporary directory '{tmp_dir}' doesn't exist. Creating it.")
+            tmp_dir.mkdir(parents=True)
+        if not tmp_dir.is_dir():
+            msg = f"Specified temporary directory '{tmp_dir}' is not a directory"
+            raise ArchiveExportError(msg)
+        # Check if directory is writable
+        # Taken from: https://stackoverflow.com/a/2113511
+        if not os.access(tmp_dir, os.W_OK | os.X_OK):
+            msg = f"Specified temporary directory '{tmp_dir}' is not writable"
+            raise ArchiveExportError(msg)
+
+    else:
+        # Create temporary directory in the same folder as the output file
+        tmp_dir = filename.parent
+
     # Create and open the archive for writing.
     # We create in a temp dir then move to final place at end,
     # so that the user cannot end up with a half written archive on errors
-    with tempfile.TemporaryDirectory() as tmpdir:
+    try:
+        tmp_dir.mkdir(parents=True, exist_ok=True)
+        with tempfile.TemporaryDirectory(dir=tmp_dir, prefix=tmp_prefix) as tmpdir:
             tmp_filename = Path(tmpdir) / 'export.zip'
             with archive_format.open(tmp_filename, mode='x', compression=compression) as writer:
                 # add metadata
@@ -302,7 +332,9 @@ def create_archive(
                     }
                 )
                 # stream entity data to the archive
-            with get_progress_reporter()(desc='Archiving database: ', total=sum(entity_counts.values())) as progress:
+                with get_progress_reporter()(
+                    desc='Archiving database: ', total=sum(entity_counts.values())
+                ) as progress:
                     for etype, ids in entity_ids.items():
                         if etype == EntityTypes.NODE and strip_checkpoints:
 
@@ -359,7 +391,9 @@ def create_archive(
 
                 # stream node repository files to the archive
                 if entity_ids[EntityTypes.NODE]:
-                _stream_repo_files(archive_format.key_format, writer, entity_ids[EntityTypes.NODE], backend, batch_size)
+                    _stream_repo_files(
+                        archive_format.key_format, writer, entity_ids[EntityTypes.NODE], backend, batch_size
+                    )
 
                 EXPORT_LOGGER.report('Finalizing archive creation...')
 
@@ -368,6 +402,16 @@ def create_archive(
 
             filename.parent.mkdir(parents=True, exist_ok=True)
             shutil.move(tmp_filename, filename)
+    except OSError as e:
+        if e.errno == 28:  # No space left on device
+            msg = (
+                f"Insufficient disk space in temporary directory '{tmp_dir}'. "
+                f'Consider using --tmp-dir to specify a location with more available space.'
+            )
+            raise ArchiveExportError(msg) from e
+
+        msg = f'Failed to create temporary directory: {e}'
+        raise ArchiveExportError(msg) from e
 
     EXPORT_LOGGER.report('Archive created successfully')

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification @GeigerJ2! This'll help in my review <3

tmp_dir.mkdir(parents=True, exist_ok=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this line added here? Everything related to checking the tmp_dir input should be above, and for the default there should be no need to create the parent directory, the command should fail in case that doesn't exist.

with tempfile.TemporaryDirectory(dir=tmp_dir, prefix='.aiida-export-') as tmpdir:
tmp_filename = Path(tmpdir) / 'export.zip'
with archive_format.open(tmp_filename, mode='x', compression=compression) as writer:
# add metadata
writer.update_metadata(
{
'ctime': datetime.now().isoformat(),
'creation_parameters': {
'entities_starting_set': None
if entities is None
else {etype.value: list(unique) for etype, unique in starting_uuids.items() if unique},
'include_authinfos': include_authinfos,
'include_comments': include_comments,
'include_logs': include_logs,
'graph_traversal_rules': full_traversal_rules,
},
}
)
# stream entity data to the archive
with get_progress_reporter()(
desc='Archiving database: ', total=sum(entity_counts.values())
) as progress:
for etype, ids in entity_ids.items():
if etype == EntityTypes.NODE and strip_checkpoints:

def transform(row):
data = row['entity']
if data.get('node_type', '').startswith('process.'):
data['attributes'].pop(orm.ProcessNode.CHECKPOINT_KEY, None)
return data
else:

def transform(row):
return row['entity']

progress.set_description_str(f'Archiving database: {etype.value}s')
if ids:
for nrows, rows in batch_iter(
querybuilder()
.append(
entity_type_to_orm[etype], filters={'id': {'in': ids}}, tag='entity', project=['**']
)
.iterdict(batch_size=batch_size),
batch_size,
transform,
):
writer.bulk_insert(etype, rows)
progress.update(nrows)

# stream links
progress.set_description_str(f'Archiving database: {EntityTypes.LINK.value}s')

def transform(d):
return {
'input_id': d.source_id,
'output_id': d.target_id,
'label': d.link_label,
'type': d.link_type,
}

for nrows, rows in batch_iter(link_data, batch_size, transform):
writer.bulk_insert(EntityTypes.LINK, rows, allow_defaults=True)
progress.update(nrows)
del link_data # release memory

# stream group_nodes
progress.set_description_str(f'Archiving database: {EntityTypes.GROUP_NODE.value}s')

def transform(d):
return {'dbgroup_id': d[0], 'dbnode_id': d[1]}

for nrows, rows in batch_iter(group_nodes, batch_size, transform):
writer.bulk_insert(EntityTypes.GROUP_NODE, rows, allow_defaults=True)
progress.update(nrows)
del group_nodes # release memory

# stream node repository files to the archive
if entity_ids[EntityTypes.NODE]:
_stream_repo_files(
archive_format.key_format, writer, entity_ids[EntityTypes.NODE], backend, batch_size
)

EXPORT_LOGGER.report('Finalizing archive creation...')

if filename.exists():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's not part you just put in the try-except block, but it's interesting. I'd move this logic to the lines after the comment on line 167 (#check/set archive file path), since that is dealing with the target filename of the archive existing and whether or not to overwrite it.

filename.unlink()

filename.parent.mkdir(parents=True, exist_ok=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, not your code, but something to think about: Apparently you can just write the archive to a folder that doesn't exist and the command will create it and all non-existent parents. I think that's ok. But the behaviour is then different from the tmp_dir one. I suppose there you don't want to just create some directory with a bunch of parents that you then have to clean up after?

shutil.move(tmp_filename, filename)
except OSError as e:
if e.errno == 28: # No space left on device
msg = (
f"Insufficient disk space in temporary directory '{tmp_dir}'. "
f'Consider using --tmp-dir to specify a location with more available space.'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this can also happen when the archive_create function is used directly:

Suggested change
f'Consider using --tmp-dir to specify a location with more available space.'
f'Consider using `tmp-dir` to specify a location with more available space.'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still checking the code above in more detail, but what happens if there is enough space in the tmp_dir, but not in the target directory for the archive? Will the shutil.move command also fail with OSError with number 28? In that case, won't we be misleading the user?

)
raise ArchiveExportError(msg) from e

for nrows, rows in batch_iter(link_data, batch_size, transform):
writer.bulk_insert(EntityTypes.LINK, rows, allow_defaults=True)
progress.update(nrows)
del link_data # release memory

# stream group_nodes
progress.set_description_str(f'Archiving database: {EntityTypes.GROUP_NODE.value}s')

def transform(d):
return {'dbgroup_id': d[0], 'dbnode_id': d[1]}

for nrows, rows in batch_iter(group_nodes, batch_size, transform):
writer.bulk_insert(EntityTypes.GROUP_NODE, rows, allow_defaults=True)
progress.update(nrows)
del group_nodes # release memory

# stream node repository files to the archive
if entity_ids[EntityTypes.NODE]:
_stream_repo_files(archive_format.key_format, writer, entity_ids[EntityTypes.NODE], backend, batch_size)

EXPORT_LOGGER.report('Finalizing archive creation...')

if filename.exists():
filename.unlink()

filename.parent.mkdir(parents=True, exist_ok=True)
shutil.move(tmp_filename, filename)
msg = f'Failed to create temporary directory: {e}'
raise ArchiveExportError(msg) from e

EXPORT_LOGGER.report('Archive created successfully')

Expand Down Expand Up @@ -668,7 +710,7 @@ def _check_unsealed_nodes(querybuilder: QbType, node_ids: set[int], batch_size:
if unsealed_node_pks:
raise ExportValidationError(
'All ProcessNodes must be sealed before they can be exported. '
f"Node(s) with PK(s): {', '.join(str(pk) for pk in unsealed_node_pks)} is/are not sealed."
f'Node(s) with PK(s): {", ".join(str(pk) for pk in unsealed_node_pks)} is/are not sealed.'
)


Expand Down Expand Up @@ -759,18 +801,18 @@ def get_init_summary(
"""Get summary for archive initialisation"""
parameters = [['Path', str(outfile)], ['Version', archive_version], ['Compression', compression]]

result = f"\n{tabulate(parameters, headers=['Archive Parameters', ''])}"
result = f'\n{tabulate(parameters, headers=["Archive Parameters", ""])}'

inclusions = [
['Computers/Nodes/Groups/Users', 'All' if collect_all else 'Selected'],
['Computer Authinfos', include_authinfos],
['Node Comments', include_comments],
['Node Logs', include_logs],
]
result += f"\n\n{tabulate(inclusions, headers=['Inclusion rules', ''])}"
result += f'\n\n{tabulate(inclusions, headers=["Inclusion rules", ""])}'

if not collect_all:
rules_table = [[f"Follow links {' '.join(name.split('_'))}s", value] for name, value in traversal_rules.items()]
result += f"\n\n{tabulate(rules_table, headers=['Traversal rules', ''])}"
rules_table = [[f'Follow links {" ".join(name.split("_"))}s', value] for name, value in traversal_rules.items()]
result += f'\n\n{tabulate(rules_table, headers=["Traversal rules", ""])}'

return result + '\n'
14 changes: 14 additions & 0 deletions tests/cmdline/commands/test_archive_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,17 @@ def test_info_empty_archive(run_cli_command):
filename_input = get_archive_file('empty.aiida', filepath='export/migrate')
result = run_cli_command(cmd_archive.archive_info, [filename_input], raises=True)
assert 'archive file unreadable' in result.output


def test_create_tmp_dir_option(run_cli_command, tmp_path):
"""Test that the --tmp-dir CLI option passes through correctly."""
node = Dict().store()

custom_tmp = tmp_path / 'custom_tmp'
custom_tmp.mkdir()
filename_output = tmp_path / 'archive.aiida'

options = ['--tmp-dir', str(custom_tmp), '-N', node.pk, '--', filename_output]

run_cli_command(cmd_archive.create, options)
assert filename_output.is_file()
Loading
Loading