Skip to content

Commit

Permalink
Merged PR 6873571: Fix and use SQLModel AsyncSession
Browse files Browse the repository at this point in the history
The whole point of SQLModel is that it returns typed objects that play nice with Pydantic and allow your editor's autocomplete functions to work properly. Except when you're using `sqlalchemy.AsyncSession`, it... doesn't. And there *is* a `sqlmodel.AsyncSession`, but it's broken. And there is a *fix* for it in an upstream PR, but it's not merged yet.
fastapi/sqlmodel#58

This PR just copies and uses the upstream PR implementation of `sqlmodel.AsyncSession`. Most of the actual diff here is related to the side-benefit that we no longer have to unwrap the row-tuples that `sqlalchemy.execute` returns, but most of the actual benefit for doing this is that we'll now actually get our appropriately-typed model objects back out of `session.exec` instead of `Any`.

Related work items: #15767106
  • Loading branch information
sdherr committed Oct 11, 2022
1 parent 09bb29e commit 3d1901f
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 35 deletions.
10 changes: 5 additions & 5 deletions server/app/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import fastapi_microsoft_identity
from fastapi import Depends, HTTPException, Request, Response
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.exc import NoResultFound
from sqlmodel import select

from app.core.db import get_session
from app.core.db import AsyncSession, get_session
from app.core.models import Account, RepoAccess, Role
from app.core.schemas import RepoId

