Skip to content

Commit

Permalink
refactor(model): update whole model design
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxin688 committed Jul 4, 2024
1 parent 2400cdd commit 3a1c021
Show file tree
Hide file tree
Showing 57 changed files with 2,030 additions and 1,163 deletions.
6 changes: 3 additions & 3 deletions backend/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@

# Interpret the config file for Python logging.
# This line sets up loggers basically.
fileConfig(config.config_file_name) # PGH003
fileConfig(config.config_file_name) # type: ignore # noqa: PGH003

# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
from src.core.models.base import Base # noqa: E402
from src.core.database.base import Base
from src.features.admin.models import *
from src.features.dcim.models import *
from src.features.ipam.models import *
from src.features.circuit.models import *
from src.features.netconfig.models import *
from src.features.arch.models import *
from src.features.intend.models import *
from src.features.org.models import *

target_metadata = Base.metadata
Expand Down
1,122 changes: 1,122 additions & 0 deletions backend/alembic/versions/2024_07_04_2338-c09c3268e2b9_init_db.py

Large diffs are not rendered by default.

20 changes: 0 additions & 20 deletions backend/deploy/init_db_ext.py

This file was deleted.

34 changes: 10 additions & 24 deletions backend/deploy/init_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
from sqlalchemy import select

from src.core.config import PROJECT_DIR
from src.core.database.session import async_session
from src.core.utils.context import request_id_ctx, user_ctx
from src.db import Block, CircuitType, DeviceRole, DeviceType, Group, IPRole, Platform, RackRole, Role, User, Vendor
from src.db.database import sessionmanager
from src.features.admin.models import Group, Role, User
from src.features.admin.security import get_password_hash
from src.features.consts import ReservedRoleSlug
from src.features.dcim.models import DeviceType, Manufacturer
from src.features.intend.models import CircuitType, DeviceRole, IPRole, Platform
from src.features.ipam.models import Block

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -70,12 +73,12 @@ async def init_platform(session: "AsyncSession") -> None:
await session.commit()


async def init_vendor(session: "AsyncSession") -> None:
async def init_manufacturer(session: "AsyncSession") -> None:
async with await open_file(f"{PROJECT_DIR}/deploy/collections/metadata/vendor.json") as f:
contents = await f.read()
vendors = json.loads(contents)
new_vendors = [Vendor(**p) for p in vendors]
db_objs = (await session.scalars(select(Vendor))).all()
new_vendors = [Manufacturer(**p) for p in vendors]
db_objs = (await session.scalars(select(Manufacturer))).all()
if not db_objs:
session.add_all(new_vendors)
else:
Expand Down Expand Up @@ -136,22 +139,6 @@ async def init_device_role(session: "AsyncSession") -> None:
await session.commit()


async def init_rack_role(session: "AsyncSession") -> None:
async with await open_file(f"{PROJECT_DIR}/deploy/collections/metadata/rack_role.json") as f:
contents = await f.read()
rack_roles = json.loads(contents)
new_rack_roles = [RackRole(**r) for r in rack_roles]
db_objs = (await session.scalars(select(RackRole))).all()
if not db_objs:
session.add_all(new_rack_roles)
else:
slugs = [r.slug for r in db_objs]
for new_r in new_rack_roles:
if new_r.slug not in slugs:
session.add(new_r)
await session.commit()


async def init_ip_role(session: "AsyncSession") -> None:
async with await open_file(f"{PROJECT_DIR}/deploy/collections/metadata/ip_role.json") as f:
contents = await f.read()
Expand Down Expand Up @@ -186,14 +173,13 @@ async def init_block(session: "AsyncSession") -> None:

async def init_meta() -> None:
request_id_ctx.set(str(uuid4()))
async with sessionmanager.session() as session:
async with async_session() as session:
await init_user(session)
await init_platform(session)
await init_vendor(session)
await init_manufacturer(session)
await init_device_type(session)
await init_circuit_type(session)
await init_device_role(session)
await init_rack_role(session)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ fixable = ["ALL"]
"deps.py" = ["B008"]
"src/features/internal/api.py" = ["ARG001"]
"src/features/admin/schemas.py" = ["N815"] # frontend menu
"alembic/*.py" = ["INP001", "UP007"]
"alembic/*.py" = ["INP001", "UP007", "PLR0915", "E402", "F403"]
"__init__.py" = ["F403"]

[tool.ruff.lint.flake8-bugbear]
Expand Down
8 changes: 4 additions & 4 deletions backend/src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from starlette.middleware.errors import ServerErrorMiddleware

