Skip to content

Commit

Permalink
Create network management CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
kyujin-cho committed Oct 8, 2024
1 parent 09c0553 commit 8904ac4
Show file tree
Hide file tree
Showing 12 changed files with 417 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/ai/backend/client/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from . import session # noqa # type: ignore
from . import session_template # noqa # type: ignore
from . import vfolder # noqa # type: ignore
from . import network # noqa # type: ignore
from . import app, logs, proxy # noqa # type: ignore

# To include the main module as an explicit dependency
Expand Down
171 changes: 171 additions & 0 deletions src/ai/backend/client/cli/network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import sys
import uuid
from typing import Any, Iterable

import click

from ai.backend.cli.main import main
from ai.backend.cli.types import ExitCode
from ai.backend.client.cli.extensions import pass_ctx_obj
from ai.backend.client.cli.types import CLIContext
from ai.backend.client.exceptions import BackendAPIError
from ai.backend.client.session import Session

from ..output.fields import network_fields
from .pretty import print_done

_default_list_fields = (
network_fields["name"],
network_fields["ref_name"],
network_fields["driver"],
network_fields["created_at"],
)


@main.group()
def network():
"""Set of inter-container network operations"""


@network.command()
@pass_ctx_obj
@click.argument("project", type=str, metavar="PROJECT_ID_OR_NAME")
@click.argument("name", type=str, metavar="NAME")
@click.option("-d", "--driver", default=None, help="Set the network driver.")
def create(ctx: CLIContext, project, name, driver):
"""Create a new network interface."""

with Session() as session:
proj_id: str | None = None

try:
uuid.UUID(project)
except ValueError:
pass
else:
if session.Group.detail(project):
proj_id = project

if not proj_id:
projects = session.Group.from_name(project)
if not projects:
ctx.output.print_fail(f"Project '{project}' not found.")
sys.exit(ExitCode.FAILURE)
proj_id = projects[0]["id"]

try:
network = session.Network.create(proj_id, name, driver=driver)
print_done(f"Network {name} (ID {network.network_id}) created.")
except Exception as e:
ctx.output.print_error(e)
sys.exit(ExitCode.FAILURE)


@network.command()
@pass_ctx_obj
@click.option(
"-f",
"--format",
default=None,
help="Display only specified fields. When specifying multiple fields separate them with comma (,).",
)
@click.option("--filter", "filter_", default=None, help="Set the query filter expression.")
@click.option("--order", default=None, help="Set the query ordering expression.")
@click.option("--offset", default=0, help="The index of the current page start for pagination.")
@click.option("--limit", type=int, default=None, help="The page size for pagination.")
def list(ctx: CLIContext, format, filter_, order, offset, limit):
"""List all available network interfaces."""

if format:
try:
fields = [network_fields[f.strip()] for f in format.split(",")]
except KeyError as e:
ctx.output.print_fail(f"Field {str(e)} not found")
sys.exit(ExitCode.FAILURE)
else:
fields = None
with Session() as session:
try:
fetch_func = lambda pg_offset, pg_size: session.Network.paginated_list(
page_offset=pg_offset,
page_size=pg_size,
filter=filter_,
order=order,
fields=fields,
)
ctx.output.print_paginated_list(
fetch_func,
initial_page_offset=offset,
page_size=limit,
)
except Exception as e:
ctx.output.print_error(e)
sys.exit(ExitCode.FAILURE)


@network.command()
@pass_ctx_obj
@click.argument("network", type=str, metavar="NETWORK_ID_OR_NAME")
@click.option(
"-f",
"--format",
default=None,
help="Display only specified fields. When specifying multiple fields separate them with comma (,).",
)
def get(ctx: CLIContext, network, format):
fields: Iterable[Any]
if format:
try:
fields = [network_fields[f.strip()] for f in format.split(",")]
except KeyError as e:
ctx.output.print_fail(f"Field {str(e)} not found")
sys.exit(ExitCode.FAILURE)
else:
fields = _default_list_fields

with Session() as session:
try:
network_info = session.Network(uuid.UUID(network)).get(fields=fields)
except (ValueError, BackendAPIError):
networks = session.Network.paginated_list(filter=f'name == "{network}"', fields=fields)
if networks.total_count == 0:
ctx.output.print_fail(f"Network {network} not found.")
sys.exit(ExitCode.FAILURE)
if networks.total_count > 1:
ctx.output.print_fail(
f"One or more networks found with name {network}. Try mentioning network ID instead of name to resolve the issue."
)
sys.exit(ExitCode.FAILURE)
network_info = networks.items[0]

ctx.output.print_item(network_info, fields)


@network.command()
@pass_ctx_obj
@click.argument("network", type=str, metavar="NETWORK_ID_OR_NAME")
def delete(ctx: CLIContext, network):
with Session() as session:
try:
network_info = session.Network(uuid.UUID(network)).get(fields=[network_fields["id"]])
except (ValueError, BackendAPIError):
networks = session.Network.paginated_list(
filter=f'name == "{network}"', fields=[network_fields["id"]]
)
if networks.total_count == 0:
ctx.output.print_fail(f"Network {network} not found.")
sys.exit(ExitCode.FAILURE)
if networks.total_count > 1:
ctx.output.print_fail(
f"One or more networks found with name {network}. Try mentioning network ID instead of name to resolve the issue."
)
sys.exit(ExitCode.FAILURE)
network_info = networks.items[0]

