From dc27ee5809829950036b37c4acbfe859f61e150a Mon Sep 17 00:00:00 2001 From: gazorby Date: Thu, 17 Feb 2022 21:26:08 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20orm=20imports?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- strawberry/experimental/pydantic/orm.py | 55 +++++++++++++++-------- strawberry/experimental/pydantic/utils.py | 28 +++++------- 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/strawberry/experimental/pydantic/orm.py b/strawberry/experimental/pydantic/orm.py index 03d1b83def..0d9246cfd8 100644 --- a/strawberry/experimental/pydantic/orm.py +++ b/strawberry/experimental/pydantic/orm.py @@ -8,17 +8,27 @@ from inspect import isclass from typing import TYPE_CHECKING, Any, List, Optional, Type -from pydantic.fields import UndefinedType from pydantic import BaseModel +from pydantic.fields import UndefinedType from .lazy_types import LazyModelType -_ormar_found = find_spec("ormar") -_sqlmodel_found = find_spec("sqlmodel") +_ormar_found = bool(find_spec("ormar")) +_sqlmodel_found = bool(find_spec("sqlmodel")) + +if _ormar_found: + import ormar + from ormar.fields.through_field import ThroughField + from ormar.queryset.field_accessor import FieldAccessor + +if _sqlmodel_found: + import sqlmodel + from sqlalchemy.orm.relationships import RelationshipProperty + if TYPE_CHECKING: - from ormar import Model + import ormar def replace_ormar_types(type_: Any, model: Type[BaseModel], name: str): @@ -33,11 +43,11 @@ def replace_ormar_types(type_: Any, model: Type[BaseModel], name: str): f_info = field.field_info f_accessor = None - child_type: Optional[Type["Model"]] = f_info.to + child_type: Optional[Type["ormar.Model"]] = f_info.to _required = field.required elif isinstance(getattr(model, name, None), FieldAccessor): field = getattr(model, name) - assert isinstance(field, FieldAccessor) + assert isinstance(field, FieldAccessor) # mypy f_accessor = field child_type = field._model _required = False @@ -65,26 +75,35 @@ def replace_ormar_types(type_: Any, model: Type[BaseModel], name: str): def is_sqlmodel_field(field) -> bool: - if not _sqlmodel_found: - return False - from sqlalchemy.orm.relationships import RelationshipProperty - - return getattr(field, "property", None).__class__ is RelationshipProperty + return ( + _sqlmodel_found + and getattr(field, "property", None).__class__ is RelationshipProperty + ) def is_ormar_field(field) -> bool: - if not _ormar_found: - return False - from ormar.queryset.field_accessor import FieldAccessor + return _ormar_found and isinstance(field, FieldAccessor) - return isinstance(field, FieldAccessor) +def get_ormar_accessors(model): + return { + name: getattr(model, name) + for name, f in model.Meta.model_fields.items() + if not issubclass(type(f), ThroughField) + } -def is_ormar_model(model: Type[BaseModel]) -> bool: - from ormar import Model - return isclass(model) and issubclass(model, Model) +def get_sqlmodel_relationships(model): + return {name: f for name, f in model.__dict__.items() if is_sqlmodel_field(f)} def is_orm_field(field) -> bool: return is_ormar_field(field) or is_sqlmodel_field(field) + + +def is_ormar_model(model: Type[BaseModel]) -> bool: + return _ormar_found and isclass(model) and issubclass(model, ormar.Model) + + +def is_sqlmodel_model(model: Type[BaseModel]) -> bool: + return _sqlmodel_found and issubclass(model, sqlmodel.SQLModel) diff --git a/strawberry/experimental/pydantic/utils.py b/strawberry/experimental/pydantic/utils.py index 8df6eaa931..0ac55e036f 100644 --- a/strawberry/experimental/pydantic/utils.py +++ b/strawberry/experimental/pydantic/utils.py @@ -20,12 +20,14 @@ is_optional, ) -from ormar.fields.through_field import ThroughField -from sqlmodel import SQLModel -from sqlalchemy.orm.relationships import RelationshipProperty -import ormar -from strawberry.experimental.pydantic.orm import is_orm_field +from strawberry.experimental.pydantic.orm import ( + get_ormar_accessors, + get_sqlmodel_relationships, + is_orm_field, + is_ormar_model, + is_sqlmodel_model, +) def normalize_type(type_) -> Any: @@ -147,21 +149,13 @@ def ensure_all_auto_fields_in_pydantic( def get_model_fields(model) -> Dict[str, Any]: - if issubclass(model, ormar.Model): - model_fields = { - name: getattr(model, name) - for name, f in model.Meta.model_fields.items() - if not issubclass(type(f), ThroughField) - } + if is_ormar_model(model): + model_fields = get_ormar_accessors(model) model_fields.update(model.__fields__) return model_fields - if issubclass(model, SQLModel): - model_fields = { - name: f - for name, f in model.__dict__.items() - if getattr(f, "property", None).__class__ is RelationshipProperty - } + if is_sqlmodel_model(model): + model_fields = get_sqlmodel_relationships(model) model_fields.update(model.__fields__) return model_fields