Skip to content

Commit

Permalink
Merge pull request #70 from arkohut/test-fix
Browse files Browse the repository at this point in the history
Test fix
  • Loading branch information
arkohut authored Feb 20, 2025
2 parents 9e924de + 8125150 commit ec169fb
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 57 deletions.
45 changes: 32 additions & 13 deletions memos/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,31 @@ def create_entity(
return Entity(**db_entity.__dict__)


def get_entity_by_id(entity_id: int, db: Session) -> Entity | None:
return db.query(EntityModel).filter(EntityModel.id == entity_id).first()
def _load_entity_with_relationships(db: Session, entity_id: int) -> EntityModel:
"""Helper function to load entity with all relationships"""
return (
db.query(EntityModel)
.options(
joinedload(EntityModel.metadata_entries),
joinedload(EntityModel.tags),
joinedload(EntityModel.plugin_status),
)
.filter(EntityModel.id == entity_id)
.first()
)


def get_entity_by_id(entity_id: int, db: Session, include_relationships: bool = False) -> Entity | None:
"""Get entity by ID with optional relationships"""
if include_relationships:
db_entity = _load_entity_with_relationships(db, entity_id)
else:
db_entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()

if db_entity is None:
return None

return Entity(**db_entity.__dict__)


def get_entities_of_folder(
Expand Down Expand Up @@ -400,9 +423,8 @@ def update_entity(
db_entity.metadata_entries.append(entity_metadata)

db.commit()
db.refresh(db_entity)

return Entity(**db_entity.__dict__)
return get_entity_by_id(entity_id, db, include_relationships=True)


def touch_entity(entity_id: int, db: Session) -> bool:
Expand All @@ -421,7 +443,7 @@ def update_entity_tags(
tags: List[str],
db: Session,
) -> Entity:
db_entity = get_entity_by_id(entity_id, db)
db_entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()
if not db_entity:
raise ValueError(f"Entity with id {entity_id} not found")

Expand All @@ -446,13 +468,12 @@ def update_entity_tags(
db_entity.last_scan_at = func.now()

db.commit()
db.refresh(db_entity)

return Entity(**db_entity.__dict__)
return get_entity_by_id(entity_id, db, include_relationships=True)


def add_new_tags(entity_id: int, tags: List[str], db: Session) -> Entity:
db_entity = get_entity_by_id(entity_id, db)
db_entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()
if not db_entity:
raise ValueError(f"Entity with id {entity_id} not found")

Expand All @@ -477,17 +498,16 @@ def add_new_tags(entity_id: int, tags: List[str], db: Session) -> Entity:
db_entity.last_scan_at = func.now()

db.commit()
db.refresh(db_entity)

return Entity(**db_entity.__dict__)
return get_entity_by_id(entity_id, db, include_relationships=True)


def update_entity_metadata_entries(
entity_id: int,
updated_metadata: List[EntityMetadataParam],
db: Session,
) -> Entity:
db_entity = get_entity_by_id(entity_id, db)
db_entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()

existing_metadata_entries = (
db.query(EntityMetadataModel)
Expand Down Expand Up @@ -532,9 +552,8 @@ def update_entity_metadata_entries(
db_entity.last_scan_at = func.now()

db.commit()
db.refresh(db_entity)

return Entity(**db_entity.__dict__)
return get_entity_by_id(entity_id, db, include_relationships=True)


def get_plugin_by_id(plugin_id: int, db: Session) -> Plugin | None:
Expand Down
3 changes: 2 additions & 1 deletion memos/fixtures/patch_entity_metadata_response.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@
}
],
"size": 5678,
"tags": []
"tags": [],
"plugin_status": []
}
7 changes: 6 additions & 1 deletion memos/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,15 @@ class Entity(BaseModel):
metadata_entries: List[EntityMetadata] = []
plugin_status: List[EntityPluginStatus] = []

@property
def tag_names(self) -> List[str]:
"""Get list of tag names from the tags relationship"""
return [tag.name for tag in self.tags]

model_config = ConfigDict(
from_attributes=True,
json_encoders={
datetime: lambda dt: dt.isoformat(),
datetime: lambda dt: dt.replace(tzinfo=None).isoformat(),
}
)

Expand Down
4 changes: 2 additions & 2 deletions memos/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def update_entity_index(self, entity_id: int, db: Session):
try:
from .crud import get_entity_by_id

entity = get_entity_by_id(entity_id, db)
entity = get_entity_by_id(entity_id, db, include_relationships=True)
if not entity:
raise ValueError(f"Entity with id {entity_id} not found")

Expand Down Expand Up @@ -693,7 +693,7 @@ def update_entity_index(self, entity_id: int, db: Session):
try:
from .crud import get_entity_by_id

entity = get_entity_by_id(entity_id, db)
entity = get_entity_by_id(entity_id, db, include_relationships=True)
if not entity:
raise ValueError(f"Entity with id {entity_id} not found")

Expand Down
20 changes: 16 additions & 4 deletions memos/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@

# Initialize search provider based on database URL
search_provider = create_search_provider(settings.database_url)
app.state.search_provider = search_provider

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

Expand Down Expand Up @@ -249,6 +250,7 @@ async def new_entity(
plugins: Annotated[List[int] | None, Query()] = None,
trigger_webhooks_flag: bool = True,
update_index: bool = False,
search_provider=Depends(lambda: app.state.search_provider),
):
library = crud.get_library_by_id(library_id, db)
if library is None:
Expand Down Expand Up @@ -347,7 +349,7 @@ def get_entities_by_filepaths(

@app.get("/entities/{entity_id}", response_model=Entity, tags=["entity"])
def get_entity_by_id(entity_id: int, db: Session = Depends(get_db)):
entity = crud.get_entity_by_id(entity_id, db)
entity = crud.get_entity_by_id(entity_id, db, include_relationships=True)
if entity is None:
return JSONResponse(
content={"detail": "Entity not found"},
Expand All @@ -364,7 +366,7 @@ def get_entity_by_id(entity_id: int, db: Session = Depends(get_db)):
def get_entity_by_id_in_library(
library_id: int, entity_id: int, db: Session = Depends(get_db)
):
entity = crud.get_entity_by_id(entity_id, db)
entity = crud.get_entity_by_id(entity_id, db, include_relationships=True)
if entity is None or entity.library_id != library_id:
return JSONResponse(
content={"detail": "Entity not found"},
Expand All @@ -382,6 +384,7 @@ async def update_entity(
trigger_webhooks_flag: bool = False,
plugins: Annotated[List[int] | None, Query()] = None,
update_index: bool = False,
search_provider=Depends(lambda: app.state.search_provider),
):
with logfire.span("fetch entity {entity_id=}", entity_id=entity_id):
entity = crud.get_entity_by_id(entity_id, db)
Expand Down Expand Up @@ -430,7 +433,11 @@ def update_entity_last_scan_at(entity_id: int, db: Session = Depends(get_db)):
status_code=status.HTTP_204_NO_CONTENT,
tags=["entity"],
)
def update_index(entity_id: int, db: Session = Depends(get_db)):
def update_index(
entity_id: int,
db: Session = Depends(get_db),
search_provider=Depends(lambda: app.state.search_provider),
):
"""
Update the FTS and vector indexes for an entity.
"""
Expand All @@ -449,7 +456,11 @@ def update_index(entity_id: int, db: Session = Depends(get_db)):
status_code=status.HTTP_204_NO_CONTENT,
tags=["entity"],
)
async def batch_update_index(request: BatchIndexRequest, db: Session = Depends(get_db)):
async def batch_update_index(
request: BatchIndexRequest,
db: Session = Depends(get_db),
search_provider=Depends(lambda: app.state.search_provider),
):
"""
Batch update the FTS and vector indexes for multiple entities.
"""
Expand Down Expand Up @@ -688,6 +699,7 @@ async def search_entities_v2(
app_names: str = Query(None, description="Comma-separated list of app names"),
facet: bool = Query(None, description="Include facet in the search results"),
db: Session = Depends(get_db),
search_provider=Depends(lambda: app.state.search_provider),
):
library_ids = [int(id) for id in library_ids.split(",")] if library_ids else None
app_name_list = (
Expand Down
82 changes: 46 additions & 36 deletions memos/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import pytest
from datetime import datetime
from copy import deepcopy

from fastapi.testclient import TestClient
from sqlalchemy import create_engine, event, text
Expand All @@ -27,22 +28,62 @@
from memos.models import Base
from memos.databases.initializers import SQLiteInitializer
from memos.config import settings
from memos.search import create_search_provider


# Create a test settings object by copying the original settings
test_settings = deepcopy(settings)
test_settings.database_path = "sqlite:///:memory:"

# Use SQLite for testing by default
test_engine = create_engine(
"sqlite:///:memory:",
test_settings.database_url,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)

# Initialize SQLite with the test engine
test_initializer = SQLiteInitializer(test_engine, settings)
test_initializer = SQLiteInitializer(test_engine, test_settings)
test_initializer.init_extensions()

# Create and override the search provider with test database URL
test_search_provider = create_search_provider(test_settings.database_url)
app.state.search_provider = test_search_provider # Store in app state

TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=test_engine)


def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()


app.dependency_overrides[get_db] = override_get_db


# Setup a fixture for the FastAPI test client
@pytest.fixture
def client():
# Create all base tables
Base.metadata.create_all(bind=test_engine)

# Create FTS and Vec tables for SQLite
test_initializer.init_specific_features()

with TestClient(app) as client:
yield client

# Clean up database
Base.metadata.drop_all(bind=test_engine)
with test_engine.connect() as conn:
conn.execute(text("DROP TABLE IF EXISTS entities_fts"))
conn.execute(text("DROP TABLE IF EXISTS entities_vec_v2"))
conn.commit()


def load_fixture(filename):
with open(Path(__file__).parent / "fixtures" / filename, "r") as file:
return json.load(file)
Expand Down Expand Up @@ -95,37 +136,6 @@ def setup_library_with_entity(client):
return library_id, folder_id, entity_id


def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()


app.dependency_overrides[get_db] = override_get_db


# Setup a fixture for the FastAPI test client
@pytest.fixture
def client():
# Create all base tables
Base.metadata.create_all(bind=test_engine)

# Create FTS and Vec tables for SQLite
test_initializer.init_specific_features()

with TestClient(app) as client:
yield client

# Clean up database
Base.metadata.drop_all(bind=test_engine)
with test_engine.connect() as conn:
conn.execute(text("DROP TABLE IF EXISTS entities_fts"))
conn.execute(text("DROP TABLE IF EXISTS entities_vec_v2"))
conn.commit()


# Test the new_library endpoint
def test_new_library(client):
library_param = NewLibraryParam(name="Test Library")
Expand Down Expand Up @@ -682,10 +692,10 @@ def test_update_entity_tags(client):

# Check the response data
updated_entity_data = update_response.json()
print("\nResponse data:", json.dumps(updated_entity_data, indent=2))
assert "tags" in updated_entity_data
assert sorted([t["name"] for t in updated_entity_data["tags"]]) == sorted(
tags, key=str
)
received_tags = sorted([t["name"] for t in updated_entity_data["tags"]])
assert received_tags == sorted(tags)


def test_patch_entity_metadata_entries(client):
Expand Down

0 comments on commit ec169fb

Please sign in to comment.