Skip to content

Commit d090ad9

Browse files
authored
Merge pull request #30 from h0rn3t/multi_sessions_wip-3
bump version to 0.7.1 and enhance multi-session cleanup with tracking
2 parents 1e6077e + a67f25a commit d090ad9

File tree

8 files changed

+1223
-18
lines changed

8 files changed

+1223
-18
lines changed

fastapi_async_sqlalchemy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66

77
__all__ = ["db", "SQLAlchemyMiddleware", "create_middleware_and_session_proxy"]
88

9-
__version__ = "0.7.0.dev5"
9+
__version__ = "0.7.1"

fastapi_async_sqlalchemy/middleware.py

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import asyncio
1+
import warnings
22
from contextvars import ContextVar
3-
from typing import Dict, Optional, Type, Union
3+
from typing import Dict, Optional, Set, Type, Union
44

55
from sqlalchemy.engine.url import URL
66
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
@@ -32,6 +32,9 @@ def create_middleware_and_session_proxy() -> tuple:
3232
_session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)
3333
_multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False)
3434
_commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False)
35+
_tracked_sessions: ContextVar[Optional[Set[AsyncSession]]] = ContextVar(
36+
"_tracked_sessions", default=None
37+
)
3538
# Usage of context vars inside closures is not recommended, since they are not properly
3639
# garbage collected, but in our use case context var is created on program startup and
3740
# is used throughout the whole its lifecycle.
@@ -103,23 +106,21 @@ async def execute_query(query):
103106
await asyncio.gather(*tasks)
104107
```
105108
"""
106-
commit_on_exit = _commit_on_exit_ctx.get()
107109
# Always create a new session for each access when multi_sessions=True
108110
session = _Session()
109111

110-
async def cleanup():
111-
try:
112-
if commit_on_exit:
113-
await session.commit()
114-
except Exception:
115-
await session.rollback()
116-
raise
117-
finally:
118-
await session.close()
112+
# Track the session for cleanup in __aexit__
113+
tracked = _tracked_sessions.get()
114+
if tracked is not None:
115+
tracked.add(session)
116+
else:
117+
warnings.warn(
118+
"""Session created in multi_sessions mode but no tracking set found.
119+
This session may leak if not properly closed.""",
120+
ResourceWarning,
121+
stacklevel=2,
122+
)
119123

120-
task = asyncio.current_task()
121-
if task is not None:
122-
task.add_done_callback(lambda t: asyncio.create_task(cleanup()))
123124
return session
124125
else:
125126
session = _session.get()
@@ -136,6 +137,7 @@ def __init__(
136137
):
137138
self.token = None
138139
self.commit_on_exit_token = None
140+
self.tracked_sessions_token = None
139141
self.session_args = session_args or {}
140142
self.commit_on_exit = commit_on_exit
141143
self.multi_sessions = multi_sessions
@@ -147,21 +149,73 @@ async def __aenter__(self):
147149
if self.multi_sessions:
148150
self.multi_sessions_token = _multi_sessions_ctx.set(True)
149151
self.commit_on_exit_token = _commit_on_exit_ctx.set(self.commit_on_exit)
152+
self.tracked_sessions_token = _tracked_sessions.set(set())
150153
else:
151154
self.token = _session.set(_Session(**self.session_args))
152155
return type(self)
153156

154157
async def __aexit__(self, exc_type, exc_value, traceback):
155158
if self.multi_sessions:
159+
# Clean up all tracked sessions
160+
tracked_sessions = _tracked_sessions.get()
161+
if tracked_sessions:
162+
cleanup_errors = []
163+
for session in tracked_sessions:
164+
try:
165+
if exc_type is not None:
166+
await session.rollback()
167+
elif self.commit_on_exit:
168+
try:
169+
await session.commit()
170+
except Exception as commit_error:
171+
warnings.warn(
172+
f"Failed to commit in multi_sessions: {commit_error}",
173+
RuntimeWarning,
174+
stacklevel=2,
175+
)
176+
await session.rollback()
177+
cleanup_errors.append(commit_error)
178+
except Exception as cleanup_error:
179+
warnings.warn(
180+
f"Failed to rollback session in multi_sessions: {cleanup_error}",
181+
RuntimeWarning,
182+
stacklevel=2,
183+
)
184+
cleanup_errors.append(cleanup_error)
185+
finally:
186+
try:
187+
await session.close()
188+
except Exception as close_error:
189+
warnings.warn(
190+
f"Failed to close session in multi_session: {close_error}",
191+
ResourceWarning,
192+
stacklevel=2,
193+
)
194+
cleanup_errors.append(close_error)
195+
196+
if cleanup_errors and exc_type is None:
197+
warnings.warn(
198+
f"Encountered {len(cleanup_errors)} error(s) during session cleanup",
199+
RuntimeWarning,
200+
stacklevel=2,
201+
)
202+
203+
# Reset context vars
204+
_tracked_sessions.reset(self.tracked_sessions_token)
156205
_multi_sessions_ctx.reset(self.multi_sessions_token)
157206
_commit_on_exit_ctx.reset(self.commit_on_exit_token)
158207
else:
208+
# Standard single-session mode
159209
session = _session.get()
160210
try:
161211
if exc_type is not None:
162212
await session.rollback()
163213
elif self.commit_on_exit:
164-
await session.commit()
214+
try:
215+
await session.commit()
216+
except Exception:
217+
await session.rollback()
218+
raise
165219
finally:
166220
await session.close()
167221
_session.reset(self.token)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ SQLAlchemy>=1.4.19
3131
sqlmodel>=0.0.24
3232
asyncpg>=0.27.0
3333
aiosqlite==0.20.0
34-
sqlparse==0.5.1
34+
sqlparse>=0.5.4
3535
starlette>=0.13.6
3636
toml>=0.10.1
3737
typed-ast>=1.4.2

tests/test_custom_engine_branch.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""
2+
Targeted test to ensure custom_engine branch (line 61) is executed
3+
"""
4+
5+
import pytest
6+
from fastapi import FastAPI
7+
from fastapi.testclient import TestClient
8+
from sqlalchemy.ext.asyncio import create_async_engine
9+
from sqlalchemy.sql import text
10+
11+
from fastapi_async_sqlalchemy.middleware import create_middleware_and_session_proxy
12+
13+
14+
@pytest.mark.asyncio
15+
async def test_custom_engine_branch_with_actual_usage():
16+
"""
17+
Ensure custom_engine else branch (line 61) is covered
18+
by actually using the middleware with a custom engine
19+
"""
20+
# Create fresh middleware and db instances
21+
SQLAlchemyMiddleware, db = create_middleware_and_session_proxy()
22+
23+
# Create a custom async engine
24+
custom_engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
25+
26+
app = FastAPI()
27+
28+
# Add middleware with custom_engine
29+
# This should execute: else: engine = custom_engine (line 61)
30+
app.add_middleware(SQLAlchemyMiddleware, custom_engine=custom_engine, commit_on_exit=False)
31+
32+
# Create endpoint to test the session works
33+
@app.get("/test")
34+
async def test_endpoint():
35+
async with db():
36+
result = await db.session.execute(text("SELECT 42 as value"))
37+
value = result.scalar()
38+
return {"value": value}
39+
40+
# Test the endpoint
41+
client = TestClient(app)
42+
response = client.get("/test")
43+
44+
assert response.status_code == 200
45+
assert response.json()["value"] == 42
46+
47+
48+
def test_custom_engine_without_db_url():
49+
"""
50+
Verify custom_engine can be used without providing db_url
51+
This ensures the else branch is used
52+
"""
53+
SQLAlchemyMiddleware, db = create_middleware_and_session_proxy()
54+
55+
app = FastAPI()
56+
57+
# Create custom engine
58+
custom_engine = create_async_engine("sqlite+aiosqlite:///:memory:")
59+
60+
# Initialize middleware with ONLY custom_engine (no db_url)
61+
# This should take the else branch at line 61
62+
middleware = SQLAlchemyMiddleware(
63+
app, custom_engine=custom_engine, engine_args={}, session_args={}
64+
)
65+
66+
assert middleware is not None
67+
assert middleware.commit_on_exit is False
68+
69+
70+
def test_custom_engine_with_session_args():
71+
"""
72+
Test custom_engine with various session_args
73+
"""
74+
SQLAlchemyMiddleware, db = create_middleware_and_session_proxy()
75+
76+
app = FastAPI()
77+
78+
custom_engine = create_async_engine("sqlite+aiosqlite://")
79+
80+
# Use custom engine with session args
81+
middleware = SQLAlchemyMiddleware(
82+
app, custom_engine=custom_engine, session_args={"autoflush": False}, commit_on_exit=True
83+
)
84+
85+
assert middleware is not None
86+
assert middleware.commit_on_exit is True
87+
88+
89+
def test_custom_engine_multiple_instances():
90+
"""
91+
Test multiple middleware instances with different custom engines
92+
"""
93+
SQLAlchemyMiddleware1, db1 = create_middleware_and_session_proxy()
94+
SQLAlchemyMiddleware2, db2 = create_middleware_and_session_proxy()
95+
96+
app = FastAPI()
97+
98+
# Create two different custom engines
99+
engine1 = create_async_engine("sqlite+aiosqlite:///:memory:")
100+
engine2 = create_async_engine("sqlite+aiosqlite://")
101+
102+
# Create two middleware instances
103+
middleware1 = SQLAlchemyMiddleware1(app, custom_engine=engine1)
104+
middleware2 = SQLAlchemyMiddleware2(app, custom_engine=engine2)
105+
106+
assert middleware1 is not None
107+
assert middleware2 is not None

0 commit comments

Comments
 (0)