try:
session.Network(uuid.UUID(network_info["row_id"])).delete()
print_done(f"Network {network} has been deleted.")
except BackendAPIError as e:
ctx.output.print_fail(f"Failed to delete network {network}:")
ctx.output.print_error(e)
sys.exit(ExitCode.FAILURE)
6 changes: 6 additions & 0 deletions src/ai/backend/client/cli/session/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,12 @@
"User should be a member of the group to execute the code."
),
),
click.option(
"--network",
metavar="NETWORK_NAME_OR_ID",
default=None,
help="Network name or ID to which the session will be connected. Only networks residing at the same project can be attached to the session.",
),
]


Expand Down
23 changes: 22 additions & 1 deletion src/ai/backend/client/cli/session/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ...compat import asyncio_run
from ...exceptions import BackendAPIError
from ...func.session import ComputeSession
from ...output.fields import session_fields
from ...output.fields import network_fields, session_fields
from ...output.types import FieldSpec
from ...session import AsyncSession, Session
from .. import events
Expand Down Expand Up @@ -161,6 +161,7 @@ def create(
# resource grouping
domain: str | None, # click_start_option
group: str | None, # click_start_option
network: str | None, # click_start_option
) -> None:
"""
Prepare and start a single compute session without executing codes.
Expand Down Expand Up @@ -189,6 +190,25 @@ def create(
assigned_agent_list = assign_agent
with Session() as session:
try:
if network:
try:
network_info = session.Network(uuid.UUID(network)).get()
except (ValueError, BackendAPIError):
networks = session.Network.paginated_list(
filter=f'name == "{network}"',
fields=[network_fields["id"], network_fields["name"]],
)
if networks.total_count == 0:
print_fail(f"Network {network} not found.")
sys.exit(ExitCode.FAILURE)
if networks.total_count > 1:
print_fail(
f"One or more networks found with name {network}. Try mentioning network ID instead of name to resolve the issue."
)
sys.exit(ExitCode.FAILURE)
network_info = networks.items[0]
network_id = network_info["row_id"]

compute_session = session.ComputeSession.get_or_create(
image,
name=name,
Expand Down Expand Up @@ -220,6 +240,7 @@ def create(
architecture=architecture,
preopen_ports=preopen_ports,
assign_agent=assigned_agent_list,
attach_network=network_id,
)
except Exception as e:
print_error(e)
Expand Down
136 changes: 136 additions & 0 deletions src/ai/backend/client/func/network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import Sequence
from uuid import UUID

from ..output.fields import network_fields
from ..output.types import FieldSpec, RelayPaginatedResult
from ..pagination import execute_paginated_relay_query
from ..session import api_session
from .base import BaseFunction, api_function

__all__ = ("Network",)

_default_list_fields = (
network_fields["name"],
network_fields["ref_name"],
network_fields["driver"],
network_fields["created_at"],
)


class Network(BaseFunction):
@api_function
@classmethod
async def paginated_list(
cls,
*,
fields: Sequence[FieldSpec] | None = None,
page_offset: int = 0,
page_size: int = 20,
filter: str | None = None,
order: str | None = None,
) -> RelayPaginatedResult[dict]:
"""
Fetches the list of created networks in this cluster.
"""
return await execute_paginated_relay_query(
"networks",
{
"filter": (filter, "String"),
"order": (order, "String"),
},
fields or _default_list_fields,
limit=page_size,
offset=page_offset,
)

@api_function
@classmethod
async def create(
cls,
project_id: str,
name: str,
*,
driver: str | None = None,
) -> "Network":
"""
Creates a new network.
:param project_id: The ID of the project to which the network belongs.
:param name: The name of the network.
:param driver: (Optional) The driver of the network. If not specified, the default driver will be used.
:return: The created network.
"""
q = (
"mutation($name: String!, $project_id: UUID!, $driver: String) {"
" create_network(name: $name, project_id: $project_id, driver: $driver) {"
" network {"
" row_id"
" }"
" }"
"}"
)
data = await api_session.get().Admin._query(
q,
{
"name": name,
"project_id": project_id,
"driver": driver,
},
)
return cls(network_id=UUID(data["create_network"]["network"]["row_id"]))

def __init__(self, network_id: UUID) -> None:
"""
:param network_id: The ID of the network. Pass `row_id` value (not `id`) of the network info fetched by `paginated_list`.
"""
super().__init__()
self.network_id = network_id

@api_function
async def get(
self,
fields: Sequence[FieldSpec] | None = None,
) -> dict:
"""
Fetches the information of the network.
"""
q = "query($id: String!) {" " network(id: $id) {" " $fields" " }" "}"
q = q.replace("$fields", " ".join(f.field_ref for f in (fields or _default_list_fields)))
data = await api_session.get().Admin._query(q, {"id": str(self.network_id)})
return data["images"]

@api_function
async def update(self, name: str) -> None:
"""
Updates network.
"""
q = (
"mutation($network: String!, $props: UpdateNetworkInput!) {"
" modify_network(network: $network, props: $props) {"
" ok msg"
" }"
"}"
)
variables = {
"network": str(self.network_id),
"props": {"name": name},
}
data = await api_session.get().Admin._query(q, variables)
return data["modify_network"]

@api_function
async def delete(self) -> None:
"""
Deletes network. Delete only works for networks that are not attached to active session.
"""
q = (
"mutation($network: String!) {"
" delete_network(network: $network) {"
" ok msg"
" }"
"}"
)
variables = {
"network": str(self.network_id),
}
data = await api_session.get().Admin._query(q, variables)
return data["delete_network"]
Loading

0 comments on commit 8904ac4

Please sign in to comment.