Skip to content

Commit

Permalink
refactor: Improve entity loading with Pydantic model validation
Browse files Browse the repository at this point in the history
  • Loading branch information
arkohut committed Feb 21, 2025
1 parent 805c4d9 commit 44c5a69
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions memos/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def add_folders(library_id: int, folders: NewFoldersParam, db: Session) -> Libra
db.refresh(db_folder)

db_library = db.query(LibraryModel).filter(LibraryModel.id == library_id).first()
return Library(**db_library.__dict__)
return Library.model_validate(db_library, from_attributes=True)


def create_entity(
Expand Down Expand Up @@ -148,9 +148,8 @@ def create_entity(
)
db.add(entity_metadata)
db.commit()
db.refresh(db_entity)

return Entity(**db_entity.__dict__)

return get_entity_by_id(db_entity.id, db, include_relationships=True)


def _load_entity_with_relationships(db: Session, entity_id: int) -> EntityModel:
Expand All @@ -177,7 +176,7 @@ def get_entity_by_id(entity_id: int, db: Session, include_relationships: bool =
if db_entity is None:
return None

return Entity(**db_entity.__dict__)
return Entity.model_validate(db_entity, from_attributes=True)


def get_entities_of_folder(
Expand Down Expand Up @@ -357,11 +356,15 @@ def add_plugin_to_library(library_id: int, plugin_id: int, db: Session):
def find_entities_by_ids(entity_ids: List[int], db: Session) -> List[Entity]:
db_entities = (
db.query(EntityModel)
.options(joinedload(EntityModel.metadata_entries), joinedload(EntityModel.tags))
.options(
joinedload(EntityModel.metadata_entries),
joinedload(EntityModel.tags),
joinedload(EntityModel.plugin_status),
)
.filter(EntityModel.id.in_(entity_ids))
.all()
)
return [Entity(**entity.__dict__) for entity in db_entities]
return [Entity.model_validate(entity, from_attributes=True) for entity in db_entities]


def update_entity(
Expand Down Expand Up @@ -586,7 +589,11 @@ def list_entities(
) -> List[Entity]:
query = (
db.query(EntityModel)
.options(joinedload(EntityModel.metadata_entries), joinedload(EntityModel.tags))
.options(
joinedload(EntityModel.metadata_entries),
joinedload(EntityModel.tags),
joinedload(EntityModel.plugin_status),
)
.filter(EntityModel.file_type_group == "image")
)

Expand All @@ -604,7 +611,7 @@ def list_entities(

entities = query.order_by(EntityModel.file_created_at.desc()).limit(limit).all()

return [Entity(**entity.__dict__) for entity in entities]
return [Entity.model_validate(entity, from_attributes=True) for entity in entities]


def get_entity_context(
Expand All @@ -617,6 +624,11 @@ def get_entity_context(
# First get the target entity to get its timestamp
target_entity = (
db.query(EntityModel)
.options(
joinedload(EntityModel.metadata_entries),
joinedload(EntityModel.tags),
joinedload(EntityModel.plugin_status),
)
.filter(
EntityModel.id == entity_id,
EntityModel.library_id == library_id,
Expand All @@ -641,7 +653,7 @@ def get_entity_context(
.all()
)
# Reverse the list to get chronological order and convert to Entity models
prev_entities = [Entity(**entity.__dict__) for entity in prev_entities][::-1]
prev_entities = [Entity.model_validate(entity, from_attributes=True) for entity in prev_entities][::-1]

# Get next entities
next_entities = []
Expand All @@ -657,7 +669,7 @@ def get_entity_context(
.all()
)
# Convert to Entity models
next_entities = [Entity(**entity.__dict__) for entity in next_entities]
next_entities = [Entity.model_validate(entity, from_attributes=True) for entity in next_entities]

return prev_entities, next_entities

Expand Down

0 comments on commit 44c5a69

Please sign in to comment.