diff --git a/changes/2566.feature.md b/changes/2566.feature.md new file mode 100644 index 0000000000..6308d7a782 --- /dev/null +++ b/changes/2566.feature.md @@ -0,0 +1 @@ +Add background task that reports manager DB status. diff --git a/src/ai/backend/manager/api/manager.py b/src/ai/backend/manager/api/manager.py index 027bfdd7b6..32deec1bde 100644 --- a/src/ai/backend/manager/api/manager.py +++ b/src/ai/backend/manager/api/manager.py @@ -6,7 +6,8 @@ import json import logging import socket -from typing import TYPE_CHECKING, Any, Final, FrozenSet, Iterable, Tuple +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Final, FrozenSet, Optional, Tuple, cast import aiohttp_cors import attrs @@ -24,6 +25,9 @@ from .. import __version__ from ..defs import DEFAULT_ROLE from ..models import AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, agents, kernels +from ..models.health import ( + report_manager_status, +) from . import ManagerStatus, SchedulerEvent from .auth import superadmin_required from .exceptions import ( @@ -34,7 +38,10 @@ ServiceUnavailable, ) from .types import CORSOptions, WebMiddleware -from .utils import check_api_params, set_handler_attr +from .utils import ( + check_api_params, + set_handler_attr, +) if TYPE_CHECKING: from ai.backend.manager.models.gql import GraphQueryContext @@ -101,6 +108,22 @@ async def detect_status_update(root_ctx: RootContext) -> None: pass +async def report_status_bgtask(root_ctx: RootContext) -> None: + interval = cast(Optional[float], root_ctx.local_config["manager"]["status-update-interval"]) + if interval is None: + # Do not run bgtask if interval is not set + return + try: + while True: + await asyncio.sleep(interval) + try: + await report_manager_status(root_ctx) + except Exception as e: + log.exception(f"Failed to report manager health status (e:{str(e)})") + except asyncio.CancelledError: + pass + + async def fetch_manager_status(request: web.Request) -> web.Response: root_ctx: RootContext = request.app["_root.context"] log.info("MANAGER.FETCH_MANAGER_STATUS ()") @@ -275,19 +298,28 @@ async def scheduler_healthcheck(request: web.Request) -> web.Response: @attrs.define(slots=True, auto_attribs=True, init=False) class PrivateContext: status_watch_task: asyncio.Task + db_status_report_task: asyncio.Task async def init(app: web.Application) -> None: root_ctx: RootContext = app["_root.context"] app_ctx: PrivateContext = app["manager.context"] app_ctx.status_watch_task = asyncio.create_task(detect_status_update(root_ctx)) + app_ctx.db_status_report_task = asyncio.create_task(report_status_bgtask(root_ctx)) async def shutdown(app: web.Application) -> None: app_ctx: PrivateContext = app["manager.context"] if app_ctx.status_watch_task is not None: app_ctx.status_watch_task.cancel() - await app_ctx.status_watch_task + await asyncio.sleep(0) + if not app_ctx.status_watch_task.done(): + await app_ctx.status_watch_task + if app_ctx.db_status_report_task is not None: + app_ctx.db_status_report_task.cancel() + await asyncio.sleep(0) + if not app_ctx.db_status_report_task.done(): + await app_ctx.db_status_report_task def create_app( diff --git a/src/ai/backend/manager/config.py b/src/ai/backend/manager/config.py index 2896ba1483..d580ad1ede 100644 --- a/src/ai/backend/manager/config.py +++ b/src/ai/backend/manager/config.py @@ -290,8 +290,8 @@ ], t.Key("aiomonitor-webui-port", default=49100): t.ToInt[1:65535], t.Key("use-experimental-redis-event-dispatcher", default=False): t.ToBool, - t.Key("status-update-interval", default=10.0): t.ToFloat[0:], - t.Key("status-lifetime", default=None): t.Null | t.ToFloat[0:], + t.Key("status-update-interval", default=None): t.Null | t.ToFloat[0:], # second + t.Key("status-lifetime", default=None): t.Null | t.ToInt[0:], # second t.Key("public-metrics-port", default=None): t.Null | t.ToInt[1:65535], }).allow_extra("*"), t.Key("docker-registry"): t.Dict({ # deprecated in v20.09 diff --git a/src/ai/backend/manager/models/health.py b/src/ai/backend/manager/models/health.py index ed05d154df..71662e3656 100644 --- a/src/ai/backend/manager/models/health.py +++ b/src/ai/backend/manager/models/health.py @@ -2,7 +2,7 @@ import os import socket -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Optional, cast from pydantic import ( BaseModel, @@ -11,6 +11,7 @@ from redis.asyncio import ConnectionPool from sqlalchemy.pool import Pool +from ai.backend.common import msgpack, redis_helper from ai.backend.common.types import ( RedisConnectionInfo, RedisHelperConfig, @@ -26,6 +27,7 @@ "get_sqlalchemy_connection_info", "get_redis_object_info_list", "_get_connnection_info", + "report_manager_status", ) _sqlalchemy_pool_type_names = ( @@ -42,6 +44,10 @@ MANAGER_STATUS_KEY = "manager.status" +def _get_connection_status_key(node_id: str, pid: int) -> str: + return f"{MANAGER_STATUS_KEY}.{node_id}:{pid}" + + class SQLAlchemyConnectionInfo(BaseModel): pool_type: str = Field( description=f"Connection pool type of SQLAlchemy engine. One of {_sqlalchemy_pool_type_names}.", @@ -61,11 +67,11 @@ def total_cxn(self) -> int: class RedisObjectConnectionInfo(BaseModel): name: str - num_connections: int | None = Field( + num_connections: Optional[int] = Field( description="The number of connections in Redis Client's connection pool." ) max_connections: int - err_msg: str | None = Field( + err_msg: Optional[str] = Field( description="Error message occurred when fetch connection info from Redis client objects.", default=None, ) @@ -133,3 +139,18 @@ async def _get_connnection_info(root_ctx: RootContext) -> ConnectionInfoOfProces return ConnectionInfoOfProcess( node_id=node_id, pid=pid, sqlalchemy_info=sqlalchemy_info, redis_connection_info=redis_infos ) + + +async def report_manager_status(root_ctx: RootContext) -> None: + lifetime = cast(Optional[int], root_ctx.local_config["manager"]["status-lifetime"]) + cxn_info = await _get_connnection_info(root_ctx) + _data = msgpack.packb(cxn_info.model_dump(mode="json")) + + await redis_helper.execute( + root_ctx.redis_stat, + lambda r: r.set( + _get_connection_status_key(cxn_info.node_id, cxn_info.pid), + _data, + ex=lifetime, + ), + )