Skip to content

Commit

Permalink
use graphql core extensions field to store contrib resolvers (#1228)
Browse files Browse the repository at this point in the history
  • Loading branch information
reallistic authored Feb 7, 2025
1 parent 000ca3d commit 748c47c
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 26 deletions.
15 changes: 10 additions & 5 deletions ariadne/contrib/federation/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import cast
from typing import Optional

from graphql import GraphQLNamedType
from graphql.type import GraphQLSchema

from ...interfaces import InterfaceType
from ...types import Resolver
from ...utils import type_get_extension
from ...utils import type_implements_interface
from ...utils import type_set_extension


class FederatedInterfaceType(InterfaceType):
Expand All @@ -26,17 +30,18 @@ def bind_to_schema(self, schema: GraphQLSchema) -> None:

if callable(self._reference_resolver):
graphql_type = schema.type_map.get(self.name)
setattr(
graphql_type = cast(GraphQLNamedType, graphql_type)
type_set_extension(
graphql_type,
"__resolve_reference__",
self._reference_resolver,
)

for object_type in schema.type_map.values():
if type_implements_interface(self.name, object_type) and not hasattr(
object_type, "__resolve_reference__"
):
setattr(
if type_implements_interface(
self.name, object_type
) and not type_get_extension(object_type, "__resolve_reference__"):
type_set_extension(
object_type,
"__resolve_reference__",
self._reference_resolver,
Expand Down
6 changes: 5 additions & 1 deletion ariadne/contrib/federation/objects.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import cast
from typing import Optional

from graphql import GraphQLNamedType
from graphql.type import GraphQLSchema

from ...objects import ObjectType
from ...types import Resolver
from ...utils import type_set_extension


class FederatedObjectType(ObjectType):
Expand All @@ -25,7 +28,8 @@ def bind_to_schema(self, schema: GraphQLSchema) -> None:

if callable(self._reference_resolver):
graphql_type = schema.type_map.get(self.name)
setattr(
graphql_type = cast(GraphQLNamedType, graphql_type)
type_set_extension(
graphql_type,
"__resolve_reference__",
self._reference_resolver,
Expand Down
22 changes: 9 additions & 13 deletions ariadne/contrib/federation/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
import os
from typing import Dict, List, Optional, Type, Union, cast
from warnings import warn

from graphql import extend_schema, parse
from graphql.language import DocumentNode
Expand Down Expand Up @@ -114,7 +115,7 @@ def make_federated_schema(

# Add the federation type definitions.
if has_entities:
schema = extend_federated_schema(schema, parse(federation_entity_type_defs))
schema = extend_schema(schema, parse(federation_entity_type_defs))

# Add _entities query.
entity_type = schema.get_type("_Entity")
Expand Down Expand Up @@ -142,20 +143,15 @@ def extend_federated_schema(
assume_valid: bool = False,
assume_valid_sdl: bool = False,
) -> GraphQLSchema:
extended_schema = extend_schema(
# This wrapper function is no longer needed and can be removed in the future.
# It is kept for backwards compatibility with previous versions of Ariadne
warn(
"extend_federated_schema is deprecated and will be removed in future versions of Ariadne. "
"Use graphql.extend_schema instead."
)
return extend_schema(
schema,
document_ast,
assume_valid,
assume_valid_sdl,
)

for k, v in schema.type_map.items():
resolve_reference = getattr(v, "__resolve_reference__", None)
if resolve_reference and k in extended_schema.type_map:
setattr(
extended_schema.type_map[k],
"__resolve_reference__",
resolve_reference,
)

return extended_schema
3 changes: 2 additions & 1 deletion ariadne/contrib/federation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
GraphQLSchema,
)

from ariadne.utils import type_get_extension

_allowed_directives = [
"skip", # Default directive as per specs.
Expand Down Expand Up @@ -91,7 +92,7 @@ def resolve_entities(_: Any, info: GraphQLResolveInfo, **kwargs) -> Any:
f" was found in the schema",
)

resolve_reference = getattr(
resolve_reference = type_get_extension(
type_object,
"__resolve_reference__",
lambda o, i, r: reference,
Expand Down
18 changes: 12 additions & 6 deletions ariadne/contrib/relay/objects.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from base64 import b64decode
from inspect import iscoroutinefunction
from typing import Optional, Tuple
from typing import Optional, Tuple, cast

from graphql import GraphQLNamedType
from graphql.pyutils import is_awaitable
from graphql.type import GraphQLSchema

Expand All @@ -16,6 +17,8 @@
GlobalIDTuple,
)
from ariadne.types import Resolver
from ariadne.utils import type_get_extension
from ariadne.utils import type_set_extension


def decode_global_id(kwargs) -> GlobalIDTuple:
Expand Down Expand Up @@ -81,7 +84,8 @@ def bind_to_schema(self, schema: GraphQLSchema) -> None:

if callable(self._node_resolver):
graphql_type = schema.type_map.get(self.name)
setattr(
graphql_type = cast(GraphQLNamedType, graphql_type)
type_set_extension(
graphql_type,
"__resolve_node__",
self._node_resolver,
Expand Down Expand Up @@ -115,10 +119,12 @@ def bindables(self) -> Tuple["RelayQueryType", "RelayNodeInterfaceType"]:

def get_node_resolver(self, type_name, schema: GraphQLSchema) -> Resolver:
type_object = schema.get_type(type_name)
try:
return getattr(type_object, "__resolve_node__")
except AttributeError as exc:
raise ValueError(f"No node resolver for type {type_name}") from exc
resolver: Optional[Resolver] = None
if type_object:
resolver = type_get_extension(type_object, "__resolve_node__")
if not resolver:
raise ValueError(f"No node resolver for type {type_name}")
return resolver

def resolve_node(self, obj, info, *args, **kwargs):
type_name, _ = self.global_id_decoder(kwargs)
Expand Down
15 changes: 15 additions & 0 deletions ariadne/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional, Union, Callable, Dict, Any, cast
from warnings import warn

from graphql import GraphQLNamedType
from graphql.language import DocumentNode, OperationDefinitionNode, OperationType
from graphql import GraphQLError, GraphQLType, parse

Expand Down Expand Up @@ -250,3 +251,17 @@ def context_value_one_arg_deprecated(): # TODO: remove in 0.20
DeprecationWarning,
stacklevel=2,
)


def type_set_extension(
object_type: GraphQLNamedType, extension_name: str, value: Any
) -> None:
if getattr(object_type, "extensions", None) is None:
object_type.extensions = {}
object_type.extensions[extension_name] = value


def type_get_extension(
object_type: GraphQLNamedType, extension_name: str, fallback: Any = None
) -> Any:
return getattr(object_type, "extensions", {}).get(extension_name, fallback)
9 changes: 9 additions & 0 deletions tests/relay/test_objects.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
from graphql import extend_schema
from graphql import graphql_sync
from graphql import parse
from pytest_mock import MockFixture

from ariadne import make_executable_schema
Expand Down Expand Up @@ -56,6 +58,13 @@ def resolve_ship(*_):

assert relay_query.get_node_resolver("Ship", schema) is resolve_ship

# extended schema re-creates the graphql object types
extended_schema = extend_schema(
schema, parse("extend type Query { fleet: [Ship] }")
)

assert relay_query.get_node_resolver("Ship", extended_schema) is resolve_ship


def test_query_type_node_field_resolver():
# pylint: disable=protected-access,comparison-with-callable
Expand Down

0 comments on commit 748c47c

Please sign in to comment.