Skip to content

Commit

Permalink
chore: Add Domain & Project RBAC apis
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Oct 17, 2024
1 parent eed9d3c commit 067ea0b
Show file tree
Hide file tree
Showing 3 changed files with 430 additions and 3 deletions.
209 changes: 207 additions & 2 deletions src/ai/backend/manager/models/domain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypedDict
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
NamedTuple,
Optional,
Sequence,
TypeAlias,
TypedDict,
cast,
override,
)

import graphene
import sqlalchemy as sa
Expand All @@ -10,7 +23,8 @@
from sqlalchemy.engine.result import Result
from sqlalchemy.engine.row import Row
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
from sqlalchemy.orm import relationship
from sqlalchemy.ext.asyncio import AsyncSession as SASession
from sqlalchemy.orm import load_only, relationship

from ai.backend.common import msgpack
from ai.backend.common.types import ResourceSlot
Expand All @@ -29,6 +43,17 @@
simple_db_mutate,
simple_db_mutate_returning_item,
)
from .rbac import (
AbstractPermissionContext,
AbstractPermissionContextBuilder,
DomainScope,
ProjectScope,
ScopeType,
UserScope,
get_predefined_roles_in_scope,
)
from .rbac.context import ClientContext
from .rbac.permission_defs import DomainPermission
from .scaling_group import ScalingGroup
from .user import UserRole

Expand Down Expand Up @@ -415,3 +440,183 @@ def verify_dotfile_name(dotfile: str) -> bool:
if dotfile in RESERVED_DOTFILES:
return False
return True


OWNER_PERMISSIONS: frozenset[DomainPermission] = frozenset([perm for perm in DomainPermission])
ADMIN_PERMISSIONS: frozenset[DomainPermission] = frozenset([
DomainPermission.READ_ATTRIBUTE,
DomainPermission.UPDATE_ATTRIBUTE,
DomainPermission.CREATE_USER,
DomainPermission.CREATE_PROJECT,
])
MONITOR_PERMISSIONS: frozenset[DomainPermission] = frozenset([
DomainPermission.READ_ATTRIBUTE,
DomainPermission.UPDATE_ATTRIBUTE,
])
PRIVILEGED_MEMBER_PERMISSIONS: frozenset[DomainPermission] = frozenset()
MEMBER_PERMISSIONS: frozenset[DomainPermission] = frozenset()

WhereClauseType: TypeAlias = (
sa.sql.expression.BinaryExpression | sa.sql.expression.BooleanClauseList
)


@dataclass
class DomainPermissionContext(AbstractPermissionContext[DomainPermission, DomainRow, str]):
@property
def query_condition(self) -> WhereClauseType | None:
cond: WhereClauseType | None = None

def _OR_coalesce(
base_cond: WhereClauseType | None,
_cond: sa.sql.expression.BinaryExpression,
) -> WhereClauseType:
return base_cond | _cond if base_cond is not None else _cond

if self.object_id_to_additional_permission_map:
cond = _OR_coalesce(
cond, DomainRow.id.in_(self.object_id_to_additional_permission_map.keys())
)
if self.object_id_to_overriding_permission_map:
cond = _OR_coalesce(
cond, DomainRow.id.in_(self.object_id_to_overriding_permission_map.keys())
)
return cond

async def build_query(self) -> sa.sql.Select | None:
cond = self.query_condition
if cond is None:
return None
return sa.select(DomainRow).where(cond)

async def calculate_final_permission(self, rbac_obj: DomainRow) -> frozenset[DomainPermission]:
domain_row = rbac_obj
domain_name = cast(str, domain_row.name)
permissions: frozenset[DomainPermission] = frozenset()

if (
overriding_perm := self.object_id_to_overriding_permission_map.get(domain_name)
) is not None:
permissions = overriding_perm
else:
permissions |= self.object_id_to_additional_permission_map.get(domain_name, set())
return permissions


class DomainPermissionContextBuilder(
AbstractPermissionContextBuilder[DomainPermission, DomainPermissionContext]
):
db_session: SASession

def __init__(self, db_session: SASession) -> None:
self.db_session = db_session

@override
async def calculate_permission(
self,
ctx: ClientContext,
target_scope: ScopeType,
) -> frozenset[DomainPermission]:
roles = await get_predefined_roles_in_scope(ctx, target_scope, self.db_session)
permissions = await self._calculate_permission_by_predefined_roles(roles)
return permissions

@override
async def build_ctx_in_system_scope(
self,
ctx: ClientContext,
) -> DomainPermissionContext:
from .domain import DomainRow

perm_ctx = DomainPermissionContext()
_domain_query_stmt = sa.select(DomainRow).options(load_only(DomainRow.name))
for row in await self.db_session.scalars(_domain_query_stmt):
to_be_merged = await self.build_ctx_in_domain_scope(ctx, DomainScope(row.name))
perm_ctx.merge(to_be_merged)
return perm_ctx

@override
async def build_ctx_in_domain_scope(
self,
ctx: ClientContext,
scope: DomainScope,
) -> DomainPermissionContext:
permissions = await self.calculate_permission(ctx, scope)
return DomainPermissionContext(
object_id_to_additional_permission_map={scope.domain_name: permissions}
)

@override
async def build_ctx_in_project_scope(
self, ctx: ClientContext, scope: ProjectScope
) -> DomainPermissionContext:
return DomainPermissionContext()

@override
async def build_ctx_in_user_scope(
self, ctx: ClientContext, scope: UserScope
) -> DomainPermissionContext:
return DomainPermissionContext()

@override
@classmethod
async def _permission_for_owner(
cls,
) -> frozenset[DomainPermission]:
return OWNER_PERMISSIONS

@override
@classmethod
async def _permission_for_admin(
cls,
) -> frozenset[DomainPermission]:
return ADMIN_PERMISSIONS

@override
@classmethod
async def _permission_for_monitor(
cls,
) -> frozenset[DomainPermission]:
return MONITOR_PERMISSIONS

@override
@classmethod
async def _permission_for_privileged_member(
cls,
) -> frozenset[DomainPermission]:
return PRIVILEGED_MEMBER_PERMISSIONS

@override
@classmethod
async def _permission_for_member(
cls,
) -> frozenset[DomainPermission]:
return MEMBER_PERMISSIONS


class DomainWithPermissionSet(NamedTuple):
domain_row: DomainRow
permissions: frozenset[DomainPermission]


async def get_domains(
target_scope: ScopeType,
requested_permission: DomainPermission,
domain_name: Optional[str] = None,
*,
ctx: ClientContext,
db_conn: SAConnection,
) -> list[DomainWithPermissionSet]:
async with ctx.db.begin_readonly_session(db_conn) as db_session:
builder = DomainPermissionContextBuilder(db_session)
permission_ctx = await builder.build(ctx, target_scope, requested_permission)
query_stmt = await permission_ctx.build_query()
if query_stmt is None:
return []
if domain_name is not None:
query_stmt = query_stmt.where(DomainRow.name == domain_name)
result: list[DomainWithPermissionSet] = []
async for row in await db_session.stream_scalars(query_stmt):
permissions = await permission_ctx.calculate_final_permission(row)
result.append(DomainWithPermissionSet(row, permissions))
return result
Loading

0 comments on commit 067ea0b

Please sign in to comment.