diff --git a/poetry.lock b/poetry.lock index 94b79d8..ffe8475 100644 --- a/poetry.lock +++ b/poetry.lock @@ -129,6 +129,17 @@ d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] +[[package]] +name = "cachetools" +version = "5.4.0" +description = "Extensible memoizing collections and decorators" +optional = false +python-versions = ">=3.7" +files = [ + {file = "cachetools-5.4.0-py3-none-any.whl", hash = "sha256:3ae3b49a3d5e28a77a0be2b37dbcb89005058959cb2323858c2657c4a8cab474"}, + {file = "cachetools-5.4.0.tar.gz", hash = "sha256:b8adc2e7c07f105ced7bc56dbb6dfbe7c4a00acce20e2227b3f355be89bc6827"}, +] + [[package]] name = "certifi" version = "2024.2.2" @@ -958,4 +969,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "0f24c309c7c5fdaf66c578049ca814891416d95e8626f398ece37d31988cc584" +content-hash = "4de07851382896bbc885fc723f808163482baa26fae903d86626594f450c8e34" diff --git a/pyproject.toml b/pyproject.toml index 8d96fef..bdcda90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ python = "^3.12" fastapi = "^0.104.1" httpx = "^0.27.0" asyncpg = "^0.29.0" +cachetools = "^5.4.0" [tool.poetry.group.dev.dependencies] pytest = "^8.1.1" diff --git a/src/app/api/v1.py b/src/app/api/api.py similarity index 53% rename from src/app/api/v1.py rename to src/app/api/api.py index c36c17f..7ef35a6 100644 --- a/src/app/api/v1.py +++ b/src/app/api/api.py @@ -1,9 +1,10 @@ import fastapi -from app.api import users +from app.api import users, ratecards -def setup_api_v1(app: fastapi.FastAPI): +def setup_api(app: fastapi.FastAPI): api_router = fastapi.APIRouter() api_router.include_router(users.router, prefix="/users", tags=["users"]) + api_router.include_router(ratecards.router, prefix="/ratecards", tags=["ratecards"]) app.include_router(api_router, prefix="/api/v1") diff --git a/src/app/api/ratecards.py b/src/app/api/ratecards.py new file mode 100644 index 0000000..e01dc87 --- /dev/null +++ b/src/app/api/ratecards.py @@ -0,0 +1,52 @@ +from uuid import UUID + +import fastapi + +from app.datalayer.ratecard import ( + RateCardRecord, + RateCardRepository, + RateCardPatchPayload, + RateCardCreatePayload, +) + +router = fastapi.APIRouter() + + +@router.get("") +async def get_all( + data_service: RateCardRepository = fastapi.Depends(), +) -> list[RateCardRecord]: + return await data_service.get_all() + + +@router.post("") +async def create( + payload: RateCardCreatePayload, + data_service: RateCardRepository = fastapi.Depends(), +) -> UUID: + return await data_service.create(payload) + + +@router.get("/{pk}") +async def get_by_id( + pk: UUID, + data_service: RateCardRepository = fastapi.Depends(), +) -> None: + return await data_service.get_by_id(pk) + + +@router.patch("/{pk}") +async def patch_by_id( + pk: UUID, + payload: RateCardPatchPayload, + data_service: RateCardRepository = fastapi.Depends(), +) -> None: + await data_service.patch_by_id(pk, payload) + + +@router.delete("/{pk}") +async def delete_by_id( + pk: UUID, + data_service: RateCardRepository = fastapi.Depends(), +) -> None: + await data_service.delete_by_id(pk) diff --git a/src/app/api/users.py b/src/app/api/users.py index 2588514..ebaba43 100644 --- a/src/app/api/users.py +++ b/src/app/api/users.py @@ -1,38 +1,12 @@ import fastapi -from app.datalayer.model import UserAccount -from app.datalayer.users import UserService +from app.datalayer.user_account import UserAccountRepository, UserAccountRecord router = fastapi.APIRouter() @router.get("") async def get_users( - data_service: UserService = fastapi.Depends(), - page: int = 0, - size: int = 10, -) -> list[UserAccount]: - # Read users - users = await data_service.read_users() - - # Create users - if not users: - await data_service.create_user("Alice", "alice@example.com") - await data_service.create_user("Bob", "bob@example.com") - - # Read users - users = await data_service.read_users() - - # # Update user - # user_id = users[0].id - # await data_service.update_user(user_id, "John") - # - # # Read users again - # await data_service.read_users() - # - # # Delete user - # await data_service.delete_user(2) - - # Read users again - - return users + data_service: UserAccountRepository = fastapi.Depends(), +) -> list[UserAccountRecord]: + return await data_service.get_all() diff --git a/src/app/app.py b/src/app/app.py index 907ce84..d3da087 100644 --- a/src/app/app.py +++ b/src/app/app.py @@ -5,9 +5,9 @@ import fastapi from starlette.responses import RedirectResponse -from app.api.v1 import setup_api_v1 +from app.api.api import setup_api from app.core.cors import setup_cors -from app.core.datasource import setup_datasource +from app.core.datasource import create_connection_pool_manager from app.core.error_handlers import setup_error_handlers from app.core.logging import setup_logging from app.core.migration import run_migrations @@ -18,14 +18,20 @@ @contextlib.asynccontextmanager async def _lifespan(app: "App"): app.logger.info(f"Starting ...") - await app.datasource.connect() - await run_migrations(app.datasource) + + await app.cpm.get_pool() + await run_migrations(app.cpm) # ...other startup code + app.logger.info("✅ Started") - yield + + yield # this is where the app is running + app.logger.info("Shutting down ...") - await app.datasource.disconnect() + + await app.cpm.close() # ...other shutdown code + app.logger.info("🛑 Shutdown") @@ -53,10 +59,10 @@ def __init__( setup_cors(self) # Setup API routes - setup_api_v1(self) + setup_api(self) # Setup DB data source - self.datasource = setup_datasource(self) + self.cpm = create_connection_pool_manager(self) if hasattr(self, "docs_url") and self.docs_url: diff --git a/src/app/core/datasource.py b/src/app/core/datasource.py index 3261806..21f1e22 100644 --- a/src/app/core/datasource.py +++ b/src/app/core/datasource.py @@ -1,39 +1,9 @@ -import logging -import os - -import asyncpg import fastapi - -# Simple class that takes care of setting up and tearing down the connection pool. -# Useful to decouple for the actual connection pool object. -# It's meant to be used in FastAPI lifespan (see https://fastapi.tiangolo.com/advanced/events/#lifespan). -class DataSource: - def __init__(self, app: fastapi.FastAPI): - self._pool: asyncpg.Pool | None = None - self.logger = logging.getLogger(__name__) - - async def connect(self): - if not self._pool: - postgres_url = os.getenv("POSTGRES_URL") - assert postgres_url, "missing POSTGRES_URL" - self._pool = await asyncpg.create_pool(dsn=postgres_url) - self.logger.info("DataSource created") - - async def disconnect(self): - if self._pool: - await self._pool.close() - self.logger.info("DataSource closed") - - @property - def pool(self) -> asyncpg.Pool: - if not self._pool: - raise RuntimeError("DataSource not initialized, you need to call `await connect()` in the lifespan event, " - "see see https://fastapi.tiangolo.com/advanced/events/#lifespan.") - return self._pool +from app.libs.datalayer.connectionpool.manager import ConnectionPoolManager -def setup_datasource(app: fastapi.FastAPI) -> DataSource: - return DataSource(app) +def create_connection_pool_manager(app: fastapi.FastAPI) -> ConnectionPoolManager: + return ConnectionPoolManager() diff --git a/src/app/core/errors.py b/src/app/core/errors.py new file mode 100644 index 0000000..ff29154 --- /dev/null +++ b/src/app/core/errors.py @@ -0,0 +1,34 @@ +import fastapi +from starlette import status + + +class BadRequestException(fastapi.HTTPException): + def __init__(self, message: str) -> None: + super().__init__( + status_code=status.HTTP_400_BAD_REQUEST, + detail=message, + ) + + +class UnauthorizedException(fastapi.HTTPException): + def __init__(self, message: str) -> None: + super().__init__( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=message, + ) + + +class ForbiddenException(fastapi.HTTPException): + def __init__(self, message: str) -> None: + super().__init__( + status_code=status.HTTP_403_FORBIDDEN, + detail=message, + ) + + +class NotFoundException(fastapi.HTTPException): + def __init__(self, message: str) -> None: + super().__init__( + status_code=status.HTTP_404_NOT_FOUND, + detail=message, + ) diff --git a/src/app/core/migration.py b/src/app/core/migration.py index e584d58..5e06818 100644 --- a/src/app/core/migration.py +++ b/src/app/core/migration.py @@ -1,20 +1,12 @@ -import os.path +import os -from app.core.datasource import DataSource - - -async def run_migrations(datasource: DataSource): - async with datasource.pool.acquire(timeout=2) as connection: - await connection.execute("create table if not exists migration (name text primary key, executed_at timestamp)") - records = await connection.fetch("select * from migration") - print(records) - if len(records) > 0: - print("Migration table already exists") - else: - with open(os.path.join(os.path.dirname(__file__), "../../schema.sql")) as f: - content = f.read() - print(content) - await connection.execute(content) - await connection.execute("insert into migration (name, executed_at) values ('migration.sql', now())") +from app.dirs import resources_dir +from app.libs.datalayer.connectionpool.manager import ConnectionPoolManager +from app.libs.migrationtool import MigrationTool +async def run_migrations(provider: ConnectionPoolManager): + pool = await provider.get_pool() + migration_tool = MigrationTool(pool) + migrations_dir = os.path.join(resources_dir, "migrations") + await migration_tool.migrate(migrations_dir) diff --git a/src/app/datalayer/db.py b/src/app/datalayer/db.py deleted file mode 100644 index b0aebc6..0000000 --- a/src/app/datalayer/db.py +++ /dev/null @@ -1,19 +0,0 @@ -import contextlib - -import fastapi - - -class DB: - def __init__(self, request: fastapi.Request): - self.pool = request.app.datasource.pool - - @contextlib.asynccontextmanager - async def transaction(self): - async with self.pool.acquire(timeout=2) as connection: - async with connection.transaction(): - yield connection - - @contextlib.asynccontextmanager - async def connection(self): - async with self.pool.acquire(timeout=2) as connection: - yield connection diff --git a/src/app/datalayer/model.py b/src/app/datalayer/model.py index 15eb52e..f0ebaa5 100644 --- a/src/app/datalayer/model.py +++ b/src/app/datalayer/model.py @@ -1,127 +1,127 @@ +import enum from datetime import datetime -from enum import Enum from typing import Optional from uuid import UUID import pydantic -class AdjustedHourTypeEnum(str, Enum): +class AdjustedHourTypeEnum(enum.StrEnum): hours = "hours" percentage = "percentage" -class BenefitTypeEnum(str, Enum): +class BenefitTypeEnum(enum.StrEnum): usd_value = "usd_value" percentage = "percentage" -class TaxTypeEnum(str, Enum): +class TaxTypeEnum(enum.StrEnum): usd_value = "usd_value" percentage = "percentage" -class UserRoleEnum(str, Enum): +class UserRoleEnum(enum.StrEnum): client = "client" admin = "admin" super_admin = "super_admin" -class AdjustedHour(pydantic.BaseModel): - id: UUID = pydantic.Field(default_factory=UUID) +class AdjustedHourRecord(pydantic.BaseModel): + id: UUID title: str value: float type: Optional[AdjustedHourTypeEnum] = None - company_id: Optional[UUID] = None + company_id: UUID | None = None _table_name: str = "adjusted_hours" -class Benefit(pydantic.BaseModel): - id: UUID = pydantic.Field(default_factory=UUID) +class BenefitRecord(pydantic.BaseModel): + id: UUID title: str value: float type: BenefitTypeEnum = BenefitTypeEnum.usd_value - company_id: Optional[UUID] = None + company_id: UUID | None = None _table_name: str = "benefits" -class BenefitProfile(pydantic.BaseModel): - id: UUID = pydantic.Field(default_factory=UUID) +class BenefitProfileRecord(pydantic.BaseModel): + id: UUID title: str - company_id: Optional[UUID] = None + company_id: UUID | None = None _table_name: str = "benefit_profiles" -class BenefitProfileBenefit(pydantic.BaseModel): +class BenefitProfileBenefitRecord(pydantic.BaseModel): benefit_profile_id: UUID benefit_id: UUID _table_name: str = "benefit_profile_benefits" -class Company(pydantic.BaseModel): +class CompanyRecord(pydantic.BaseModel): name: str company_code: str - id: UUID = pydantic.Field(default_factory=UUID) + id: UUID _table_name: str = "companies" -class JobTitle(pydantic.BaseModel): - id: UUID = pydantic.Field(default_factory=UUID) +class JobTitleRecord(pydantic.BaseModel): + id: UUID title: str - company_id: Optional[UUID] = None + company_id: UUID | None = None _table_name: str = "job_titles" -class Location(pydantic.BaseModel): - id: UUID = pydantic.Field(default_factory=UUID) +class LocationRecord(pydantic.BaseModel): + id: UUID country: str state: str - company_id: Optional[UUID] = None + company_id: UUID | None = None _table_name: str = "locations" -class LocationAdjustedHour(pydantic.BaseModel): +class LocationAdjustedHourRecord(pydantic.BaseModel): location_id: UUID adjusted_hour_id: UUID _table_name: str = "location_adjusted_hours" -class LocationBenefit(pydantic.BaseModel): +class LocationBenefitRecord(pydantic.BaseModel): location_id: UUID benefit_id: UUID _table_name: str = "location_benefits" -class LocationTax(pydantic.BaseModel): +class LocationTaxRecord(pydantic.BaseModel): location_id: UUID tax_id: UUID _table_name: str = "location_taxes" -class PasswordResetToken(pydantic.BaseModel): - id: UUID = pydantic.Field(default_factory=UUID) +class PasswordResetTokenRecord(pydantic.BaseModel): + id: UUID token: str created_at: datetime = pydantic.Field(default_factory=datetime.utcnow) expires_at: datetime - user_id: Optional[UUID] = None + user_id: UUID | None = None _table_name: str = "password_reset_tokens" -class Position(pydantic.BaseModel): - id: UUID = pydantic.Field(default_factory=UUID) +class PositionRecord(pydantic.BaseModel): + id: UUID title: str - company_id: Optional[UUID] = None + company_id: UUID | None = None hourly_rate: Optional[float] = None headcount: Optional[int] = None is_manual: bool = False @@ -129,41 +129,30 @@ class Position(pydantic.BaseModel): _table_name: str = "positions" -class QuickQuote(pydantic.BaseModel): - id: UUID = pydantic.Field(default_factory=UUID) +class QuickQuoteRecord(pydantic.BaseModel): + id: UUID title: str created_at: datetime = pydantic.Field(default_factory=datetime.utcnow) updated_at: datetime = pydantic.Field(default_factory=datetime.utcnow) - company_id: Optional[UUID] = None - created_by_user_id: Optional[UUID] = None - ratecard_id: Optional[UUID] = None + company_id: UUID | None = None + created_by_user_id: UUID | None = None + ratecard_id: UUID | None = None _table_name: str = "quick_quotes" -class QuickQuoteItem(pydantic.BaseModel): - id: UUID = pydantic.Field(default_factory=UUID) +class QuickQuoteItemRecord(pydantic.BaseModel): + id: UUID resource_id: str dedication_percentage: float - quick_quote_id: Optional[UUID] = None + quick_quote_id: UUID | None = None billable_rate: float = 0.0 _table_name: str = "quick_quote_items" -class RateCard(pydantic.BaseModel): - id: UUID = pydantic.Field(default_factory=UUID) - net_margin: float = 0.0 - overhead: float = 0.0 - selling_cost: float = 0.0 - company_id: Optional[UUID] = None - title: Optional[str] = None - - _table_name: str = "rate_cards" - - -class Resource(pydantic.BaseModel): - id: UUID = pydantic.Field(default_factory=UUID) +class ResourceRecord(pydantic.BaseModel): + id: UUID first_name: str last_name: str salary: float @@ -173,90 +162,76 @@ class Resource(pydantic.BaseModel): weekly_hours: float = 40.0 total_adjusted_hours: Optional[float] = None resource_cost_per_adjusted_hour: Optional[float] = None - position_id: Optional[UUID] = None - location_id: Optional[UUID] = None - company_id: Optional[UUID] = None - team_id: Optional[UUID] = None - manager_id: Optional[UUID] = None - employee_code: Optional[str] = None + position_id: UUID | None = None + location_id: UUID | None = None + company_id: UUID | None = None + team_id: UUID | None = None + manager_id: UUID | None = None + employee_code: str | None = None total_annual_hours: Optional[float] = None - location_title: Optional[str] = None + location_title: str | None = None total_working_hours: Optional[float] = None _table_name: str = "resources" -class ResourceSkill(pydantic.BaseModel): +class ResourceSkillRecord(pydantic.BaseModel): resource_id: UUID skill_id: UUID _table_name: str = "resource_skills" -class ResourceTax(pydantic.BaseModel): - id: UUID = pydantic.Field(default_factory=UUID) +class ResourceTaxRecord(pydantic.BaseModel): + id: UUID total_annual_amount: float - resource_id: Optional[UUID] = None - tax_id: Optional[UUID] = None + resource_id: UUID | None = None + tax_id: UUID | None = None _table_name: str = "resource_taxes" -class Skill(pydantic.BaseModel): - id: UUID = pydantic.Field(default_factory=UUID) +class SkillRecord(pydantic.BaseModel): + id: UUID title: str - company_id: Optional[UUID] = None + company_id: UUID | None = None _table_name: str = "skills" -class Tax(pydantic.BaseModel): - id: UUID = pydantic.Field(default_factory=UUID) +class TaxRecord(pydantic.BaseModel): + id: UUID title: str value: float type: TaxTypeEnum = TaxTypeEnum.usd_value - company_id: Optional[UUID] = None + company_id: UUID | None = None _table_name: str = "taxes" -class Team(pydantic.BaseModel): - id: UUID = pydantic.Field(default_factory=UUID) +class TeamRecord(pydantic.BaseModel): + id: UUID title: str - company_id: Optional[UUID] = None + company_id: UUID | None = None _table_name: str = "teams" -class TeamBundle(pydantic.BaseModel): - id: UUID = pydantic.Field(default_factory=UUID) +class TeamBundleRecord(pydantic.BaseModel): + id: UUID title: str created_at: datetime = pydantic.Field(default_factory=datetime.utcnow) updated_at: datetime = pydantic.Field(default_factory=datetime.utcnow) - company_id: Optional[UUID] = None - created_by_user_id: Optional[UUID] = None + company_id: UUID | None = None + created_by_user_id: UUID | None = None _table_name: str = "team_bundles" -class TeamBundleItem(pydantic.BaseModel): - id: UUID = pydantic.Field(default_factory=UUID) +class TeamBundleItemRecord(pydantic.BaseModel): + id: UUID position_id: str - team_bundle_id: Optional[UUID] = None + team_bundle_id: UUID | None = None head_count: Optional[float] = None _table_name: str = "team_bundle_items" - - -class UserAccount(pydantic.BaseModel): - id: UUID = pydantic.Field(default_factory=UUID) - full_name: Optional[str] = None - email: str - password: str - phone_number: Optional[str] = None - role: UserRoleEnum = UserRoleEnum.client - created_at: datetime = pydantic.Field(default_factory=datetime.utcnow) - updated_at: datetime = pydantic.Field(default_factory=datetime.utcnow) - company_id: Optional[UUID] = None - - _table_name: str = "user_accounts" diff --git a/src/app/datalayer/modelutils.py b/src/app/datalayer/modelutils.py new file mode 100644 index 0000000..d4be8f4 --- /dev/null +++ b/src/app/datalayer/modelutils.py @@ -0,0 +1,132 @@ +import typing +from typing import Any + +import pydantic + +from app.libs.datalayer.metadata.provider import TableMetadata + + +def get_columns_from_type(model_cls: typing.Type[pydantic.BaseModel]): + return [f for f in model_cls.__annotations__.keys() if f and not f.startswith("_")] + + +def get_columns(model: pydantic.BaseModel): + return [f for f in model.model_fields_set if f and not f.startswith("_")] + + +async def select(table_name: str, record_cls: typing.Type[pydantic.BaseModel]) -> str: + columns = ",".join(get_columns_from_type(record_cls)) + return f"SELECT {columns} FROM {table_name}" + + +async def insert(table_name: str, model: pydantic.BaseModel) -> tuple[str, list[Any]]: + fields = get_columns(model) + columns = ",".join(fields) + values = [] + values_str = [] + raw = model.model_dump() + for i, field in enumerate(fields, start=1): + values = raw.get(field) + values_str.insert(f"${i}") + return f"INSERT INTO {table_name}({columns}) VALUES({values_str})", values + + +async def update(table_name: str, model: pydantic.BaseModel) -> tuple[str, list[Any]]: + fields = get_columns(model) + columns = ",".join(fields) + values = [] + values_str = [] + raw = model.model_dump() + for i, field in enumerate(fields, start=1): + values = raw.get(field) + values_str.insert(f"${i}") + return f"UPDATE {table_name} SET ({values_str})", values + + +async def delete(self, user_id): + async with self.db.transaction() as connection: + await connection.execute( + "DELETE FROM user_account WHERE id=$1", + user_id, + ) + print(f"User {user_id} deleted.") + + + +def prepare_insert_from_meta( + meta: TableMetadata, + obj: typing.Any, +) -> tuple[str, list[Any]]: + + column_names = [] + flat_values = [] + row_values = [] + + for column_name, column_value in obj.items(): + if column_name not in meta.column_names: + raise ValueError( + f"Column {column_name} not found in table {meta.table_name}" + ) + if column_name == meta.pk_name: + continue + column_names.append(column_name) + flat_values.append(column_value) + row_values.append(f"${len(flat_values)}") + + if not row_values: + raise ValueError(f"Nothing to insert") + + return ( + f"INSERT INTO {meta.table}({','.join(column_names)}) VALUES ({row_values});", + flat_values, + ) + + +def prepare_update_from_meta( + meta: TableMetadata, + obj: dict, +) -> tuple[str, list[Any]]: + + pk_value = obj.get(meta.pk_name, None) + if not pk_value: + raise ValueError(f"Missing PK value") + + flat_values = [pk_value] + setters = [] + + for column_name, column_value in obj.items(): + if column_name not in meta.column_names: + raise ValueError( + f"Column {column_name} not found in table {meta.table_name}" + ) + if column_name == meta.pk_name: + continue + flat_values.append(column_value) + setters.append(f"{column_name}=${len(flat_values)}") + + if not setters: + raise ValueError(f"Nothing to set") + + return ( + f"UPDATE {meta.table_name} SET {','.join(setters)} WHERE {meta.pk_name} = $1;", + flat_values, + ) + + +def prepare_insert_many( + meta: TableMetadata, + values: list[typing.Any], +) -> tuple[str, list[Any]]: + flat_values = [] + row_values = [] + for row_i, row in enumerate(values, start=0): + values_placeholders = [] + for col_i, col in enumerate(row, start=1): + i = col_i + (row_i * len(row)) + flat_values.append(col) + values_placeholders.append(f"${i}") + row_values.append(f"({','.join(values_placeholders)})") + return ( + f"INSERT INTO {meta.table}({','.join(meta.column_names)}) VALUES {','.join(row_values)};", + flat_values, + ) diff --git a/src/app/datalayer/ratecard.py b/src/app/datalayer/ratecard.py new file mode 100644 index 0000000..61b16aa --- /dev/null +++ b/src/app/datalayer/ratecard.py @@ -0,0 +1,43 @@ +import uuid + +import fastapi +import pydantic + +from app.libs.datalayer.base_repository import BaseRepository +from app.libs.datalayer.fastapi.ddb import DependableDataLayer + + +class RateCardRecord(pydantic.BaseModel): + id: uuid.UUID + net_margin: float + overhead: float + selling_cost: float + company_id: uuid.UUID | None + title: str | None + + +class RateCardCreatePayload(pydantic.BaseModel): + net_margin: float + overhead: float + selling_cost: float + company_id: uuid.UUID | None = None + title: str | None = None + + +class RateCardPatchPayload(pydantic.BaseModel): + net_margin: float | None = None + overhead: float | None = None + selling_cost: float | None = None + company_id: uuid.UUID | None = None + title: str | None = None + + +class RateCardRepository( + BaseRepository[RateCardRecord, RateCardCreatePayload, RateCardPatchPayload] +): + def __init__(self, ddb: DependableDataLayer = fastapi.Depends()) -> None: + super().__init__( + ddb.datalayer, + "ratecard", + RateCardRecord, + ) diff --git a/src/app/datalayer/user_account.py b/src/app/datalayer/user_account.py new file mode 100644 index 0000000..bdee798 --- /dev/null +++ b/src/app/datalayer/user_account.py @@ -0,0 +1,51 @@ +import datetime +import uuid + +import fastapi +import pydantic + +from app.datalayer.model import UserRoleEnum +from app.libs.datalayer.base_repository import BaseRepository +from app.libs.datalayer.fastapi.ddb import DependableDataLayer + + +class UserAccountRecord(pydantic.BaseModel): + user_id: uuid.UUID + full_name: str | None + email: str + password: str | None + phone_number: str | None + role: UserRoleEnum = UserRoleEnum.client + created_at: datetime.datetime + last_updated_at: datetime.datetime | None + company_id: uuid.UUID | None + + +class UserAccountCreatePayload(pydantic.BaseModel): + user_id: uuid.UUID + full_name: str | None = None + email: str + password: str | None + phone_number: str | None = None + role: UserRoleEnum = UserRoleEnum.client + company_id: uuid.UUID | None = None + + +class UserAccountPatchPayload(pydantic.BaseModel): + full_name: str | None + email: str | None + password: str | None + phone_number: str | None + role: UserRoleEnum | None + company_id: uuid.UUID | None + + +class UserAccountRepository( + BaseRepository[UserAccountRecord, UserAccountCreatePayload, UserAccountPatchPayload] +): + def __init__(self, ddb: DependableDataLayer = fastapi.Depends()) -> None: + super().__init__( + ddb.datalayer, + "user_account", + UserAccountRecord, + ) diff --git a/src/app/datalayer/users.py b/src/app/datalayer/users.py deleted file mode 100644 index 55f65c1..0000000 --- a/src/app/datalayer/users.py +++ /dev/null @@ -1,52 +0,0 @@ -import fastapi -import pydantic - -from app.datalayer.db import DB -from app.datalayer.model import UserAccount - -IGNORE_FIELDS = {"_table_name"} - - -class UserService: - - def __init__(self, db: DB = fastapi.Depends()): - super().__init__() - self.db = db - self.record_cls: pydantic.BaseModel = UserAccount - self.field_names = [ - f for f in self.record_cls.__annotations__.keys() if f not in IGNORE_FIELDS - ] - - async def read_users(self) -> list[UserAccount]: - async with self.db.connection() as connection: - columns = ",".join(self.field_names) - records = await connection.fetch(f"SELECT {columns} FROM user_account") - print("Users:", records) - return [self.record_cls(**dict(row)) for row in records] - - async def create_user(self, full_name, email): - async with self.db.transaction() as connection: - await connection.execute( - "INSERT INTO user_account(full_name, email, password) VALUES($1, $2, $3)", - full_name, - email, - "secret", - ) - print(f"User {full_name} added.") - - async def update_user(self, user_id, full_name): - async with self.db.transaction() as connection: - await connection.execute( - "UPDATE user_account SET full_name=$1 WHERE id=$2", - full_name, - user_id, - ) - print(f"User {user_id} updated.") - - async def delete_user(self, user_id): - async with self.db.transaction() as connection: - await connection.execute( - "DELETE FROM user_account WHERE id=$1", - user_id, - ) - print(f"User {user_id} deleted.") diff --git a/src/app/dirs.py b/src/app/dirs.py new file mode 100644 index 0000000..2e5a966 --- /dev/null +++ b/src/app/dirs.py @@ -0,0 +1,4 @@ +import os + +app_dir = os.path.dirname(__file__) +resources_dir = os.path.join(app_dir, "resources") diff --git a/src/app/libs/NOTES.md b/src/app/libs/NOTES.md new file mode 100644 index 0000000..ef48312 --- /dev/null +++ b/src/app/libs/NOTES.md @@ -0,0 +1,48 @@ +## Postgresql Python Libraries + +## [psycopg2](https://www.psycopg.org/docs/) + +psycopg2 is a robust, stable, and widely-used library ideal for synchronous applications but may face performance +limitations in high-concurrency scenarios. + +### Pros + +- very famous, widely used + +### Cons + +- mostly used with sqlalchemy +- mostly blocking, async communication is possible with polling (very cumbersome), see + docs https://www.psycopg.org/docs/advanced.html#async-support + +## [psycopg3](https://www.psycopg.org/psycopg3/) + +psycopg3 aims to modernize and improve upon psycopg2 with asynchronous support and better performance, making it a +strong candidate for future projects, though it may still be catching up in terms of stability and community support. + +### Pros + +- async support +- similar API from psycopg2 (familiar by most developers) + +### Cons + +- not yet mature + +## [asyncpg](https://magicstack.github.io/asyncpg/) + +asyncpg is a high-performance, asynchronous library tailored for modern, high-concurrency applications, though it may +pose compatibility issues with synchronous tools. + +### Pros + +- async by design +- lightweight + +### Cons + +- different API from psycopg2 (familiar by most developers) + +## Performance Benchmark + +See https://fernandoarteaga.dev/blog/psycopg-vs-asyncpg/ diff --git a/src/app/libs/README.md b/src/app/libs/README.md new file mode 100644 index 0000000..cac9ef5 --- /dev/null +++ b/src/app/libs/README.md @@ -0,0 +1,17 @@ +# Libs + +This module contains code that we may want to eventually move to an external library. + +## [migrationtool](migrationtool) + +A lightweight migration tool inspired by famous migration tools like [Flyway](https://flywaydb.org/) +or [Liquibase](https://www.liquibase.com/). + +Essentially, it keeps your database up to date with a set of SQL scripts that you ship with the code. + +It is supposed to run when the application starts up, which is very useful for integration test +with dockerized database since it creates the schema from scratch at startup. + +## [datalayer](datalayer) + +A lightweight data access layer that abstracts the database access. diff --git a/src/app/libs/__init__.py b/src/app/libs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/app/libs/datalayer/__init__.py b/src/app/libs/datalayer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/app/libs/datalayer/base_repository.py b/src/app/libs/datalayer/base_repository.py new file mode 100644 index 0000000..76a44ff --- /dev/null +++ b/src/app/libs/datalayer/base_repository.py @@ -0,0 +1,140 @@ +import datetime +import typing +import uuid + +import pydantic + +from app.libs.datalayer.datalayer import DataLayer +from app.libs.datalayer.dbutils import DBUtils +from app.libs.datalayer.metadata.provider import TableMetadata, DBMetadataProvider + +_LAST_MODIFIED_AT = "last_modified_at" +_LAST_MODIFIED_BY = "last_modified_by" +_LAST_UPDATED_AT = "last_updated_at" +_LAST_UPDATED_BY = "last_updated_by" +_CREATED_AT = "created_at" +_CREATED_BY = "created_by" +_UNMODIFIABLE_FIELDS = [ + _LAST_MODIFIED_AT, + _LAST_MODIFIED_BY, + _LAST_UPDATED_AT, + _LAST_UPDATED_BY, + _CREATED_AT, + _CREATED_BY, +] + + +Record = typing.TypeVar("Record", bound=pydantic.BaseModel) +PatchPayload = typing.TypeVar("PatchPayload", bound=pydantic.BaseModel) +CreatePayload = typing.TypeVar("CreatePayload", bound=pydantic.BaseModel) + + +class BaseRepository(typing.Generic[Record, CreatePayload, PatchPayload]): + def __init__( + self, + datalayer: DataLayer, + table_name: str, + record_cls: typing.Type[Record], + ) -> None: + super().__init__() + self.dbutils: DBUtils = datalayer.dbutils + self.meta_provider: DBMetadataProvider = datalayer.meta_provider + self.table_name = table_name + self.record_cls = record_cls + + def _build_record(self, **kwargs) -> Record: + try: + return self.record_cls(**kwargs) + except Exception as e: + raise RuntimeError(f"Invalid {self.record_cls.__name__}: {kwargs}") from e + + async def _get_meta(self) -> TableMetadata: + return await self.meta_provider.get_table_metadata(self.table_name) + + async def get_all(self, **kwargs) -> list[Record]: + meta = await self._get_meta() + results = await self.dbutils.select( + table=meta.table_name, + columns=meta.column_names, + filters=kwargs, + ) + results = [self._build_record(**r) for r in results] + return results + + async def get_by_id(self, entity_id: uuid.UUID) -> Record | None: + meta = await self._get_meta() + results = await self.dbutils.select( + table=meta.table_name, + columns=meta.column_names, + filters={meta.pk_name: entity_id}, + ) + if results: + return self._build_record(**results[0]) + return None + + async def delete_by_id(self, entity_id: uuid.UUID) -> int: + meta = await self._get_meta() + return await self.dbutils.delete( + table=meta.table_name, + filters={meta.pk_name: entity_id}, + ) + + async def patch_by_id( + self, + entity_id: uuid.UUID, + payload: PatchPayload, + user_id: str = None, + ) -> int: + + meta = await self._get_meta() + obj = payload.model_dump(exclude_unset=True) + + # set audit fields + if _LAST_UPDATED_BY in meta.column_names: + assert user_id, "user_id is required" + obj[_LAST_UPDATED_BY] = user_id + if _LAST_UPDATED_AT in meta.column_names: + obj[_LAST_UPDATED_AT] = datetime.datetime.now() + if _LAST_MODIFIED_BY in meta.column_names: + assert user_id, "user_id is required" + obj[_LAST_MODIFIED_BY] = user_id + if _LAST_MODIFIED_AT in meta.column_names: + obj[_LAST_MODIFIED_AT] = datetime.datetime.now() + + return await self.dbutils.update( + table=meta.table_name, + obj=obj, + filters={meta.pk_name: entity_id}, + ) + + async def create( + self, + payload: CreatePayload, + user_id: str = None, + ) -> uuid.UUID: + + meta = await self._get_meta() + obj = payload.model_dump() + + if meta.pk_name not in obj: + assert meta.pk.data_type == "uuid", "pk must be of type uuid" + obj[meta.pk_name] = uuid.uuid4() + + # set audit fields + if _CREATED_BY in meta.column_names: + assert user_id, "user_id is required" + obj[_CREATED_BY] = user_id + if _CREATED_AT in meta.column_names: + obj[_CREATED_AT] = datetime.datetime.now() + if _LAST_MODIFIED_BY in meta.column_names: + assert user_id, "user_id is required" + obj[_LAST_MODIFIED_BY] = user_id + if _LAST_MODIFIED_AT in meta.column_names: + obj[_LAST_MODIFIED_AT] = datetime.datetime.now() + + await self.dbutils.insert( + table=meta.table_name, + obj=obj, + ) + + return obj[meta.pk_name] diff --git a/src/app/libs/datalayer/connectionpool/__init__.py b/src/app/libs/datalayer/connectionpool/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/app/libs/datalayer/connectionpool/manager.py b/src/app/libs/datalayer/connectionpool/manager.py new file mode 100644 index 0000000..9637b6e --- /dev/null +++ b/src/app/libs/datalayer/connectionpool/manager.py @@ -0,0 +1,44 @@ +import os + +import asyncpg + +from app.libs.datalayer.connectionpool.wrapper import ConnectionPoolWrapper + + +class ConnectionPoolManager: + """ + This class takes care of opening and closing the connection pool (and related resources). + It's meant to be used in FastAPI lifespan (see https://fastapi.tiangolo.com/advanced/events/#lifespan). + + """ + + def __init__(self, postgres_url: str = None): + # TODO add possible other params we want to pass to the connection pool + self._postgres_url = postgres_url or os.getenv("POSTGRES_URL") + assert self._postgres_url, ( + f"Missing postgres_url, you need to pass it as constructor argument to " + f"{type(self).__name__} or define POSTGRES_URL env var." + ) + self._low_level_pool: asyncpg.Pool | None = None + + async def get_low_level_pool(self) -> asyncpg.Pool: + if not self._low_level_pool: + assert self._postgres_url, "missing POSTGRES_URL" + self._low_level_pool = await asyncpg.create_pool( + dsn=self._postgres_url, + min_size=2, + max_size=10, + ) + return self._low_level_pool + + async def close(self): + if self._low_level_pool: + await self._low_level_pool.close() + + async def get_pool(self) -> ConnectionPoolWrapper: + pool = await self.get_low_level_pool() + return ConnectionPoolWrapper(pool) + + def get_pool_unsafe(self) -> ConnectionPoolWrapper: + assert self._low_level_pool, "pool not initialized yet" + return ConnectionPoolWrapper(self._low_level_pool) diff --git a/src/app/libs/datalayer/connectionpool/wrapper.py b/src/app/libs/datalayer/connectionpool/wrapper.py new file mode 100644 index 0000000..daf5d1d --- /dev/null +++ b/src/app/libs/datalayer/connectionpool/wrapper.py @@ -0,0 +1,25 @@ +import contextlib + +import asyncpg + + +class ConnectionPoolWrapper: + """ + This is wrapper class to decouple for the actual connection pool object. + + """ + + def __init__(self, pool: asyncpg.Pool): + assert pool, "missing pool" + self.pool = pool + + @contextlib.asynccontextmanager + async def transaction(self, timeout: float = None): + async with self.pool.acquire(timeout=timeout) as connection: + async with connection.transaction(): + yield connection + + @contextlib.asynccontextmanager + async def connection(self, timeout: float = None): + async with self.pool.acquire(timeout=timeout) as connection: + yield connection diff --git a/src/app/libs/datalayer/datalayer.py b/src/app/libs/datalayer/datalayer.py new file mode 100644 index 0000000..130adbf --- /dev/null +++ b/src/app/libs/datalayer/datalayer.py @@ -0,0 +1,13 @@ +from .connectionpool.wrapper import ConnectionPoolWrapper +from .dbutils import DBUtils +from .metadata.provider_with_cache import DBMetadataProviderWithCache + + +class DataLayer: + def __init__( + self, + pool: ConnectionPoolWrapper, + ) -> None: + super().__init__() + self.dbutils = DBUtils(pool) + self.meta_provider = DBMetadataProviderWithCache(pool) diff --git a/src/app/libs/datalayer/dbutils.py b/src/app/libs/datalayer/dbutils.py new file mode 100644 index 0000000..342efbb --- /dev/null +++ b/src/app/libs/datalayer/dbutils.py @@ -0,0 +1,78 @@ +from .connectionpool.wrapper import ConnectionPoolWrapper +from .sqlutils import ( + prepare_insert, + prepare_update, + prepare_delete, + prepare_select, +) + + +class DBUtils: + def __init__( + self, + pool: ConnectionPoolWrapper, + ) -> None: + super().__init__() + self.pool = pool + + async def insert_and_return( + self, + table: str, + obj: dict, + ) -> dict: + async with self.pool.transaction() as connection: + sql, values = prepare_insert(table, obj, returning=True) + row = await connection.fetchrow(sql, *values) + return dict(row) + + async def insert( + self, + table: str, + obj: dict, + ) -> int: + async with self.pool.transaction() as connection: + sql, values = prepare_insert(table, obj) + result = await connection.execute(sql, *values) + # returns sth like `INSERT 0 1` + return int(result.split(" ")[-1]) + + async def update( + self, + table: str, + obj: dict, + filters: dict, + ) -> int: + async with self.pool.transaction() as connection: + sql, values = prepare_update(table, obj, filters) + result = await connection.execute(sql, *values) + # returns sth like `UPDATE 1` + return int(result.split(" ")[-1]) + + async def delete( + self, + table: str, + filters: dict, + ) -> int: + async with self.pool.transaction() as connection: + sql, values = prepare_delete(table, filters) + result = await connection.execute(sql, *values) + # returns sth like `DELETE 1` + return int(result.split(" ")[-1]) + + async def select( + self, table: str, filters: dict, columns: list[str] = None + ) -> list[dict]: + async with self.pool.connection() as connection: + sql, values = prepare_select(table, filters, columns) + records = await connection.fetch(sql, *values) + return [dict(r) for r in records] + + async def select_one( + self, table: str, filters: dict, columns: list[str] = None + ) -> dict | None: + async with self.pool.connection() as connection: + sql, values = prepare_select(table, filters, columns) + row = await connection.fetchrow(sql, *values) + if not row: + return None + return dict(row) diff --git a/src/app/libs/datalayer/fastapi/__init__.py b/src/app/libs/datalayer/fastapi/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/app/libs/datalayer/fastapi/ddb.py b/src/app/libs/datalayer/fastapi/ddb.py new file mode 100644 index 0000000..c61c7b2 --- /dev/null +++ b/src/app/libs/datalayer/fastapi/ddb.py @@ -0,0 +1,16 @@ +import fastapi + +from app.libs.datalayer.connectionpool.manager import ConnectionPoolManager +from app.libs.datalayer.datalayer import DataLayer + + +class DependableDataLayer: + """ + DB class that supports fastapi dependency injection. + """ + + def __init__(self, request: fastapi.Request): + cpm: ConnectionPoolManager = request.app.cpm + # This is safe because the pool is created at app startup. + pool = cpm.get_pool_unsafe() + self.datalayer = DataLayer(pool) diff --git a/src/app/libs/datalayer/metadata/__init__.py b/src/app/libs/datalayer/metadata/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/app/libs/datalayer/metadata/provider.py b/src/app/libs/datalayer/metadata/provider.py new file mode 100644 index 0000000..6caed79 --- /dev/null +++ b/src/app/libs/datalayer/metadata/provider.py @@ -0,0 +1,82 @@ +import enum +import logging +import os +import typing + +import pydantic + +from app.libs.datalayer.connectionpool.wrapper import ConnectionPoolWrapper + + +class ConstraintType(enum.StrEnum): + PRIMARY_KEY = "PRIMARY KEY" + FOREIGN_KEY = "FOREIGN KEY" + UNIQUE = "UNIQUE" + + +class ColumnMetadata(pydantic.BaseModel): + ordinal_position: int + column_name: str + data_type: str + is_nullable: bool + constraint_type: ConstraintType | None + fk_target_table: str | None + fk_target_column: str | None + + +class TableMetadata(pydantic.BaseModel): + table_name: str + columns_meta: typing.List[ColumnMetadata] + pks: typing.List[ColumnMetadata] = pydantic.Field(default_factory=list) + fks: typing.List[ColumnMetadata] = pydantic.Field(default_factory=list) + + def model_post_init(self, __context): + self.pks = [ + c + for c in self.columns_meta + if c.constraint_type == ConstraintType.PRIMARY_KEY + ] + self.fks = [ + c + for c in self.columns_meta + if c.constraint_type == ConstraintType.FOREIGN_KEY + ] + + @property + def column_names(self) -> typing.List[str]: + return [c.column_name for c in self.columns_meta] + + @property + def pk(self) -> ColumnMetadata | None: + return self.pks[0] if self.pks else None + + @property + def pk_name(self) -> str | None: + return self.pk.column_name if self.pk else None + + +class DBMetadataProvider: + + def __init__(self, pool: ConnectionPoolWrapper): + super().__init__() + + assert pool, "pool is required" + + path = os.path.join(os.path.dirname(__file__), "select.sql") + with open(path, "r") as f: + sql_for_table_metadata = f.read() + assert sql_for_table_metadata, "select.sql file is empty" + + self.logger = logging.getLogger(f"{__name__}.{type(self).__name__}") + self.pool = pool + self.sql_for_table_metadata = sql_for_table_metadata + + async def get_table_metadata(self, table_name: str) -> TableMetadata: + async with self.pool.connection() as conn: + results = await conn.fetch(self.sql_for_table_metadata, table_name) + if not results: + raise RuntimeError(f"table '{table_name}' not found") + return TableMetadata( + table_name=table_name, + columns_meta=[ColumnMetadata(**dict(row)) for row in results], + ) diff --git a/src/app/libs/datalayer/metadata/provider_with_cache.py b/src/app/libs/datalayer/metadata/provider_with_cache.py new file mode 100644 index 0000000..70c926f --- /dev/null +++ b/src/app/libs/datalayer/metadata/provider_with_cache.py @@ -0,0 +1,47 @@ +import asyncio +import logging +import threading + +import cachetools + +from app.libs.datalayer.connectionpool.wrapper import ConnectionPoolWrapper +from app.libs.datalayer.metadata.provider import DBMetadataProvider, TableMetadata + + +class DBMetadataProviderWithCache(DBMetadataProvider): + + def __init__( + self, + pool: ConnectionPoolWrapper, + cache_ttl: int = None, + ): + # Do not call init on super class, as it is not needed. + super().__init__(pool) + self.logger = logging.getLogger(f"{__name__}.{type(self).__name__}") + self.cache = cachetools.TTLCache( + maxsize=1000, + ttl=cache_ttl or 300.0, # ttl is in seconds + ) + self._lock = threading.Lock() + + async def get_table_metadata( + self, table_name: str, pool: ConnectionPoolWrapper = None + ) -> TableMetadata: + key = f"{hash(self.pool)}/{table_name}" + if key not in self.cache: + self.logger.debug(f"key {key} not found") + # The asyncio lock assures that no other coroutine has already invoked self._load() and it is now idle + # awaiting the result. + async with asyncio.Lock(): + self.logger.debug(f"key {key} async lock acquired") + # IMPORTANT: The threading lock has to be acquired after asyncio lock to prevent deadlock, or use + # a reentrant lock instead. + # For info on how to prevent deadlocks, see https://superfastpython.com/asyncio-use-threading-lock/ + with self._lock: + self.logger.debug(f"key {key} threading lock acquired") + if key not in self.cache: + self.logger.debug(f"key {key} not found after acquiring locks") + result = await super().get_table_metadata(table_name) + self.cache[key] = result + self.logger.debug(f"key {key} added to cache") + return self.cache[key] diff --git a/src/app/libs/datalayer/metadata/select.sql b/src/app/libs/datalayer/metadata/select.sql new file mode 100644 index 0000000..2909b0e --- /dev/null +++ b/src/app/libs/datalayer/metadata/select.sql @@ -0,0 +1,55 @@ +SELECT + c.ordinal_position, + c.column_name, + c.data_type, + CASE c.is_nullable + WHEN 'YES' THEN true + ELSE false + END AS is_nullable, + tc.constraint_type, + fk.fk_target_table, + fk.fk_target_column +FROM + information_schema.columns c +LEFT JOIN ( + SELECT + kcu.table_schema, + kcu.table_name, + kcu.column_name, + tc.constraint_type + FROM + information_schema.table_constraints tc + JOIN + information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name +) tc +ON c.table_schema = tc.table_schema + AND c.table_name = tc.table_name + AND c.column_name = tc.column_name +LEFT JOIN ( + SELECT + kcu.table_schema, + kcu.table_name, + kcu.column_name, + rc.update_rule, + rc.delete_rule, + kcu2.table_name AS fk_target_table, + kcu2.column_name AS fk_target_column + FROM + information_schema.referential_constraints rc + JOIN + information_schema.key_column_usage kcu + ON rc.constraint_name = kcu.constraint_name + JOIN + information_schema.key_column_usage kcu2 + ON rc.unique_constraint_name = kcu2.constraint_name + AND kcu.ordinal_position = kcu2.ordinal_position +) fk +ON c.table_schema = fk.table_schema + AND c.table_name = fk.table_name + AND c.column_name = fk.column_name +WHERE + c.table_schema = 'public' + AND c.table_name = $1 +ORDER BY + c.ordinal_position; \ No newline at end of file diff --git a/src/app/libs/datalayer/sqlutils.py b/src/app/libs/datalayer/sqlutils.py new file mode 100644 index 0000000..7d91031 --- /dev/null +++ b/src/app/libs/datalayer/sqlutils.py @@ -0,0 +1,105 @@ +import typing +from typing import Any + + +def prepare_select( + table_name: str, + filters: dict | None, + columns: list[str] | None, +) -> tuple[str, list[Any]]: + + columns = ",".join(columns) if columns else "*" + + if not filters: + return ( + f"SELECT {columns} FROM {table_name};", + [], + ) + + flat_values = [] + where_conditions = [] + + for column_name, column_value in filters.items(): + flat_values.append(column_value) + where_conditions.append(f"{column_name}=${len(flat_values)}") + + return ( + f"SELECT {columns} FROM {table_name} WHERE {' AND '.join(where_conditions)};", + flat_values, + ) + + +def prepare_insert( + table_name: str, + obj: typing.Any, + returning: bool = False, +) -> tuple[str, list[Any]]: + + column_names = [] + flat_values = [] + row_values = [] + + for column_name, column_value in obj.items(): + column_names.append(column_name) + flat_values.append(column_value) + row_values.append(f"${len(flat_values)}") + + if not row_values: + raise ValueError(f"Nothing to insert") + + sql = f"INSERT INTO {table_name}({','.join(column_names)}) VALUES ({','.join(row_values)})" + if returning: + sql += f"{sql} RETURNING *" + + return sql, flat_values + + +def prepare_update( + table_name: str, + obj: dict, + filters: dict, +) -> tuple[str, list[Any]]: + + if not filters: + raise ValueError(f"Missing filters") + + flat_values = [] + where_conditions = [] + setters = [] + + for column_name, column_value in filters.items(): + flat_values.append(column_value) + where_conditions.append(f"{column_name}=${len(flat_values)}") + + for column_name, column_value in obj.items(): + flat_values.append(column_value) + setters.append(f"{column_name}=${len(flat_values)}") + + if not setters: + raise ValueError(f"Nothing to set") + + return ( + f"UPDATE {table_name} SET {','.join(setters)} WHERE {' AND '.join(where_conditions)};", + flat_values, + ) + + +def prepare_delete( + table_name: str, + filters: dict, +) -> tuple[str, list[Any]]: + + if not filters: + raise ValueError(f"Missing filters") + + flat_values = [] + where_conditions = [] + + for column_name, column_value in filters.items(): + flat_values.append(column_value) + where_conditions.append(f"{column_name}=${len(flat_values)}") + + return ( + f"DELETE FROM {table_name} WHERE {' AND '.join(where_conditions)};", + flat_values, + ) diff --git a/src/app/libs/migrationtool/__init__.py b/src/app/libs/migrationtool/__init__.py new file mode 100644 index 0000000..f9684d7 --- /dev/null +++ b/src/app/libs/migrationtool/__init__.py @@ -0,0 +1,111 @@ +import hashlib +import logging +import os +from typing import Dict, List + +import asyncpg + +from app.libs.datalayer.connectionpool.wrapper import ConnectionPoolWrapper + + +class MigrationTool: + """ + MigrationTool Class + + Manages and applies SQL migrations to a PostgreSQL database using asyncpg. + + Key functionalities: + - Checks and creates a `migration_history` table if it doesn't exist. + - Computes SHA-256 hashes of migration files to ensure integrity. + - Retrieves applied migrations and applies new ones in sorted order. + - Handles transactions to ensure atomic application of migrations. + + Attributes: + - pool (asyncpg.Pool): A connection pool to the PostgreSQL database. + + Methods: + - `migrate(migrations_dir: str) -> None`: Applies new migrations from the specified directory. + """ + + def __init__(self, pool: ConnectionPoolWrapper) -> None: + self.pool = pool + self.logger = logging.getLogger("migrationtool") + + def _calculate_file_hash(self, filepath: str) -> str: + """Calculate SHA-256 hash of the given file.""" + sha256 = hashlib.sha256() + with open(filepath, "rb") as file: + while chunk := file.read(8192): + sha256.update(chunk) + return sha256.hexdigest() + + async def _get_applied_migrations(self, conn: asyncpg.Connection) -> Dict[str, str]: + records = await conn.fetch( + "SELECT filename, hash FROM migration_history ORDER BY applied_at" + ) + return {record["filename"]: record["hash"] for record in records} + + async def _apply_migration( + self, conn: asyncpg.Connection, filename: str, filepath: str + ) -> None: + file_hash = self._calculate_file_hash(filepath) + with open(filepath, "r") as file: + sql_content = file.read() + async with conn.transaction(): + await conn.execute(sql_content) + await conn.execute( + "INSERT INTO migration_history (filename, hash) VALUES ($1, $2)", + filename, + file_hash, + ) + + async def _check_and_create_history_table(self, conn: asyncpg.Connection) -> None: + result = await conn.fetchval( + """ + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = 'migration_history' + ) + """ + ) + if not result: + await conn.execute( + """ + CREATE TABLE migration_history ( + id SERIAL PRIMARY KEY, + filename TEXT NOT NULL, + hash TEXT NOT NULL, + applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + + def _get_sql_migration_files(self, migrations_dir: str) -> List[str]: + """Retrieve .sql files sorted by name from the migrations directory.""" + return sorted( + [file for file in os.listdir(migrations_dir) if file.endswith(".sql")] + ) + + async def migrate(self, migrations_dir: str) -> None: + async with self.pool.connection() as conn: + await self._check_and_create_history_table(conn) + applied_migrations = await self._get_applied_migrations(conn) + migration_files = self._get_sql_migration_files(migrations_dir) + + applied_any = False + for filename in migration_files: + filepath = os.path.join(migrations_dir, filename) + + if filename in applied_migrations: + file_hash = self._calculate_file_hash(filepath) + if applied_migrations[filename] != file_hash: + raise ValueError( + f"Hash mismatch for {filename}. Migration file may have been altered." + ) + else: + await self._apply_migration(conn, filename, filepath) + applied_any = True + self.logger.info(f"Applied migration: {filename}") + + if not applied_any: + self.logger.info(f"No migrations applied, database is up-to-date.") diff --git a/src/schema.sql b/src/app/resources/migrations/v001-initial-schema.sql similarity index 97% rename from src/schema.sql rename to src/app/resources/migrations/v001-initial-schema.sql index 715b3c8..e89b522 100644 --- a/src/schema.sql +++ b/src/app/resources/migrations/v001-initial-schema.sql @@ -105,7 +105,7 @@ CREATE TABLE quick_quote_item quick_quote_id uuid, billable_rate double precision DEFAULT '0'::double precision NOT NULL ); -CREATE TABLE rate_card +CREATE TABLE ratecard ( id uuid DEFAULT uuid_generate_v4() NOT NULL, net_margin double precision DEFAULT '0'::double precision NOT NULL, @@ -207,7 +207,7 @@ ALTER TABLE ONLY quick_quote ADD CONSTRAINT pk_quick_quote PRIMARY KEY (id); ALTER TABLE ONLY team_bundle ADD CONSTRAINT pk_team_bundle PRIMARY KEY (id); ALTER TABLE ONLY quick_quote_item ADD CONSTRAINT pk_quick_quote_item PRIMARY KEY (id); ALTER TABLE ONLY adjusted_hour ADD CONSTRAINT pk_adjusted_hour PRIMARY KEY (id); -ALTER TABLE ONLY rate_card ADD CONSTRAINT pk_rate_card PRIMARY KEY (id); +ALTER TABLE ONLY ratecard ADD CONSTRAINT pk_ratecard PRIMARY KEY (id); ALTER TABLE ONLY location_benefit ADD CONSTRAINT pk_location_benefit PRIMARY KEY (location_id, benefit_id); ALTER TABLE ONLY benefit_profile ADD CONSTRAINT pk_benefit_profile PRIMARY KEY (id); ALTER TABLE ONLY location ADD CONSTRAINT pk_location PRIMARY KEY (id); @@ -249,7 +249,7 @@ ALTER TABLE ONLY resource_tax ADD CONSTRAINT fk_resource_tax_tax_id FOREIGN KEY ALTER TABLE ONLY location_tax ADD CONSTRAINT fk_location_tax_location_id FOREIGN KEY (location_id) REFERENCES location (id); ALTER TABLE ONLY location_adjusted_hour ADD CONSTRAINT fk_location_adjusted_hour_adjusted_hour_id FOREIGN KEY (adjusted_hour_id) REFERENCES adjusted_hour (id); ALTER TABLE ONLY location_benefit ADD CONSTRAINT fk_location_benefit_benefit_id FOREIGN KEY (benefit_id) REFERENCES benefit (id); -ALTER TABLE ONLY rate_card ADD CONSTRAINT fk_rate_card_company_id FOREIGN KEY (company_id) REFERENCES company (id); +ALTER TABLE ONLY ratecard ADD CONSTRAINT fk_ratecard_company_id FOREIGN KEY (company_id) REFERENCES company (id); ALTER TABLE ONLY resource ADD CONSTRAINT fk_resource_team_id FOREIGN KEY (team_id) REFERENCES team (id); ALTER TABLE ONLY benefit_profile ADD CONSTRAINT fk_benefit_profile_company_id FOREIGN KEY (company_id) REFERENCES company (id); ALTER TABLE ONLY resource ADD CONSTRAINT fk_resource_position_id FOREIGN KEY (position_id) REFERENCES position (id); @@ -261,7 +261,7 @@ ALTER TABLE ONLY user_account ADD CONSTRAINT fk_user_company_id FOREIGN KEY (com ALTER TABLE ONLY password_reset_token ADD CONSTRAINT fk_password_reset_token_user_id FOREIGN KEY (user_id) REFERENCES user_account (id) ON DELETE CASCADE; ALTER TABLE ONLY team ADD CONSTRAINT fk_team_company_id FOREIGN KEY (company_id) REFERENCES company (id); ALTER TABLE ONLY team_bundle ADD CONSTRAINT fk_team_bundle_created_by_user_id FOREIGN KEY (created_by_user_id) REFERENCES user_account (id); -ALTER TABLE ONLY quick_quote ADD CONSTRAINT fk_quick_quote_ratecard_id FOREIGN KEY (ratecard_id) REFERENCES rate_card (id) ON DELETE CASCADE; +ALTER TABLE ONLY quick_quote ADD CONSTRAINT fk_quick_quote_ratecard_id FOREIGN KEY (ratecard_id) REFERENCES ratecard (id) ON DELETE CASCADE; ALTER TABLE ONLY quick_quote ADD CONSTRAINT fk_quick_quote_company_id FOREIGN KEY (company_id) REFERENCES company (id); ALTER TABLE ONLY quick_quote ADD CONSTRAINT fk_quick_quote_created_by_user_id FOREIGN KEY (created_by_user_id) REFERENCES user_account (id); ALTER TABLE ONLY team_bundle ADD CONSTRAINT fk_team_bundle_company_id FOREIGN KEY (company_id) REFERENCES company (id); diff --git a/src/tests/app/api/ratecards.py b/src/tests/app/api/ratecards.py new file mode 100644 index 0000000..a4b0633 --- /dev/null +++ b/src/tests/app/api/ratecards.py @@ -0,0 +1,4 @@ +def test_get_all(client): + res = client.get("/api/v1/ratecards") + print(f"{res.request.method} {res.url} >> {res.status_code} {res.text}") + assert res.status_code == 200 diff --git a/src/tests/app/api/users.py b/src/tests/app/api/users.py index 08e38eb..bc018be 100644 --- a/src/tests/app/api/users.py +++ b/src/tests/app/api/users.py @@ -1,4 +1,4 @@ -def test_v1_user(client): +def test_get_all(client): res = client.get("/api/v1/users") print(f"{res.request.method} {res.url} >> {res.status_code} {res.text}") assert res.status_code == 200 diff --git a/src/tests/libs/__init__.py b/src/tests/libs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tests/libs/datalayer/__init__.py b/src/tests/libs/datalayer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tests/libs/datalayer/conftest.py b/src/tests/libs/datalayer/conftest.py new file mode 100644 index 0000000..d208601 --- /dev/null +++ b/src/tests/libs/datalayer/conftest.py @@ -0,0 +1,16 @@ +import pytest +from testcontainers.postgres import PostgresContainer + +from app.libs.datalayer.connectionpool.manager import ConnectionPoolManager + + +@pytest.fixture +def fresh_postgres_url(): + with PostgresContainer("postgres:15", driver=None) as postgres: + yield postgres.get_connection_url() + + +@pytest.fixture +async def fresh_pool(fresh_postgres_url): + cpm = ConnectionPoolManager(fresh_postgres_url) + yield await cpm.get_pool() diff --git a/src/tests/libs/datalayer/test_base_repository.py b/src/tests/libs/datalayer/test_base_repository.py new file mode 100644 index 0000000..84b0711 --- /dev/null +++ b/src/tests/libs/datalayer/test_base_repository.py @@ -0,0 +1,91 @@ +import datetime +from uuid import UUID + +import pydantic +import pytest + +from app.libs.datalayer.base_repository import BaseRepository +from app.libs.datalayer.connectionpool.manager import ConnectionPoolManager +from app.libs.datalayer.connectionpool.wrapper import ConnectionPoolWrapper +from app.libs.datalayer.datalayer import DataLayer + + +class FoobarRecord(pydantic.BaseModel): + foo_id: UUID + email: str + full_name: str | None + bar_id: UUID | None + created_at: datetime.datetime + last_updated_at: datetime.datetime | None + + +class FoobarCreatePayload(pydantic.BaseModel): + email: str + full_name: str | None = None + bar_id: UUID | None = None + + +class FoobarPatchPayload(pydantic.BaseModel): + email: str | None = None + full_name: str | None = None + bar_id: UUID | None = None + + +class FoobarRepository( + BaseRepository[FoobarRecord, FoobarCreatePayload, FoobarPatchPayload] +): + def __init__(self, pool: ConnectionPoolWrapper) -> None: + super().__init__( + DataLayer(pool), + "foobar", + FoobarRecord, + ) + + +@pytest.mark.asyncio +async def test_crud(fresh_postgres_url): + cpm = ConnectionPoolManager(fresh_postgres_url) + pool = await cpm.get_pool() + async with pool.connection() as conn: + await conn.execute( + """ + CREATE TABLE foobar ( + foo_id UUID PRIMARY KEY, + email VARCHAR(50) UNIQUE NOT NULL, + full_name VARCHAR(50), + bar_id UUID, + created_at TIMESTAMP, + last_updated_at TIMESTAMP + ); + """ + ) + + repository = FoobarRepository(pool) + + new_foo_id = await repository.create(FoobarCreatePayload(email="foo@example.com")) + jd_foo_id = await repository.create( + FoobarCreatePayload(email="jdoe@example.com", name="John Doe") + ) + + with pytest.raises(Exception): + await repository.create(FoobarCreatePayload(email="foo@example.com")) + + all_users = await repository.get_all() + assert all_users + assert len(all_users) == 2 + + filtered_users = await repository.get_all(email="jdoe@example.com") + assert len(filtered_users) == 1 + + new_user = await repository.get_by_id(new_foo_id) + assert new_user.email == "foo@example.com" + + jd_user = await repository.get_by_id(jd_foo_id) + assert jd_user.email == "jdoe@example.com" + + await repository.patch_by_id( + new_foo_id, FoobarPatchPayload(email="bar@example.com") + ) + new_user = await repository.get_by_id(new_foo_id) + assert new_user.email == "bar@example.com" + assert new_user.created_at < new_user.last_updated_at diff --git a/src/tests/libs/datalayer/test_datalayer.py b/src/tests/libs/datalayer/test_datalayer.py new file mode 100644 index 0000000..028f2b9 --- /dev/null +++ b/src/tests/libs/datalayer/test_datalayer.py @@ -0,0 +1,65 @@ +import uuid + +import pytest + +from app.libs.datalayer.connectionpool.manager import ConnectionPoolManager +from app.libs.datalayer.dbutils import DBUtils + + +@pytest.mark.asyncio +async def test_crud(fresh_postgres_url): + cpm = ConnectionPoolManager(fresh_postgres_url) + pool = await cpm.get_pool() + async with pool.connection() as conn: + await conn.execute( + """ + CREATE TABLE user_account ( + id UUID PRIMARY KEY, + email VARCHAR(50) UNIQUE NOT NULL, + full_name VARCHAR(50), + company_id UUID + ); + """ + ) + base_repo = DBUtils(pool) + + company_id = uuid.uuid4() + user_id = uuid.uuid4() + await base_repo.insert( + "user_account", + { + "id": user_id, + "email": "foo@example.com", + "company_id": company_id, + }, + ) + + user_accounts = await base_repo.select("user_account", {}) + assert len(user_accounts) == 1 + + user_account = await base_repo.select_one("user_account", {"id": user_id}) + assert user_account["id"] == user_id + assert user_account["full_name"] is None + assert user_account["email"] == "foo@example.com" + assert user_account["company_id"] == company_id + + await base_repo.update( + "user_account", + {"email": "bar@example.com"}, + {"id": user_id}, + ) + + user_account = await base_repo.select_one("user_account", {"id": user_id}) + assert user_account["email"] == "bar@example.com" + assert user_account["company_id"] == company_id + + await base_repo.delete( + "user_account", + {"id": user_id}, + ) + + user_account = await base_repo.select_one("user_account", {"id": user_id}) + assert user_account is None + + user_accounts = await base_repo.select("user_account", {}) + assert len(user_accounts) == 0 diff --git a/src/tests/libs/datalayer/test_metadata.py b/src/tests/libs/datalayer/test_metadata.py new file mode 100644 index 0000000..be0d926 --- /dev/null +++ b/src/tests/libs/datalayer/test_metadata.py @@ -0,0 +1,39 @@ +import pytest + +from app.libs.datalayer.connectionpool.manager import ConnectionPoolManager +from app.libs.datalayer.metadata.provider import TableMetadata, DBMetadataProvider + + +@pytest.mark.asyncio +async def test_get_table_metadata(fresh_postgres_url): + cpm = ConnectionPoolManager(fresh_postgres_url) + pool = await cpm.get_pool() + async with pool.connection() as conn: + await conn.execute( + """ + CREATE TABLE user_account ( + id UUID PRIMARY KEY, + email VARCHAR(50) UNIQUE NOT NULL, + full_name VARCHAR(50), + company_id UUID + ); + CREATE TABLE company ( + id UUID PRIMARY KEY, + name VARCHAR(50) UNIQUE NOT NULL, + address VARCHAR(50) + ); + ALTER TABLE ONLY user_account + ADD CONSTRAINT fk_user_account_company_id FOREIGN KEY (company_id) REFERENCES company (id); + """ + ) + metadata_service = DBMetadataProvider(pool) + meta = await metadata_service.get_table_metadata("user_account") + assert isinstance(meta, TableMetadata) + assert meta.table_name == "user_account" + assert meta.column_names == ["id", "email", "full_name", "company_id"] + assert meta.pk_name == "id" + assert len(meta.fks) == 1 + fk = meta.fks[0] + assert fk.column_name == "company_id" + assert fk.fk_target_table == "company" + assert fk.fk_target_column == "id"