11import asyncio
2- from asyncio import Task
32from contextvars import ContextVar
43from typing import Dict , Optional , Union
54
@@ -22,6 +21,9 @@ def create_middleware_and_session_proxy():
2221 _Session : Optional [async_sessionmaker ] = None
2322 _session : ContextVar [Optional [AsyncSession ]] = ContextVar ("_session" , default = None )
2423 _multi_sessions_ctx : ContextVar [bool ] = ContextVar ("_multi_sessions_context" , default = False )
24+ _task_session_ctx : ContextVar [Optional [AsyncSession ]] = ContextVar (
25+ "_task_session_ctx" , default = None
26+ )
2527 _commit_on_exit_ctx : ContextVar [bool ] = ContextVar ("_commit_on_exit_ctx" , default = False )
2628 # Usage of context vars inside closures is not recommended, since they are not properly
2729 # garbage collected, but in our use case context var is created on program startup and
@@ -90,28 +92,26 @@ async def execute_query(query):
9092 ```
9193 """
9294 commit_on_exit = _commit_on_exit_ctx .get ()
93- task : Task = asyncio .current_task () # type: ignore
94- if not hasattr (task , "_db_session" ):
95- task ._db_session = _Session () # type: ignore
96-
97- def cleanup (future ):
98- session = getattr (task , "_db_session" , None )
99- if session :
100-
101- async def do_cleanup ():
102- try :
103- if future .exception ():
104- await session .rollback ()
105- else :
106- if commit_on_exit :
107- await session .commit ()
108- finally :
109- await session .close ()
110-
111- asyncio .create_task (do_cleanup ())
112-
113- task .add_done_callback (cleanup )
114- return task ._db_session # type: ignore
95+ session = _task_session_ctx .get ()
96+ if session is None :
97+ session = _Session ()
98+ _task_session_ctx .set (session )
99+
100+ async def cleanup ():
101+ try :
102+ if commit_on_exit :
103+ await session .commit ()
104+ except Exception :
105+ await session .rollback ()
106+ raise
107+ finally :
108+ await session .close ()
109+ _task_session_ctx .set (None )
110+
111+ task = asyncio .current_task ()
112+ if task is not None :
113+ task .add_done_callback (lambda t : asyncio .create_task (cleanup ()))
114+ return session
115115 else :
116116 session = _session .get ()
117117 if session is None :
@@ -139,23 +139,24 @@ async def __aenter__(self):
139139 if self .multi_sessions :
140140 self .multi_sessions_token = _multi_sessions_ctx .set (True )
141141 self .commit_on_exit_token = _commit_on_exit_ctx .set (self .commit_on_exit )
142-
143- self .token = _session .set (_Session (** self .session_args ))
142+ else :
143+ self .token = _session .set (_Session (** self .session_args ))
144144 return type (self )
145145
146146 async def __aexit__ (self , exc_type , exc_value , traceback ):
147- session = _session .get ()
148- try :
149- if exc_type is not None :
150- await session .rollback ()
151- elif self .commit_on_exit :
152- await session .commit ()
153- finally :
154- await session .close ()
155- _session .reset (self .token )
156- if self .multi_sessions_token is not None :
157- _multi_sessions_ctx .reset (self .multi_sessions_token )
158- _commit_on_exit_ctx .reset (self .commit_on_exit_token )
147+ if self .multi_sessions :
148+ _multi_sessions_ctx .reset (self .multi_sessions_token )
149+ _commit_on_exit_ctx .reset (self .commit_on_exit_token )
150+ else :
151+ session = _session .get ()
152+ try :
153+ if exc_type is not None :
154+ await session .rollback ()
155+ elif self .commit_on_exit :
156+ await session .commit ()
157+ finally :
158+ await session .close ()
159+ _session .reset (self .token )
159160
160161 return SQLAlchemyMiddleware , DBSession
161162
0 commit comments