Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/sqlite-vec' into sqlite-vec
Browse files Browse the repository at this point in the history
  • Loading branch information
lukehinds committed Jan 21, 2025
2 parents 4ba9c4b + bb63b3c commit ae612ab
Show file tree
Hide file tree
Showing 19 changed files with 555 additions and 113 deletions.
5 changes: 1 addition & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
python-version: ["3.12"]

steps:
- name: Checkout github repo
- name: Checkout
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
with:
lfs: true
Expand Down Expand Up @@ -47,9 +47,6 @@ jobs:
- name: Run linting
run: make lint

- name: Run formatting
run: make format

- name: Run tests
run: make test

Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ jobs:
with:
lfs: true

- name: Checkout LFS objects
run: git lfs pull

- name: Ensure file permissions for mounted volume
run: |
mkdir -p ./codegate_volume/certs ./codegate_volume/models ./codegate_volume/db
Expand Down
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ format:
poetry run ruff check --fix .

lint:
poetry run black --check .
poetry run ruff check .

test:
Expand Down
13 changes: 10 additions & 3 deletions api/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@
"description": "Successful Response",
"content": {
"application/json": {
"schema": {}
"schema": {
"$ref": "#/components/schemas/Workspace"
}
}
}
},
Expand Down Expand Up @@ -283,8 +285,13 @@
}
],
"responses": {
"204": {
"description": "Successful Response"
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {}
}
}
},
"422": {
"description": "Validation Error",
Expand Down
Binary file modified codegate_volume/models/all-minilm-L6-v2-q5_k_m.gguf
Binary file not shown.
30 changes: 30 additions & 0 deletions migrations/versions/8e4b4b8d1a88_add_soft_delete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""add soft delete
Revision ID: 8e4b4b8d1a88
Revises: 5c2f3eee5f90
Create Date: 2025-01-20 14:08:40.851647
"""

from typing import Sequence, Union

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "8e4b4b8d1a88"
down_revision: Union[str, None] = "5c2f3eee5f90"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.execute(
"""
ALTER TABLE workspaces
ADD COLUMN deleted_at DATETIME DEFAULT NULL;
"""
)


def downgrade() -> None:
op.execute("ALTER TABLE workspaces DROP COLUMN deleted_at;")
26 changes: 26 additions & 0 deletions migrations/versions/a692c8b52308_add_workspace_system_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""add_workspace_system_prompt
Revision ID: a692c8b52308
Revises: 5c2f3eee5f90
Create Date: 2025-01-17 16:33:58.464223
"""

from typing import Sequence, Union

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "a692c8b52308"
down_revision: Union[str, None] = "5c2f3eee5f90"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Add column to workspaces table
op.execute("ALTER TABLE workspaces ADD COLUMN system_prompt TEXT DEFAULT NULL;")


def downgrade() -> None:
op.execute("ALTER TABLE workspaces DROP COLUMN system_prompt;")
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""merging system prompt and soft-deletes
Revision ID: e6227073183d
Revises: 8e4b4b8d1a88, a692c8b52308
Create Date: 2025-01-20 16:08:40.645298
"""

from typing import Sequence, Union

# revision identifiers, used by Alembic.
revision: str = "e6227073183d"
down_revision: Union[str, None] = ("8e4b4b8d1a88", "a692c8b52308")
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
pass


def downgrade() -> None:
pass
47 changes: 28 additions & 19 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from fastapi import APIRouter, Response
from fastapi.exceptions import HTTPException
from fastapi import APIRouter, HTTPException, Response
from fastapi.routing import APIRoute
from pydantic import ValidationError

from codegate.api import v1_models
from codegate.db.connection import AlreadyExistsError
from codegate.workspaces.crud import WorkspaceCrud
from codegate.api.dashboard.dashboard import dashboard_router
from codegate.db.connection import AlreadyExistsError
from codegate.workspaces import crud

v1 = APIRouter()
v1.include_router(dashboard_router)

wscrud = WorkspaceCrud()
wscrud = crud.WorkspaceCrud()


def uniq_name(route: APIRoute):
Expand Down Expand Up @@ -44,40 +42,51 @@ async def list_active_workspaces() -> v1_models.ListActiveWorkspacesResponse:
@v1.post("/workspaces/active", tags=["Workspaces"], generate_unique_id_function=uniq_name)
async def activate_workspace(request: v1_models.ActivateWorkspaceRequest, status_code=204):
"""Activate a workspace by name."""
activated = await wscrud.activate_workspace(request.name)

# TODO: Refactor
if not activated:
try:
await wscrud.activate_workspace(request.name)
except crud.WorkspaceAlreadyActiveError:
return HTTPException(status_code=409, detail="Workspace already active")
except crud.WorkspaceDoesNotExistError:
return HTTPException(status_code=404, detail="Workspace does not exist")
except Exception:
return HTTPException(status_code=500, detail="Internal server error")

return Response(status_code=204)


@v1.post("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name, status_code=201)
async def create_workspace(request: v1_models.CreateWorkspaceRequest):
async def create_workspace(request: v1_models.CreateWorkspaceRequest) -> v1_models.Workspace:
"""Create a new workspace."""
# Input validation is done in the model
try:
created = await wscrud.add_workspace(request.name)
_ = await wscrud.add_workspace(request.name)
except AlreadyExistsError:
raise HTTPException(status_code=409, detail="Workspace already exists")
except ValidationError:
raise HTTPException(status_code=400,
detail=("Invalid workspace name. "
"Please use only alphanumeric characters and dashes"))
raise HTTPException(
status_code=400,
detail=(
"Invalid workspace name. " "Please use only alphanumeric characters and dashes"
),
)
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")

if created:
return v1_models.Workspace(name=created.name)
return v1_models.Workspace(name=request.name, is_active=False)


@v1.delete(
"/workspaces/{workspace_name}",
tags=["Workspaces"],
generate_unique_id_function=uniq_name,
status_code=204,
)
async def delete_workspace(workspace_name: str):
"""Delete a workspace by name."""
raise NotImplementedError
try:
_ = await wscrud.soft_delete_workspace(workspace_name)
except crud.WorkspaceDoesNotExistError:
raise HTTPException(status_code=404, detail="Workspace does not exist")
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")

return Response(status_code=204)
69 changes: 57 additions & 12 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
alert_queue = asyncio.Queue()
fim_cache = FimCache()


class AlreadyExistsError(Exception):
pass


class DbCodeGate:
_instance = None

Expand Down Expand Up @@ -246,16 +248,15 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
except Exception as e:
logger.error(f"Failed to record context: {context}.", error=str(e))

async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
async def add_workspace(self, workspace_name: str) -> Workspace:
"""Add a new workspace to the DB.
This handles validation and insertion of a new workspace.
It may raise a ValidationError if the workspace name is invalid.
or a AlreadyExistsError if the workspace already exists.
"""
workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name)

workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name, system_prompt=None)
sql = text(
"""
INSERT INTO workspaces (id, name)
Expand All @@ -266,12 +267,28 @@ async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:

try:
added_workspace = await self._execute_update_pydantic_model(
workspace, sql, should_raise=True)
workspace, sql, should_raise=True
)
except IntegrityError as e:
logger.debug(f"Exception type: {type(e)}")
raise AlreadyExistsError(f"Workspace {workspace_name} already exists.")
return added_workspace