Expand Down Expand Up @@ -39,7 +38,8 @@ async def get_active_account(

statement = select(Account).where(Account.oid == oid)
try:
account = (await session.execute(statement)).one()[0]
results = await session.exec(statement)
account = results.one()
except NoResultFound:
raise HTTPException(
status_code=403, detail=f"Domain UUID {id} is not provisioned in PMC. {SUPPORT}"
Expand Down Expand Up @@ -109,7 +109,7 @@ async def requires_repo_permission(
statement = select(RepoAccess).where(
RepoAccess.account_id == account.id, RepoAccess.repo_id == id
)
if (await session.execute(statement)).one_or_none():
if (await session.exec(statement)).one_or_none():
return

raise HTTPException(
Expand Down
19 changes: 9 additions & 10 deletions server/app/api/routes/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
from typing import Any, List

from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlmodel import select

from app.core.db import get_session
from app.core.db import AsyncSession, get_session
from app.core.models import Account, OwnedPackage, RepoAccess
from app.core.schemas import (
AccountRepoPackagePermissionUpdate,
Expand Down Expand Up @@ -41,7 +40,8 @@ async def _get_named_accounts(session: AsyncSession, account_names: List[str]) -
ret = []
for name in account_names:
statement = select(Account).where(Account.name == name)
account = (await session.execute(statement)).one()[0]
results = await session.exec(statement)
account = results.one()
ret.append(account)
return ret

Expand All @@ -52,8 +52,8 @@ async def list_repo_access(
session: AsyncSession = Depends(get_session),
) -> List[RepoAccessResponse]:
statement = select(RepoAccess)
results = (await session.execute(statement)).all()
return [x[0] for x in results]
results = await session.exec(statement)
return list(results.all())


@router.post("/access/repo/{id}/clone_from/{original_id}/", response_model=List[RepoAccessResponse])
Expand All @@ -64,15 +64,14 @@ async def clone_repo_access_from(
) -> Any:
"""Additively clone the repo permissions from another repo."""
statement = select(RepoAccess).where(RepoAccess.repo_id == id)
current_perms = (await session.execute(statement)).all()
current_perms_accounts = [x[0].account_id for x in current_perms]
current_perms = (await session.exec(statement)).all()
current_perms_accounts = [x.account_id for x in current_perms]

statement = select(RepoAccess).where(RepoAccess.repo_id == original_id)
original_perms = (await session.execute(statement)).all()
original_perms = (await session.exec(statement)).all()

new_perms = []
for perm in original_perms:
perm = perm[0] # unwrap the row tuple
if perm.account_id not in current_perms_accounts:
new_perm = RepoAccess(account_id=perm.account_id, repo_id=id, operator=perm.operator)
new_perms.append(new_perm)
Expand Down
7 changes: 3 additions & 4 deletions server/app/api/routes/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@

from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.sql.selectable import Select

from app.core.db import get_session
from app.core.db import AsyncSession, get_session
from app.core.models import Account
from app.core.schemas import (
AccountCreate,
Expand All @@ -25,10 +24,10 @@ async def _get_list(
) -> Tuple[List[Any], int]:
"""Takes a query and returns a page of results and count of total results."""
count_query = select(func.count()).select_from(query.subquery())
count = (await session.execute(count_query)).scalar_one()
count = (await session.exec(count_query)).scalar_one()

query = query.limit(limit).offset(offset)
results = (await session.execute(query)).scalars().all()
results = (await session.exec(query)).scalars().all()

return results, count

Expand Down
16 changes: 5 additions & 11 deletions server/app/api/routes/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from typing import Any, Optional

from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select

from app.api.auth import (
get_active_account,
Expand All @@ -13,7 +12,7 @@
requires_repo_permission,
)
from app.core.config import settings
from app.core.db import get_session
from app.core.db import AsyncSession, get_session
from app.core.models import Account, OwnedPackage, RepoAccess, Role
from app.core.schemas import (
PackageListResponse,
Expand Down Expand Up @@ -128,10 +127,7 @@ async def update_packages(
statement = select(RepoAccess).where(
RepoAccess.account_id == account.id, RepoAccess.repo_id == id
)
repo_perm = (await session.execute(statement)).one_or_none()
# sqlalchemy returns things from `session.execute` wrapped in tuples, for legacy reasons
if repo_perm:
repo_perm = repo_perm[0]
repo_perm = (await session.exec(statement)).one_or_none()

if account.role == Role.Publisher and not repo_perm:
raise HTTPException(
Expand All @@ -142,10 +138,8 @@ async def update_packages(
# Create a mapping of package names to accounts that are allowed to modify them in this repo.
package_name_to_account_id = defaultdict(set)
statement = select(OwnedPackage).where(OwnedPackage.repo_id == id)
for owned_package_tuple in await session.execute(statement):
# sqlalchemy returns things from `session.execute` wrapped in tuples, for legacy reasons
op = owned_package_tuple[0]
package_name_to_account_id[op.package_name].add(op.account_id)
for owned_package in await session.exec(statement):
package_name_to_account_id[owned_package.package_name].add(owned_package.account_id)

# Next enforce package adding permissions
if add_names and account.role not in (Role.Repo_Admin, Role.Publisher):
Expand Down
113 changes: 110 additions & 3 deletions server/app/core/db.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,118 @@
from typing import AsyncGenerator
from typing import Any, AsyncGenerator, Mapping, Optional, Sequence, TypeVar, Union, overload

from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy import util
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.asyncio import engine as _engine
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
from sqlalchemy.orm import sessionmaker
from sqlmodel import SQLModel
from sqlalchemy.util.concurrency import greenlet_spawn
from sqlmodel import Session, SQLModel
from sqlmodel.engine.result import Result, ScalarResult
from sqlmodel.sql.base import Executable
from sqlmodel.sql.expression import Select, SelectOfScalar

from app.core.config import settings

_TSelectParam = TypeVar("_TSelectParam")


class AsyncSession(_AsyncSession):
"""
SQLModel provides a Session wrapper over the regular sqlalchemy session that:
1) unwraps the legacy rows-are-returned-as-objects-wrapped-in-tuples behavior if you call "exec"
2) passes through Pydantic type hints of the returned objects.
SQLModel has an equivalent wrapper for AsyncSession (sqlmodel.ext.asyncio.session.AsyncSession),
but it's busted and throws type errors if you try to use it normally like you would Session.
There's an upstream PR to fix it, but it's not merged yet. Until it's merged and fixed let's
just copy-and-paste it in ourselves.
https://github.com/tiangolo/sqlmodel/pull/58
"""

sync_session: Session

def __init__(
self,
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
**kw: Any,
):
# All the same code of the original AsyncSession
kw["future"] = True
if bind:
self.bind = bind
bind = _engine._get_sync_engine_or_connection(bind) # type: ignore

if binds:
self.binds = binds
binds = {
key: _engine._get_sync_engine_or_connection(b) # type: ignore
for key, b in binds.items()
}

self.sync_session = self._proxied = self._assign_proxied( # type: ignore
Session(bind=bind, binds=binds, **kw) # type: ignore
)

@overload
async def exec(
self,
statement: Select[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Result[_TSelectParam]:
...

@overload
async def exec(
self,
statement: SelectOfScalar[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> ScalarResult[_TSelectParam]:
...

async def exec( # type: ignore
self,
statement: Union[
Select[_TSelectParam],
SelectOfScalar[_TSelectParam],
Executable[_TSelectParam],
],
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> ScalarResult[_TSelectParam]:
# TODO: the documentation says execution_options accepts a dict, but only
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore

return await greenlet_spawn(
self.sync_session.exec,
statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw,
)

async def __aenter__(self) -> "AsyncSession":
# PyCharm does not understand TypeVar here :/
return await super().__aenter__()


engine = create_async_engine(settings.db_uri(), **settings.db_engine_args())
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)

Expand Down
4 changes: 2 additions & 2 deletions server/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import pytest_asyncio
from fastapi import FastAPI
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import SQLModel

from app.api.auth import get_active_account
from app.core.config import settings
from app.core.db import async_session, get_session
from app.core.db import AsyncSession, async_session, get_session
from app.core.models import Account, Role
from app.core.schemas import RepoType
from app.main import app as fastapi_app
Expand Down

0 comments on commit 3d1901f

Please sign in to comment.