from src.core.config import _Env, settings
from src.core.errors.exceptions import default_exception_handler, exception_handlers, sentry_ignore_errors
from src.libs.redis import cache
from src.core.errors.exception_handlers import default_exception_handler, exception_handlers, sentry_ignore_errors
from src.libs.redis import session
from src.register.middlewares import RequestMiddleware
from src.register.openapi import get_open_api_intro, get_stoplight_elements_html
from src.register.routers import router
Expand All @@ -20,9 +20,9 @@ def create_app() -> FastAPI:
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]: # noqa: ARG001
pool = aioreids.ConnectionPool.from_url(
settings.REDIS_DSN, encoding="utf-8", db=cache.RedisDBType.DEFAULT, decode_response=True
settings.REDIS_DSN, encoding="utf-8", db=session.RedisDBType.DEFAULT, decode_response=True
)
cache.redis_client = cache.FastapiCache(connection_pool=pool)
session.redis_client = session.FastapiCache(connection_pool=pool)
yield
await pool.disconnect()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Any, ClassVar, TypeVar
from typing import Any, ClassVar, TypedDict, TypeVar

from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import DeclarativeBase

from src.features._types import VisibleName

class VisibleName(TypedDict, total=True):
en_US: str
zh_CN: str


class Base(DeclarativeBase):
Expand Down
5 changes: 5 additions & 0 deletions backend/src/core/database/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from src.core.database.mixins.audit_log import AuditLog, AuditLogMixin
from src.core.database.mixins.audit_time import AuditTimeMixin
from src.core.database.mixins.audit_user import AuditUserMixin

__all__ = ("AuditLogMixin", "AuditTimeMixin", "AuditUserMixin", "AuditLog")
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@
from sqlalchemy.orm import Mapped, Mapper, class_mapper, mapped_column, relationship
from sqlalchemy.orm.attributes import get_history

from src.core.database.base import Base
from src.core.database.types import DateTimeTZ, int_pk
from src.core.models.base import Base
from src.core.utils.context import orm_diff_ctx, request_id_ctx, user_ctx

if TYPE_CHECKING:

from src.core.models.base import ModelT
from src.core.database.base import ModelT
from src.features.admin.models import User


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@

class AuditTimeMixin:
created_at: Mapped[datetime] = mapped_column(DateTimeTZ, default=func.now(), index=True)
updated_at: Mapped[datetime] = mapped_column(DateTimeTZ, default=func.now(), onupdate=func.now())
updated_at: Mapped[datetime | None] = mapped_column(DateTimeTZ, onupdate=func.now())
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime
from typing import TYPE_CHECKING

from sqlalchemy import Integer, func
Expand All @@ -8,14 +9,12 @@
from src.core.utils.context import user_ctx

if TYPE_CHECKING:
from datetime import datetime

from src.features.admin.models import User


class AuditUserMixin:
created_at: Mapped["datetime"] = mapped_column(DateTimeTZ, default=func.now(), index=True)
updated_at: Mapped["datetime"] = mapped_column(DateTimeTZ, default=func.now(), onupdate=func.now())
created_at: Mapped[datetime] = mapped_column(DateTimeTZ, default=func.now(), index=True)
updated_at: Mapped[datetime] = mapped_column(DateTimeTZ, default=func.now(), onupdate=func.now())

@declared_attr
@classmethod
Expand Down
2 changes: 1 addition & 1 deletion backend/src/core/database/types/annotated.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
datetime_optional = Annotated[datetime, mapped_column(DateTimeTZ, nullable=True)]
date_required = Annotated[date, mapped_column(Date, nullable=False)]
date_optional = Annotated[date | None, mapped_column(Date, nullable=True)]
i18n_name = Annotated[dict, mapped_column(MutableDict.as_mutable(HSTORE))]
i18n_name = Annotated[dict, mapped_column(MutableDict.as_mutable(HSTORE))] # type: ignore # noqa: PGH003
179 changes: 179 additions & 0 deletions backend/src/core/errors/exception_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import logging
import sys
import traceback
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from typing import Any, NewType
from uuid import UUID

from fastapi import Request, status
from fastapi.responses import JSONResponse

from src.core.errors import err_codes
from src.core.errors.err_codes import ErrorCode
from src.core.utils.context import locale_ctx, request_id_ctx
from src.core.utils.i18n import _

_E = NewType("_E", Exception)
logger = logging.getLogger(__name__)


def error_message_value_handler(value: Any) -> Any:
if isinstance(value, dict) and "en_US" in value:
return value[locale_ctx.get()]
if isinstance(value, IPv4Address | IPv6Address | IPv4Network | IPv6Network | IPv4Interface | IPv6Interface | UUID):
return str(value)
if isinstance(value, list):
return [str(_v) for _v in value]
return value


class TokenInvalidForRefreshError(Exception): ...


class TokenInvalidError(Exception): ...


