diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 2e5bc4eb..37f3a619 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -1,8 +1,10 @@ from fastapi import APIRouter, Response from fastapi.exceptions import HTTPException from fastapi.routing import APIRoute +from pydantic import ValidationError from codegate.api import v1_models +from codegate.db.connection import AlreadyExistsError from codegate.workspaces.crud import WorkspaceCrud v1 = APIRouter() @@ -52,13 +54,17 @@ async def activate_workspace(request: v1_models.ActivateWorkspaceRequest, status async def create_workspace(request: v1_models.CreateWorkspaceRequest): """Create a new workspace.""" # Input validation is done in the model - created = await wscrud.add_workspace(request.name) - - # TODO: refactor to use a more specific exception - if not created: - raise HTTPException(status_code=400, detail="Failed to create workspace") - - return v1_models.Workspace(name=request.name) + try: + created = await wscrud.add_workspace(request.name) + except AlreadyExistsError: + raise HTTPException(status_code=409, detail="Workspace already exists") + except ValidationError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception: + raise HTTPException(status_code=500, detail="Internal server error") + + if created: + return v1_models.Workspace(name=created.name) @v1.delete( diff --git a/src/codegate/dashboard/dashboard.py b/src/codegate/dashboard/dashboard.py index 89e15314..ee71424f 100644 --- a/src/codegate/dashboard/dashboard.py +++ b/src/codegate/dashboard/dashboard.py @@ -6,8 +6,8 @@ import structlog from fastapi import APIRouter, Depends, FastAPI from fastapi.responses import StreamingResponse -from codegate import __version__ +from codegate import __version__ from codegate.dashboard.post_processing import ( parse_get_alert_conversation, parse_messages_in_conversations, @@ -82,7 +82,7 @@ def version_check(): latest_version_stripped = latest_version.lstrip('v') is_latest: bool = latest_version_stripped == current_version - + return { "current_version": current_version, "latest_version": latest_version_stripped, diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 2086039d..b83ceb7c 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -7,9 +7,9 @@ import structlog from alembic import command as alembic_command from alembic.config import Config as AlembicConfig -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel from sqlalchemy import CursorResult, TextClause, text -from sqlalchemy.exc import OperationalError +from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.ext.asyncio import create_async_engine from codegate.db.fim_cache import FimCache @@ -30,6 +30,8 @@ alert_queue = asyncio.Queue() fim_cache = FimCache() +class AlreadyExistsError(Exception): + pass class DbCodeGate: _instance = None @@ -70,11 +72,11 @@ def __init__(self, sqlite_path: Optional[str] = None): super().__init__(sqlite_path) async def _execute_update_pydantic_model( - self, model: BaseModel, sql_command: TextClause + self, model: BaseModel, sql_command: TextClause, should_raise: bool = False ) -> Optional[BaseModel]: """Execute an update or insert command for a Pydantic model.""" - async with self._async_db_engine.begin() as conn: - try: + try: + async with self._async_db_engine.begin() as conn: result = await conn.execute(sql_command, model.model_dump()) row = result.first() if row is None: @@ -83,9 +85,11 @@ async def _execute_update_pydantic_model( # Get the class of the Pydantic object to create a new object model_class = model.__class__ return model_class(**row._asdict()) - except Exception as e: - logger.error(f"Failed to update model: {model}.", error=str(e)) - return None + except Exception as e: + logger.error(f"Failed to update model: {model}.", error=str(e)) + if should_raise: + raise e + return None async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]: if prompt_params is None: @@ -243,11 +247,14 @@ async def record_context(self, context: Optional[PipelineContext]) -> None: logger.error(f"Failed to record context: {context}.", error=str(e)) async def add_workspace(self, workspace_name: str) -> Optional[Workspace]: - try: - workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name) - except ValidationError as e: - logger.error(f"Failed to create workspace with name: {workspace_name}: {str(e)}") - return None + """Add a new workspace to the DB. + + This handles validation and insertion of a new workspace. + + It may raise a ValidationError if the workspace name is invalid. + or a AlreadyExistsError if the workspace already exists. + """ + workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name) sql = text( """ @@ -256,12 +263,13 @@ async def add_workspace(self, workspace_name: str) -> Optional[Workspace]: RETURNING * """ ) - try: - added_workspace = await self._execute_update_pydantic_model(workspace, sql) - except Exception as e: - logger.error(f"Failed to add workspace: {workspace_name}.", error=str(e)) - return None + try: + added_workspace = await self._execute_update_pydantic_model( + workspace, sql, should_raise=True) + except IntegrityError as e: + logger.debug(f"Exception type: {type(e)}") + raise AlreadyExistsError(f"Workspace {workspace_name} already exists.") return added_workspace async def update_session(self, session: Session) -> Optional[Session]: diff --git a/src/codegate/pipeline/cli/commands.py b/src/codegate/pipeline/cli/commands.py index b3a6db7e..45141ea9 100644 --- a/src/codegate/pipeline/cli/commands.py +++ b/src/codegate/pipeline/cli/commands.py @@ -1,7 +1,10 @@ from abc import ABC, abstractmethod from typing import List +from pydantic import ValidationError + from codegate import __version__ +from codegate.db.connection import AlreadyExistsError from codegate.workspaces.crud import WorkspaceCrud @@ -69,14 +72,16 @@ async def _add_workspace(self, args: List[str]) -> str: if not new_workspace_name: return "Please provide a name. Use `codegate workspace add your_workspace_name`" - workspace_created = await self.workspace_crud.add_workspace(new_workspace_name) - if not workspace_created: - return ( - "Something went wrong. Workspace could not be added.\n" - "1. Check if the name is alphanumeric and only contains dashes, and underscores.\n" - "2. Check if the workspace already exists." - ) - return f"Workspace **{new_workspace_name}** has been added" + try: + workspace_created = await self.workspace_crud.add_workspace(new_workspace_name) + except ValidationError as e: + return f"Invalid workspace name: {e}" + except AlreadyExistsError: + return f"Workspace **{new_workspace_name}** already exists" + except Exception: + return "An error occurred while adding the workspace" + + return f"Workspace **{workspace_created.name}** has been added" async def _activate_workspace(self, args: List[str]) -> str: """ diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py index 9d2beef6..6097c395 100644 --- a/src/codegate/workspaces/crud.py +++ b/src/codegate/workspaces/crud.py @@ -1,16 +1,19 @@ import datetime -from typing import Optional, Tuple, List +from typing import List, Optional, Tuple from codegate.db.connection import DbReader, DbRecorder -from codegate.db.models import Session, Workspace, WorkspaceActive, ActiveWorkspace +from codegate.db.models import ActiveWorkspace, Session, Workspace, WorkspaceActive +class WorkspaceCrudError(Exception): + pass + class WorkspaceCrud: def __init__(self): self._db_reader = DbReader() - async def add_workspace(self, new_workspace_name: str) -> bool: + async def add_workspace(self, new_workspace_name: str) -> Workspace: """ Add a workspace @@ -19,7 +22,7 @@ async def add_workspace(self, new_workspace_name: str) -> bool: """ db_recorder = DbRecorder() workspace_created = await db_recorder.add_workspace(new_workspace_name) - return bool(workspace_created) + return workspace_created async def get_workspaces(self)-> List[WorkspaceActive]: """