Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fastapi_async_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@

__all__ = ["db", "SQLAlchemyMiddleware", "create_middleware_and_session_proxy"]

__version__ = "0.7.0.dev5"
__version__ = "0.7.1"
86 changes: 70 additions & 16 deletions fastapi_async_sqlalchemy/middleware.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import warnings
from contextvars import ContextVar
from typing import Dict, Optional, Type, Union
from typing import Dict, Optional, Set, Type, Union

from sqlalchemy.engine.url import URL
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
Expand Down Expand Up @@ -32,6 +32,9 @@ def create_middleware_and_session_proxy() -> tuple:
_session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)
_multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False)
_commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False)
_tracked_sessions: ContextVar[Optional[Set[AsyncSession]]] = ContextVar(
"_tracked_sessions", default=None
)
# Usage of context vars inside closures is not recommended, since they are not properly
# garbage collected, but in our use case context var is created on program startup and
# is used throughout the whole its lifecycle.
Expand Down Expand Up @@ -103,23 +106,21 @@ async def execute_query(query):
await asyncio.gather(*tasks)
```
"""
commit_on_exit = _commit_on_exit_ctx.get()
# Always create a new session for each access when multi_sessions=True
session = _Session()

async def cleanup():
try:
if commit_on_exit:
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
# Track the session for cleanup in __aexit__
tracked = _tracked_sessions.get()
if tracked is not None:
tracked.add(session)
else:
warnings.warn(
"""Session created in multi_sessions mode but no tracking set found.
This session may leak if not properly closed.""",
ResourceWarning,
stacklevel=2,
)

task = asyncio.current_task()
if task is not None:
task.add_done_callback(lambda t: asyncio.create_task(cleanup()))
return session
else:
session = _session.get()
Expand All @@ -136,6 +137,7 @@ def __init__(
):
self.token = None
self.commit_on_exit_token = None
self.tracked_sessions_token = None
self.session_args = session_args or {}
self.commit_on_exit = commit_on_exit
self.multi_sessions = multi_sessions
Expand All @@ -147,21 +149,73 @@ async def __aenter__(self):
if self.multi_sessions:
self.multi_sessions_token = _multi_sessions_ctx.set(True)
self.commit_on_exit_token = _commit_on_exit_ctx.set(self.commit_on_exit)
self.tracked_sessions_token = _tracked_sessions.set(set())
else:
self.token = _session.set(_Session(**self.session_args))
return type(self)

async def __aexit__(self, exc_type, exc_value, traceback):
if self.multi_sessions:
# Clean up all tracked sessions
tracked_sessions = _tracked_sessions.get()
if tracked_sessions:
cleanup_errors = []
for session in tracked_sessions:
try:
if exc_type is not None:
await session.rollback()
elif self.commit_on_exit:
try:
await session.commit()
except Exception as commit_error:
warnings.warn(
f"Failed to commit in multi_sessions: {commit_error}",
RuntimeWarning,
stacklevel=2,
)
await session.rollback()
cleanup_errors.append(commit_error)
except Exception as cleanup_error:
warnings.warn(
f"Failed to rollback session in multi_sessions: {cleanup_error}",
RuntimeWarning,
stacklevel=2,
)
cleanup_errors.append(cleanup_error)
finally:
try:
await session.close()
except Exception as close_error:
warnings.warn(
f"Failed to close session in multi_session: {close_error}",
ResourceWarning,
stacklevel=2,
)
cleanup_errors.append(close_error)

if cleanup_errors and exc_type is None:
warnings.warn(
f"Encountered {len(cleanup_errors)} error(s) during session cleanup",
RuntimeWarning,
stacklevel=2,
)

# Reset context vars
_tracked_sessions.reset(self.tracked_sessions_token)
_multi_sessions_ctx.reset(self.multi_sessions_token)
_commit_on_exit_ctx.reset(self.commit_on_exit_token)
else:
# Standard single-session mode
session = _session.get()
try:
if exc_type is not None:
await session.rollback()
elif self.commit_on_exit:
await session.commit()
try:
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
_session.reset(self.token)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ SQLAlchemy>=1.4.19
sqlmodel>=0.0.24
asyncpg>=0.27.0
aiosqlite==0.20.0
sqlparse==0.5.1
sqlparse>=0.5.4
starlette>=0.13.6
toml>=0.10.1
typed-ast>=1.4.2
Expand Down
107 changes: 107 additions & 0 deletions tests/test_custom_engine_branch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
Targeted test to ensure custom_engine branch (line 61) is executed
"""

import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql import text

from fastapi_async_sqlalchemy.middleware import create_middleware_and_session_proxy


@pytest.mark.asyncio
async def test_custom_engine_branch_with_actual_usage():
"""
Ensure custom_engine else branch (line 61) is covered
by actually using the middleware with a custom engine
"""
# Create fresh middleware and db instances
SQLAlchemyMiddleware, db = create_middleware_and_session_proxy()

# Create a custom async engine
custom_engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)

app = FastAPI()

# Add middleware with custom_engine
# This should execute: else: engine = custom_engine (line 61)
app.add_middleware(SQLAlchemyMiddleware, custom_engine=custom_engine, commit_on_exit=False)

# Create endpoint to test the session works
@app.get("/test")
async def test_endpoint():
async with db():
result = await db.session.execute(text("SELECT 42 as value"))
value = result.scalar()
return {"value": value}

# Test the endpoint
client = TestClient(app)
response = client.get("/test")

assert response.status_code == 200
assert response.json()["value"] == 42


def test_custom_engine_without_db_url():
"""
Verify custom_engine can be used without providing db_url
This ensures the else branch is used
"""
SQLAlchemyMiddleware, db = create_middleware_and_session_proxy()

app = FastAPI()

# Create custom engine
custom_engine = create_async_engine("sqlite+aiosqlite:///:memory:")

# Initialize middleware with ONLY custom_engine (no db_url)
# This should take the else branch at line 61
middleware = SQLAlchemyMiddleware(
app, custom_engine=custom_engine, engine_args={}, session_args={}
)

assert middleware is not None
assert middleware.commit_on_exit is False


def test_custom_engine_with_session_args():
"""
Test custom_engine with various session_args
"""
SQLAlchemyMiddleware, db = create_middleware_and_session_proxy()

app = FastAPI()

custom_engine = create_async_engine("sqlite+aiosqlite://")

# Use custom engine with session args
middleware = SQLAlchemyMiddleware(
app, custom_engine=custom_engine, session_args={"autoflush": False}, commit_on_exit=True
)

assert middleware is not None
assert middleware.commit_on_exit is True


def test_custom_engine_multiple_instances():
"""
Test multiple middleware instances with different custom engines
"""
SQLAlchemyMiddleware1, db1 = create_middleware_and_session_proxy()
SQLAlchemyMiddleware2, db2 = create_middleware_and_session_proxy()

app = FastAPI()

# Create two different custom engines
engine1 = create_async_engine("sqlite+aiosqlite:///:memory:")
engine2 = create_async_engine("sqlite+aiosqlite://")

# Create two middleware instances
middleware1 = SQLAlchemyMiddleware1(app, custom_engine=engine1)
middleware2 = SQLAlchemyMiddleware2(app, custom_engine=engine2)

assert middleware1 is not None
assert middleware2 is not None
Loading
Loading