From 8ed856d32259c21297831f59a9be7ce90ceb6f4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 18 Nov 2023 12:30:37 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Upgrade=20SQLAlchemy=20to=202.0,=20?= =?UTF-8?q?including=20initial=20work=20by=20farahats9=20(#700)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mohamed Farahat Co-authored-by: Stefan Borer Co-authored-by: Peter Landry --- .github/workflows/test.yml | 2 + pyproject.toml | 14 +- scripts/generate_select.py | 6 +- sqlmodel/__init__.py | 53 ++- sqlmodel/engine/__init__.py | 0 sqlmodel/engine/create.py | 139 ------ sqlmodel/engine/result.py | 79 ---- sqlmodel/ext/asyncio/session.py | 144 ++++-- sqlmodel/main.py | 27 +- sqlmodel/orm/session.py | 109 +++-- sqlmodel/sql/expression.py | 428 +++++++++++++----- sqlmodel/sql/expression.py.jinja2 | 255 ++++++++++- sqlmodel/sql/sqltypes.py | 8 +- .../test_delete/test_tutorial001.py | 6 +- .../test_limit_and_offset/test_tutorial001.py | 6 +- .../test_multiple_models/test_tutorial001.py | 4 +- .../test_multiple_models/test_tutorial002.py | 4 +- .../test_read_one/test_tutorial001.py | 4 +- .../test_relationships/test_tutorial001.py | 6 +- .../test_response_model/test_tutorial001.py | 4 +- .../test_tutorial001.py | 6 +- .../test_simple_hero_api/test_tutorial001.py | 4 +- .../test_teams/test_tutorial001.py | 6 +- .../test_update/test_tutorial001.py | 6 +- 24 files changed, 809 insertions(+), 511 deletions(-) delete mode 100644 sqlmodel/engine/__init__.py delete mode 100644 sqlmodel/engine/create.py delete mode 100644 sqlmodel/engine/result.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 201abc7c22..c3b07f484e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -56,6 +56,8 @@ jobs: if: steps.cache.outputs.cache-hit != 'true' run: python -m poetry install - name: Lint + # Do not run on Python 3.7 as mypy behaves differently + if: matrix.python-version != '3.7' run: python -m poetry run bash scripts/lint.sh - run: mkdir coverage - name: Test diff --git a/pyproject.toml b/pyproject.toml index 23fa79bf31..515bbaf66c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,9 +31,8 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.7" -SQLAlchemy = ">=1.4.36,<2.0.0" +SQLAlchemy = ">=2.0.0,<2.1.0" pydantic = "^1.9.0" -sqlalchemy2-stubs = {version = "*", allow-prereleases = true} [tool.poetry.group.dev.dependencies] pytest = "^7.0.1" @@ -45,9 +44,10 @@ pillow = "^9.3.0" cairosvg = "^2.5.2" mdx-include = "^1.4.1" coverage = {extras = ["toml"], version = ">=6.2,<8.0"} -fastapi = "^0.68.1" -requests = "^2.26.0" +fastapi = "^0.103.2" ruff = "^0.1.2" +# For FastAPI tests +httpx = "0.24.1" [build-system] requires = ["poetry-core"] @@ -80,6 +80,12 @@ strict = true module = "sqlmodel.sql.expression" warn_unused_ignores = false +[[tool.mypy.overrides]] +module = "docs_src.*" +disallow_incomplete_defs = false +disallow_untyped_defs = false +disallow_untyped_calls = false + [tool.ruff] select = [ "E", # pycodestyle errors diff --git a/scripts/generate_select.py b/scripts/generate_select.py index f8aa30023f..88e0e0a997 100644 --- a/scripts/generate_select.py +++ b/scripts/generate_select.py @@ -34,9 +34,9 @@ class Arg(BaseModel): arg = Arg(name=f"entity_{i}", annotation=t_var) ret_type = t_var else: - t_type = f"_TModel_{i}" - t_var = f"Type[{t_type}]" - arg = Arg(name=f"entity_{i}", annotation=t_var) + t_type = f"_T{i}" + t_var = f"_TCCA[{t_type}]" + arg = Arg(name=f"__ent{i}", annotation=t_var) ret_type = t_type args.append(arg) return_types.append(ret_type) diff --git a/sqlmodel/__init__.py b/sqlmodel/__init__.py index 495ac9c8a8..e943257165 100644 --- a/sqlmodel/__init__.py +++ b/sqlmodel/__init__.py @@ -1,9 +1,12 @@ __version__ = "0.0.11" # Re-export from SQLAlchemy +from sqlalchemy.engine import create_engine as create_engine from sqlalchemy.engine import create_mock_engine as create_mock_engine from sqlalchemy.engine import engine_from_config as engine_from_config from sqlalchemy.inspection import inspect as inspect +from sqlalchemy.pool import QueuePool as QueuePool +from sqlalchemy.pool import StaticPool as StaticPool from sqlalchemy.schema import BLANK_SCHEMA as BLANK_SCHEMA from sqlalchemy.schema import DDL as DDL from sqlalchemy.schema import CheckConstraint as CheckConstraint @@ -21,7 +24,6 @@ from sqlalchemy.schema import PrimaryKeyConstraint as PrimaryKeyConstraint from sqlalchemy.schema import Sequence as Sequence from sqlalchemy.schema import Table as Table -from sqlalchemy.schema import ThreadLocalMetaData as ThreadLocalMetaData from sqlalchemy.schema import UniqueConstraint as UniqueConstraint from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT from sqlalchemy.sql import ( @@ -32,26 +34,14 @@ LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL, ) from sqlalchemy.sql import alias as alias -from sqlalchemy.sql import all_ as all_ -from sqlalchemy.sql import and_ as and_ -from sqlalchemy.sql import any_ as any_ -from sqlalchemy.sql import asc as asc -from sqlalchemy.sql import between as between from sqlalchemy.sql import bindparam as bindparam -from sqlalchemy.sql import case as case -from sqlalchemy.sql import cast as cast -from sqlalchemy.sql import collate as collate from sqlalchemy.sql import column as column from sqlalchemy.sql import delete as delete -from sqlalchemy.sql import desc as desc -from sqlalchemy.sql import distinct as distinct from sqlalchemy.sql import except_ as except_ from sqlalchemy.sql import except_all as except_all from sqlalchemy.sql import exists as exists -from sqlalchemy.sql import extract as extract from sqlalchemy.sql import false as false from sqlalchemy.sql import func as func -from sqlalchemy.sql import funcfilter as funcfilter from sqlalchemy.sql import insert as insert from sqlalchemy.sql import intersect as intersect from sqlalchemy.sql import intersect_all as intersect_all @@ -61,28 +51,19 @@ from sqlalchemy.sql import literal as literal from sqlalchemy.sql import literal_column as literal_column from sqlalchemy.sql import modifier as modifier -from sqlalchemy.sql import not_ as not_ from sqlalchemy.sql import null as null -from sqlalchemy.sql import nulls_first as nulls_first -from sqlalchemy.sql import nulls_last as nulls_last from sqlalchemy.sql import nullsfirst as nullsfirst from sqlalchemy.sql import nullslast as nullslast -from sqlalchemy.sql import or_ as or_ from sqlalchemy.sql import outerjoin as outerjoin from sqlalchemy.sql import outparam as outparam -from sqlalchemy.sql import over as over -from sqlalchemy.sql import subquery as subquery from sqlalchemy.sql import table as table from sqlalchemy.sql import tablesample as tablesample from sqlalchemy.sql import text as text from sqlalchemy.sql import true as true -from sqlalchemy.sql import tuple_ as tuple_ -from sqlalchemy.sql import type_coerce as type_coerce from sqlalchemy.sql import union as union from sqlalchemy.sql import union_all as union_all from sqlalchemy.sql import update as update from sqlalchemy.sql import values as values -from sqlalchemy.sql import within_group as within_group from sqlalchemy.types import ARRAY as ARRAY from sqlalchemy.types import BIGINT as BIGINT from sqlalchemy.types import BINARY as BINARY @@ -93,6 +74,8 @@ from sqlalchemy.types import DATE as DATE from sqlalchemy.types import DATETIME as DATETIME from sqlalchemy.types import DECIMAL as DECIMAL +from sqlalchemy.types import DOUBLE as DOUBLE +from sqlalchemy.types import DOUBLE_PRECISION as DOUBLE_PRECISION from sqlalchemy.types import FLOAT as FLOAT from sqlalchemy.types import INT as INT from sqlalchemy.types import INTEGER as INTEGER @@ -105,12 +88,14 @@ from sqlalchemy.types import TEXT as TEXT from sqlalchemy.types import TIME as TIME from sqlalchemy.types import TIMESTAMP as TIMESTAMP +from sqlalchemy.types import UUID as UUID from sqlalchemy.types import VARBINARY as VARBINARY from sqlalchemy.types import VARCHAR as VARCHAR from sqlalchemy.types import BigInteger as BigInteger from sqlalchemy.types import Boolean as Boolean from sqlalchemy.types import Date as Date from sqlalchemy.types import DateTime as DateTime +from sqlalchemy.types import Double as Double from sqlalchemy.types import Enum as Enum from sqlalchemy.types import Float as Float from sqlalchemy.types import Integer as Integer @@ -122,16 +107,38 @@ from sqlalchemy.types import String as String from sqlalchemy.types import Text as Text from sqlalchemy.types import Time as Time +from sqlalchemy.types import TupleType as TupleType from sqlalchemy.types import TypeDecorator as TypeDecorator from sqlalchemy.types import Unicode as Unicode from sqlalchemy.types import UnicodeText as UnicodeText +from sqlalchemy.types import Uuid as Uuid # From SQLModel, modifications of SQLAlchemy or equivalents of Pydantic -from .engine.create import create_engine as create_engine from .main import Field as Field from .main import Relationship as Relationship from .main import SQLModel as SQLModel from .orm.session import Session as Session +from .sql.expression import all_ as all_ +from .sql.expression import and_ as and_ +from .sql.expression import any_ as any_ +from .sql.expression import asc as asc +from .sql.expression import between as between +from .sql.expression import case as case +from .sql.expression import cast as cast from .sql.expression import col as col +from .sql.expression import collate as collate +from .sql.expression import desc as desc +from .sql.expression import distinct as distinct +from .sql.expression import extract as extract +from .sql.expression import funcfilter as funcfilter +from .sql.expression import not_ as not_ +from .sql.expression import nulls_first as nulls_first +from .sql.expression import nulls_last as nulls_last +from .sql.expression import or_ as or_ +from .sql.expression import over as over from .sql.expression import select as select +from .sql.expression import tuple_ as tuple_ +from .sql.expression import type_coerce as type_coerce +from .sql.expression import within_group as within_group +from .sql.sqltypes import GUID as GUID from .sql.sqltypes import AutoString as AutoString diff --git a/sqlmodel/engine/__init__.py b/sqlmodel/engine/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/sqlmodel/engine/create.py b/sqlmodel/engine/create.py deleted file mode 100644 index b2d567b1b1..0000000000 --- a/sqlmodel/engine/create.py +++ /dev/null @@ -1,139 +0,0 @@ -import json -import sqlite3 -from typing import Any, Callable, Dict, List, Optional, Type, Union - -from sqlalchemy import create_engine as _create_engine -from sqlalchemy.engine.url import URL -from sqlalchemy.future import Engine as _FutureEngine -from sqlalchemy.pool import Pool -from typing_extensions import Literal, TypedDict - -from ..default import Default, _DefaultPlaceholder - -# Types defined in sqlalchemy2-stubs, but can't be imported, so re-define here - -_Debug = Literal["debug"] - -_IsolationLevel = Literal[ - "SERIALIZABLE", - "REPEATABLE READ", - "READ COMMITTED", - "READ UNCOMMITTED", - "AUTOCOMMIT", -] -_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat"] -_ResetOnReturn = Literal["rollback", "commit"] - - -class _SQLiteConnectArgs(TypedDict, total=False): - timeout: float - detect_types: Any - isolation_level: Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]] - check_same_thread: bool - factory: Type[sqlite3.Connection] - cached_statements: int - uri: bool - - -_ConnectArgs = Union[_SQLiteConnectArgs, Dict[str, Any]] - - -# Re-define create_engine to have by default future=True, and assume that's what is used -# Also show the default values used for each parameter, but don't set them unless -# explicitly passed as arguments by the user to prevent errors. E.g. SQLite doesn't -# support pool connection arguments. -def create_engine( - url: Union[str, URL], - *, - connect_args: _ConnectArgs = Default({}), # type: ignore - echo: Union[bool, _Debug] = Default(False), - echo_pool: Union[bool, _Debug] = Default(False), - enable_from_linting: bool = Default(True), - encoding: str = Default("utf-8"), - execution_options: Dict[Any, Any] = Default({}), - future: bool = True, - hide_parameters: bool = Default(False), - implicit_returning: bool = Default(True), - isolation_level: Optional[_IsolationLevel] = Default(None), - json_deserializer: Callable[..., Any] = Default(json.loads), - json_serializer: Callable[..., Any] = Default(json.dumps), - label_length: Optional[int] = Default(None), - logging_name: Optional[str] = Default(None), - max_identifier_length: Optional[int] = Default(None), - max_overflow: int = Default(10), - module: Optional[Any] = Default(None), - paramstyle: Optional[_ParamStyle] = Default(None), - pool: Optional[Pool] = Default(None), - poolclass: Optional[Type[Pool]] = Default(None), - pool_logging_name: Optional[str] = Default(None), - pool_pre_ping: bool = Default(False), - pool_size: int = Default(5), - pool_recycle: int = Default(-1), - pool_reset_on_return: Optional[_ResetOnReturn] = Default("rollback"), - pool_timeout: float = Default(30), - pool_use_lifo: bool = Default(False), - plugins: Optional[List[str]] = Default(None), - query_cache_size: Optional[int] = Default(None), - **kwargs: Any, -) -> _FutureEngine: - current_kwargs: Dict[str, Any] = { - "future": future, - } - if not isinstance(echo, _DefaultPlaceholder): - current_kwargs["echo"] = echo - if not isinstance(echo_pool, _DefaultPlaceholder): - current_kwargs["echo_pool"] = echo_pool - if not isinstance(enable_from_linting, _DefaultPlaceholder): - current_kwargs["enable_from_linting"] = enable_from_linting - if not isinstance(connect_args, _DefaultPlaceholder): - current_kwargs["connect_args"] = connect_args - if not isinstance(encoding, _DefaultPlaceholder): - current_kwargs["encoding"] = encoding - if not isinstance(execution_options, _DefaultPlaceholder): - current_kwargs["execution_options"] = execution_options - if not isinstance(hide_parameters, _DefaultPlaceholder): - current_kwargs["hide_parameters"] = hide_parameters - if not isinstance(implicit_returning, _DefaultPlaceholder): - current_kwargs["implicit_returning"] = implicit_returning - if not isinstance(isolation_level, _DefaultPlaceholder): - current_kwargs["isolation_level"] = isolation_level - if not isinstance(json_deserializer, _DefaultPlaceholder): - current_kwargs["json_deserializer"] = json_deserializer - if not isinstance(json_serializer, _DefaultPlaceholder): - current_kwargs["json_serializer"] = json_serializer - if not isinstance(label_length, _DefaultPlaceholder): - current_kwargs["label_length"] = label_length - if not isinstance(logging_name, _DefaultPlaceholder): - current_kwargs["logging_name"] = logging_name - if not isinstance(max_identifier_length, _DefaultPlaceholder): - current_kwargs["max_identifier_length"] = max_identifier_length - if not isinstance(max_overflow, _DefaultPlaceholder): - current_kwargs["max_overflow"] = max_overflow - if not isinstance(module, _DefaultPlaceholder): - current_kwargs["module"] = module - if not isinstance(paramstyle, _DefaultPlaceholder): - current_kwargs["paramstyle"] = paramstyle - if not isinstance(pool, _DefaultPlaceholder): - current_kwargs["pool"] = pool - if not isinstance(poolclass, _DefaultPlaceholder): - current_kwargs["poolclass"] = poolclass - if not isinstance(pool_logging_name, _DefaultPlaceholder): - current_kwargs["pool_logging_name"] = pool_logging_name - if not isinstance(pool_pre_ping, _DefaultPlaceholder): - current_kwargs["pool_pre_ping"] = pool_pre_ping - if not isinstance(pool_size, _DefaultPlaceholder): - current_kwargs["pool_size"] = pool_size - if not isinstance(pool_recycle, _DefaultPlaceholder): - current_kwargs["pool_recycle"] = pool_recycle - if not isinstance(pool_reset_on_return, _DefaultPlaceholder): - current_kwargs["pool_reset_on_return"] = pool_reset_on_return - if not isinstance(pool_timeout, _DefaultPlaceholder): - current_kwargs["pool_timeout"] = pool_timeout - if not isinstance(pool_use_lifo, _DefaultPlaceholder): - current_kwargs["pool_use_lifo"] = pool_use_lifo - if not isinstance(plugins, _DefaultPlaceholder): - current_kwargs["plugins"] = plugins - if not isinstance(query_cache_size, _DefaultPlaceholder): - current_kwargs["query_cache_size"] = query_cache_size - current_kwargs.update(kwargs) - return _create_engine(url, **current_kwargs) # type: ignore diff --git a/sqlmodel/engine/result.py b/sqlmodel/engine/result.py deleted file mode 100644 index 7a25422227..0000000000 --- a/sqlmodel/engine/result.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import Generic, Iterator, List, Optional, TypeVar - -from sqlalchemy.engine.result import Result as _Result -from sqlalchemy.engine.result import ScalarResult as _ScalarResult - -_T = TypeVar("_T") - - -class ScalarResult(_ScalarResult, Generic[_T]): - def all(self) -> List[_T]: - return super().all() - - def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: - return super().partitions(size) - - def fetchall(self) -> List[_T]: - return super().fetchall() - - def fetchmany(self, size: Optional[int] = None) -> List[_T]: - return super().fetchmany(size) - - def __iter__(self) -> Iterator[_T]: - return super().__iter__() - - def __next__(self) -> _T: - return super().__next__() # type: ignore - - def first(self) -> Optional[_T]: - return super().first() - - def one_or_none(self) -> Optional[_T]: - return super().one_or_none() - - def one(self) -> _T: - return super().one() # type: ignore - - -class Result(_Result, Generic[_T]): - def scalars(self, index: int = 0) -> ScalarResult[_T]: - return super().scalars(index) # type: ignore - - def __iter__(self) -> Iterator[_T]: # type: ignore - return super().__iter__() # type: ignore - - def __next__(self) -> _T: # type: ignore - return super().__next__() # type: ignore - - def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: # type: ignore - return super().partitions(size) # type: ignore - - def fetchall(self) -> List[_T]: # type: ignore - return super().fetchall() # type: ignore - - def fetchone(self) -> Optional[_T]: # type: ignore - return super().fetchone() # type: ignore - - def fetchmany(self, size: Optional[int] = None) -> List[_T]: # type: ignore - return super().fetchmany() # type: ignore - - def all(self) -> List[_T]: # type: ignore - return super().all() # type: ignore - - def first(self) -> Optional[_T]: # type: ignore - return super().first() # type: ignore - - def one_or_none(self) -> Optional[_T]: # type: ignore - return super().one_or_none() # type: ignore - - def scalar_one(self) -> _T: - return super().scalar_one() # type: ignore - - def scalar_one_or_none(self) -> Optional[_T]: - return super().scalar_one_or_none() - - def one(self) -> _T: # type: ignore - return super().one() # type: ignore - - def scalar(self) -> Optional[_T]: - return super().scalar() diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index f500c44dc2..012d8ef5e4 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -1,45 +1,38 @@ -from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload +from typing import ( + Any, + Dict, + Mapping, + Optional, + Sequence, + Type, + TypeVar, + Union, + cast, + overload, +) from sqlalchemy import util +from sqlalchemy.engine.interfaces import _CoreAnyExecuteParams +from sqlalchemy.engine.result import Result, ScalarResult, TupleResult from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession -from sqlalchemy.ext.asyncio import engine -from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine +from sqlalchemy.ext.asyncio.result import _ensure_sync_result +from sqlalchemy.ext.asyncio.session import _EXECUTE_OPTIONS +from sqlalchemy.orm._typing import OrmExecuteOptionsParameter +from sqlalchemy.sql.base import Executable as _Executable from sqlalchemy.util.concurrency import greenlet_spawn +from typing_extensions import deprecated -from ...engine.result import Result, ScalarResult from ...orm.session import Session from ...sql.base import Executable from ...sql.expression import Select, SelectOfScalar -_TSelectParam = TypeVar("_TSelectParam") +_TSelectParam = TypeVar("_TSelectParam", bound=Any) class AsyncSession(_AsyncSession): + sync_session_class: Type[Session] = Session sync_session: Session - def __init__( - self, - bind: Optional[Union[AsyncConnection, AsyncEngine]] = None, - binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None, - **kw: Any, - ): - # All the same code of the original AsyncSession - kw["future"] = True - if bind: - self.bind = bind - bind = engine._get_sync_engine_or_connection(bind) # type: ignore - - if binds: - self.binds = binds - binds = { - key: engine._get_sync_engine_or_connection(b) # type: ignore - for key, b in binds.items() - } - - self.sync_session = self._proxied = self._assign_proxied( # type: ignore - Session(bind=bind, binds=binds, **kw) # type: ignore - ) - @overload async def exec( self, @@ -47,11 +40,10 @@ async def exec( *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, + bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - **kw: Any, - ) -> Result[_TSelectParam]: + ) -> TupleResult[_TSelectParam]: ... @overload @@ -61,10 +53,9 @@ async def exec( *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, + bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - **kw: Any, ) -> ScalarResult[_TSelectParam]: ... @@ -75,20 +66,87 @@ async def exec( SelectOfScalar[_TSelectParam], Executable[_TSelectParam], ], + *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, - execution_options: Mapping[Any, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, - **kw: Any, - ) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]: - # TODO: the documentation says execution_options accepts a dict, but only - # util.immutabledict has the union() method. Is this a bug in SQLAlchemy? - execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore - - return await greenlet_spawn( + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam]]: + if execution_options: + execution_options = util.immutabledict(execution_options).union( + _EXECUTE_OPTIONS + ) + else: + execution_options = _EXECUTE_OPTIONS + + result = await greenlet_spawn( self.sync_session.exec, statement, params=params, execution_options=execution_options, bind_arguments=bind_arguments, - **kw, + _parent_execute_state=_parent_execute_state, + _add_event=_add_event, + ) + result_value = await _ensure_sync_result( + cast(Result[_TSelectParam], result), self.exec + ) + return result_value # type: ignore + + @deprecated( + """ + 🚨 You probably want to use `session.exec()` instead of `session.execute()`. + + This is the original SQLAlchemy `session.execute()` method that returns objects + of type `Row`, and that you have to call `scalars()` to get the model objects. + + For example: + + ```Python + heroes = await session.execute(select(Hero)).scalars().all() + ``` + + instead you could use `exec()`: + + ```Python + heroes = await session.exec(select(Hero)).all() + ``` + """ + ) + async def execute( # type: ignore + self, + statement: _Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + """ + 🚨 You probably want to use `session.exec()` instead of `session.execute()`. + + This is the original SQLAlchemy `session.execute()` method that returns objects + of type `Row`, and that you have to call `scalars()` to get the model objects. + + For example: + + ```Python + heroes = await session.execute(select(Hero)).scalars().all() + ``` + + instead you could use `exec()`: + + ```Python + heroes = await session.exec(select(Hero)).all() + ``` + """ + return await super().execute( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + _parent_execute_state=_parent_execute_state, + _add_event=_add_event, ) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 2b69dd2a75..c30af5779f 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -45,12 +45,19 @@ inspect, ) from sqlalchemy import Enum as sa_Enum -from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship +from sqlalchemy.orm import ( + Mapped, + RelationshipProperty, + declared_attr, + registry, + relationship, +) from sqlalchemy.orm.attributes import set_attribute from sqlalchemy.orm.decl_api import DeclarativeMeta from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time +from typing_extensions import get_origin from .sql.sqltypes import GUID, AutoString @@ -483,7 +490,16 @@ def __init__( # over anything else, use that and continue with the next attribute setattr(cls, rel_name, rel_info.sa_relationship) # Fix #315 continue - ann = cls.__annotations__[rel_name] + raw_ann = cls.__annotations__[rel_name] + origin = get_origin(raw_ann) + if origin is Mapped: + ann = raw_ann.__args__[0] + else: + ann = raw_ann + # Plain forward references, for models not yet defined, are not + # handled well by SQLAlchemy without Mapped, so, wrap the + # annotations in Mapped here + cls.__annotations__[rel_name] = Mapped[ann] # type: ignore[valid-type] temp_field = ModelField.infer( name=rel_name, value=rel_info, @@ -511,9 +527,7 @@ def __init__( rel_args.extend(rel_info.sa_relationship_args) if rel_info.sa_relationship_kwargs: rel_kwargs.update(rel_info.sa_relationship_kwargs) - rel_value: RelationshipProperty = relationship( # type: ignore - relationship_to, *rel_args, **rel_kwargs - ) + rel_value = relationship(relationship_to, *rel_args, **rel_kwargs) setattr(cls, rel_name, rel_value) # Fix #315 # SQLAlchemy no longer uses dict_ # Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77 @@ -642,6 +656,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore __name__: ClassVar[str] metadata: ClassVar[MetaData] + __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six class Config: orm_mode = True @@ -685,7 +700,7 @@ def __setattr__(self, name: str, value: Any) -> None: return else: # Set in SQLAlchemy, before Pydantic to trigger events and updates - if getattr(self.__config__, "table", False) and is_instrumented(self, name): + if getattr(self.__config__, "table", False) and is_instrumented(self, name): # type: ignore set_attribute(self, name, value) # Set in Pydantic model to trigger possible validation changes, only for # non relationship values diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index 0c70c290ae..6050d5fbc1 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -1,16 +1,27 @@ -from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, Union, overload +from typing import ( + Any, + Dict, + Mapping, + Optional, + Sequence, + TypeVar, + Union, + overload, +) from sqlalchemy import util +from sqlalchemy.engine.interfaces import _CoreAnyExecuteParams +from sqlalchemy.engine.result import Result, ScalarResult, TupleResult from sqlalchemy.orm import Query as _Query from sqlalchemy.orm import Session as _Session +from sqlalchemy.orm._typing import OrmExecuteOptionsParameter +from sqlalchemy.sql._typing import _ColumnsClauseArgument from sqlalchemy.sql.base import Executable as _Executable -from typing_extensions import Literal +from sqlmodel.sql.base import Executable +from sqlmodel.sql.expression import Select, SelectOfScalar +from typing_extensions import deprecated -from ..engine.result import Result, ScalarResult -from ..sql.base import Executable -from ..sql.expression import Select, SelectOfScalar - -_TSelectParam = TypeVar("_TSelectParam") +_TSelectParam = TypeVar("_TSelectParam", bound=Any) class Session(_Session): @@ -21,11 +32,10 @@ def exec( *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, + bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - **kw: Any, - ) -> Result[_TSelectParam]: + ) -> TupleResult[_TSelectParam]: ... @overload @@ -35,10 +45,9 @@ def exec( *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, + bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - **kw: Any, ) -> ScalarResult[_TSelectParam]: ... @@ -52,11 +61,10 @@ def exec( *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, + bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - **kw: Any, - ) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]: + ) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam]]: results = super().execute( statement, params=params, @@ -64,21 +72,40 @@ def exec( bind_arguments=bind_arguments, _parent_execute_state=_parent_execute_state, _add_event=_add_event, - **kw, ) if isinstance(statement, SelectOfScalar): - return results.scalars() # type: ignore + return results.scalars() return results # type: ignore - def execute( + @deprecated( + """ + 🚨 You probably want to use `session.exec()` instead of `session.execute()`. + + This is the original SQLAlchemy `session.execute()` method that returns objects + of type `Row`, and that you have to call `scalars()` to get the model objects. + + For example: + + ```Python + heroes = session.execute(select(Hero)).scalars().all() + ``` + + instead you could use `exec()`: + + ```Python + heroes = session.exec(select(Hero)).all() + ``` + """ + ) + def execute( # type: ignore self, statement: _Executable, - params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, - execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - **kw: Any, ) -> Result[Any]: """ 🚨 You probably want to use `session.exec()` instead of `session.execute()`. @@ -98,17 +125,16 @@ def execute( heroes = session.exec(select(Hero)).all() ``` """ - return super().execute( # type: ignore + return super().execute( statement, params=params, execution_options=execution_options, bind_arguments=bind_arguments, _parent_execute_state=_parent_execute_state, _add_event=_add_event, - **kw, ) - def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]": + @deprecated( """ 🚨 You probably want to use `session.exec()` instead of `session.query()`. @@ -118,24 +144,17 @@ def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]": Or otherwise you might want to use `session.execute()` instead of `session.query()`. """ - return super().query(*entities, **kwargs) + ) + def query( # type: ignore + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any + ) -> _Query[Any]: + """ + 🚨 You probably want to use `session.exec()` instead of `session.query()`. - def get( - self, - entity: Type[_TSelectParam], - ident: Any, - options: Optional[Sequence[Any]] = None, - populate_existing: bool = False, - with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None, - identity_token: Optional[Any] = None, - execution_options: Optional[Mapping[Any, Any]] = util.EMPTY_DICT, - ) -> Optional[_TSelectParam]: - return super().get( - entity, - ident, - options=options, - populate_existing=populate_existing, - with_for_update=with_for_update, - identity_token=identity_token, - execution_options=execution_options, - ) + `session.exec()` is SQLModel's own short version with increased type + annotations. + + Or otherwise you might want to use `session.execute()` instead of + `session.query()`. + """ + return super().query(*entities, **kwargs) diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index 264e39cba7..a8a572501c 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -2,10 +2,10 @@ from datetime import datetime from typing import ( - TYPE_CHECKING, Any, - Generic, + Iterable, Mapping, + Optional, Sequence, Tuple, Type, @@ -15,15 +15,223 @@ ) from uuid import UUID -from sqlalchemy import Column -from sqlalchemy.orm import InstrumentedAttribute -from sqlalchemy.sql.elements import ColumnClause +import sqlalchemy +from sqlalchemy import ( + Column, + ColumnElement, + Extract, + FunctionElement, + FunctionFilter, + Label, + Over, + TypeCoerce, + WithinGroup, +) +from sqlalchemy.orm import InstrumentedAttribute, Mapped +from sqlalchemy.sql._typing import ( + _ColumnExpressionArgument, + _ColumnExpressionOrLiteralArgument, + _ColumnExpressionOrStrLabelArgument, +) +from sqlalchemy.sql.elements import ( + BinaryExpression, + Case, + Cast, + CollectionAggregate, + ColumnClause, + SQLCoreOperations, + TryCast, + UnaryExpression, +) from sqlalchemy.sql.expression import Select as _Select +from sqlalchemy.sql.roles import TypedColumnsClauseRole +from sqlalchemy.sql.type_api import TypeEngine +from typing_extensions import Literal, Self + +_T = TypeVar("_T") + +_TypeEngineArgument = Union[Type[TypeEngine[_T]], TypeEngine[_T]] + +# Redefine operatos that would only take a column expresion to also take the (virtual) +# types of Pydantic models, e.g. str instead of only Mapped[str]. + + +def all_(expr: Union[_ColumnExpressionArgument[_T], _T]) -> CollectionAggregate[bool]: + return sqlalchemy.all_(expr) # type: ignore[arg-type] + + +def and_( + initial_clause: Union[Literal[True], _ColumnExpressionArgument[bool], bool], + *clauses: Union[_ColumnExpressionArgument[bool], bool], +) -> ColumnElement[bool]: + return sqlalchemy.and_(initial_clause, *clauses) # type: ignore[arg-type] + + +def any_(expr: Union[_ColumnExpressionArgument[_T], _T]) -> CollectionAggregate[bool]: + return sqlalchemy.any_(expr) # type: ignore[arg-type] + + +def asc( + column: Union[_ColumnExpressionOrStrLabelArgument[_T], _T], +) -> UnaryExpression[_T]: + return sqlalchemy.asc(column) # type: ignore[arg-type] + + +def collate( + expression: Union[_ColumnExpressionArgument[str], str], collation: str +) -> BinaryExpression[str]: + return sqlalchemy.collate(expression, collation) # type: ignore[arg-type] + + +def between( + expr: Union[_ColumnExpressionOrLiteralArgument[_T], _T], + lower_bound: Any, + upper_bound: Any, + symmetric: bool = False, +) -> BinaryExpression[bool]: + return sqlalchemy.between(expr, lower_bound, upper_bound, symmetric=symmetric) # type: ignore[arg-type] + + +def not_(clause: Union[_ColumnExpressionArgument[_T], _T]) -> ColumnElement[_T]: + return sqlalchemy.not_(clause) # type: ignore[arg-type] + + +def case( + *whens: Union[ + Tuple[Union[_ColumnExpressionArgument[bool], bool], Any], Mapping[Any, Any] + ], + value: Optional[Any] = None, + else_: Optional[Any] = None, +) -> Case[Any]: + return sqlalchemy.case(*whens, value=value, else_=else_) # type: ignore[arg-type] + + +def cast( + expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any], + type_: "_TypeEngineArgument[_T]", +) -> Cast[_T]: + return sqlalchemy.cast(expression, type_) # type: ignore[arg-type] + + +def try_cast( + expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any], + type_: "_TypeEngineArgument[_T]", +) -> TryCast[_T]: + return sqlalchemy.try_cast(expression, type_) # type: ignore[arg-type] + + +def desc( + column: Union[_ColumnExpressionOrStrLabelArgument[_T], _T], +) -> UnaryExpression[_T]: + return sqlalchemy.desc(column) # type: ignore[arg-type] + + +def distinct(expr: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]: + return sqlalchemy.distinct(expr) # type: ignore[arg-type] + + +def bitwise_not(expr: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]: + return sqlalchemy.bitwise_not(expr) # type: ignore[arg-type] + + +def extract(field: str, expr: Union[_ColumnExpressionArgument[Any], Any]) -> Extract: + return sqlalchemy.extract(field, expr) # type: ignore[arg-type] + + +def funcfilter( + func: FunctionElement[_T], *criterion: Union[_ColumnExpressionArgument[bool], bool] +) -> FunctionFilter[_T]: + return sqlalchemy.funcfilter(func, *criterion) # type: ignore[arg-type] + + +def label( + name: str, + element: Union[_ColumnExpressionArgument[_T], _T], + type_: Optional["_TypeEngineArgument[_T]"] = None, +) -> Label[_T]: + return sqlalchemy.label(name, element, type_=type_) # type: ignore[arg-type] + + +def nulls_first( + column: Union[_ColumnExpressionArgument[_T], _T] +) -> UnaryExpression[_T]: + return sqlalchemy.nulls_first(column) # type: ignore[arg-type] + + +def nulls_last(column: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]: + return sqlalchemy.nulls_last(column) # type: ignore[arg-type] + + +def or_( # type: ignore[empty-body] + initial_clause: Union[Literal[False], _ColumnExpressionArgument[bool], bool], + *clauses: Union[_ColumnExpressionArgument[bool], bool], +) -> ColumnElement[bool]: + return sqlalchemy.or_(initial_clause, *clauses) # type: ignore[arg-type] + + +def over( + element: FunctionElement[_T], + partition_by: Optional[ + Union[ + Iterable[Union[_ColumnExpressionArgument[Any], Any]], + _ColumnExpressionArgument[Any], + Any, + ] + ] = None, + order_by: Optional[ + Union[ + Iterable[Union[_ColumnExpressionArgument[Any], Any]], + _ColumnExpressionArgument[Any], + Any, + ] + ] = None, + range_: Optional[Tuple[Optional[int], Optional[int]]] = None, + rows: Optional[Tuple[Optional[int], Optional[int]]] = None, +) -> Over[_T]: + return sqlalchemy.over( + element, partition_by=partition_by, order_by=order_by, range_=range_, rows=rows + ) # type: ignore[arg-type] + + +def tuple_( + *clauses: Union[_ColumnExpressionArgument[Any], Any], + types: Optional[Sequence["_TypeEngineArgument[Any]"]] = None, +) -> Tuple[Any, ...]: + return sqlalchemy.tuple_(*clauses, types=types) # type: ignore[return-value] + + +def type_coerce( + expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any], + type_: "_TypeEngineArgument[_T]", +) -> TypeCoerce[_T]: + return sqlalchemy.type_coerce(expression, type_) # type: ignore[arg-type] + + +def within_group( + element: FunctionElement[_T], *order_by: Union[_ColumnExpressionArgument[Any], Any] +) -> WithinGroup[_T]: + return sqlalchemy.within_group(element, *order_by) # type: ignore[arg-type] + + +# Separate this class in SelectBase, Select, and SelectOfScalar so that they can share +# where and having without having type overlap incompatibility in session.exec(). +class SelectBase(_Select[Tuple[_T]]): + inherit_cache = True + + def where(self, *whereclause: Union[_ColumnExpressionArgument[bool], bool]) -> Self: + """Return a new `Select` construct with the given expression added to + its `WHERE` clause, joined to the existing clause via `AND`, if any. + """ + return super().where(*whereclause) # type: ignore[arg-type] -_TSelect = TypeVar("_TSelect") + def having(self, *having: Union[_ColumnExpressionArgument[bool], bool]) -> Self: + """Return a new `Select` construct with the given expression added to + its `HAVING` clause, joined to the existing clause via `AND`, if any. + """ + return super().having(*having) # type: ignore[arg-type] -class Select(_Select, Generic[_TSelect]): +class Select(SelectBase[_T]): inherit_cache = True @@ -31,12 +239,15 @@ class Select(_Select, Generic[_TSelect]): # purpose. This is the same as a normal SQLAlchemy Select class where there's only one # entity, so the result will be converted to a scalar by default. This way writing # for loops on the results will feel natural. -class SelectOfScalar(_Select, Generic[_TSelect]): +class SelectOfScalar(SelectBase[_T]): inherit_cache = True -if TYPE_CHECKING: # pragma: no cover - from ..main import SQLModel +_TCCA = Union[ + TypedColumnsClauseRole[_T], + SQLCoreOperations[_T], + Type[_T], +] # Generated TypeVars start @@ -56,7 +267,7 @@ class SelectOfScalar(_Select, Generic[_TSelect]): None, ) -_TModel_0 = TypeVar("_TModel_0", bound="SQLModel") +_T0 = TypeVar("_T0") _TScalar_1 = TypeVar( @@ -74,7 +285,7 @@ class SelectOfScalar(_Select, Generic[_TSelect]): None, ) -_TModel_1 = TypeVar("_TModel_1", bound="SQLModel") +_T1 = TypeVar("_T1") _TScalar_2 = TypeVar( @@ -92,7 +303,7 @@ class SelectOfScalar(_Select, Generic[_TSelect]): None, ) -_TModel_2 = TypeVar("_TModel_2", bound="SQLModel") +_T2 = TypeVar("_T2") _TScalar_3 = TypeVar( @@ -110,19 +321,19 @@ class SelectOfScalar(_Select, Generic[_TSelect]): None, ) -_TModel_3 = TypeVar("_TModel_3", bound="SQLModel") +_T3 = TypeVar("_T3") # Generated TypeVars end @overload -def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore +def select(__ent0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore ... @overload -def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore +def select(__ent0: _TCCA[_T0]) -> SelectOfScalar[_T0]: ... @@ -133,7 +344,6 @@ def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1]]: ... @@ -141,27 +351,24 @@ def select( # type: ignore @overload def select( # type: ignore entity_0: _TScalar_0, - entity_1: Type[_TModel_1], - **kw: Any, -) -> Select[Tuple[_TScalar_0, _TModel_1]]: + __ent1: _TCCA[_T1], +) -> Select[Tuple[_TScalar_0, _T1]]: ... @overload def select( # type: ignore - entity_0: Type[_TModel_0], + __ent0: _TCCA[_T0], entity_1: _TScalar_1, - **kw: Any, -) -> Select[Tuple[_TModel_0, _TScalar_1]]: +) -> Select[Tuple[_T0, _TScalar_1]]: ... @overload def select( # type: ignore - entity_0: Type[_TModel_0], - entity_1: Type[_TModel_1], - **kw: Any, -) -> Select[Tuple[_TModel_0, _TModel_1]]: + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], +) -> Select[Tuple[_T0, _T1]]: ... @@ -170,7 +377,6 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, entity_2: _TScalar_2, - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2]]: ... @@ -179,69 +385,62 @@ def select( # type: ignore def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, - entity_2: Type[_TModel_2], - **kw: Any, -) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2]]: + __ent2: _TCCA[_T2], +) -> Select[Tuple[_TScalar_0, _TScalar_1, _T2]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - entity_1: Type[_TModel_1], + __ent1: _TCCA[_T1], entity_2: _TScalar_2, - **kw: Any, -) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2]]: +) -> Select[Tuple[_TScalar_0, _T1, _TScalar_2]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - entity_1: Type[_TModel_1], - entity_2: Type[_TModel_2], - **kw: Any, -) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2]]: + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], +) -> Select[Tuple[_TScalar_0, _T1, _T2]]: ... @overload def select( # type: ignore - entity_0: Type[_TModel_0], + __ent0: _TCCA[_T0], entity_1: _TScalar_1, entity_2: _TScalar_2, - **kw: Any, -) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2]]: +) -> Select[Tuple[_T0, _TScalar_1, _TScalar_2]]: ... @overload def select( # type: ignore - entity_0: Type[_TModel_0], + __ent0: _TCCA[_T0], entity_1: _TScalar_1, - entity_2: Type[_TModel_2], - **kw: Any, -) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2]]: + __ent2: _TCCA[_T2], +) -> Select[Tuple[_T0, _TScalar_1, _T2]]: ... @overload def select( # type: ignore - entity_0: Type[_TModel_0], - entity_1: Type[_TModel_1], + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], entity_2: _TScalar_2, - **kw: Any, -) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2]]: +) -> Select[Tuple[_T0, _T1, _TScalar_2]]: ... @overload def select( # type: ignore - entity_0: Type[_TModel_0], - entity_1: Type[_TModel_1], - entity_2: Type[_TModel_2], - **kw: Any, -) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2]]: + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], +) -> Select[Tuple[_T0, _T1, _T2]]: ... @@ -251,7 +450,6 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: _TScalar_3, - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TScalar_3]]: ... @@ -261,9 +459,8 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, entity_2: _TScalar_2, - entity_3: Type[_TModel_3], - **kw: Any, -) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TModel_3]]: + __ent3: _TCCA[_T3], +) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _T3]]: ... @@ -271,10 +468,9 @@ def select( # type: ignore def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, - entity_2: Type[_TModel_2], + __ent2: _TCCA[_T2], entity_3: _TScalar_3, - **kw: Any, -) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TScalar_3]]: +) -> Select[Tuple[_TScalar_0, _TScalar_1, _T2, _TScalar_3]]: ... @@ -282,156 +478,142 @@ def select( # type: ignore def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, - entity_2: Type[_TModel_2], - entity_3: Type[_TModel_3], - **kw: Any, -) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TModel_3]]: + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], +) -> Select[Tuple[_TScalar_0, _TScalar_1, _T2, _T3]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - entity_1: Type[_TModel_1], + __ent1: _TCCA[_T1], entity_2: _TScalar_2, entity_3: _TScalar_3, - **kw: Any, -) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TScalar_3]]: +) -> Select[Tuple[_TScalar_0, _T1, _TScalar_2, _TScalar_3]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - entity_1: Type[_TModel_1], + __ent1: _TCCA[_T1], entity_2: _TScalar_2, - entity_3: Type[_TModel_3], - **kw: Any, -) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TModel_3]]: + __ent3: _TCCA[_T3], +) -> Select[Tuple[_TScalar_0, _T1, _TScalar_2, _T3]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - entity_1: Type[_TModel_1], - entity_2: Type[_TModel_2], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], entity_3: _TScalar_3, - **kw: Any, -) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TScalar_3]]: +) -> Select[Tuple[_TScalar_0, _T1, _T2, _TScalar_3]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - entity_1: Type[_TModel_1], - entity_2: Type[_TModel_2], - entity_3: Type[_TModel_3], - **kw: Any, -) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TModel_3]]: + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], +) -> Select[Tuple[_TScalar_0, _T1, _T2, _T3]]: ... @overload def select( # type: ignore - entity_0: Type[_TModel_0], + __ent0: _TCCA[_T0], entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: _TScalar_3, - **kw: Any, -) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TScalar_3]]: +) -> Select[Tuple[_T0, _TScalar_1, _TScalar_2, _TScalar_3]]: ... @overload def select( # type: ignore - entity_0: Type[_TModel_0], + __ent0: _TCCA[_T0], entity_1: _TScalar_1, entity_2: _TScalar_2, - entity_3: Type[_TModel_3], - **kw: Any, -) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TModel_3]]: + __ent3: _TCCA[_T3], +) -> Select[Tuple[_T0, _TScalar_1, _TScalar_2, _T3]]: ... @overload def select( # type: ignore - entity_0: Type[_TModel_0], + __ent0: _TCCA[_T0], entity_1: _TScalar_1, - entity_2: Type[_TModel_2], + __ent2: _TCCA[_T2], entity_3: _TScalar_3, - **kw: Any, -) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TScalar_3]]: +) -> Select[Tuple[_T0, _TScalar_1, _T2, _TScalar_3]]: ... @overload def select( # type: ignore - entity_0: Type[_TModel_0], + __ent0: _TCCA[_T0], entity_1: _TScalar_1, - entity_2: Type[_TModel_2], - entity_3: Type[_TModel_3], - **kw: Any, -) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TModel_3]]: + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], +) -> Select[Tuple[_T0, _TScalar_1, _T2, _T3]]: ... @overload def select( # type: ignore - entity_0: Type[_TModel_0], - entity_1: Type[_TModel_1], + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], entity_2: _TScalar_2, entity_3: _TScalar_3, - **kw: Any, -) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TScalar_3]]: +) -> Select[Tuple[_T0, _T1, _TScalar_2, _TScalar_3]]: ... @overload def select( # type: ignore - entity_0: Type[_TModel_0], - entity_1: Type[_TModel_1], + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], entity_2: _TScalar_2, - entity_3: Type[_TModel_3], - **kw: Any, -) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TModel_3]]: + __ent3: _TCCA[_T3], +) -> Select[Tuple[_T0, _T1, _TScalar_2, _T3]]: ... @overload def select( # type: ignore - entity_0: Type[_TModel_0], - entity_1: Type[_TModel_1], - entity_2: Type[_TModel_2], + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], entity_3: _TScalar_3, - **kw: Any, -) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TScalar_3]]: +) -> Select[Tuple[_T0, _T1, _T2, _TScalar_3]]: ... @overload def select( # type: ignore - entity_0: Type[_TModel_0], - entity_1: Type[_TModel_1], - entity_2: Type[_TModel_2], - entity_3: Type[_TModel_3], - **kw: Any, -) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TModel_3]]: + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], +) -> Select[Tuple[_T0, _T1, _T2, _T3]]: ... # Generated overloads end -def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore +def select(*entities: Any) -> Union[Select, SelectOfScalar]: # type: ignore if len(entities) == 1: - return SelectOfScalar._create(*entities, **kw) # type: ignore - return Select._create(*entities, **kw) # type: ignore + return SelectOfScalar(*entities) + return Select(*entities) -# TODO: add several @overload from Python types to SQLAlchemy equivalents -def col(column_expression: Any) -> ColumnClause: # type: ignore +def col(column_expression: _T) -> Mapped[_T]: if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") - return column_expression + return column_expression # type: ignore diff --git a/sqlmodel/sql/expression.py.jinja2 b/sqlmodel/sql/expression.py.jinja2 index 26d12a0395..f1a25419c0 100644 --- a/sqlmodel/sql/expression.py.jinja2 +++ b/sqlmodel/sql/expression.py.jinja2 @@ -1,9 +1,9 @@ from datetime import datetime from typing import ( - TYPE_CHECKING, Any, - Generic, + Iterable, Mapping, + Optional, Sequence, Tuple, Type, @@ -13,28 +13,243 @@ from typing import ( ) from uuid import UUID -from sqlalchemy import Column -from sqlalchemy.orm import InstrumentedAttribute -from sqlalchemy.sql.elements import ColumnClause +import sqlalchemy +from sqlalchemy import ( + Column, + ColumnElement, + Extract, + FunctionElement, + FunctionFilter, + Label, + Over, + TypeCoerce, + WithinGroup, +) +from sqlalchemy.orm import InstrumentedAttribute, Mapped +from sqlalchemy.sql._typing import ( + _ColumnExpressionArgument, + _ColumnExpressionOrLiteralArgument, + _ColumnExpressionOrStrLabelArgument, +) +from sqlalchemy.sql.elements import ( + BinaryExpression, + Case, + Cast, + CollectionAggregate, + ColumnClause, + SQLCoreOperations, + TryCast, + UnaryExpression, +) from sqlalchemy.sql.expression import Select as _Select +from sqlalchemy.sql.roles import TypedColumnsClauseRole +from sqlalchemy.sql.type_api import TypeEngine +from typing_extensions import Literal, Self + +_T = TypeVar("_T") + +_TypeEngineArgument = Union[Type[TypeEngine[_T]], TypeEngine[_T]] + +# Redefine operatos that would only take a column expresion to also take the (virtual) +# types of Pydantic models, e.g. str instead of only Mapped[str]. + + +def all_(expr: Union[_ColumnExpressionArgument[_T], _T]) -> CollectionAggregate[bool]: + return sqlalchemy.all_(expr) # type: ignore[arg-type] + + +def and_( + initial_clause: Union[Literal[True], _ColumnExpressionArgument[bool], bool], + *clauses: Union[_ColumnExpressionArgument[bool], bool], +) -> ColumnElement[bool]: + return sqlalchemy.and_(initial_clause, *clauses) # type: ignore[arg-type] + + +def any_(expr: Union[_ColumnExpressionArgument[_T], _T]) -> CollectionAggregate[bool]: + return sqlalchemy.any_(expr) # type: ignore[arg-type] + + +def asc( + column: Union[_ColumnExpressionOrStrLabelArgument[_T], _T], +) -> UnaryExpression[_T]: + return sqlalchemy.asc(column) # type: ignore[arg-type] + + +def collate( + expression: Union[_ColumnExpressionArgument[str], str], collation: str +) -> BinaryExpression[str]: + return sqlalchemy.collate(expression, collation) # type: ignore[arg-type] + + +def between( + expr: Union[_ColumnExpressionOrLiteralArgument[_T], _T], + lower_bound: Any, + upper_bound: Any, + symmetric: bool = False, +) -> BinaryExpression[bool]: + return sqlalchemy.between(expr, lower_bound, upper_bound, symmetric=symmetric) # type: ignore[arg-type] + + +def not_(clause: Union[_ColumnExpressionArgument[_T], _T]) -> ColumnElement[_T]: + return sqlalchemy.not_(clause) # type: ignore[arg-type] + + +def case( + *whens: Union[ + Tuple[Union[_ColumnExpressionArgument[bool], bool], Any], Mapping[Any, Any] + ], + value: Optional[Any] = None, + else_: Optional[Any] = None, +) -> Case[Any]: + return sqlalchemy.case(*whens, value=value, else_=else_) # type: ignore[arg-type] + + +def cast( + expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any], + type_: "_TypeEngineArgument[_T]", +) -> Cast[_T]: + return sqlalchemy.cast(expression, type_) # type: ignore[arg-type] + + +def try_cast( + expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any], + type_: "_TypeEngineArgument[_T]", +) -> TryCast[_T]: + return sqlalchemy.try_cast(expression, type_) # type: ignore[arg-type] + + +def desc( + column: Union[_ColumnExpressionOrStrLabelArgument[_T], _T], +) -> UnaryExpression[_T]: + return sqlalchemy.desc(column) # type: ignore[arg-type] + + +def distinct(expr: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]: + return sqlalchemy.distinct(expr) # type: ignore[arg-type] + -_TSelect = TypeVar("_TSelect") +def bitwise_not(expr: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]: + return sqlalchemy.bitwise_not(expr) # type: ignore[arg-type] -class Select(_Select, Generic[_TSelect]): + +def extract(field: str, expr: Union[_ColumnExpressionArgument[Any], Any]) -> Extract: + return sqlalchemy.extract(field, expr) # type: ignore[arg-type] + + +def funcfilter( + func: FunctionElement[_T], *criterion: Union[_ColumnExpressionArgument[bool], bool] +) -> FunctionFilter[_T]: + return sqlalchemy.funcfilter(func, *criterion) # type: ignore[arg-type] + + +def label( + name: str, + element: Union[_ColumnExpressionArgument[_T], _T], + type_: Optional["_TypeEngineArgument[_T]"] = None, +) -> Label[_T]: + return sqlalchemy.label(name, element, type_=type_) # type: ignore[arg-type] + + +def nulls_first( + column: Union[_ColumnExpressionArgument[_T], _T] +) -> UnaryExpression[_T]: + return sqlalchemy.nulls_first(column) # type: ignore[arg-type] + + +def nulls_last(column: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]: + return sqlalchemy.nulls_last(column) # type: ignore[arg-type] + + +def or_( # type: ignore[empty-body] + initial_clause: Union[Literal[False], _ColumnExpressionArgument[bool], bool], + *clauses: Union[_ColumnExpressionArgument[bool], bool], +) -> ColumnElement[bool]: + return sqlalchemy.or_(initial_clause, *clauses) # type: ignore[arg-type] + + +def over( + element: FunctionElement[_T], + partition_by: Optional[ + Union[ + Iterable[Union[_ColumnExpressionArgument[Any], Any]], + _ColumnExpressionArgument[Any], + Any, + ] + ] = None, + order_by: Optional[ + Union[ + Iterable[Union[_ColumnExpressionArgument[Any], Any]], + _ColumnExpressionArgument[Any], + Any, + ] + ] = None, + range_: Optional[Tuple[Optional[int], Optional[int]]] = None, + rows: Optional[Tuple[Optional[int], Optional[int]]] = None, +) -> Over[_T]: + return sqlalchemy.over( + element, partition_by=partition_by, order_by=order_by, range_=range_, rows=rows + ) # type: ignore[arg-type] + + +def tuple_( + *clauses: Union[_ColumnExpressionArgument[Any], Any], + types: Optional[Sequence["_TypeEngineArgument[Any]"]] = None, +) -> Tuple[Any, ...]: + return sqlalchemy.tuple_(*clauses, types=types) # type: ignore[return-value] + + +def type_coerce( + expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any], + type_: "_TypeEngineArgument[_T]", +) -> TypeCoerce[_T]: + return sqlalchemy.type_coerce(expression, type_) # type: ignore[arg-type] + + +def within_group( + element: FunctionElement[_T], *order_by: Union[_ColumnExpressionArgument[Any], Any] +) -> WithinGroup[_T]: + return sqlalchemy.within_group(element, *order_by) # type: ignore[arg-type] + + +# Separate this class in SelectBase, Select, and SelectOfScalar so that they can share +# where and having without having type overlap incompatibility in session.exec(). +class SelectBase(_Select[Tuple[_T]]): inherit_cache = True + def where(self, *whereclause: Union[_ColumnExpressionArgument[bool], bool]) -> Self: + """Return a new `Select` construct with the given expression added to + its `WHERE` clause, joined to the existing clause via `AND`, if any. + """ + return super().where(*whereclause) # type: ignore[arg-type] + + def having(self, *having: Union[_ColumnExpressionArgument[bool], bool]) -> Self: + """Return a new `Select` construct with the given expression added to + its `HAVING` clause, joined to the existing clause via `AND`, if any. + """ + return super().having(*having) # type: ignore[arg-type] + + +class Select(SelectBase[_T]): + inherit_cache = True + + # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different # purpose. This is the same as a normal SQLAlchemy Select class where there's only one # entity, so the result will be converted to a scalar by default. This way writing # for loops on the results will feel natural. -class SelectOfScalar(_Select, Generic[_TSelect]): +class SelectOfScalar(SelectBase[_T]): inherit_cache = True -if TYPE_CHECKING: # pragma: no cover - from ..main import SQLModel + +_TCCA = Union[ + TypedColumnsClauseRole[_T], + SQLCoreOperations[_T], + Type[_T], +] # Generated TypeVars start + {% for i in range(number_of_types) %} _TScalar_{{ i }} = TypeVar( "_TScalar_{{ i }}", @@ -51,19 +266,19 @@ _TScalar_{{ i }} = TypeVar( None, ) -_TModel_{{ i }} = TypeVar("_TModel_{{ i }}", bound="SQLModel") +_T{{ i }} = TypeVar("_T{{ i }}") {% endfor %} # Generated TypeVars end @overload -def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore +def select(__ent0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore ... @overload -def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore +def select(__ent0: _TCCA[_T0]) -> SelectOfScalar[_T0]: ... @@ -73,7 +288,7 @@ def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: @overload def select( # type: ignore - {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %}**kw: Any, + {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %} ) -> Select[Tuple[{%for ret in signature[1] %}{{ ret }} {% if not loop.last %}, {% endif %}{% endfor %}]]: ... @@ -81,14 +296,14 @@ def select( # type: ignore # Generated overloads end -def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore + +def select(*entities: Any) -> Union[Select, SelectOfScalar]: # type: ignore if len(entities) == 1: - return SelectOfScalar._create(*entities, **kw) # type: ignore - return Select._create(*entities, **kw) # type: ignore + return SelectOfScalar(*entities) + return Select(*entities) -# TODO: add several @overload from Python types to SQLAlchemy equivalents -def col(column_expression: Any) -> ColumnClause: # type: ignore +def col(column_expression: _T) -> Mapped[_T]: if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") - return column_expression + return column_expression # type: ignore diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index 17d9b06126..5a4bb04ef1 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -15,7 +15,7 @@ class AutoString(types.TypeDecorator): # type: ignore def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]": impl = cast(types.String, self.impl) if impl.length is None and dialect.name == "mysql": - return dialect.type_descriptor(types.String(self.mysql_default_length)) # type: ignore + return dialect.type_descriptor(types.String(self.mysql_default_length)) return super().load_dialect_impl(dialect) @@ -32,11 +32,11 @@ class GUID(types.TypeDecorator): # type: ignore impl = CHAR cache_ok = True - def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: # type: ignore + def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: if dialect.name == "postgresql": - return dialect.type_descriptor(UUID()) # type: ignore + return dialect.type_descriptor(UUID()) else: - return dialect.type_descriptor(CHAR(32)) # type: ignore + return dialect.type_descriptor(CHAR(32)) def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]: if value is None: diff --git a/tests/test_tutorial/test_fastapi/test_delete/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_delete/test_tutorial001.py index b08affb920..6a55d6cb98 100644 --- a/tests/test_tutorial/test_fastapi/test_delete/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_delete/test_tutorial001.py @@ -59,7 +59,7 @@ def test_tutorial(clear_sqlmodel): response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == { - "openapi": "3.0.2", + "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { "/heroes/": { @@ -315,7 +315,9 @@ def test_tutorial(clear_sqlmodel): "loc": { "title": "Location", "type": "array", - "items": {"type": "string"}, + "items": { + "anyOf": [{"type": "string"}, {"type": "integer"}] + }, }, "msg": {"title": "Message", "type": "string"}, "type": {"title": "Error Type", "type": "string"}, diff --git a/tests/test_tutorial/test_fastapi/test_limit_and_offset/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_limit_and_offset/test_tutorial001.py index 0aee3ca004..2709231504 100644 --- a/tests/test_tutorial/test_fastapi/test_limit_and_offset/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_limit_and_offset/test_tutorial001.py @@ -64,7 +64,7 @@ def test_tutorial(clear_sqlmodel): response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == { - "openapi": "3.0.2", + "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { "/heroes/": { @@ -239,7 +239,9 @@ def test_tutorial(clear_sqlmodel): "loc": { "title": "Location", "type": "array", - "items": {"type": "string"}, + "items": { + "anyOf": [{"type": "string"}, {"type": "integer"}] + }, }, "msg": {"title": "Message", "type": "string"}, "type": {"title": "Error Type", "type": "string"}, diff --git a/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial001.py index 8d99cf9f5b..dc5a3cb8ff 100644 --- a/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial001.py @@ -5,7 +5,7 @@ from sqlmodel.pool import StaticPool openapi_schema = { - "openapi": "3.0.2", + "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { "/heroes/": { @@ -103,7 +103,7 @@ "loc": { "title": "Location", "type": "array", - "items": {"type": "string"}, + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, }, "msg": {"title": "Message", "type": "string"}, "type": {"title": "Error Type", "type": "string"}, diff --git a/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial002.py b/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial002.py index 94a41b3076..e3c20404c0 100644 --- a/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial002.py +++ b/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial002.py @@ -5,7 +5,7 @@ from sqlmodel.pool import StaticPool openapi_schema = { - "openapi": "3.0.2", + "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { "/heroes/": { @@ -103,7 +103,7 @@ "loc": { "title": "Location", "type": "array", - "items": {"type": "string"}, + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, }, "msg": {"title": "Message", "type": "string"}, "type": {"title": "Error Type", "type": "string"}, diff --git a/tests/test_tutorial/test_fastapi/test_read_one/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_read_one/test_tutorial001.py index 0609ae41ff..0a599574d5 100644 --- a/tests/test_tutorial/test_fastapi/test_read_one/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_read_one/test_tutorial001.py @@ -3,7 +3,7 @@ from sqlmodel.pool import StaticPool openapi_schema = { - "openapi": "3.0.2", + "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { "/heroes/": { @@ -135,7 +135,7 @@ "loc": { "title": "Location", "type": "array", - "items": {"type": "string"}, + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, }, "msg": {"title": "Message", "type": "string"}, "type": {"title": "Error Type", "type": "string"}, diff --git a/tests/test_tutorial/test_fastapi/test_relationships/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_relationships/test_tutorial001.py index 8869862e95..fb08b9a5fd 100644 --- a/tests/test_tutorial/test_fastapi/test_relationships/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_relationships/test_tutorial001.py @@ -107,7 +107,7 @@ def test_tutorial(clear_sqlmodel): response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == { - "openapi": "3.0.2", + "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { "/heroes/": { @@ -622,7 +622,9 @@ def test_tutorial(clear_sqlmodel): "loc": { "title": "Location", "type": "array", - "items": {"type": "string"}, + "items": { + "anyOf": [{"type": "string"}, {"type": "integer"}] + }, }, "msg": {"title": "Message", "type": "string"}, "type": {"title": "Error Type", "type": "string"}, diff --git a/tests/test_tutorial/test_fastapi/test_response_model/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_response_model/test_tutorial001.py index ebb3046ef3..968fefa8ca 100644 --- a/tests/test_tutorial/test_fastapi/test_response_model/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_response_model/test_tutorial001.py @@ -3,7 +3,7 @@ from sqlmodel.pool import StaticPool openapi_schema = { - "openapi": "3.0.2", + "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { "/heroes/": { @@ -91,7 +91,7 @@ "loc": { "title": "Location", "type": "array", - "items": {"type": "string"}, + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, }, "msg": {"title": "Message", "type": "string"}, "type": {"title": "Error Type", "type": "string"}, diff --git a/tests/test_tutorial/test_fastapi/test_session_with_dependency/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_session_with_dependency/test_tutorial001.py index cb0a6f9282..6f97cbf92b 100644 --- a/tests/test_tutorial/test_fastapi/test_session_with_dependency/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_session_with_dependency/test_tutorial001.py @@ -59,7 +59,7 @@ def test_tutorial(clear_sqlmodel): response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == { - "openapi": "3.0.2", + "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { "/heroes/": { @@ -315,7 +315,9 @@ def test_tutorial(clear_sqlmodel): "loc": { "title": "Location", "type": "array", - "items": {"type": "string"}, + "items": { + "anyOf": [{"type": "string"}, {"type": "integer"}] + }, }, "msg": {"title": "Message", "type": "string"}, "type": {"title": "Error Type", "type": "string"}, diff --git a/tests/test_tutorial/test_fastapi/test_simple_hero_api/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_simple_hero_api/test_tutorial001.py index eb834ec2a4..435155d6e9 100644 --- a/tests/test_tutorial/test_fastapi/test_simple_hero_api/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_simple_hero_api/test_tutorial001.py @@ -3,7 +3,7 @@ from sqlmodel.pool import StaticPool openapi_schema = { - "openapi": "3.0.2", + "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { "/heroes/": { @@ -79,7 +79,7 @@ "loc": { "title": "Location", "type": "array", - "items": {"type": "string"}, + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, }, "msg": {"title": "Message", "type": "string"}, "type": {"title": "Error Type", "type": "string"}, diff --git a/tests/test_tutorial/test_fastapi/test_teams/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_teams/test_tutorial001.py index e66c975142..42f87cef76 100644 --- a/tests/test_tutorial/test_fastapi/test_teams/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_teams/test_tutorial001.py @@ -94,7 +94,7 @@ def test_tutorial(clear_sqlmodel): response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == { - "openapi": "3.0.2", + "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { "/heroes/": { @@ -579,7 +579,9 @@ def test_tutorial(clear_sqlmodel): "loc": { "title": "Location", "type": "array", - "items": {"type": "string"}, + "items": { + "anyOf": [{"type": "string"}, {"type": "integer"}] + }, }, "msg": {"title": "Message", "type": "string"}, "type": {"title": "Error Type", "type": "string"}, diff --git a/tests/test_tutorial/test_fastapi/test_update/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_update/test_tutorial001.py index 49906256c9..a4573ef11b 100644 --- a/tests/test_tutorial/test_fastapi/test_update/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_update/test_tutorial001.py @@ -66,7 +66,7 @@ def test_tutorial(clear_sqlmodel): response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == { - "openapi": "3.0.2", + "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { "/heroes/": { @@ -294,7 +294,9 @@ def test_tutorial(clear_sqlmodel): "loc": { "title": "Location", "type": "array", - "items": {"type": "string"}, + "items": { + "anyOf": [{"type": "string"}, {"type": "integer"}] + }, }, "msg": {"title": "Message", "type": "string"}, "type": {"title": "Error Type", "type": "string"},