Skip to content

Commit

Permalink
feat: Add manager status report bgtask (#2566)
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa authored Oct 17, 2024
1 parent 2db3331 commit eb498be
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 8 deletions.
1 change: 1 addition & 0 deletions changes/2566.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add background task that reports manager DB status.
38 changes: 35 additions & 3 deletions src/ai/backend/manager/api/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import json
import logging
import socket
from typing import TYPE_CHECKING, Any, Final, FrozenSet, Iterable, Tuple
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Final, FrozenSet, Optional, Tuple, cast

import aiohttp_cors
import attrs
Expand All @@ -24,6 +25,9 @@
from .. import __version__
from ..defs import DEFAULT_ROLE
from ..models import AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, agents, kernels
from ..models.health import (
report_manager_status,
)
from . import ManagerStatus, SchedulerEvent
from .auth import superadmin_required
from .exceptions import (
Expand All @@ -34,7 +38,10 @@
ServiceUnavailable,
)
from .types import CORSOptions, WebMiddleware
from .utils import check_api_params, set_handler_attr
from .utils import (
check_api_params,
set_handler_attr,
)

if TYPE_CHECKING:
from ai.backend.manager.models.gql import GraphQueryContext
Expand Down Expand Up @@ -101,6 +108,22 @@ async def detect_status_update(root_ctx: RootContext) -> None:
pass


async def report_status_bgtask(root_ctx: RootContext) -> None:
interval = cast(Optional[float], root_ctx.local_config["manager"]["status-update-interval"])
if interval is None:
# Do not run bgtask if interval is not set
return
try:
while True:
await asyncio.sleep(interval)
try:
await report_manager_status(root_ctx)
except Exception as e:
log.exception(f"Failed to report manager health status (e:{str(e)})")
except asyncio.CancelledError:
pass


async def fetch_manager_status(request: web.Request) -> web.Response:
root_ctx: RootContext = request.app["_root.context"]
log.info("MANAGER.FETCH_MANAGER_STATUS ()")
Expand Down Expand Up @@ -275,19 +298,28 @@ async def scheduler_healthcheck(request: web.Request) -> web.Response:
@attrs.define(slots=True, auto_attribs=True, init=False)
class PrivateContext:
status_watch_task: asyncio.Task
db_status_report_task: asyncio.Task


async def init(app: web.Application) -> None:
root_ctx: RootContext = app["_root.context"]
app_ctx: PrivateContext = app["manager.context"]
app_ctx.status_watch_task = asyncio.create_task(detect_status_update(root_ctx))
app_ctx.db_status_report_task = asyncio.create_task(report_status_bgtask(root_ctx))


async def shutdown(app: web.Application) -> None:
app_ctx: PrivateContext = app["manager.context"]
if app_ctx.status_watch_task is not None:
app_ctx.status_watch_task.cancel()
await app_ctx.status_watch_task
await asyncio.sleep(0)
if not app_ctx.status_watch_task.done():
await app_ctx.status_watch_task
if app_ctx.db_status_report_task is not None:
app_ctx.db_status_report_task.cancel()
await asyncio.sleep(0)
if not app_ctx.db_status_report_task.done():
await app_ctx.db_status_report_task


def create_app(
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@
],
t.Key("aiomonitor-webui-port", default=49100): t.ToInt[1:65535],
t.Key("use-experimental-redis-event-dispatcher", default=False): t.ToBool,
t.Key("status-update-interval", default=10.0): t.ToFloat[0:],
t.Key("status-lifetime", default=None): t.Null | t.ToFloat[0:],
t.Key("status-update-interval", default=None): t.Null | t.ToFloat[0:], # second
t.Key("status-lifetime", default=None): t.Null | t.ToInt[0:], # second
t.Key("public-metrics-port", default=None): t.Null | t.ToInt[1:65535],
}).allow_extra("*"),
t.Key("docker-registry"): t.Dict({ # deprecated in v20.09
Expand Down
27 changes: 24 additions & 3 deletions src/ai/backend/manager/models/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import socket
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Optional, cast

from pydantic import (
BaseModel,
Expand All @@ -11,6 +11,7 @@
from redis.asyncio import ConnectionPool
from sqlalchemy.pool import Pool

from ai.backend.common import msgpack, redis_helper
from ai.backend.common.types import (
RedisConnectionInfo,
RedisHelperConfig,
Expand All @@ -26,6 +27,7 @@
"get_sqlalchemy_connection_info",
"get_redis_object_info_list",
"_get_connnection_info",
"report_manager_status",
)

_sqlalchemy_pool_type_names = (
Expand All @@ -42,6 +44,10 @@
MANAGER_STATUS_KEY = "manager.status"


def _get_connection_status_key(node_id: str, pid: int) -> str:
return f"{MANAGER_STATUS_KEY}.{node_id}:{pid}"


class SQLAlchemyConnectionInfo(BaseModel):
pool_type: str = Field(
description=f"Connection pool type of SQLAlchemy engine. One of {_sqlalchemy_pool_type_names}.",
Expand All @@ -61,11 +67,11 @@ def total_cxn(self) -> int:

class RedisObjectConnectionInfo(BaseModel):
name: str
num_connections: int | None = Field(
num_connections: Optional[int] = Field(
description="The number of connections in Redis Client's connection pool."
)
max_connections: int
err_msg: str | None = Field(
err_msg: Optional[str] = Field(
description="Error message occurred when fetch connection info from Redis client objects.",
default=None,
)
Expand Down Expand Up @@ -133,3 +139,18 @@ async def _get_connnection_info(root_ctx: RootContext) -> ConnectionInfoOfProces
return ConnectionInfoOfProcess(
node_id=node_id, pid=pid, sqlalchemy_info=sqlalchemy_info, redis_connection_info=redis_infos
)


async def report_manager_status(root_ctx: RootContext) -> None:
lifetime = cast(Optional[int], root_ctx.local_config["manager"]["status-lifetime"])
cxn_info = await _get_connnection_info(root_ctx)
_data = msgpack.packb(cxn_info.model_dump(mode="json"))

await redis_helper.execute(
root_ctx.redis_stat,
lambda r: r.set(
_get_connection_status_key(cxn_info.node_id, cxn_info.pid),
_data,
ex=lifetime,
),
)

0 comments on commit eb498be

Please sign in to comment.