Skip to content

Commit

Permalink
🐛 fix: orm imports
Browse files Browse the repository at this point in the history
  • Loading branch information
gazorby committed Feb 17, 2022
1 parent 02d4bc2 commit dc27ee5
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 35 deletions.
55 changes: 37 additions & 18 deletions strawberry/experimental/pydantic/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,27 @@
from inspect import isclass
from typing import TYPE_CHECKING, Any, List, Optional, Type

from pydantic.fields import UndefinedType
from pydantic import BaseModel
from pydantic.fields import UndefinedType

from .lazy_types import LazyModelType


_ormar_found = find_spec("ormar")
_sqlmodel_found = find_spec("sqlmodel")
_ormar_found = bool(find_spec("ormar"))
_sqlmodel_found = bool(find_spec("sqlmodel"))

if _ormar_found:
import ormar
from ormar.fields.through_field import ThroughField
from ormar.queryset.field_accessor import FieldAccessor

if _sqlmodel_found:
import sqlmodel
from sqlalchemy.orm.relationships import RelationshipProperty


if TYPE_CHECKING:
from ormar import Model
import ormar


def replace_ormar_types(type_: Any, model: Type[BaseModel], name: str):
Expand All @@ -33,11 +43,11 @@ def replace_ormar_types(type_: Any, model: Type[BaseModel], name: str):

f_info = field.field_info
f_accessor = None
child_type: Optional[Type["Model"]] = f_info.to
child_type: Optional[Type["ormar.Model"]] = f_info.to
_required = field.required
elif isinstance(getattr(model, name, None), FieldAccessor):
field = getattr(model, name)
assert isinstance(field, FieldAccessor)
assert isinstance(field, FieldAccessor) # mypy
f_accessor = field
child_type = field._model
_required = False
Expand Down Expand Up @@ -65,26 +75,35 @@ def replace_ormar_types(type_: Any, model: Type[BaseModel], name: str):


def is_sqlmodel_field(field) -> bool:
if not _sqlmodel_found:
return False
from sqlalchemy.orm.relationships import RelationshipProperty

return getattr(field, "property", None).__class__ is RelationshipProperty
return (
_sqlmodel_found
and getattr(field, "property", None).__class__ is RelationshipProperty
)


def is_ormar_field(field) -> bool:
if not _ormar_found:
return False
from ormar.queryset.field_accessor import FieldAccessor
return _ormar_found and isinstance(field, FieldAccessor)

return isinstance(field, FieldAccessor)

def get_ormar_accessors(model):
return {
name: getattr(model, name)
for name, f in model.Meta.model_fields.items()
if not issubclass(type(f), ThroughField)
}

def is_ormar_model(model: Type[BaseModel]) -> bool:
from ormar import Model

return isclass(model) and issubclass(model, Model)
def get_sqlmodel_relationships(model):
return {name: f for name, f in model.__dict__.items() if is_sqlmodel_field(f)}


def is_orm_field(field) -> bool:
return is_ormar_field(field) or is_sqlmodel_field(field)


def is_ormar_model(model: Type[BaseModel]) -> bool:
return _ormar_found and isclass(model) and issubclass(model, ormar.Model)


def is_sqlmodel_model(model: Type[BaseModel]) -> bool:
return _sqlmodel_found and issubclass(model, sqlmodel.SQLModel)
28 changes: 11 additions & 17 deletions strawberry/experimental/pydantic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
is_optional,
)

from ormar.fields.through_field import ThroughField
from sqlmodel import SQLModel
from sqlalchemy.orm.relationships import RelationshipProperty
import ormar

from strawberry.experimental.pydantic.orm import is_orm_field
from strawberry.experimental.pydantic.orm import (
get_ormar_accessors,
get_sqlmodel_relationships,
is_orm_field,
is_ormar_model,
is_sqlmodel_model,
)


def normalize_type(type_) -> Any:
Expand Down Expand Up @@ -147,21 +149,13 @@ def ensure_all_auto_fields_in_pydantic(


def get_model_fields(model) -> Dict[str, Any]:
if issubclass(model, ormar.Model):
model_fields = {
name: getattr(model, name)
for name, f in model.Meta.model_fields.items()
if not issubclass(type(f), ThroughField)
}
if is_ormar_model(model):
model_fields = get_ormar_accessors(model)
model_fields.update(model.__fields__)
return model_fields

if issubclass(model, SQLModel):
model_fields = {
name: f
for name, f in model.__dict__.items()
if getattr(f, "property", None).__class__ is RelationshipProperty
}
if is_sqlmodel_model(model):
model_fields = get_sqlmodel_relationships(model)
model_fields.update(model.__fields__)
return model_fields

Expand Down

0 comments on commit dc27ee5

Please sign in to comment.