diff --git a/src/backend/app/db/models.py b/src/backend/app/db/models.py index cbc7bc409..aff4408dd 100644 --- a/src/backend/app/db/models.py +++ b/src/backend/app/db/models.py @@ -78,7 +78,7 @@ ProjectUpdate, ) from app.tasks.task_schemas import TaskEventIn - from app.users.user_schemas import UserIn + from app.users.user_schemas import UserIn, UserUpdate def dump_and_check_model(db_model: BaseModel): @@ -314,6 +314,36 @@ async def create( return new_user + @classmethod + async def update( + cls, db: Connection, user_id: int, user_update: "UserUpdate" + ) -> Self: + """Update the role of a specific user.""" + model_dump = dump_and_check_model(user_update) + placeholders = [f"{key} = %({key})s" for key in model_dump.keys()] + sql = f""" + UPDATE users + SET {", ".join(placeholders)} + WHERE id = %(user_id)s + RETURNING *; + """ + + async with db.cursor(row_factory=class_row(cls)) as cur: + await cur.execute( + sql, + {"user_id": user_id, **model_dump}, + ) + updated_user = await cur.fetchone() + + if updated_user is None: + msg = f"Failed to update user with ID: {user_id}" + log.error(msg) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=msg + ) + + return updated_user + class DbOrganisation(BaseModel): """Table organisations.""" diff --git a/src/backend/app/users/user_routes.py b/src/backend/app/users/user_routes.py index 67bbdb0c3..9f1389790 100644 --- a/src/backend/app/users/user_routes.py +++ b/src/backend/app/users/user_routes.py @@ -58,6 +58,17 @@ async def get_user_roles(current_user: Annotated[DbUser, Depends(mapper)]): return user_roles +@router.patch("/{user_id}", response_model=user_schemas.UserOut) +async def update_existing_user( + user_id: int, + new_user_data: user_schemas.UserUpdate, + current_user: Annotated[DbUser, Depends(super_admin)], + db: Annotated[Connection, Depends(db_conn)], +): + """Change the role of a user.""" + return await DbUser.update(db=db, user_id=user_id, user_update=new_user_data) + + @router.get("/{id}", response_model=user_schemas.UserOut) async def get_user_by_identifier( user: Annotated[DbUser, Depends(get_user)], diff --git a/src/backend/app/users/user_schemas.py b/src/backend/app/users/user_schemas.py index 0084f3caa..58585597e 100644 --- a/src/backend/app/users/user_schemas.py +++ b/src/backend/app/users/user_schemas.py @@ -19,9 +19,9 @@ from typing import Annotated, Optional -from pydantic import BaseModel, Field +from pydantic import AwareDatetime, BaseModel, Field -from app.db.enums import UserRole +from app.db.enums import ProjectRole, UserRole from app.db.models import DbUser, DbUserRole @@ -34,6 +34,22 @@ class UserIn(DbUser): pass +class UserUpdate(DbUser): + """User details for update in DB.""" + + # Exclude (do not allow update) + id: Annotated[Optional[int], Field(exclude=True)] = None + username: Annotated[Optional[str], Field(exclude=True)] = None + registered_at: Annotated[Optional[AwareDatetime], Field(exclude=True)] = None + tasks_mapped: Annotated[Optional[int], Field(exclude=True)] = None + tasks_validated: Annotated[Optional[int], Field(exclude=True)] = None + tasks_invalidated: Annotated[Optional[int], Field(exclude=True)] = None + project_roles: Annotated[Optional[dict[int, ProjectRole]], Field(exclude=True)] = ( + None + ) + orgs_managed: Annotated[Optional[list[int]], Field(exclude=True)] = None + + class UserOut(DbUser): """User with ID and role."""