Skip to content

Commit c0ae8f4

Browse files
authored
fix: surface FileObject session errors; align commit/rollback semantics (#580)
Storage failures are now properly tracked and raised when there's an issue during the session tracking/file handling listener.
1 parent 05a413b commit c0ae8f4

File tree

6 files changed

+554
-50
lines changed

6 files changed

+554
-50
lines changed

advanced_alchemy/_listeners.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,17 @@ def is_async_context() -> bool:
5454
return _is_async_context.get()
5555

5656

57-
def _get_session_tracker(create: bool = True) -> Optional["FileObjectSessionTracker"]:
57+
def _get_session_tracker(
58+
create: bool = True, session: Optional["Session"] = None
59+
) -> Optional["FileObjectSessionTracker"]:
5860
from advanced_alchemy.types.file_object import FileObjectSessionTracker
5961

6062
tracker = _current_session_tracker.get()
6163
if tracker is None and create:
62-
tracker = FileObjectSessionTracker()
64+
raise_on_error = True
65+
if session is not None:
66+
raise_on_error = session.info.get("file_object_raise_on_error", True)
67+
tracker = FileObjectSessionTracker(raise_on_error=raise_on_error)
6368
_current_session_tracker.set(tracker)
6469
return tracker
6570

@@ -348,7 +353,7 @@ def before_flush(cls, session: "Session", flush_context: "UOWTransaction", insta
348353
if not cls._is_listener_enabled(session):
349354
return
350355

351-
tracker = _get_session_tracker(create=True)
356+
tracker = _get_session_tracker(create=True, session=session)
352357
if not tracker:
353358
return
354359

advanced_alchemy/config/common.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,12 @@ class GenericSQLAlchemyConfig(Generic[EngineT, SessionT, SessionMakerT]):
181181
This is a listener that will automatically save and delete :class:`FileObject <advanced_alchemy.types.file_object.FileObject>` instances when they are saved or deleted.
182182
183183
Disable if you plan to bring your own save/delete mechanism for these columns"""
184+
file_object_raise_on_error: bool = True
185+
"""Control FileObject error handling behavior.
186+
187+
- ``False``: Log warnings on file operation failures, don't raise exceptions
188+
- ``True`` (default): Raise exceptions on file operation failures
189+
"""
184190
_SESSION_SCOPE_KEY_REGISTRY: "ClassVar[set[str]]" = field(init=False, default=cast("set[str]", set()))
185191
"""Internal counter for ensuring unique identification of session scope keys in the class."""
186192
_ENGINE_APP_STATE_KEY_REGISTRY: "ClassVar[set[str]]" = field(init=False, default=cast("set[str]", set()))
@@ -208,6 +214,13 @@ def __post_init__(self) -> None:
208214

209215
setup_file_object_listeners()
210216

217+
# Store file_object_raise_on_error in session_config.info
218+
# Ensure session_config.info is a dict (convert from Empty if needed)
219+
if self.session_config.info is Empty:
220+
self.session_config.info = {}
221+
if isinstance(self.session_config.info, dict):
222+
self.session_config.info["file_object_raise_on_error"] = self.file_object_raise_on_error
223+
211224
def __hash__(self) -> int: # pragma: no cover
212225
return hash(
213226
(

advanced_alchemy/types/file_object/session_tracker.py

Lines changed: 129 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@
33

44
import asyncio
55
import logging
6+
import sys
67
from typing import TYPE_CHECKING, Any, Union
78

9+
if sys.version_info >= (3, 11):
10+
from builtins import ExceptionGroup
11+
else:
12+
from exceptiongroup import ExceptionGroup # type: ignore[import-not-found,unused-ignore]
13+
814
if TYPE_CHECKING:
9-
from collections.abc import Awaitable
1015
from pathlib import Path
1116

1217
from advanced_alchemy.types.file_object import FileObject
@@ -17,8 +22,20 @@
1722
class FileObjectSessionTracker:
1823
"""Tracks FileObject changes within a single session transaction."""
1924

20-
def __init__(self) -> None:
21-
"""Initialize the tracker."""
25+
def __init__(self, raise_on_error: bool = False) -> None:
26+
"""Initialize empty tracking state.
27+
28+
Args:
29+
raise_on_error: If True, raise exceptions on file operation failures.
30+
If False, log warnings and continue.
31+
32+
Internal structures:
33+
- ``pending_saves``: ``FileObject -> data`` to be saved on commit
34+
- ``pending_deletes``: ``FileObject`` instances to delete on commit
35+
- ``_saved_in_transaction``: successfully saved objects used for
36+
selective cleanup on rollback
37+
"""
38+
self.raise_on_error = raise_on_error
2239
# Stores objects that have pending data to be saved on commit.
2340
# Maps FileObject -> data source (bytes or Path)
2441
self.pending_saves: "dict[FileObject, Union[bytes, Path]]" = {}
@@ -47,43 +64,94 @@ def commit(self) -> None:
4764
for obj, data in self.pending_saves.items():
4865
try:
4966
obj.save(data)
50-
except Exception as e: # noqa: BLE001
51-
logger.warning("Error saving file for object %s: %s", obj, e.__cause__)
67+
self._saved_in_transaction.add(obj)
68+
except Exception:
69+
if self.raise_on_error:
70+
logger.exception("error saving file for object %s", obj)
71+
raise
72+
logger.warning("error saving file for object %s", obj, exc_info=True)
73+
5274
for obj in self.pending_deletes:
5375
try:
5476
obj.delete()
5577
except FileNotFoundError:
56-
# Ignore if the file is already gone (shouldn't happen often here)
5778
pass
58-
except Exception as e: # noqa: BLE001
59-
logger.warning("Error deleting file for object %s: %s", obj, e.__cause__)
79+
except Exception:
80+
if self.raise_on_error:
81+
logger.exception("error deleting file for object %s", obj)
82+
raise
83+
logger.warning("error deleting file for object %s", obj, exc_info=True)
84+
6085
self.clear()
6186

6287
async def commit_async(self) -> None:
6388
"""Process pending saves and deletes after a successful commit."""
64-
save_tasks: list[Awaitable[Any]] = []
65-
for obj, data in self.pending_saves.items():
66-
save_tasks.append(obj.save_async(data))
67-
self._saved_in_transaction.add(obj)
68-
69-
delete_tasks: list[Awaitable[Any]] = [obj.delete_async() for obj in self.pending_deletes]
7089

71-
# Run save and delete tasks concurrently
72-
save_results = await asyncio.gather(*save_tasks, return_exceptions=True)
73-
delete_results = await asyncio.gather(*delete_tasks, return_exceptions=True)
74-
75-
# Process save results (log errors)
76-
for result, (obj, _data) in zip(save_results, self.pending_saves.items()):
77-
if isinstance(result, Exception):
78-
logger.warning("Error saving file for object %s: %s", obj, result.__cause__)
79-
# Process delete results (log errors, ignore FileNotFoundError)
80-
for result, obj_to_delete in zip(delete_results, self.pending_deletes):
90+
save_items: "list[tuple[FileObject, Union[bytes, Path]]]" = list(self.pending_saves.items())
91+
delete_items: "list[FileObject]" = list(self.pending_deletes)
92+
93+
save_results: "list[Any]" = await asyncio.gather(
94+
*(obj.save_async(data) for obj, data in save_items),
95+
return_exceptions=True,
96+
)
97+
delete_results: "list[Any]" = await asyncio.gather(
98+
*(obj.delete_async() for obj in delete_items),
99+
return_exceptions=True,
100+
)
101+
102+
errors: list[Exception] = []
103+
104+
for (obj, _data), result in zip(save_items, save_results):
105+
if isinstance(result, BaseException):
106+
if isinstance(result, Exception):
107+
if self.raise_on_error:
108+
logger.error(
109+
"error saving file for object %s",
110+
obj,
111+
exc_info=(type(result), result, result.__traceback__),
112+
)
113+
else:
114+
# Legacy behavior: warning level
115+
logger.warning(
116+
"error saving file for object %s",
117+
obj,
118+
exc_info=(type(result), result, result.__traceback__),
119+
)
120+
errors.append(result)
121+
else:
122+
# BaseException (e.g., CancelledError) - always raise
123+
raise result
124+
else:
125+
self._saved_in_transaction.add(obj)
126+
127+
for obj_to_delete, result in zip(delete_items, delete_results):
81128
if isinstance(result, FileNotFoundError):
82129
continue
83-
if isinstance(result, Exception):
84-
logger.warning("Error deleting file %s: %s", obj_to_delete.path, result.__cause__)
85-
86-
self.clear()
130+
if isinstance(result, BaseException):
131+
if isinstance(result, Exception):
132+
if self.raise_on_error:
133+
logger.error(
134+
"error deleting file %s",
135+
obj_to_delete.path or obj_to_delete,
136+
exc_info=(type(result), result, result.__traceback__),
137+
)
138+
else:
139+
logger.warning(
140+
"error deleting file %s",
141+
obj_to_delete.path or obj_to_delete,
142+
exc_info=(type(result), result, result.__traceback__),
143+
)
144+
errors.append(result)
145+
else:
146+
raise result
147+
148+
if errors and self.raise_on_error:
149+
if len(errors) == 1:
150+
raise errors[0]
151+
msg = "multiple FileObject operation failures"
152+
raise ExceptionGroup(msg, errors)
153+
if not errors:
154+
self.clear()
87155

88156
def rollback(self) -> None:
89157
"""Clean up files saved during a transaction that is being rolled back."""
@@ -94,30 +162,45 @@ def rollback(self) -> None:
94162
except FileNotFoundError:
95163
# Ignore if the file is already gone (shouldn't happen often here)
96164
pass
97-
except Exception as e: # noqa: BLE001
98-
logger.warning("Error deleting file during rollback %s: %s", obj.path, e.__cause__)
165+
except Exception:
166+
logger.exception("error deleting file during rollback %s", obj.path or obj)
167+
raise
99168
self.clear()
100169

101170
async def rollback_async(self) -> None:
102171
"""Clean up files saved during a transaction that is being rolled back."""
103-
rollback_delete_tasks: list[Awaitable[Any]] = []
104-
objects_to_delete_on_rollback: list[FileObject] = []
105-
# Only delete files that were actually saved *during this transaction*
106-
for obj in self._saved_in_transaction:
107-
if obj.path:
108-
rollback_delete_tasks.append(obj.delete_async())
109-
objects_to_delete_on_rollback.append(obj)
110-
111-
for task, obj_to_delete in zip(rollback_delete_tasks, objects_to_delete_on_rollback):
112-
try:
113-
await task
114-
except FileNotFoundError:
115-
# Ignore if the file is already gone (shouldn't happen often here)
116-
pass
117-
except Exception as e: # noqa: BLE001
118-
logger.warning("Error deleting file during rollback %s: %s", obj_to_delete.path, e.__cause__)
172+
objects_to_delete = [obj for obj in self._saved_in_transaction if obj.path]
173+
if not objects_to_delete:
174+
self.clear()
175+
return
176+
177+
delete_results = await asyncio.gather(
178+
*(obj.delete_async() for obj in objects_to_delete),
179+
return_exceptions=True,
180+
)
181+
182+
errors: list[Exception] = []
183+
for obj, result in zip(objects_to_delete, delete_results):
184+
if isinstance(result, FileNotFoundError):
185+
continue
186+
if isinstance(result, BaseException):
187+
if isinstance(result, Exception):
188+
logger.error(
189+
"error deleting file during rollback %s",
190+
obj.path or obj,
191+
exc_info=(type(result), result, result.__traceback__),
192+
)
193+
errors.append(result)
194+
else:
195+
# Propagate BaseExceptions like CancelledError
196+
raise result
119197

120198
self.clear()
199+
if errors:
200+
if len(errors) == 1:
201+
raise errors[0]
202+
msg = "multiple FileObject rollback failures"
203+
raise ExceptionGroup(msg, errors)
121204

122205
def clear(self) -> None:
123206
"""Clear the tracker's state."""

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ dependencies = [
3030
"typing-extensions>=4.0.0",
3131
"greenlet",
3232
"eval-type-backport ; python_full_version < '3.10'",
33+
"exceptiongroup ; python_full_version < '3.11'",
3334
]
3435
description = "Ready-to-go SQLAlchemy concoctions."
3536
keywords = ["sqlalchemy", "alembic", "litestar", "sanic", "fastapi", "flask"]

tests/unit/test_extensions/test_litestar/test_init_plugin/test_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def test_session_config_dict_with_no_provided_config(
4646
) -> None:
4747
"""Test session_config_dict with no provided config."""
4848
config = config_cls()
49-
assert config.session_config_dict == {}
49+
# Config now includes file_object_raise_on_error in session info by default
50+
assert config.session_config_dict == {"info": {"file_object_raise_on_error": True}}
5051

5152

5253
def test_config_create_engine_if_engine_instance_provided(

0 commit comments

Comments
 (0)