Skip to content

Commit

Permalink
🐛 Fix AsyncSession type annotations for exec() (#58)
Browse files Browse the repository at this point in the history
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
  • Loading branch information
Bobronium and tiangolo authored Oct 23, 2023
1 parent b8996f0 commit 9732c5a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
46 changes: 39 additions & 7 deletions sqlmodel/ext/asyncio/session.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload

from sqlalchemy import util
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
from sqlalchemy.ext.asyncio import engine
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
from sqlalchemy.util.concurrency import greenlet_spawn
from sqlmodel.sql.base import Executable

from ...engine.result import ScalarResult
from ...engine.result import Result, ScalarResult
from ...orm.session import Session
from ...sql.expression import Select
from ...sql.base import Executable
from ...sql.expression import Select, SelectOfScalar

_T = TypeVar("_T")
_TSelectParam = TypeVar("_TSelectParam")


class AsyncSession(_AsyncSession):
Expand Down Expand Up @@ -40,14 +40,46 @@ def __init__(
Session(bind=bind, binds=binds, **kw) # type: ignore
)

@overload
async def exec(
self,
statement: Union[Select[_T], Executable[_T]],
statement: Select[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Result[_TSelectParam]:
...

@overload
async def exec(
self,
statement: SelectOfScalar[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> ScalarResult[_TSelectParam]:
...

async def exec(
self,
statement: Union[
Select[_TSelectParam],
SelectOfScalar[_TSelectParam],
Executable[_TSelectParam],
],
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> ScalarResult[_T]:
) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]:
# TODO: the documentation says execution_options accepts a dict, but only
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion sqlmodel/orm/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from sqlalchemy.orm import Query as _Query
from sqlalchemy.orm import Session as _Session
from sqlalchemy.sql.base import Executable as _Executable
from sqlmodel.sql.expression import Select, SelectOfScalar
from typing_extensions import Literal

from ..engine.result import Result, ScalarResult
from ..sql.base import Executable
from ..sql.expression import Select, SelectOfScalar

_TSelectParam = TypeVar("_TSelectParam")

Expand Down

0 comments on commit 9732c5a

Please sign in to comment.