Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metrics collector - tool calls and requests #934

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,4 @@ logs/
credentials.json
token.json
src/interfaces/assistants_web/bun.lockb
/metrics/
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ services:
# Mount configurations
- ./src/backend/config/secrets.yaml:/workspace/src/backend/config/secrets.yaml
- ./src/backend/config/configuration.yaml:/workspace/src/backend/config/configuration.yaml
- ./metrics:/workspace/metrics
# network_mode: host
networks:
- proxynet
Expand Down
8 changes: 8 additions & 0 deletions src/backend/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,13 @@ class LoggerSettings(BaseSettings, BaseModel):
)


class MetricsSettings(BaseSettings, BaseModel):
model_config = SETTINGS_CONFIG
enabled: Optional[bool] = Field(
default=False, validation_alias=AliasChoices("METRICS_ENABLED", "enabled")
)


class Settings(BaseSettings):
"""
Settings class used to grab environment variables from configuration.yaml
Expand All @@ -467,6 +474,7 @@ class Settings(BaseSettings):
google_cloud: Optional[GoogleCloudSettings] = Field(default=GoogleCloudSettings())
deployments: Optional[DeploymentSettings] = Field(default=DeploymentSettings())
logger: Optional[LoggerSettings] = Field(default=LoggerSettings())
metrics: Optional[MetricsSettings] = Field(default=MetricsSettings())

def get(self, path: str) -> Any:
keys = path.split('.')
Expand Down
7 changes: 7 additions & 0 deletions src/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from backend.config.routers import ROUTER_DEPENDENCIES, RouterName
from backend.config.settings import Settings
from backend.exceptions import DeploymentNotFoundError
from backend.metrics import RequestMetricsMiddleware
from backend.routers.agent import router as agent_router
from backend.routers.auth import router as auth_router
from backend.routers.chat import router as chat_router
Expand All @@ -33,6 +34,7 @@
from backend.routers.user import router as user_router
from backend.services.context import ContextMiddleware, get_context
from backend.services.logger.middleware import LoggingMiddleware
from backend.services.logger.utils import LoggerFactory

# Only show errors for Pydantic
logging.getLogger('pydantic').setLevel(logging.ERROR)
Expand Down Expand Up @@ -106,6 +108,10 @@ def create_app() -> FastAPI:
allow_headers=["*"],
)
app.add_middleware(LoggingMiddleware)
if settings.get("metrics.enabled"):
logger = LoggerFactory().get_logger()
logger.info(event="Metrics enabled")
app.add_middleware(RequestMetricsMiddleware)
app.add_middleware(ContextMiddleware) # This should be the first middleware
app.add_exception_handler(SCIMException, scim_exception_handler) # pyright: ignore

Expand All @@ -114,6 +120,7 @@ def create_app() -> FastAPI:

app = create_app()


@app.exception_handler(Exception)
async def validation_exception_handler(request: Request, exc: Exception):
ctx = get_context(request)
Expand Down
13 changes: 13 additions & 0 deletions src/backend/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from backend.metrics.middleware import (
MONITORED_PATHS,
RequestMetricsMiddleware,
collector,
)
from backend.metrics.tool_call_decorator import track_tool_call_time

__all__ = [
"RequestMetricsMiddleware",
"collector",
"MONITORED_PATHS",
"track_tool_call_time",
]
64 changes: 64 additions & 0 deletions src/backend/metrics/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import csv
import os
import time
from typing import Any, Dict, List

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request

MONITORED_PATHS = ["/v1/conversations", "/v1/chat-stream"]

class RequestMetricsMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if any(request.url.path.startswith(prefix) for prefix in MONITORED_PATHS):
start_time = time.time()
response = await call_next(request)
latency = time.time() - start_time
collector.add_metric("request", request.url.path, latency)
else:
response = await call_next(request)

return response



class MetricsCollector:
def __init__(self):
self.metrics: List[Dict[str, Any]] = []

def add_metric(
self,
metric_type: str,
name: str,
latency: float,
class_name: str = "",
method_name: str = "",
method_params: Dict[str, Any] = {},
timestamp: float = time.time()
):
self.metrics.append({
"timestamp": timestamp or time.time(),
"type": metric_type,
"name": name,
"class_name": class_name,
"method_name": method_name,
"method_params": method_params,
"latency": latency
})
self.save_to_csv()

def save_to_csv(self, filename: str = "./metrics/metrics.csv"):
if not self.metrics:
return
keys = self.metrics[0].keys()
file_exists = os.path.isfile(filename)
with open(filename, mode='a' if file_exists else 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=keys)
if not file_exists:
writer.writeheader()
writer.writerows(self.metrics)
self.metrics.clear() # Clear after saving


# Singleton instance
collector = MetricsCollector()
26 changes: 26 additions & 0 deletions src/backend/metrics/tool_call_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import time
from typing import Any, Callable

from backend.metrics import collector


def track_tool_call_time() -> Callable:
"""
Decorator to track the execution time of a method and log it to a metrics collector.
Handles both instance and class methods.
"""
def decorator(func):
async def wrapper(self, *args, **kwargs) -> Any:
class_name = self.__class__.__name__
passed_method_params = kwargs.get("parameters", {}) or (args[0] if args else {})
start_time = time.time()
result = await func(self, *args, **kwargs)
end_time = time.time()
time_taken = end_time - start_time
collector.add_metric('call', 'tool_call', class_name=class_name, method_name='call', method_params=passed_method_params,
latency=time_taken)
return result

return wrapper

return decorator
14 changes: 9 additions & 5 deletions src/backend/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
from backend.crud import tool_auth as tool_auth_crud
from backend.database_models.database import DBSessionDep
from backend.database_models.tool_auth import ToolAuth
from backend.metrics import track_tool_call_time
from backend.schemas.context import Context
from backend.schemas.tool import ToolDefinition
from backend.services.logger.utils import LoggerFactory
from backend.tools.utils.tools_checkers import check_tool_parameters

logger = LoggerFactory().get_logger()


class ToolErrorCode(StrEnum):
HTTP_ERROR = "http_error"
AUTH = "auth"
Expand Down Expand Up @@ -47,10 +47,14 @@ class ParametersValidationMeta(type):
def __new__(cls, name, bases, class_dict):
for attr_name, attr_value in class_dict.items():
if callable(attr_value) and attr_name == "call":
# Decorate methods with the parameter checker
class_dict[attr_name] = check_tool_parameters(
lambda self: self.__class__.get_tool_definition()
)(attr_value)
metrics_enabled = Settings().get('metrics.enabled')
if metrics_enabled:
# Decorate methods with the metrics collector and parameter checker
class_dict[attr_name] = track_tool_call_time()(
check_tool_parameters(lambda self: self.__class__.get_tool_definition())(attr_value))
else:
# Decorate methods with the parameter checker
class_dict[attr_name] = check_tool_parameters(lambda self: self.__class__.get_tool_definition())(attr_value)
return super().__new__(cls, name, bases, class_dict)


Expand Down
Loading