From 748c47c3432a619b58e9b5301179e5ef205a719a Mon Sep 17 00:00:00 2001 From: Michael <3686226+reallistic@users.noreply.github.com> Date: Fri, 7 Feb 2025 00:12:13 -0800 Subject: [PATCH] use graphql core extensions field to store contrib resolvers (#1228) --- ariadne/contrib/federation/interfaces.py | 15 ++++++++++----- ariadne/contrib/federation/objects.py | 6 +++++- ariadne/contrib/federation/schema.py | 22 +++++++++------------- ariadne/contrib/federation/utils.py | 3 ++- ariadne/contrib/relay/objects.py | 18 ++++++++++++------ ariadne/utils.py | 15 +++++++++++++++ tests/relay/test_objects.py | 9 +++++++++ 7 files changed, 62 insertions(+), 26 deletions(-) diff --git a/ariadne/contrib/federation/interfaces.py b/ariadne/contrib/federation/interfaces.py index 4b3da0e35..b36377eec 100644 --- a/ariadne/contrib/federation/interfaces.py +++ b/ariadne/contrib/federation/interfaces.py @@ -1,10 +1,14 @@ +from typing import cast from typing import Optional +from graphql import GraphQLNamedType from graphql.type import GraphQLSchema from ...interfaces import InterfaceType from ...types import Resolver +from ...utils import type_get_extension from ...utils import type_implements_interface +from ...utils import type_set_extension class FederatedInterfaceType(InterfaceType): @@ -26,17 +30,18 @@ def bind_to_schema(self, schema: GraphQLSchema) -> None: if callable(self._reference_resolver): graphql_type = schema.type_map.get(self.name) - setattr( + graphql_type = cast(GraphQLNamedType, graphql_type) + type_set_extension( graphql_type, "__resolve_reference__", self._reference_resolver, ) for object_type in schema.type_map.values(): - if type_implements_interface(self.name, object_type) and not hasattr( - object_type, "__resolve_reference__" - ): - setattr( + if type_implements_interface( + self.name, object_type + ) and not type_get_extension(object_type, "__resolve_reference__"): + type_set_extension( object_type, "__resolve_reference__", self._reference_resolver, diff --git a/ariadne/contrib/federation/objects.py b/ariadne/contrib/federation/objects.py index 6789d26e9..99e428b4f 100644 --- a/ariadne/contrib/federation/objects.py +++ b/ariadne/contrib/federation/objects.py @@ -1,9 +1,12 @@ +from typing import cast from typing import Optional +from graphql import GraphQLNamedType from graphql.type import GraphQLSchema from ...objects import ObjectType from ...types import Resolver +from ...utils import type_set_extension class FederatedObjectType(ObjectType): @@ -25,7 +28,8 @@ def bind_to_schema(self, schema: GraphQLSchema) -> None: if callable(self._reference_resolver): graphql_type = schema.type_map.get(self.name) - setattr( + graphql_type = cast(GraphQLNamedType, graphql_type) + type_set_extension( graphql_type, "__resolve_reference__", self._reference_resolver, diff --git a/ariadne/contrib/federation/schema.py b/ariadne/contrib/federation/schema.py index 20d724b27..8a35a69de 100644 --- a/ariadne/contrib/federation/schema.py +++ b/ariadne/contrib/federation/schema.py @@ -1,6 +1,7 @@ import re import os from typing import Dict, List, Optional, Type, Union, cast +from warnings import warn from graphql import extend_schema, parse from graphql.language import DocumentNode @@ -114,7 +115,7 @@ def make_federated_schema( # Add the federation type definitions. if has_entities: - schema = extend_federated_schema(schema, parse(federation_entity_type_defs)) + schema = extend_schema(schema, parse(federation_entity_type_defs)) # Add _entities query. entity_type = schema.get_type("_Entity") @@ -142,20 +143,15 @@ def extend_federated_schema( assume_valid: bool = False, assume_valid_sdl: bool = False, ) -> GraphQLSchema: - extended_schema = extend_schema( + # This wrapper function is no longer needed and can be removed in the future. + # It is kept for backwards compatibility with previous versions of Ariadne + warn( + "extend_federated_schema is deprecated and will be removed in future versions of Ariadne. " + "Use graphql.extend_schema instead." + ) + return extend_schema( schema, document_ast, assume_valid, assume_valid_sdl, ) - - for k, v in schema.type_map.items(): - resolve_reference = getattr(v, "__resolve_reference__", None) - if resolve_reference and k in extended_schema.type_map: - setattr( - extended_schema.type_map[k], - "__resolve_reference__", - resolve_reference, - ) - - return extended_schema diff --git a/ariadne/contrib/federation/utils.py b/ariadne/contrib/federation/utils.py index a93acc45e..ff9b8b117 100644 --- a/ariadne/contrib/federation/utils.py +++ b/ariadne/contrib/federation/utils.py @@ -18,6 +18,7 @@ GraphQLSchema, ) +from ariadne.utils import type_get_extension _allowed_directives = [ "skip", # Default directive as per specs. @@ -91,7 +92,7 @@ def resolve_entities(_: Any, info: GraphQLResolveInfo, **kwargs) -> Any: f" was found in the schema", ) - resolve_reference = getattr( + resolve_reference = type_get_extension( type_object, "__resolve_reference__", lambda o, i, r: reference, diff --git a/ariadne/contrib/relay/objects.py b/ariadne/contrib/relay/objects.py index 8d6436783..674b21804 100644 --- a/ariadne/contrib/relay/objects.py +++ b/ariadne/contrib/relay/objects.py @@ -1,7 +1,8 @@ from base64 import b64decode from inspect import iscoroutinefunction -from typing import Optional, Tuple +from typing import Optional, Tuple, cast +from graphql import GraphQLNamedType from graphql.pyutils import is_awaitable from graphql.type import GraphQLSchema @@ -16,6 +17,8 @@ GlobalIDTuple, ) from ariadne.types import Resolver +from ariadne.utils import type_get_extension +from ariadne.utils import type_set_extension def decode_global_id(kwargs) -> GlobalIDTuple: @@ -81,7 +84,8 @@ def bind_to_schema(self, schema: GraphQLSchema) -> None: if callable(self._node_resolver): graphql_type = schema.type_map.get(self.name) - setattr( + graphql_type = cast(GraphQLNamedType, graphql_type) + type_set_extension( graphql_type, "__resolve_node__", self._node_resolver, @@ -115,10 +119,12 @@ def bindables(self) -> Tuple["RelayQueryType", "RelayNodeInterfaceType"]: def get_node_resolver(self, type_name, schema: GraphQLSchema) -> Resolver: type_object = schema.get_type(type_name) - try: - return getattr(type_object, "__resolve_node__") - except AttributeError as exc: - raise ValueError(f"No node resolver for type {type_name}") from exc + resolver: Optional[Resolver] = None + if type_object: + resolver = type_get_extension(type_object, "__resolve_node__") + if not resolver: + raise ValueError(f"No node resolver for type {type_name}") + return resolver def resolve_node(self, obj, info, *args, **kwargs): type_name, _ = self.global_id_decoder(kwargs) diff --git a/ariadne/utils.py b/ariadne/utils.py index 3d0a326ba..0cf5dfe10 100644 --- a/ariadne/utils.py +++ b/ariadne/utils.py @@ -4,6 +4,7 @@ from typing import Optional, Union, Callable, Dict, Any, cast from warnings import warn +from graphql import GraphQLNamedType from graphql.language import DocumentNode, OperationDefinitionNode, OperationType from graphql import GraphQLError, GraphQLType, parse @@ -250,3 +251,17 @@ def context_value_one_arg_deprecated(): # TODO: remove in 0.20 DeprecationWarning, stacklevel=2, ) + + +def type_set_extension( + object_type: GraphQLNamedType, extension_name: str, value: Any +) -> None: + if getattr(object_type, "extensions", None) is None: + object_type.extensions = {} + object_type.extensions[extension_name] = value + + +def type_get_extension( + object_type: GraphQLNamedType, extension_name: str, fallback: Any = None +) -> Any: + return getattr(object_type, "extensions", {}).get(extension_name, fallback) diff --git a/tests/relay/test_objects.py b/tests/relay/test_objects.py index 004c6fe42..e8c2ca546 100644 --- a/tests/relay/test_objects.py +++ b/tests/relay/test_objects.py @@ -1,5 +1,7 @@ import pytest +from graphql import extend_schema from graphql import graphql_sync +from graphql import parse from pytest_mock import MockFixture from ariadne import make_executable_schema @@ -56,6 +58,13 @@ def resolve_ship(*_): assert relay_query.get_node_resolver("Ship", schema) is resolve_ship + # extended schema re-creates the graphql object types + extended_schema = extend_schema( + schema, parse("extend type Query { fleet: [Ship] }") + ) + + assert relay_query.get_node_resolver("Ship", extended_schema) is resolve_ship + def test_query_type_node_field_resolver(): # pylint: disable=protected-access,comparison-with-callable