1- import asyncio
1+ import warnings
22from contextvars import ContextVar
3- from typing import Dict , Optional , Type , Union
3+ from typing import Dict , Optional , Set , Type , Union
44
55from sqlalchemy .engine .url import URL
66from sqlalchemy .ext .asyncio import AsyncEngine , AsyncSession , create_async_engine
@@ -32,6 +32,9 @@ def create_middleware_and_session_proxy() -> tuple:
3232 _session : ContextVar [Optional [AsyncSession ]] = ContextVar ("_session" , default = None )
3333 _multi_sessions_ctx : ContextVar [bool ] = ContextVar ("_multi_sessions_context" , default = False )
3434 _commit_on_exit_ctx : ContextVar [bool ] = ContextVar ("_commit_on_exit_ctx" , default = False )
35+ _tracked_sessions : ContextVar [Optional [Set [AsyncSession ]]] = ContextVar (
36+ "_tracked_sessions" , default = None
37+ )
3538 # Usage of context vars inside closures is not recommended, since they are not properly
3639 # garbage collected, but in our use case context var is created on program startup and
3740 # is used throughout the whole its lifecycle.
@@ -103,23 +106,21 @@ async def execute_query(query):
103106 await asyncio.gather(*tasks)
104107 ```
105108 """
106- commit_on_exit = _commit_on_exit_ctx .get ()
107109 # Always create a new session for each access when multi_sessions=True
108110 session = _Session ()
109111
110- async def cleanup ():
111- try :
112- if commit_on_exit :
113- await session .commit ()
114- except Exception :
115- await session .rollback ()
116- raise
117- finally :
118- await session .close ()
112+ # Track the session for cleanup in __aexit__
113+ tracked = _tracked_sessions .get ()
114+ if tracked is not None :
115+ tracked .add (session )
116+ else :
117+ warnings .warn (
118+ """Session created in multi_sessions mode but no tracking set found.
119+ This session may leak if not properly closed.""" ,
120+ ResourceWarning ,
121+ stacklevel = 2 ,
122+ )
119123
120- task = asyncio .current_task ()
121- if task is not None :
122- task .add_done_callback (lambda t : asyncio .create_task (cleanup ()))
123124 return session
124125 else :
125126 session = _session .get ()
@@ -136,6 +137,7 @@ def __init__(
136137 ):
137138 self .token = None
138139 self .commit_on_exit_token = None
140+ self .tracked_sessions_token = None
139141 self .session_args = session_args or {}
140142 self .commit_on_exit = commit_on_exit
141143 self .multi_sessions = multi_sessions
@@ -147,21 +149,73 @@ async def __aenter__(self):
147149 if self .multi_sessions :
148150 self .multi_sessions_token = _multi_sessions_ctx .set (True )
149151 self .commit_on_exit_token = _commit_on_exit_ctx .set (self .commit_on_exit )
152+ self .tracked_sessions_token = _tracked_sessions .set (set ())
150153 else :
151154 self .token = _session .set (_Session (** self .session_args ))
152155 return type (self )
153156
154157 async def __aexit__ (self , exc_type , exc_value , traceback ):
155158 if self .multi_sessions :
159+ # Clean up all tracked sessions
160+ tracked_sessions = _tracked_sessions .get ()
161+ if tracked_sessions :
162+ cleanup_errors = []
163+ for session in tracked_sessions :
164+ try :
165+ if exc_type is not None :
166+ await session .rollback ()
167+ elif self .commit_on_exit :
168+ try :
169+ await session .commit ()
170+ except Exception as commit_error :
171+ warnings .warn (
172+ f"Failed to commit in multi_sessions: { commit_error } " ,
173+ RuntimeWarning ,
174+ stacklevel = 2 ,
175+ )
176+ await session .rollback ()
177+ cleanup_errors .append (commit_error )
178+ except Exception as cleanup_error :
179+ warnings .warn (
180+ f"Failed to rollback session in multi_sessions: { cleanup_error } " ,
181+ RuntimeWarning ,
182+ stacklevel = 2 ,
183+ )
184+ cleanup_errors .append (cleanup_error )
185+ finally :
186+ try :
187+ await session .close ()
188+ except Exception as close_error :
189+ warnings .warn (
190+ f"Failed to close session in multi_session: { close_error } " ,
191+ ResourceWarning ,
192+ stacklevel = 2 ,
193+ )
194+ cleanup_errors .append (close_error )
195+
196+ if cleanup_errors and exc_type is None :
197+ warnings .warn (
198+ f"Encountered { len (cleanup_errors )} error(s) during session cleanup" ,
199+ RuntimeWarning ,
200+ stacklevel = 2 ,
201+ )
202+
203+ # Reset context vars
204+ _tracked_sessions .reset (self .tracked_sessions_token )
156205 _multi_sessions_ctx .reset (self .multi_sessions_token )
157206 _commit_on_exit_ctx .reset (self .commit_on_exit_token )
158207 else :
208+ # Standard single-session mode
159209 session = _session .get ()
160210 try :
161211 if exc_type is not None :
162212 await session .rollback ()
163213 elif self .commit_on_exit :
164- await session .commit ()
214+ try :
215+ await session .commit ()
216+ except Exception :
217+ await session .rollback ()
218+ raise
165219 finally :
166220 await session .close ()
167221 _session .reset (self .token )
0 commit comments