11import asyncio
22from contextvars import ContextVar
3- from typing import Dict , Optional , Union
3+ from typing import Dict , Optional , Type , Union
44
5- from sqlalchemy .engine import Engine
65from sqlalchemy .engine .url import URL
7- from sqlalchemy .ext .asyncio import AsyncSession , async_sessionmaker , create_async_engine
6+ from sqlalchemy .ext .asyncio import AsyncEngine , AsyncSession , create_async_engine
87from starlette .middleware .base import BaseHTTPMiddleware , RequestResponseEndpoint
98from starlette .requests import Request
109from starlette .types import ASGIApp
1110
12- from fastapi_async_sqlalchemy .exceptions import MissingSessionError , SessionNotInitialisedError
11+ from fastapi_async_sqlalchemy .exceptions import (
12+ MissingSessionError ,
13+ SessionNotInitialisedError ,
14+ )
1315
1416try :
15- from sqlalchemy .ext .asyncio import async_sessionmaker # noqa: F811
17+ from sqlalchemy .ext .asyncio import async_sessionmaker
1618except ImportError :
17- from sqlalchemy .orm import sessionmaker as async_sessionmaker
19+ from sqlalchemy .orm import sessionmaker as async_sessionmaker # type: ignore
20+
21+ # Try to import SQLModel's AsyncSession which has the .exec() method
22+ try :
23+ from sqlmodel .ext .asyncio .session import AsyncSession as SQLModelAsyncSession
24+
25+ DefaultAsyncSession : Type [AsyncSession ] = SQLModelAsyncSession # type: ignore
26+ except ImportError :
27+ DefaultAsyncSession : Type [AsyncSession ] = AsyncSession # type: ignore
1828
1929
20- def create_middleware_and_session_proxy ():
30+ def create_middleware_and_session_proxy () -> tuple :
2131 _Session : Optional [async_sessionmaker ] = None
2232 _session : ContextVar [Optional [AsyncSession ]] = ContextVar ("_session" , default = None )
2333 _multi_sessions_ctx : ContextVar [bool ] = ContextVar ("_multi_sessions_context" , default = False )
@@ -31,9 +41,9 @@ def __init__(
3141 self ,
3242 app : ASGIApp ,
3343 db_url : Optional [Union [str , URL ]] = None ,
34- custom_engine : Optional [Engine ] = None ,
35- engine_args : Dict = None ,
36- session_args : Dict = None ,
44+ custom_engine : Optional [AsyncEngine ] = None ,
45+ engine_args : Optional [ Dict ] = None ,
46+ session_args : Optional [ Dict ] = None ,
3747 commit_on_exit : bool = False ,
3848 ):
3949 super ().__init__ (app )
@@ -44,13 +54,18 @@ def __init__(
4454 if not custom_engine and not db_url :
4555 raise ValueError ("You need to pass a db_url or a custom_engine parameter." )
4656 if not custom_engine :
57+ if db_url is None :
58+ raise ValueError ("db_url cannot be None when custom_engine is not provided" )
4759 engine = create_async_engine (db_url , ** engine_args )
4860 else :
4961 engine = custom_engine
5062
5163 nonlocal _Session
5264 _Session = async_sessionmaker (
53- engine , class_ = AsyncSession , expire_on_commit = False , ** session_args
65+ engine ,
66+ class_ = DefaultAsyncSession ,
67+ expire_on_commit = False ,
68+ ** session_args ,
5469 )
5570
5671 async def dispatch (self , request : Request , call_next : RequestResponseEndpoint ):
@@ -115,7 +130,7 @@ async def cleanup():
115130 class DBSession (metaclass = DBSessionMeta ):
116131 def __init__ (
117132 self ,
118- session_args : Dict = None ,
133+ session_args : Optional [ Dict ] = None ,
119134 commit_on_exit : bool = False ,
120135 multi_sessions : bool = False ,
121136 ):
0 commit comments