diff --git a/transient/image.py b/transient/image.py index 958d5fe..25b5251 100644 --- a/transient/image.py +++ b/transient/image.py @@ -49,37 +49,40 @@ def retrieve_image( # Do partial downloads into the working directory in the backend dest_name = os.path.basename(destination) temp_destination = os.path.join(store.working, dest_name) - fd = self.__lock_backend_destination(temp_destination) - # We now hold the lock. Either another process started the retrieval - # and died (or never started at all) or they completed. If the final file exists, - # the must have completed successfully so just return. - if os.path.exists(destination): - logging.info("Retrieval completed by another processes. Skipping.") - os.close(fd) - return None + with utils.cleanup_on_error( + temp_destination, + "wb+", + opener=self.__lock_backend_destination, + rename=destination, + ) as temp_file: + # We now hold the lock. Either another process started the retrieval + # and died (or never started at all) or they completed. If the final file exists, + # they must have completed successfully so just return. + if os.path.exists(destination): + logging.info("Retrieval completed by another processes. Skipping.") + temp_file.unlink() + return - with os.fdopen(fd, "wb+") as temp_file: + temp_file.truncate(0) # might have been partially downloaded self._do_retrieve_image(store, spec, temp_file) - # Now that the entire file is retrieved, atomically move it to the destination. - # This avoids issues where a process was killed in the middle of retrieval - os.rename(temp_destination, destination) - # There is a qemu hotkey to commit a 'snapshot' to the backing file. # Making the backend images read-only prevents this. - os.chmod(destination, stat.S_IREAD | stat.S_IRGRP | stat.S_IROTH) + os.fchmod(temp_file.fileno(), stat.S_IREAD | stat.S_IRGRP | stat.S_IROTH) def _do_retrieve_image( self, store: "ImageStore", spec: "ImageSpec", destination: IO[bytes] ) -> None: raise RuntimeError("Protocol did not implement '_do_retrieve_image'") - def __lock_backend_destination(self, dest: str) -> int: + @staticmethod + def __lock_backend_destination(dest: str, flags: int) -> int: # By default, python 'open' call will truncate writable files. We can't allow that # as we don't yet hold the flock (and there is no way to open _and_ flock in one # call). So we use os.open to avoid the truncate. - fd = os.open(dest, os.O_RDWR | os.O_CREAT) + flags &= ~os.O_TRUNC + fd = os.open(dest, flags) logging.debug(f"Attempting to acquire lock of '{dest}'") diff --git a/transient/utils.py b/transient/utils.py index 8087d9d..68388be 100644 --- a/transient/utils.py +++ b/transient/utils.py @@ -14,6 +14,7 @@ import uuid import sys import zlib +from contextlib import contextmanager try: import importlib.resources as pkg_resources @@ -21,7 +22,19 @@ # Try backported to PY<37 `importlib_resources`. import importlib_resources as pkg_resources # type: ignore -from typing import cast, Optional, ContextManager, List, Union, IO, Any, Tuple, Callable +from typing import ( + cast, + Optional, + ContextManager, + List, + Union, + IO, + Any, + Tuple, + Iterator, + Callable, + no_type_check, +) from . import static # From the typeshed Popen definitions @@ -407,3 +420,53 @@ def exit(self) -> None: else: errcode = 1 sys.exit(errcode) + + +# MyPI doesn't really support monkey patching like this. Don't check this +# function for now to make the rest of the code happy. Might be a solution +# with custom protocols or something. +@contextmanager +@no_type_check +def cleanup_on_error( + *args: Any, rename: Optional[str] = None, **kwargs: Any +) -> Iterator[IO[Any]]: + """Open (or create) a file, as with open(). If the context manager exits + with an error, then the file will be deleted. Otherwise it will be left + alone. If the rename parameter is given, then the file will be atomically + renamed on success. + + The returned file object will have the following extra parameters: + - rename: The file path after rename. May be cleared or changed. + - unlink(): Deletes the file, regardless of success or failure. + + If you want to use this to wrap a file descriptor, you might need to + manually set fp.name, or use an opener function. + """ + + fp = open(*args, **kwargs) + unlinked = False + + def unlink(): + nonlocal unlinked + if unlinked: + return + + unlinked = True + try: + os.unlink(fp.name) + except AttributeError: + logging.warning("Cannot remove temp file: We don't know it's name") + except OSError as e: + logging.warning("Cannot remove temp file %s: %s", fp.name, e) + + with fp: + try: + fp.rename = rename + fp.unlink = unlink + yield fp + rename = getattr(fp, "rename", None) + if not unlinked and rename is not None: + os.rename(fp.name, rename) + except: + unlink() + raise