Skip to content

Commit

Permalink
feat: Add filter and order parameters to group relay API (#2863)
Browse files Browse the repository at this point in the history
Backported-from: main (24.12)
Backported-to: 24.09
Backport-of: 2863
  • Loading branch information
fregataa committed Oct 17, 2024
1 parent dcfd1f9 commit b9801fd
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 8 deletions.
1 change: 1 addition & 0 deletions changes/2863.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `filter` and `order` parameters to Group GQL Relay API.
9 changes: 6 additions & 3 deletions src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from sqlalchemy.ext.asyncio import AsyncEngine as SAEngine
from sqlalchemy.ext.asyncio import AsyncSession as SASession
from sqlalchemy.orm import DeclarativeMeta, registry
from sqlalchemy.types import CHAR, SchemaType, TypeDecorator
from sqlalchemy.types import CHAR, SchemaType, TypeDecorator, Unicode, UnicodeText

from ai.backend.common import validators as tx
from ai.backend.common.auth import PublicKey
Expand Down Expand Up @@ -450,7 +450,7 @@ class URLColumn(TypeDecorator):
A column type for URL strings
"""

impl = sa.types.UnicodeText
impl = UnicodeText
cache_ok = True

def process_bind_param(self, value: Optional[yarl.URL], dialect: Dialect) -> Optional[str]:
Expand Down Expand Up @@ -621,7 +621,7 @@ class SlugType(TypeDecorator):
A type wrapper for slug type string
"""

impl = sa.types.Unicode
impl = Unicode
cache_ok = True

def __init__(
Expand All @@ -640,6 +640,9 @@ def __init__(
allow_unicode=allow_unicode,
)

def coerce_compared_value(self, op, value):
return Unicode()

def process_bind_param(self, value: str, dialect) -> str:
try:
self._tx_slug.check(value)
Expand Down
7 changes: 6 additions & 1 deletion src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,12 @@ class Queries(graphene.ObjectType):
group_node = graphene.Field(
GroupNode, id=graphene.String(required=True), description="Added in 24.03.0."
)
group_nodes = PaginatedConnectionField(GroupConnection, description="Added in 24.03.0.")
group_nodes = PaginatedConnectionField(
GroupConnection,
description="Added in 24.03.0.",
filter=graphene.String(description="Added in 24.09.0."),
order=graphene.String(description="Added in 24.09.0."),
)

group = graphene.Field(
Group,
Expand Down
36 changes: 32 additions & 4 deletions src/ai/backend/manager/models/gql_models/group.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections.abc import Mapping
from typing import (
TYPE_CHECKING,
Self,
Expand All @@ -8,6 +9,7 @@

import graphene
import sqlalchemy as sa
from dateutil.parser import parse as dtparse
from graphene.types.datetime import DateTime as GQLDateTime

from ..base import (
Expand All @@ -22,14 +24,36 @@
ConnectionResolverResult,
)
from ..group import AssocGroupUserRow, GroupRow, ProjectType
from ..minilang.ordering import QueryOrderParser
from ..minilang.queryfilter import QueryFilterParser
from ..minilang.ordering import OrderSpecItem, QueryOrderParser
from ..minilang.queryfilter import FieldSpecItem, QueryFilterParser
from .user import UserConnection, UserNode

if TYPE_CHECKING:
from ..gql import GraphQueryContext
from ..scaling_group import ScalingGroup

_queryfilter_fieldspec: Mapping[str, FieldSpecItem] = {
"id": ("id", None),
"row_id": ("id", None),
"name": ("name", None),
"is_active": ("is_active", None),
"created_at": ("created_at", dtparse),
"modified_at": ("modified_at", dtparse),
"domain_name": ("domain_name", None),
"resource_policy": ("resource_policy", None),
}

_queryorder_colmap: Mapping[str, OrderSpecItem] = {
"id": ("id", None),
"row_id": ("id", None),
"name": ("name", None),
"is_active": ("is_active", None),
"created_at": ("created_at", dtparse),
"modified_at": ("modified_at", dtparse),
"domain_name": ("domain_name", None),
"resource_policy": ("resource_policy", None),
}


class GroupInput(graphene.InputObjectType):
type = graphene.String(
Expand Down Expand Up @@ -203,10 +227,14 @@ async def get_connection(
) -> ConnectionResolverResult[Self]:
graph_ctx: GraphQueryContext = info.context
_filter_arg = (
FilterExprArg(filter_expr, QueryFilterParser()) if filter_expr is not None else None
FilterExprArg(filter_expr, QueryFilterParser(_queryfilter_fieldspec))
if filter_expr is not None
else None
)
_order_expr = (
OrderExprArg(order_expr, QueryOrderParser()) if order_expr is not None else None
OrderExprArg(order_expr, QueryOrderParser(_queryorder_colmap))
if order_expr is not None
else None
)
(
query,
Expand Down

0 comments on commit b9801fd

Please sign in to comment.