Skip to content

Commit

Permalink
Upgrade to Pydantic v2 (#18)
Browse files Browse the repository at this point in the history
* [change] bump pydantic to v2 and sqlmodel to latest

* [change] bump ver

* [change] appease tests

* [change] use sa_type, sa_column_args and sa_column_kwargs instead of sa_column

* [change] pydantic ver to 2.5
  • Loading branch information
duynguyen158 authored Dec 4, 2023
1 parent cd6a12e commit 2b4f13a
Show file tree
Hide file tree
Showing 10 changed files with 216 additions and 98 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""column changes following pydantic upgrade to 2.0
Revision ID: 70e5e82e94c0
Revises: 46c271f8f06f # noqa: W291
Create Date: 2023-12-04 16:19:45.970626
"""
import pgvector.sqlalchemy
import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision = "70e5e82e94c0"
down_revision = "46c271f8f06f"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column("article", "site", existing_type=sa.VARCHAR(), nullable=False)
op.alter_column("embedding", "vector", existing_type=pgvector.sqlalchemy.Vector(dim=384), nullable=False)
op.alter_column("page", "url", existing_type=sa.VARCHAR(), nullable=False)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column("page", "url", existing_type=sa.VARCHAR(), nullable=True)
op.alter_column("embedding", "vector", existing_type=pgvector.sqlalchemy.Vector(dim=384), nullable=True)
op.alter_column("article", "site", existing_type=sa.VARCHAR(), nullable=True)
# ### end Alembic commands ###
9 changes: 5 additions & 4 deletions article_rec_db/models/article.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from typing import Annotated
from uuid import UUID

from sqlmodel import Column, Field, Relationship, String, UniqueConstraint
from sqlmodel import Field, Relationship, String, UniqueConstraint

from article_rec_db.sites import SiteName

from .helpers import SQLModel, UpdateTracked
from .helpers import UpdateTracked
from .page import Page


Expand All @@ -21,11 +21,12 @@ class Language(StrEnum):
SPANISH = "es"


class Article(SQLModel, UpdateTracked, table=True):
class Article(UpdateTracked, table=True):
__table_args__ = (UniqueConstraint("site", "id_in_site", name="article_site_idinsite_unique"),)
__mapper_args__ = {"polymorphic_identity": "article"}

page_id: Annotated[UUID, Field(primary_key=True, foreign_key="page.id")]
site: Annotated[SiteName, Field(sa_column=Column(String))]
site: Annotated[SiteName, Field(sa_type=String)]
id_in_site: str # ID of article in the partner site's internal system
title: str
published_at: datetime
Expand Down
8 changes: 4 additions & 4 deletions article_rec_db/models/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
from uuid import UUID

from pgvector.sqlalchemy import Vector # type: ignore
from sqlmodel import Column, Field, Relationship
from sqlmodel import Field, Relationship

from .article import Article
from .execution import Execution
from .helpers import AutoUUIDPrimaryKey, CreationTracked, SQLModel
from .helpers import AutoUUIDPrimaryKey, CreationTracked

# The maximum number of dimensions that the vector can have. Vectors with fewer dimensions will be padded with zeros.
MAX_EMBEDDING_DIMENSIONS = 384


class Embedding(SQLModel, AutoUUIDPrimaryKey, CreationTracked, table=True):
class Embedding(AutoUUIDPrimaryKey, CreationTracked, table=True):
article_id: Annotated[UUID, Field(foreign_key="article.page_id")]
execution_id: Annotated[UUID, Field(foreign_key="execution.id")]
vector: Annotated[list[float], Field(sa_column=Column(Vector(MAX_EMBEDDING_DIMENSIONS)))]
vector: Annotated[list[float], Field(sa_type=Vector(MAX_EMBEDDING_DIMENSIONS))]

# An embedding always corresonds to an article
article: Article = Relationship(back_populates="embeddings")
Expand Down
4 changes: 2 additions & 2 deletions article_rec_db/models/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sqlmodel import Relationship

from .helpers import AutoUUIDPrimaryKey, CreationTracked, SQLModel
from .helpers import AutoUUIDPrimaryKey, CreationTracked


class StrategyType(StrEnum):
Expand All @@ -17,7 +17,7 @@ class StrategyRecommendationType(StrEnum):
SOURCE_TARGET_NOT_INTERCHANGEABLE = "source_target_not_interchangeable"


class Execution(SQLModel, AutoUUIDPrimaryKey, CreationTracked, table=True):
class Execution(AutoUUIDPrimaryKey, CreationTracked, table=True):
"""
Log of training job task executions, each with respect to a single strategy.
"""
Expand Down
11 changes: 5 additions & 6 deletions article_rec_db/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@
from typing import Annotated, Optional
from uuid import UUID, uuid4

from pydantic import BaseModel
from sqlmodel import Column, DateTime, Field, SQLModel # noqa: F401
from sqlmodel import Field, SQLModel # noqa: F401


# Common fields as Pydantic model mixins
class AutoUUIDPrimaryKey(BaseModel):
class AutoUUIDPrimaryKey(SQLModel, table=False):
id: Annotated[UUID, Field(default_factory=uuid4, primary_key=True)]


class CreationTracked(BaseModel):
class CreationTracked(SQLModel, table=False):
db_created_at: Annotated[datetime, Field(default_factory=datetime.utcnow)]


class UpdateTracked(CreationTracked):
db_updated_at: Annotated[Optional[datetime], Field(sa_column=Column(DateTime, onupdate=datetime.utcnow))]
class UpdateTracked(CreationTracked, table=False):
db_updated_at: Annotated[Optional[datetime], Field(sa_column_kwargs={"onupdate": datetime.utcnow})]
8 changes: 4 additions & 4 deletions article_rec_db/models/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from uuid import UUID, uuid4

from pydantic import HttpUrl
from sqlmodel import Column, Field, Relationship, String
from sqlmodel import Field, Relationship, String

from .helpers import AutoUUIDPrimaryKey, SQLModel, UpdateTracked
from .helpers import AutoUUIDPrimaryKey, UpdateTracked


class Page(SQLModel, AutoUUIDPrimaryKey, UpdateTracked, table=True):
class Page(AutoUUIDPrimaryKey, UpdateTracked, table=True):
id: Annotated[UUID, Field(default_factory=uuid4, primary_key=True)]
url: Annotated[HttpUrl, Field(sa_column=Column(String, unique=True))]
url: Annotated[HttpUrl, Field(sa_type=String, unique=True)]

# An article is always a page, but a page is not always an article
# Techinically SQLModel considers Page the "many" in the many-to-one relationship, so this list will only ever have at most one element
Expand Down
19 changes: 4 additions & 15 deletions article_rec_db/models/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,14 @@
from uuid import UUID

from sqlalchemy import event
from sqlmodel import (
CheckConstraint,
Column,
Field,
Float,
Relationship,
UniqueConstraint,
)
from sqlmodel import CheckConstraint, Field, Relationship, UniqueConstraint

from .article import Article
from .execution import Execution, StrategyRecommendationType
from .helpers import AutoUUIDPrimaryKey, CreationTracked, SQLModel
from .helpers import AutoUUIDPrimaryKey, CreationTracked


class Recommendation(SQLModel, AutoUUIDPrimaryKey, CreationTracked, table=True):
class Recommendation(AutoUUIDPrimaryKey, CreationTracked, table=True):
"""
Usual recommendations have a source article (i.e., the one the reader is reading)
and a target article (i.e., the one the reader is recommended upon/after reading the source).
Expand All @@ -37,11 +30,7 @@ class Recommendation(SQLModel, AutoUUIDPrimaryKey, CreationTracked, table=True):
# Recommendation score, between 0 and 1. Top recs should have higher scores
score: Annotated[
float,
Field(
sa_column=Column(
Float, CheckConstraint("score >= 0 AND score <= 1", name="recommendation_score_between_0_and_1")
)
),
Field(sa_column_args=[CheckConstraint("score >= 0 AND score <= 1", name="recommendation_score_between_0_and_1")]),
]

# A recommendation always corresponds to a task execution
Expand Down
5 changes: 3 additions & 2 deletions article_rec_db/sites/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re

from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator

from .helpers import SiteName

Expand All @@ -10,7 +10,8 @@
class Site(BaseModel):
name: SiteName

@validator("name")
@field_validator("name")
@classmethod
def name_must_be_kebabcase(cls, value: SiteName) -> SiteName:
assert PATTERN_SITE_NAME_KEBAB.fullmatch(value) is not None, "Site name must be kebab-case"
return value
Expand Down
Loading

0 comments on commit 2b4f13a

Please sign in to comment.