diff --git a/src/ai/backend/client/cli/__init__.py b/src/ai/backend/client/cli/__init__.py index 0157fb41a4..f6453d6ea1 100644 --- a/src/ai/backend/client/cli/__init__.py +++ b/src/ai/backend/client/cli/__init__.py @@ -9,6 +9,7 @@ from . import session # noqa # type: ignore from . import session_template # noqa # type: ignore from . import vfolder # noqa # type: ignore +from . import network # noqa # type: ignore from . import app, logs, proxy # noqa # type: ignore # To include the main module as an explicit dependency diff --git a/src/ai/backend/client/cli/network.py b/src/ai/backend/client/cli/network.py new file mode 100644 index 0000000000..319c80a78c --- /dev/null +++ b/src/ai/backend/client/cli/network.py @@ -0,0 +1,171 @@ +import sys +import uuid +from typing import Any, Iterable + +import click + +from ai.backend.cli.main import main +from ai.backend.cli.types import ExitCode +from ai.backend.client.cli.extensions import pass_ctx_obj +from ai.backend.client.cli.types import CLIContext +from ai.backend.client.exceptions import BackendAPIError +from ai.backend.client.session import Session + +from ..output.fields import network_fields +from .pretty import print_done + +_default_list_fields = ( + network_fields["name"], + network_fields["ref_name"], + network_fields["driver"], + network_fields["created_at"], +) + + +@main.group() +def network(): + """Set of inter-container network operations""" + + +@network.command() +@pass_ctx_obj +@click.argument("project", type=str, metavar="PROJECT_ID_OR_NAME") +@click.argument("name", type=str, metavar="NAME") +@click.option("-d", "--driver", default=None, help="Set the network driver.") +def create(ctx: CLIContext, project, name, driver): + """Create a new network interface.""" + + with Session() as session: + proj_id: str | None = None + + try: + uuid.UUID(project) + except ValueError: + pass + else: + if session.Group.detail(project): + proj_id = project + + if not proj_id: + projects = session.Group.from_name(project) + if not projects: + ctx.output.print_fail(f"Project '{project}' not found.") + sys.exit(ExitCode.FAILURE) + proj_id = projects[0]["id"] + + try: + network = session.Network.create(proj_id, name, driver=driver) + print_done(f"Network {name} (ID {network.network_id}) created.") + except Exception as e: + ctx.output.print_error(e) + sys.exit(ExitCode.FAILURE) + + +@network.command() +@pass_ctx_obj +@click.option( + "-f", + "--format", + default=None, + help="Display only specified fields. When specifying multiple fields separate them with comma (,).", +) +@click.option("--filter", "filter_", default=None, help="Set the query filter expression.") +@click.option("--order", default=None, help="Set the query ordering expression.") +@click.option("--offset", default=0, help="The index of the current page start for pagination.") +@click.option("--limit", type=int, default=None, help="The page size for pagination.") +def list(ctx: CLIContext, format, filter_, order, offset, limit): + """List all available network interfaces.""" + + if format: + try: + fields = [network_fields[f.strip()] for f in format.split(",")] + except KeyError as e: + ctx.output.print_fail(f"Field {str(e)} not found") + sys.exit(ExitCode.FAILURE) + else: + fields = None + with Session() as session: + try: + fetch_func = lambda pg_offset, pg_size: session.Network.paginated_list( + page_offset=pg_offset, + page_size=pg_size, + filter=filter_, + order=order, + fields=fields, + ) + ctx.output.print_paginated_list( + fetch_func, + initial_page_offset=offset, + page_size=limit, + ) + except Exception as e: + ctx.output.print_error(e) + sys.exit(ExitCode.FAILURE) + + +@network.command() +@pass_ctx_obj +@click.argument("network", type=str, metavar="NETWORK_ID_OR_NAME") +@click.option( + "-f", + "--format", + default=None, + help="Display only specified fields. When specifying multiple fields separate them with comma (,).", +) +def get(ctx: CLIContext, network, format): + fields: Iterable[Any] + if format: + try: + fields = [network_fields[f.strip()] for f in format.split(",")] + except KeyError as e: + ctx.output.print_fail(f"Field {str(e)} not found") + sys.exit(ExitCode.FAILURE) + else: + fields = _default_list_fields + + with Session() as session: + try: + network_info = session.Network(uuid.UUID(network)).get(fields=fields) + except (ValueError, BackendAPIError): + networks = session.Network.paginated_list(filter=f'name == "{network}"', fields=fields) + if networks.total_count == 0: + ctx.output.print_fail(f"Network {network} not found.") + sys.exit(ExitCode.FAILURE) + if networks.total_count > 1: + ctx.output.print_fail( + f"One or more networks found with name {network}. Try mentioning network ID instead of name to resolve the issue." + ) + sys.exit(ExitCode.FAILURE) + network_info = networks.items[0] + + ctx.output.print_item(network_info, fields) + + +@network.command() +@pass_ctx_obj +@click.argument("network", type=str, metavar="NETWORK_ID_OR_NAME") +def delete(ctx: CLIContext, network): + with Session() as session: + try: + network_info = session.Network(uuid.UUID(network)).get(fields=[network_fields["id"]]) + except (ValueError, BackendAPIError): + networks = session.Network.paginated_list( + filter=f'name == "{network}"', fields=[network_fields["id"]] + ) + if networks.total_count == 0: + ctx.output.print_fail(f"Network {network} not found.") + sys.exit(ExitCode.FAILURE) + if networks.total_count > 1: + ctx.output.print_fail( + f"One or more networks found with name {network}. Try mentioning network ID instead of name to resolve the issue." + ) + sys.exit(ExitCode.FAILURE) + network_info = networks.items[0] + + try: + session.Network(uuid.UUID(network_info["row_id"])).delete() + print_done(f"Network {network} has been deleted.") + except BackendAPIError as e: + ctx.output.print_fail(f"Failed to delete network {network}:") + ctx.output.print_error(e) + sys.exit(ExitCode.FAILURE) diff --git a/src/ai/backend/client/cli/session/args.py b/src/ai/backend/client/cli/session/args.py index 250449f0cb..2f646a5ccd 100644 --- a/src/ai/backend/client/cli/session/args.py +++ b/src/ai/backend/client/cli/session/args.py @@ -154,6 +154,12 @@ "User should be a member of the group to execute the code." ), ), + click.option( + "--network", + metavar="NETWORK_NAME_OR_ID", + default=None, + help="Network name or ID to which the session will be connected. Only networks residing at the same project can be attached to the session.", + ), ] diff --git a/src/ai/backend/client/cli/session/lifecycle.py b/src/ai/backend/client/cli/session/lifecycle.py index 7c7b2e3236..2aaf8937ef 100644 --- a/src/ai/backend/client/cli/session/lifecycle.py +++ b/src/ai/backend/client/cli/session/lifecycle.py @@ -31,7 +31,7 @@ from ...compat import asyncio_run from ...exceptions import BackendAPIError from ...func.session import ComputeSession -from ...output.fields import session_fields +from ...output.fields import network_fields, session_fields from ...output.types import FieldSpec from ...session import AsyncSession, Session from .. import events @@ -161,6 +161,7 @@ def create( # resource grouping domain: str | None, # click_start_option group: str | None, # click_start_option + network: str | None, # click_start_option ) -> None: """ Prepare and start a single compute session without executing codes. @@ -189,6 +190,25 @@ def create( assigned_agent_list = assign_agent with Session() as session: try: + if network: + try: + network_info = session.Network(uuid.UUID(network)).get() + except (ValueError, BackendAPIError): + networks = session.Network.paginated_list( + filter=f'name == "{network}"', + fields=[network_fields["id"], network_fields["name"]], + ) + if networks.total_count == 0: + print_fail(f"Network {network} not found.") + sys.exit(ExitCode.FAILURE) + if networks.total_count > 1: + print_fail( + f"One or more networks found with name {network}. Try mentioning network ID instead of name to resolve the issue." + ) + sys.exit(ExitCode.FAILURE) + network_info = networks.items[0] + network_id = network_info["row_id"] + compute_session = session.ComputeSession.get_or_create( image, name=name, @@ -220,6 +240,7 @@ def create( architecture=architecture, preopen_ports=preopen_ports, assign_agent=assigned_agent_list, + attach_network=network_id, ) except Exception as e: print_error(e) diff --git a/src/ai/backend/client/func/network.py b/src/ai/backend/client/func/network.py new file mode 100644 index 0000000000..b28b9f04a1 --- /dev/null +++ b/src/ai/backend/client/func/network.py @@ -0,0 +1,136 @@ +from typing import Sequence +from uuid import UUID + +from ..output.fields import network_fields +from ..output.types import FieldSpec, RelayPaginatedResult +from ..pagination import execute_paginated_relay_query +from ..session import api_session +from .base import BaseFunction, api_function + +__all__ = ("Network",) + +_default_list_fields = ( + network_fields["name"], + network_fields["ref_name"], + network_fields["driver"], + network_fields["created_at"], +) + + +class Network(BaseFunction): + @api_function + @classmethod + async def paginated_list( + cls, + *, + fields: Sequence[FieldSpec] | None = None, + page_offset: int = 0, + page_size: int = 20, + filter: str | None = None, + order: str | None = None, + ) -> RelayPaginatedResult[dict]: + """ + Fetches the list of created networks in this cluster. + """ + return await execute_paginated_relay_query( + "networks", + { + "filter": (filter, "String"), + "order": (order, "String"), + }, + fields or _default_list_fields, + limit=page_size, + offset=page_offset, + ) + + @api_function + @classmethod + async def create( + cls, + project_id: str, + name: str, + *, + driver: str | None = None, + ) -> "Network": + """ + Creates a new network. + :param project_id: The ID of the project to which the network belongs. + :param name: The name of the network. + :param driver: (Optional) The driver of the network. If not specified, the default driver will be used. + :return: The created network. + """ + q = ( + "mutation($name: String!, $project_id: UUID!, $driver: String) {" + " create_network(name: $name, project_id: $project_id, driver: $driver) {" + " network {" + " row_id" + " }" + " }" + "}" + ) + data = await api_session.get().Admin._query( + q, + { + "name": name, + "project_id": project_id, + "driver": driver, + }, + ) + return cls(network_id=UUID(data["create_network"]["network"]["row_id"])) + + def __init__(self, network_id: UUID) -> None: + """ + :param network_id: The ID of the network. Pass `row_id` value (not `id`) of the network info fetched by `paginated_list`. + """ + super().__init__() + self.network_id = network_id + + @api_function + async def get( + self, + fields: Sequence[FieldSpec] | None = None, + ) -> dict: + """ + Fetches the information of the network. + """ + q = "query($id: String!) {" " network(id: $id) {" " $fields" " }" "}" + q = q.replace("$fields", " ".join(f.field_ref for f in (fields or _default_list_fields))) + data = await api_session.get().Admin._query(q, {"id": str(self.network_id)}) + return data["images"] + + @api_function + async def update(self, name: str) -> None: + """ + Updates network. + """ + q = ( + "mutation($network: String!, $props: UpdateNetworkInput!) {" + " modify_network(network: $network, props: $props) {" + " ok msg" + " }" + "}" + ) + variables = { + "network": str(self.network_id), + "props": {"name": name}, + } + data = await api_session.get().Admin._query(q, variables) + return data["modify_network"] + + @api_function + async def delete(self) -> None: + """ + Deletes network. Delete only works for networks that are not attached to active session. + """ + q = ( + "mutation($network: String!) {" + " delete_network(network: $network) {" + " ok msg" + " }" + "}" + ) + variables = { + "network": str(self.network_id), + } + data = await api_session.get().Admin._query(q, variables) + return data["delete_network"] diff --git a/src/ai/backend/client/func/session.py b/src/ai/backend/client/func/session.py index 8f838f9f4b..9f09c8b034 100644 --- a/src/ai/backend/client/func/session.py +++ b/src/ai/backend/client/func/session.py @@ -196,6 +196,7 @@ async def get_or_create( owner_access_key: Optional[str] = None, preopen_ports: Optional[list[int]] = None, assign_agent: Optional[list[str]] = None, + attach_network: Optional[str] = None, ) -> ComputeSession: """ Get-or-creates a compute session. @@ -258,7 +259,9 @@ async def get_or_create( :param tag: An optional string to annotate extra information. :param owner: An optional access key that owns the created session. (Only available to administrators) + :param attach_network: An optional string to select which network to attach to session. Must supply network ID (not name). + .. versionadded:: 24.09.0 :returns: The :class:`ComputeSession` instance. """ if name is not None: @@ -299,6 +302,7 @@ async def get_or_create( if api_session.get().api_version >= (8, "20240915"): if priority is not None: params["priority"] = priority + params["config"]["attach_network"] = attach_network if api_session.get().api_version >= (6, "20220315"): params["dependencies"] = dependencies params["callback_url"] = callback_url diff --git a/src/ai/backend/client/output/fields.py b/src/ai/backend/client/output/fields.py index 4d4eeac034..c22bfd691d 100644 --- a/src/ai/backend/client/output/fields.py +++ b/src/ai/backend/client/output/fields.py @@ -343,3 +343,16 @@ FieldSpec("quota_scope_id"), FieldSpec("storage_host_name"), ]) + + +network_fields = FieldSet([ + FieldSpec(field_ref="row_id", field_name="id", alt_name="id"), + FieldSpec("name"), + FieldSpec("ref_name"), + FieldSpec("driver"), + FieldSpec("domain_name"), + FieldSpec("project"), + FieldSpec("options"), + FieldSpec("created_at"), + FieldSpec("updated_at", "Last Updated"), +]) diff --git a/src/ai/backend/client/output/types.py b/src/ai/backend/client/output/types.py index ba93bb2184..3feb50a04c 100644 --- a/src/ai/backend/client/output/types.py +++ b/src/ai/backend/client/output/types.py @@ -121,6 +121,14 @@ class PaginatedResult(Generic[T]): fields: Sequence[FieldSpec] +@attr.define(slots=True) +class RelayPaginatedResult(Generic[T]): + total_count: int + items: Sequence[T] + fields: Sequence[FieldSpec] + next_cursor: str | None + + class BaseOutputHandler(metaclass=ABCMeta): def __init__(self, cli_context: CLIContext) -> None: self.ctx = cli_context diff --git a/src/ai/backend/client/pagination.py b/src/ai/backend/client/pagination.py index a807176235..926041c2e1 100644 --- a/src/ai/backend/client/pagination.py +++ b/src/ai/backend/client/pagination.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Final, Sequence, Tuple, TypeVar from .exceptions import BackendAPIVersionError -from .output.types import FieldSpec, PaginatedResult +from .output.types import FieldSpec, PaginatedResult, RelayPaginatedResult from .session import api_session MAX_PAGE_SIZE: Final = 100 @@ -55,6 +55,55 @@ async def execute_paginated_query( ) +async def execute_paginated_relay_query( + root_field: str, + variables: Dict[str, Tuple[Any, str]], + fields: Sequence[FieldSpec], + *, + limit: int | None = None, + offset: int | None = None, + after: str | None = None, + before: str | None = None, +) -> RelayPaginatedResult: + if limit and limit > MAX_PAGE_SIZE: + raise ValueError(f"The page size cannot exceed {MAX_PAGE_SIZE}") + if limit and limit < MIN_PAGE_SIZE: + raise ValueError(f"The page size cannot be less than {MIN_PAGE_SIZE}") + query = """ + query($limit:Int, $after:String, $offset:Int, $before:String, $var_decls) { + $root_field( + first:$limit, offset:$offset, after:$after, before:$before $var_args) { + edges { node { $fields } cursor } + count + } + }""" + query = query.replace("$root_field", root_field) + query = query.replace("$fields", " ".join(f.field_ref for f in fields)) + query = query.replace( + "$var_decls", + ", ".join(f"${key}: {value[1]}" for key, value in variables.items()), + ) + query = query.replace( + "$var_args", + ", ".join(f"{key}:${key}" for key in variables.keys()), + ) + query = textwrap.dedent(query).strip() + var_values = {key: value[0] for key, value in variables.items()} + var_values["limit"] = limit + var_values["offset"] = offset + var_values["after"] = after + var_values["before"] = before + data = await api_session.get().Admin._query(query, var_values) + return RelayPaginatedResult( + total_count=data[root_field]["count"], + items=[x["node"] for x in data[root_field]["edges"]], + fields=fields, + next_cursor=data[root_field]["edges"][0]["cursor"] + if len(data[root_field]["edges"]) > 0 + else None, + ) + + async def fetch_paginated_result( root_field: str, variables: Dict[str, Tuple[Any, str]], diff --git a/src/ai/backend/client/session.py b/src/ai/backend/client/session.py index 312b626fbd..7723a62650 100644 --- a/src/ai/backend/client/session.py +++ b/src/ai/backend/client/session.py @@ -272,6 +272,7 @@ class BaseSession(metaclass=abc.ABCMeta): "Service", "Model", "QuotaScope", + "Network", ) aiohttp_session: aiohttp.ClientSession @@ -306,6 +307,7 @@ def __init__( from .func.keypair_resource_policy import KeypairResourcePolicy from .func.manager import Manager from .func.model import Model + from .func.network import Network from .func.quota_scope import QuotaScope from .func.resource import Resource from .func.scaling_group import ScalingGroup @@ -344,6 +346,7 @@ def __init__( self.Service = Service self.Model = Model self.QuotaScope = QuotaScope + self.Network = Network @property def proxy_mode(self) -> bool: diff --git a/src/ai/backend/manager/config.py b/src/ai/backend/manager/config.py index bf75301783..f34dae1dd7 100644 --- a/src/ai/backend/manager/config.py +++ b/src/ai/backend/manager/config.py @@ -411,7 +411,8 @@ t.Key("network", default=_config_defaults["network"]): t.Dict({ t.Key("inter-container", default=_config_defaults["network"]["inter-container"]): t.Dict({ t.Key( - "default-driver", default=_config_defaults["network"]["inter-container"]["default-driver"] + "default-driver", + default=_config_defaults["network"]["inter-container"]["default-driver"], ): t.Null | t.String, }).allow_extra("*"), t.Key("subnet", default=_config_defaults["network"]["subnet"]): t.Dict({ diff --git a/src/ai/backend/manager/models/network.py b/src/ai/backend/manager/models/network.py index e4f177e4c5..5e44713b3f 100644 --- a/src/ai/backend/manager/models/network.py +++ b/src/ai/backend/manager/models/network.py @@ -332,9 +332,7 @@ async def mutate( network_info = await network_plugin.create_network() network_name = network_info.network_id except Exception: - log.exception( - f"Failed to create the inter-container network (plugin: {_driver})" - ) + log.exception(f"Failed to create the inter-container network (plugin: {_driver})") raise async def _do_mutate() -> CreateNetwork: