33
44import asyncio
55import logging
6+ import sys
67from typing import TYPE_CHECKING , Any , Union
78
9+ if sys .version_info >= (3 , 11 ):
10+ from builtins import ExceptionGroup
11+ else :
12+ from exceptiongroup import ExceptionGroup # type: ignore[import-not-found,unused-ignore]
13+
814if TYPE_CHECKING :
9- from collections .abc import Awaitable
1015 from pathlib import Path
1116
1217 from advanced_alchemy .types .file_object import FileObject
1722class FileObjectSessionTracker :
1823 """Tracks FileObject changes within a single session transaction."""
1924
20- def __init__ (self ) -> None :
21- """Initialize the tracker."""
25+ def __init__ (self , raise_on_error : bool = False ) -> None :
26+ """Initialize empty tracking state.
27+
28+ Args:
29+ raise_on_error: If True, raise exceptions on file operation failures.
30+ If False, log warnings and continue.
31+
32+ Internal structures:
33+ - ``pending_saves``: ``FileObject -> data`` to be saved on commit
34+ - ``pending_deletes``: ``FileObject`` instances to delete on commit
35+ - ``_saved_in_transaction``: successfully saved objects used for
36+ selective cleanup on rollback
37+ """
38+ self .raise_on_error = raise_on_error
2239 # Stores objects that have pending data to be saved on commit.
2340 # Maps FileObject -> data source (bytes or Path)
2441 self .pending_saves : "dict[FileObject, Union[bytes, Path]]" = {}
@@ -47,43 +64,94 @@ def commit(self) -> None:
4764 for obj , data in self .pending_saves .items ():
4865 try :
4966 obj .save (data )
50- except Exception as e : # noqa: BLE001
51- logger .warning ("Error saving file for object %s: %s" , obj , e .__cause__ )
67+ self ._saved_in_transaction .add (obj )
68+ except Exception :
69+ if self .raise_on_error :
70+ logger .exception ("error saving file for object %s" , obj )
71+ raise
72+ logger .warning ("error saving file for object %s" , obj , exc_info = True )
73+
5274 for obj in self .pending_deletes :
5375 try :
5476 obj .delete ()
5577 except FileNotFoundError :
56- # Ignore if the file is already gone (shouldn't happen often here)
5778 pass
58- except Exception as e : # noqa: BLE001
59- logger .warning ("Error deleting file for object %s: %s" , obj , e .__cause__ )
79+ except Exception :
80+ if self .raise_on_error :
81+ logger .exception ("error deleting file for object %s" , obj )
82+ raise
83+ logger .warning ("error deleting file for object %s" , obj , exc_info = True )
84+
6085 self .clear ()
6186
6287 async def commit_async (self ) -> None :
6388 """Process pending saves and deletes after a successful commit."""
64- save_tasks : list [Awaitable [Any ]] = []
65- for obj , data in self .pending_saves .items ():
66- save_tasks .append (obj .save_async (data ))
67- self ._saved_in_transaction .add (obj )
68-
69- delete_tasks : list [Awaitable [Any ]] = [obj .delete_async () for obj in self .pending_deletes ]
7089
71- # Run save and delete tasks concurrently
72- save_results = await asyncio .gather (* save_tasks , return_exceptions = True )
73- delete_results = await asyncio .gather (* delete_tasks , return_exceptions = True )
74-
75- # Process save results (log errors)
76- for result , (obj , _data ) in zip (save_results , self .pending_saves .items ()):
77- if isinstance (result , Exception ):
78- logger .warning ("Error saving file for object %s: %s" , obj , result .__cause__ )
79- # Process delete results (log errors, ignore FileNotFoundError)
80- for result , obj_to_delete in zip (delete_results , self .pending_deletes ):
90+ save_items : "list[tuple[FileObject, Union[bytes, Path]]]" = list (self .pending_saves .items ())
91+ delete_items : "list[FileObject]" = list (self .pending_deletes )
92+
93+ save_results : "list[Any]" = await asyncio .gather (
94+ * (obj .save_async (data ) for obj , data in save_items ),
95+ return_exceptions = True ,
96+ )
97+ delete_results : "list[Any]" = await asyncio .gather (
98+ * (obj .delete_async () for obj in delete_items ),
99+ return_exceptions = True ,
100+ )
101+
102+ errors : list [Exception ] = []
103+
104+ for (obj , _data ), result in zip (save_items , save_results ):
105+ if isinstance (result , BaseException ):
106+ if isinstance (result , Exception ):
107+ if self .raise_on_error :
108+ logger .error (
109+ "error saving file for object %s" ,
110+ obj ,
111+ exc_info = (type (result ), result , result .__traceback__ ),
112+ )
113+ else :
114+ # Legacy behavior: warning level
115+ logger .warning (
116+ "error saving file for object %s" ,
117+ obj ,
118+ exc_info = (type (result ), result , result .__traceback__ ),
119+ )
120+ errors .append (result )
121+ else :
122+ # BaseException (e.g., CancelledError) - always raise
123+ raise result
124+ else :
125+ self ._saved_in_transaction .add (obj )
126+
127+ for obj_to_delete , result in zip (delete_items , delete_results ):
81128 if isinstance (result , FileNotFoundError ):
82129 continue
83- if isinstance (result , Exception ):
84- logger .warning ("Error deleting file %s: %s" , obj_to_delete .path , result .__cause__ )
85-
86- self .clear ()
130+ if isinstance (result , BaseException ):
131+ if isinstance (result , Exception ):
132+ if self .raise_on_error :
133+ logger .error (
134+ "error deleting file %s" ,
135+ obj_to_delete .path or obj_to_delete ,
136+ exc_info = (type (result ), result , result .__traceback__ ),
137+ )
138+ else :
139+ logger .warning (
140+ "error deleting file %s" ,
141+ obj_to_delete .path or obj_to_delete ,
142+ exc_info = (type (result ), result , result .__traceback__ ),
143+ )
144+ errors .append (result )
145+ else :
146+ raise result
147+
148+ if errors and self .raise_on_error :
149+ if len (errors ) == 1 :
150+ raise errors [0 ]
151+ msg = "multiple FileObject operation failures"
152+ raise ExceptionGroup (msg , errors )
153+ if not errors :
154+ self .clear ()
87155
88156 def rollback (self ) -> None :
89157 """Clean up files saved during a transaction that is being rolled back."""
@@ -94,30 +162,45 @@ def rollback(self) -> None:
94162 except FileNotFoundError :
95163 # Ignore if the file is already gone (shouldn't happen often here)
96164 pass
97- except Exception as e : # noqa: BLE001
98- logger .warning ("Error deleting file during rollback %s: %s" , obj .path , e .__cause__ )
165+ except Exception :
166+ logger .exception ("error deleting file during rollback %s" , obj .path or obj )
167+ raise
99168 self .clear ()
100169
101170 async def rollback_async (self ) -> None :
102171 """Clean up files saved during a transaction that is being rolled back."""
103- rollback_delete_tasks : list [Awaitable [Any ]] = []
104- objects_to_delete_on_rollback : list [FileObject ] = []
105- # Only delete files that were actually saved *during this transaction*
106- for obj in self ._saved_in_transaction :
107- if obj .path :
108- rollback_delete_tasks .append (obj .delete_async ())
109- objects_to_delete_on_rollback .append (obj )
110-
111- for task , obj_to_delete in zip (rollback_delete_tasks , objects_to_delete_on_rollback ):
112- try :
113- await task
114- except FileNotFoundError :
115- # Ignore if the file is already gone (shouldn't happen often here)
116- pass
117- except Exception as e : # noqa: BLE001
118- logger .warning ("Error deleting file during rollback %s: %s" , obj_to_delete .path , e .__cause__ )
172+ objects_to_delete = [obj for obj in self ._saved_in_transaction if obj .path ]
173+ if not objects_to_delete :
174+ self .clear ()
175+ return
176+
177+ delete_results = await asyncio .gather (
178+ * (obj .delete_async () for obj in objects_to_delete ),
179+ return_exceptions = True ,
180+ )
181+
182+ errors : list [Exception ] = []
183+ for obj , result in zip (objects_to_delete , delete_results ):
184+ if isinstance (result , FileNotFoundError ):
185+ continue
186+ if isinstance (result , BaseException ):
187+ if isinstance (result , Exception ):
188+ logger .error (
189+ "error deleting file during rollback %s" ,
190+ obj .path or obj ,
191+ exc_info = (type (result ), result , result .__traceback__ ),
192+ )
193+ errors .append (result )
194+ else :
195+ # Propagate BaseExceptions like CancelledError
196+ raise result
119197
120198 self .clear ()
199+ if errors :
200+ if len (errors ) == 1 :
201+ raise errors [0 ]
202+ msg = "multiple FileObject rollback failures"
203+ raise ExceptionGroup (msg , errors )
121204
122205 def clear (self ) -> None :
123206 """Clear the tracker's state."""
0 commit comments