diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index f1dba385..2e5bc4eb 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -3,10 +3,10 @@ from fastapi.routing import APIRoute from codegate.api import v1_models -from codegate.pipeline.workspace import commands as wscmd +from codegate.workspaces.crud import WorkspaceCrud v1 = APIRouter() -wscrud = wscmd.WorkspaceCrud() +wscrud = WorkspaceCrud() def uniq_name(route: APIRoute): @@ -61,9 +61,12 @@ async def create_workspace(request: v1_models.CreateWorkspaceRequest): return v1_models.Workspace(name=request.name) - -@v1.delete("/workspaces/{workspace_name}", tags=["Workspaces"], - generate_unique_id_function=uniq_name, status_code=204) +@v1.delete( + "/workspaces/{workspace_name}", + tags=["Workspaces"], + generate_unique_id_function=uniq_name, + status_code=204, +) async def delete_workspace(workspace_name: str): """Delete a workspace by name.""" raise NotImplementedError diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index 55418e2b..cf789a6c 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -9,36 +9,44 @@ class Workspace(pydantic.BaseModel): name: str is_active: bool + class ActiveWorkspace(Workspace): # TODO: use a more specific type for last_updated last_updated: Any + class ListWorkspacesResponse(pydantic.BaseModel): workspaces: list[Workspace] @classmethod def from_db_workspaces( - cls, db_workspaces: List[db_models.WorkspaceActive])-> "ListWorkspacesResponse": - return cls(workspaces=[ - Workspace(name=ws.name, is_active=ws.active_workspace_id is not None) - for ws in db_workspaces]) + cls, db_workspaces: List[db_models.WorkspaceActive] + ) -> "ListWorkspacesResponse": + return cls( + workspaces=[ + Workspace(name=ws.name, is_active=ws.active_workspace_id is not None) + for ws in db_workspaces + ] + ) + class ListActiveWorkspacesResponse(pydantic.BaseModel): workspaces: list[ActiveWorkspace] @classmethod def from_db_workspaces( - cls, ws: Optional[db_models.ActiveWorkspace]) -> "ListActiveWorkspacesResponse": + cls, ws: Optional[db_models.ActiveWorkspace] + ) -> "ListActiveWorkspacesResponse": if ws is None: return cls(workspaces=[]) - return cls(workspaces=[ - ActiveWorkspace(name=ws.name, - is_active=True, - last_updated=ws.last_update) - ]) + return cls( + workspaces=[ActiveWorkspace(name=ws.name, is_active=True, last_updated=ws.last_update)] + ) + class CreateWorkspaceRequest(pydantic.BaseModel): name: str + class ActivateWorkspaceRequest(pydantic.BaseModel): name: str diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 006616ba..2086039d 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -30,6 +30,7 @@ alert_queue = asyncio.Queue() fim_cache = FimCache() + class DbCodeGate: _instance = None @@ -256,8 +257,7 @@ async def add_workspace(self, workspace_name: str) -> Optional[Workspace]: """ ) try: - added_workspace = await self._execute_update_pydantic_model( - workspace, sql) + added_workspace = await self._execute_update_pydantic_model(workspace, sql) except Exception as e: logger.error(f"Failed to add workspace: {workspace_name}.", error=str(e)) return None diff --git a/src/codegate/pipeline/cli/cli.py b/src/codegate/pipeline/cli/cli.py new file mode 100644 index 00000000..08e33fb4 --- /dev/null +++ b/src/codegate/pipeline/cli/cli.py @@ -0,0 +1,92 @@ +import shlex + +from litellm import ChatCompletionRequest + +from codegate.pipeline.base import ( + PipelineContext, + PipelineResponse, + PipelineResult, + PipelineStep, +) +from codegate.pipeline.cli.commands import Version, Workspace + +HELP_TEXT = """ +## CodeGate CLI\n +**Usage**: `codegate [-h] [args]`\n +Check the help of each command by running `codegate -h`\n +Available commands: +- `version`: Show the version of CodeGate +- `workspace`: Perform different operations on workspaces +""" + +NOT_FOUND_TEXT = "Command not found. Use `codegate -h` to see available commands." + + +async def codegate_cli(command): + """ + Process the 'codegate' command. + """ + if len(command) == 0: + return HELP_TEXT + + available_commands = { + "version": Version().exec, + "workspace": Workspace().exec, + "-h": lambda _: HELP_TEXT, + } + out_func = available_commands.get(command[0]) + if out_func is None: + return NOT_FOUND_TEXT + + return await out_func(command[1:]) + + +class CodegateCli(PipelineStep): + """Pipeline step that handles codegate cli.""" + + @property + def name(self) -> str: + """ + Returns the name of this pipeline step. + + Returns: + str: The identifier 'codegate-cli' + """ + return "codegate-cli" + + async def process( + self, request: ChatCompletionRequest, context: PipelineContext + ) -> PipelineResult: + """ + Checks if the last user message contains "codegate" and process the command. + This short-circuits the pipeline if the message is found. + + Args: + request (ChatCompletionRequest): The chat completion request to process + context (PipelineContext): The current pipeline context + + Returns: + PipelineResult: Contains the response if triggered, otherwise continues + pipeline + """ + last_user_message = self.get_last_user_message(request) + + if last_user_message is not None: + last_user_message_str, _ = last_user_message + splitted_message = last_user_message_str.lower().split(" ") + # We expect codegate as the first word in the message + if splitted_message[0] == "codegate": + context.shortcut_response = True + args = shlex.split(last_user_message_str) + cmd_out = await codegate_cli(args[1:]) + return PipelineResult( + response=PipelineResponse( + step_name=self.name, + content=cmd_out, + model=request["model"], + ), + context=context, + ) + + # Fall through + return PipelineResult(request=request, context=context) diff --git a/src/codegate/pipeline/cli/commands.py b/src/codegate/pipeline/cli/commands.py new file mode 100644 index 00000000..7e783113 --- /dev/null +++ b/src/codegate/pipeline/cli/commands.py @@ -0,0 +1,125 @@ +from abc import ABC, abstractmethod +from typing import List + +from codegate import __version__ +from codegate.workspaces.crud import WorkspaceCrud + + +class CodegateCommand(ABC): + @abstractmethod + async def run(self, args: List[str]) -> str: + pass + + @property + @abstractmethod + def help(self) -> str: + pass + + async def exec(self, args: List[str]) -> str: + if args and args[0] == "-h": + return self.help + return await self.run(args) + + +class Version(CodegateCommand): + async def run(self, args: List[str]) -> str: + return f"CodeGate version: {__version__}" + + @property + def help(self) -> str: + return ( + "### CodeGate Version\n\n" + "Prints the version of CodeGate.\n\n" + "**Usage**: `codegate version`\n\n" + "*args*: None" + ) + + +class Workspace(CodegateCommand): + + def __init__(self): + self.workspace_crud = WorkspaceCrud() + self.commands = { + "list": self._list_workspaces, + "add": self._add_workspace, + "activate": self._activate_workspace, + } + + async def _list_workspaces(self, *args: List[str]) -> str: + """ + List all workspaces + """ + workspaces = await self.workspace_crud.get_workspaces() + respond_str = "" + for workspace in workspaces: + respond_str += f"- {workspace.name}" + if workspace.active_workspace_id: + respond_str += " **(active)**" + respond_str += "\n" + return respond_str + + async def _add_workspace(self, args: List[str]) -> str: + """ + Add a workspace + """ + if args is None or len(args) == 0: + return "Please provide a name. Use `codegate-workspace add your_workspace_name`" + + new_workspace_name = args[0] + if not new_workspace_name: + return "Please provide a name. Use `codegate-workspace add your_workspace_name`" + + workspace_created = await self.workspace_crud.add_workspace(new_workspace_name) + if not workspace_created: + return ( + "Something went wrong. Workspace could not be added.\n" + "1. Check if the name is alphanumeric and only contains dashes, and underscores.\n" + "2. Check if the workspace already exists." + ) + return f"Workspace **{new_workspace_name}** has been added" + + async def _activate_workspace(self, args: List[str]) -> str: + """ + Activate a workspace + """ + if args is None or len(args) == 0: + return "Please provide a name. Use `codegate-workspace activate workspace_name`" + + workspace_name = args[0] + if not workspace_name: + return "Please provide a name. Use `codegate-workspace activate workspace_name`" + + was_activated = await self.workspace_crud.activate_workspace(workspace_name) + if not was_activated: + return ( + f"Workspace **{workspace_name}** does not exist or was already active. " + f"Use `codegate-workspace add {workspace_name}` to add it" + ) + return f"Workspace **{workspace_name}** has been activated" + + async def run(self, args: List[str]) -> str: + if not args: + return "Please provide a command. Use `codegate workspace -h` to see available commands" + command = args[0] + command_to_execute = self.commands.get(command) + if command_to_execute is not None: + return await command_to_execute(args[1:]) + else: + return "Command not found. Use `codegate workspace -h` to see available commands" + + @property + def help(self) -> str: + return ( + "### CodeGate Workspace\n\n" + "Manage workspaces.\n\n" + "**Usage**: `codegate workspace [args]`\n\n" + "Available commands:\n\n" + "- `list`: List all workspaces\n\n" + " - *args*: None\n\n" + "- `add`: Add a workspace\n\n" + " - *args*:\n\n" + " - `workspace_name`\n\n" + "- `activate`: Activate a workspace\n\n" + " - *args*:\n\n" + " - `workspace_name`" + ) diff --git a/src/codegate/pipeline/factory.py b/src/codegate/pipeline/factory.py index cd14df8e..0cefb689 100644 --- a/src/codegate/pipeline/factory.py +++ b/src/codegate/pipeline/factory.py @@ -2,6 +2,7 @@ from codegate.config import Config from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor +from codegate.pipeline.cli.cli import CodegateCli from codegate.pipeline.codegate_context_retriever.codegate import CodegateContextRetriever from codegate.pipeline.extract_snippets.extract_snippets import CodeSnippetExtractor from codegate.pipeline.extract_snippets.output import CodeCommentStep @@ -13,8 +14,6 @@ SecretUnredactionStep, ) from codegate.pipeline.system_prompt.codegate import SystemPrompt -from codegate.pipeline.version.version import CodegateVersion -from codegate.pipeline.workspace.workspace import CodegateWorkspace class PipelineFactory: @@ -28,8 +27,7 @@ def create_input_pipeline(self) -> SequentialPipelineProcessor: # and without obfuscating the secrets, we'd leak the secrets during those # later steps CodegateSecrets(), - CodegateVersion(), - CodegateWorkspace(), + CodegateCli(), CodeSnippetExtractor(), CodegateContextRetriever(), SystemPrompt(Config.get_config().prompts.default_chat), diff --git a/src/codegate/pipeline/version/version.py b/src/codegate/pipeline/version/version.py deleted file mode 100644 index a42011b8..00000000 --- a/src/codegate/pipeline/version/version.py +++ /dev/null @@ -1,58 +0,0 @@ -from litellm import ChatCompletionRequest - -from codegate import __version__ -from codegate.pipeline.base import ( - PipelineContext, - PipelineResponse, - PipelineResult, - PipelineStep, -) - - -class CodegateVersion(PipelineStep): - """Pipeline step that handles version information requests.""" - - @property - def name(self) -> str: - """ - Returns the name of this pipeline step. - - Returns: - str: The identifier 'codegate-version' - """ - return "codegate-version" - - async def process( - self, request: ChatCompletionRequest, context: PipelineContext - ) -> PipelineResult: - """ - Checks if the last user message contains "codegate-version" and - responds with the current version. - This short-circuits the pipeline if the message is found. - - Args: - request (ChatCompletionRequest): The chat completion request to process - context (PipelineContext): The current pipeline context - - Returns: - PipelineResult: Contains version response if triggered, otherwise continues - pipeline - """ - last_user_message = self.get_last_user_message(request) - - if last_user_message is not None: - last_user_message_str, _ = last_user_message - if "codegate-version" in last_user_message_str.lower(): - context.shortcut_response = True - context.add_alert(self.name, trigger_string=last_user_message_str) - return PipelineResult( - response=PipelineResponse( - step_name=self.name, - content="CodeGate version: {}".format(__version__), - model=request["model"], - ), - context=context, - ) - - # Fall through - return PipelineResult(request=request, context=context) diff --git a/src/codegate/pipeline/workspace/__init__.py b/src/codegate/pipeline/workspace/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/codegate/pipeline/workspace/commands.py b/src/codegate/pipeline/workspace/commands.py deleted file mode 100644 index 812d6c85..00000000 --- a/src/codegate/pipeline/workspace/commands.py +++ /dev/null @@ -1,164 +0,0 @@ -import datetime -from typing import List, Optional, Tuple - -from codegate.db.connection import DbReader, DbRecorder -from codegate.db.models import ActiveWorkspace, Session, Workspace, WorkspaceActive - - -class WorkspaceCrud: - - def __init__(self): - self._db_reader = DbReader() - - async def add_workspace(self, new_workspace_name: str) -> bool: - """ - Add a workspace - - Args: - name (str): The name of the workspace - """ - db_recorder = DbRecorder() - workspace_created = await db_recorder.add_workspace( - new_workspace_name) - return bool(workspace_created) - - async def get_workspaces(self) -> List[WorkspaceActive]: - """ - Get all workspaces - """ - return await self._db_reader.get_workspaces() - - async def get_active_workspace(self) -> Optional[ActiveWorkspace]: - """ - Get the active workspace - """ - return await self._db_reader.get_active_workspace() - - async def _is_workspace_active_or_not_exist( - self, workspace_name: str - ) -> Tuple[bool, Optional[Session], Optional[Workspace]]: - """ - Check if the workspace is active - - Will return: - - True if the workspace was activated - - False if the workspace is already active or does not exist - """ - selected_workspace = await self._db_reader.get_workspace_by_name(workspace_name) - if not selected_workspace: - return True, None, None - - sessions = await self._db_reader.get_sessions() - # The current implementation expects only one active session - if len(sessions) != 1: - raise RuntimeError("Something went wrong. No active session found.") - - session = sessions[0] - if session.active_workspace_id == selected_workspace.id: - return True, None, None - return False, session, selected_workspace - - async def activate_workspace(self, workspace_name: str) -> bool: - """ - Activate a workspace - - Will return: - - True if the workspace was activated - - False if the workspace is already active or does not exist - """ - is_active, session, workspace = await self._is_workspace_active_or_not_exist(workspace_name) - if is_active: - return False - - session.active_workspace_id = workspace.id - session.last_update = datetime.datetime.now(datetime.timezone.utc) - db_recorder = DbRecorder() - await db_recorder.update_session(session) - return True - - -class WorkspaceCommands: - - def __init__(self): - self.workspace_crud = WorkspaceCrud() - self.commands = { - "list": self._list_workspaces, - "add": self._add_workspace, - "activate": self._activate_workspace, - } - - async def _list_workspaces(self, *args) -> str: - """ - List all workspaces - """ - workspaces = await self.workspace_crud.get_workspaces() - respond_str = "" - for workspace in workspaces: - respond_str += f"- {workspace.name}" - if workspace.active_workspace_id: - respond_str += " **(active)**" - respond_str += "\n" - return respond_str - - async def _add_workspace(self, *args) -> str: - """ - Add a workspace - """ - if args is None or len(args) == 0: - return "Please provide a name. Use `codegate-workspace add your_workspace_name`" - - new_workspace_name = args[0] - if not new_workspace_name: - return "Please provide a name. Use `codegate-workspace add your_workspace_name`" - - workspace_created = await self.workspace_crud.add_workspace(new_workspace_name) - if not workspace_created: - return ( - "Something went wrong. Workspace could not be added.\n" - "1. Check if the name is alphanumeric and only contains dashes, and underscores.\n" - "2. Check if the workspace already exists." - ) - return f"Workspace **{new_workspace_name}** has been added" - - async def _activate_workspace(self, *args) -> str: - """ - Activate a workspace - """ - if args is None or len(args) == 0: - return "Please provide a name. Use `codegate-workspace activate workspace_name`" - - workspace_name = args[0] - if not workspace_name: - return "Please provide a name. Use `codegate-workspace activate workspace_name`" - - was_activated = await self.workspace_crud.activate_workspace(workspace_name) - if not was_activated: - return ( - f"Workspace **{workspace_name}** does not exist or was already active. " - f"Use `codegate-workspace add {workspace_name}` to add it" - ) - return f"Workspace **{workspace_name}** has been activated" - - async def execute(self, command: str, *args) -> str: - """ - Execute the given command - - Args: - command (str): The command to execute - """ - command_to_execute = self.commands.get(command) - if command_to_execute is not None: - return await command_to_execute(*args) - else: - return "Command not found" - - async def parse_execute_cmd(self, last_user_message: str) -> str: - """ - Parse the last user message and execute the command - - Args: - last_user_message (str): The last user message - """ - command_and_args = last_user_message.lower().split("codegate-workspace ")[1] - command, *args = command_and_args.split(" ") - return await self.execute(command, *args) diff --git a/src/codegate/pipeline/workspace/workspace.py b/src/codegate/pipeline/workspace/workspace.py deleted file mode 100644 index 5cde5177..00000000 --- a/src/codegate/pipeline/workspace/workspace.py +++ /dev/null @@ -1,58 +0,0 @@ -from litellm import ChatCompletionRequest - -from codegate.pipeline.base import ( - PipelineContext, - PipelineResponse, - PipelineResult, - PipelineStep, -) -from codegate.pipeline.workspace.commands import WorkspaceCommands - - -class CodegateWorkspace(PipelineStep): - """Pipeline step that handles workspace information requests.""" - - @property - def name(self) -> str: - """ - Returns the name of this pipeline step. - - Returns: - str: The identifier 'codegate-workspace' - """ - return "codegate-workspace" - - async def process( - self, request: ChatCompletionRequest, context: PipelineContext - ) -> PipelineResult: - """ - Checks if the last user message contains "codegate-workspace" and - responds with command specified. - This short-circuits the pipeline if the message is found. - - Args: - request (ChatCompletionRequest): The chat completion request to process - context (PipelineContext): The current pipeline context - - Returns: - PipelineResult: Contains workspace response if triggered, otherwise continues - pipeline - """ - last_user_message = self.get_last_user_message(request) - - if last_user_message is not None: - last_user_message_str, _ = last_user_message - if "codegate-workspace" in last_user_message_str.lower(): - context.shortcut_response = True - command_output = await WorkspaceCommands().parse_execute_cmd(last_user_message_str) - return PipelineResult( - response=PipelineResponse( - step_name=self.name, - content=command_output, - model=request["model"], - ), - context=context, - ) - - # Fall through - return PipelineResult(request=request, context=context) diff --git a/src/codegate/pipeline/version/__init__.py b/src/codegate/workspaces/__init__.py similarity index 100% rename from src/codegate/pipeline/version/__init__.py rename to src/codegate/workspaces/__init__.py diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py new file mode 100644 index 00000000..9902e80b --- /dev/null +++ b/src/codegate/workspaces/crud.py @@ -0,0 +1,70 @@ +import datetime +from typing import Optional, Tuple + +from codegate.db.connection import DbReader, DbRecorder +from codegate.db.models import Session, Workspace + + +class WorkspaceCrud: + + def __init__(self): + self._db_reader = DbReader() + + async def add_workspace(self, new_workspace_name: str) -> bool: + """ + Add a workspace + + Args: + name (str): The name of the workspace + """ + db_recorder = DbRecorder() + workspace_created = await db_recorder.add_workspace(new_workspace_name) + return bool(workspace_created) + + async def get_workspaces(self): + """ + Get all workspaces + """ + return await self._db_reader.get_workspaces() + + async def _is_workspace_active_or_not_exist( + self, workspace_name: str + ) -> Tuple[bool, Optional[Session], Optional[Workspace]]: + """ + Check if the workspace is active + + Will return: + - True if the workspace was activated + - False if the workspace is already active or does not exist + """ + selected_workspace = await self._db_reader.get_workspace_by_name(workspace_name) + if not selected_workspace: + return True, None, None + + sessions = await self._db_reader.get_sessions() + # The current implementation expects only one active session + if len(sessions) != 1: + raise RuntimeError("Something went wrong. No active session found.") + + session = sessions[0] + if session.active_workspace_id == selected_workspace.id: + return True, None, None + return False, session, selected_workspace + + async def activate_workspace(self, workspace_name: str) -> bool: + """ + Activate a workspace + + Will return: + - True if the workspace was activated + - False if the workspace is already active or does not exist + """ + is_active, session, workspace = await self._is_workspace_active_or_not_exist(workspace_name) + if is_active: + return False + + session.active_workspace_id = workspace.id + session.last_update = datetime.datetime.now(datetime.timezone.utc) + db_recorder = DbRecorder() + await db_recorder.update_session(session) + return True diff --git a/tests/pipeline/workspace/test_workspace.py b/tests/pipeline/workspace/test_workspace.py index d67c1f6c..61475102 100644 --- a/tests/pipeline/workspace/test_workspace.py +++ b/tests/pipeline/workspace/test_workspace.py @@ -3,7 +3,7 @@ import pytest from codegate.db.models import WorkspaceActive -from codegate.pipeline.workspace.commands import WorkspaceCommands +from codegate.pipeline.cli.commands import Workspace @pytest.mark.asyncio @@ -35,7 +35,7 @@ async def test_list_workspaces(mock_workspaces, expected_output): """ Test _list_workspaces with different sets of returned workspaces. """ - workspace_commands = WorkspaceCommands() + workspace_commands = Workspace() # Mock DbReader inside workspace_commands mock_get_workspaces = AsyncMock(return_value=mock_workspaces) @@ -69,7 +69,7 @@ async def test_add_workspaces(args, existing_workspaces, expected_message): - workspace already exists - workspace successfully added """ - workspace_commands = WorkspaceCommands() + workspace_commands = Workspace() # Mock the DbReader to return existing_workspaces mock_db_reader = AsyncMock() @@ -78,13 +78,14 @@ async def test_add_workspaces(args, existing_workspaces, expected_message): # We'll also patch DbRecorder to ensure no real DB operations happen with patch( - "codegate.pipeline.workspace.commands.DbRecorder", autospec=True + "codegate.pipeline.cli.commands.WorkspaceCrud", autospec=True ) as mock_recorder_cls: mock_recorder = mock_recorder_cls.return_value + workspace_commands.workspace_crud = mock_recorder mock_recorder.add_workspace = AsyncMock() # Call the method - result = await workspace_commands._add_workspace(*args) + result = await workspace_commands._add_workspace(args) # Assertions assert result == expected_message @@ -100,9 +101,9 @@ async def test_add_workspaces(args, existing_workspaces, expected_message): @pytest.mark.parametrize( "user_message, expected_command, expected_args, mocked_execute_response", [ - ("codegate-workspace list", "list", [], "List workspaces output"), - ("codegate-workspace add myws", "add", ["myws"], "Added workspace"), - ("codegate-workspace activate myws", "activate", ["myws"], "Activated workspace"), + (["list"], "list", ["list"], "List workspaces output"), + (["add", "myws"], "add", ["add", "myws"], "Added workspace"), + (["activate", "myws"], "activate", ["activate", "myws"], "Activated workspace"), ], ) async def test_parse_execute_cmd( @@ -112,13 +113,13 @@ async def test_parse_execute_cmd( Test parse_execute_cmd to ensure it parses the user message and calls the correct command with the correct args. """ - workspace_commands = WorkspaceCommands() + workspace_commands = Workspace() with patch.object( - workspace_commands, "execute", return_value=mocked_execute_response - ) as mock_execute: - result = await workspace_commands.parse_execute_cmd(user_message) + workspace_commands, "run", return_value=mocked_execute_response + ) as mock_run: + result = await workspace_commands.exec(user_message) assert result == mocked_execute_response # Verify 'execute' was called with the expected command and args - mock_execute.assert_awaited_once_with(expected_command, *expected_args) + mock_run.assert_awaited_once_with(expected_args)