async def update_workspace(self, workspace: Workspace) -> Workspace:
sql = text(
"""
UPDATE workspaces SET
name = :name,
system_prompt = :system_prompt
WHERE id = :id
RETURNING *
"""
)
updated_workspace = await self._execute_update_pydantic_model(
workspace, sql, should_raise=True
)
return updated_workspace

async def update_session(self, session: Session) -> Optional[Session]:
sql = text(
"""
Expand All @@ -284,9 +301,23 @@ async def update_session(self, session: Session) -> Optional[Session]:
"""
)
# We only pass an object to respect the signature of the function
active_session = await self._execute_update_pydantic_model(session, sql)
active_session = await self._execute_update_pydantic_model(session, sql, should_raise=True)
return active_session

async def soft_delete_workspace(self, workspace: Workspace) -> Optional[Workspace]:
sql = text(
"""
UPDATE workspaces
SET deleted_at = CURRENT_TIMESTAMP
WHERE id = :id
RETURNING *
"""
)
deleted_workspace = await self._execute_update_pydantic_model(
workspace, sql, should_raise=True
)
return deleted_workspace


class DbReader(DbCodeGate):

Expand Down Expand Up @@ -317,14 +348,21 @@ async def _execute_select_pydantic_model(
return None

async def _exec_select_conditions_to_pydantic(
self, model_type: Type[BaseModel], sql_command: TextClause, conditions: dict
self,
model_type: Type[BaseModel],
sql_command: TextClause,
conditions: dict,
should_raise: bool = False,
) -> Optional[List[BaseModel]]:
async with self._async_db_engine.begin() as conn:
try:
result = await conn.execute(sql_command, conditions)
return await self._dump_result_to_pydantic_model(model_type, result)
except Exception as e:
logger.error(f"Failed to select model with conditions: {model_type}.", error=str(e))
# Exposes errors to the caller
if should_raise:
raise e
return None

async def get_prompts_with_output(self) -> List[GetPromptWithOutputsRow]:
Expand Down Expand Up @@ -377,22 +415,25 @@ async def get_workspaces(self) -> List[WorkspaceActive]:
w.id, w.name, s.active_workspace_id
FROM workspaces w
LEFT JOIN sessions s ON w.id = s.active_workspace_id
WHERE w.deleted_at IS NULL
"""
)
workspaces = await self._execute_select_pydantic_model(WorkspaceActive, sql)
return workspaces

async def get_workspace_by_name(self, name: str) -> List[Workspace]:
async def get_workspace_by_name(self, name: str) -> Optional[Workspace]:
sql = text(
"""
SELECT
id, name
id, name, system_prompt
FROM workspaces
WHERE name = :name
WHERE name = :name AND deleted_at IS NULL
"""
)
conditions = {"name": name}
workspaces = await self._exec_select_conditions_to_pydantic(Workspace, sql, conditions)
workspaces = await self._exec_select_conditions_to_pydantic(
Workspace, sql, conditions, should_raise=True
)
return workspaces[0] if workspaces else None

async def get_sessions(self) -> List[Session]:
Expand All @@ -410,7 +451,7 @@ async def get_active_workspace(self) -> Optional[ActiveWorkspace]:
sql = text(
"""
SELECT
w.id, w.name, s.id as session_id, s.last_update
w.id, w.name, w.system_prompt, s.id as session_id, s.last_update
FROM sessions s
INNER JOIN workspaces w ON w.id = s.active_workspace_id
"""
Expand Down Expand Up @@ -453,7 +494,11 @@ def init_session_if_not_exists(db_path: Optional[str] = None):
last_update=datetime.datetime.now(datetime.timezone.utc),
)
db_recorder = DbRecorder(db_path)
asyncio.run(db_recorder.update_session(session))
try:
asyncio.run(db_recorder.update_session(session))
except Exception as e:
logger.error(f"Failed to initialize session in DB: {e}")
return
logger.info("Session in DB initialized successfully.")


Expand Down
Loading

0 comments on commit ae612ab

Please sign in to comment.