Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 Fix AsyncSession type annotations for exec() #58

Merged
merged 8 commits into from
Oct 23, 2023
50 changes: 43 additions & 7 deletions sqlmodel/ext/asyncio/session.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload

from sqlalchemy import util
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
from sqlalchemy.ext.asyncio import engine
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
from sqlalchemy.util.concurrency import greenlet_spawn
from sqlmodel.sql.base import Executable

from ...engine.result import ScalarResult
from ...engine.result import Result, ScalarResult
from ...orm.session import Session
from ...sql.expression import Select
from ...sql.base import Executable
from ...sql.expression import Select, SelectOfScalar

_T = TypeVar("_T")
_TSelectParam = TypeVar("_TSelectParam")


class AsyncSession(_AsyncSession):
Expand Down Expand Up @@ -40,14 +40,46 @@ def __init__(
Session(bind=bind, binds=binds, **kw) # type: ignore
)

@overload
async def exec(
self,
statement: Union[Select[_T], Executable[_T]],
statement: Select[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Result[_TSelectParam]:
...

@overload
async def exec(
self,
statement: SelectOfScalar[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> ScalarResult[_TSelectParam]:
...

async def exec(
self,
statement: Union[
Select[_TSelectParam],
SelectOfScalar[_TSelectParam],
Executable[_TSelectParam],
],
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> ScalarResult[_T]:
) -> ScalarResult[_TSelectParam]:
# TODO: the documentation says execution_options accepts a dict, but only
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore
Expand All @@ -60,3 +92,7 @@ async def exec(
bind_arguments=bind_arguments,
**kw,
)

async def __aenter__(self) -> "AsyncSession":
# PyCharm does not understand TypeVar here :/
return await super().__aenter__()
2 changes: 1 addition & 1 deletion sqlmodel/orm/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from sqlalchemy.orm import Query as _Query
from sqlalchemy.orm import Session as _Session
from sqlalchemy.sql.base import Executable as _Executable
from sqlmodel.sql.expression import Select, SelectOfScalar
from typing_extensions import Literal

from ..engine.result import Result, ScalarResult
from ..sql.base import Executable
from ..sql.expression import Select, SelectOfScalar

_TSelectParam = TypeVar("_TSelectParam")

Expand Down