class TokenExpireError(Exception): ...


class PermissionDenyError(Exception): ...


class NotFoundError(Exception):
def __init__(self, name: str, field: str, value: Any) -> None:
self.name = name
self.field = field
self.value = error_message_value_handler(value)

def __repr__(self) -> str:
return f"Object:{self.name} with field:{self.field}-value:{self.value} not found."


class ExistError(Exception):
def __init__(self, name: str, field: str, value: Any) -> None:
self.name = name
self.field = field
self.value = error_message_value_handler(value)

def __repr__(self) -> str:
return f"Object:{self.name} with field:{self.field}-value:{self.value} already exist."


class GenerError(Exception):
def __init__(
self,
error: ErrorCode,
params: dict[str, Any] | None = None,
status_code: int = status.HTTP_400_BAD_REQUEST,
) -> None:
self.error = error
self.params = params
self.status_code = status_code

def __repr__(self) -> str:
return f"Gener Error Occurred: ErrCode: {self.error.error}, Message: {self.error.message}"


def log_exception(exc: type[BaseException] | Exception, logger_trace_info: bool) -> None:
"""
Logs an exception.
Args:
exc (Type[BaseException] | Exception): The exception to be logged.
logger_trace_info (bool): Indicates whether to include detailed trace information in the log.
Returns:
None
Raises:
N/A
"""
logger = logging.getLogger(__name__)
ex_type, _, ex_traceback = sys.exc_info()
trace_back = traceback.format_list(traceback.extract_tb(ex_traceback)[-1:])[-1]

logger.warning(f"ErrorMessage: {exc!s}")
logger.warning(f"Exception Type {ex_type.__name__}: ")

if not logger_trace_info:
logger.warning(f"Stack trace: {trace_back}")
else:
logger.exception(f"Stack trace: {trace_back}")


async def token_invalid_handler(request: Request, exc: TokenInvalidError) -> JSONResponse:
log_exception(exc, False)
response_content = err_codes.ERR_10002.dict()
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=response_content)


async def invalid_token_for_refresh_handler(request: Request, exc: TokenInvalidForRefreshError) -> JSONResponse:
log_exception(exc, False)
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=err_codes.ERR_10004.dict())


async def token_expired_handler(request: Request, exc: TokenExpireError) -> JSONResponse:
log_exception(exc, False)
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=err_codes.ERR_10003.dict())


async def permission_deny_handler(request: Request, exc: PermissionDenyError) -> JSONResponse:
log_exception(exc, False)
return JSONResponse(status_code=status.HTTP_403_FORBIDDEN, content=err_codes.ERR_10004.dict())


async def resource_not_found_handler(request: Request, exc: NotFoundError) -> JSONResponse:
log_exception(exc, True)
error_message = _(err_codes.ERR_404.message, name=exc.name, filed=exc.field, value=exc.value)
content = {"error": err_codes.ERR_404.error, "message": error_message}
return JSONResponse(status_code=status.HTTP_404_NOT_FOUND, content=content)


async def resource_exist_handler(request: Request, exc: ExistError) -> JSONResponse:
log_exception(exc, True)
error_message = _(err_codes.ERR_409.message, name=exc.name, filed=exc.field, value=exc.value)
content = {"error": err_codes.ERR_409.error, "message": error_message}
return JSONResponse(status_code=status.HTTP_404_NOT_FOUND, content=content)


def gener_error_handler(request: Request, exc: GenerError) -> JSONResponse:
log_exception(exc, True)
return JSONResponse(
status_code=exc.status_code,
content={
"error": exc.error.error,
"message": _(exc.error.message, **exc.params) if exc.params else _(exc.error.message),
},
)


def default_exception_handler(request: Request, exc: Exception) -> JSONResponse:
log_exception(exc, logger_trace_info=True)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"error": err_codes.ERR_500.error,
"message": _(err_codes.ERR_500.message, request_id=request_id_ctx.get()),
},
)


exception_handlers = [
{"exception": TokenInvalidError, "handler": token_invalid_handler},
{"exception": TokenExpireError, "handler": token_expired_handler},
{"exception": TokenInvalidForRefreshError, "handler": invalid_token_for_refresh_handler},
{"exception": PermissionDenyError, "handler": permission_deny_handler},
{"exception": NotFoundError, "handler": resource_not_found_handler},
{"exception": ExistError, "handler": resource_exist_handler},
{"exception": GenerError, "handler": gener_error_handler},
]


sentry_ignore_errors = [
TokenExpireError,
TokenInvalidError,
TokenInvalidForRefreshError,
PermissionDenyError,
NotFoundError,
ExistError,
]
Loading

0 comments on commit 3a1c021

Please sign in to comment.