Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Add Domain & Project RBAC apis #2920

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 207 additions & 2 deletions src/ai/backend/manager/models/domain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypedDict
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
NamedTuple,
Optional,
Sequence,
TypeAlias,
TypedDict,
cast,
override,
)

import graphene
import sqlalchemy as sa
Expand All @@ -10,7 +23,8 @@
from sqlalchemy.engine.result import Result
from sqlalchemy.engine.row import Row
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
from sqlalchemy.orm import relationship
from sqlalchemy.ext.asyncio import AsyncSession as SASession
from sqlalchemy.orm import load_only, relationship

from ai.backend.common import msgpack
from ai.backend.common.types import ResourceSlot
Expand All @@ -29,6 +43,17 @@
simple_db_mutate,
simple_db_mutate_returning_item,
)
from .rbac import (
AbstractPermissionContext,
AbstractPermissionContextBuilder,
DomainScope,
ProjectScope,
ScopeType,
UserScope,
get_predefined_roles_in_scope,
)
from .rbac.context import ClientContext
from .rbac.permission_defs import DomainPermission
from .scaling_group import ScalingGroup
from .user import UserRole

Expand Down Expand Up @@ -415,3 +440,183 @@ def verify_dotfile_name(dotfile: str) -> bool:
if dotfile in RESERVED_DOTFILES:
return False
return True


OWNER_PERMISSIONS: frozenset[DomainPermission] = frozenset([perm for perm in DomainPermission])
ADMIN_PERMISSIONS: frozenset[DomainPermission] = frozenset([
DomainPermission.READ_ATTRIBUTE,
DomainPermission.UPDATE_ATTRIBUTE,
DomainPermission.CREATE_USER,
DomainPermission.CREATE_PROJECT,
])
MONITOR_PERMISSIONS: frozenset[DomainPermission] = frozenset([
DomainPermission.READ_ATTRIBUTE,
DomainPermission.UPDATE_ATTRIBUTE,
])
PRIVILEGED_MEMBER_PERMISSIONS: frozenset[DomainPermission] = frozenset()
MEMBER_PERMISSIONS: frozenset[DomainPermission] = frozenset()

WhereClauseType: TypeAlias = (
sa.sql.expression.BinaryExpression | sa.sql.expression.BooleanClauseList
)


@dataclass
class DomainPermissionContext(AbstractPermissionContext[DomainPermission, DomainRow, str]):
@property
def query_condition(self) -> WhereClauseType | None:
cond: WhereClauseType | None = None

def _OR_coalesce(
base_cond: WhereClauseType | None,
_cond: sa.sql.expression.BinaryExpression,
) -> WhereClauseType:
return base_cond | _cond if base_cond is not None else _cond

if self.object_id_to_additional_permission_map:
cond = _OR_coalesce(
cond, DomainRow.id.in_(self.object_id_to_additional_permission_map.keys())
)
if self.object_id_to_overriding_permission_map:
cond = _OR_coalesce(
cond, DomainRow.id.in_(self.object_id_to_overriding_permission_map.keys())
)
return cond

async def build_query(self) -> sa.sql.Select | None:
cond = self.query_condition
if cond is None:
return None
return sa.select(DomainRow).where(cond)

async def calculate_final_permission(self, rbac_obj: DomainRow) -> frozenset[DomainPermission]:
domain_row = rbac_obj
domain_name = cast(str, domain_row.name)
permissions: frozenset[DomainPermission] = frozenset()

if (
overriding_perm := self.object_id_to_overriding_permission_map.get(domain_name)
) is not None:
permissions = overriding_perm
else:
permissions |= self.object_id_to_additional_permission_map.get(domain_name, set())
return permissions


class DomainPermissionContextBuilder(
AbstractPermissionContextBuilder[DomainPermission, DomainPermissionContext]
):
db_session: SASession

def __init__(self, db_session: SASession) -> None:
self.db_session = db_session

@override
async def calculate_permission(
self,
ctx: ClientContext,
target_scope: ScopeType,
) -> frozenset[DomainPermission]:
roles = await get_predefined_roles_in_scope(ctx, target_scope, self.db_session)
permissions = await self._calculate_permission_by_predefined_roles(roles)
return permissions

@override
async def build_ctx_in_system_scope(
self,
ctx: ClientContext,
) -> DomainPermissionContext:
from .domain import DomainRow

perm_ctx = DomainPermissionContext()
_domain_query_stmt = sa.select(DomainRow).options(load_only(DomainRow.name))
for row in await self.db_session.scalars(_domain_query_stmt):
to_be_merged = await self.build_ctx_in_domain_scope(ctx, DomainScope(row.name))
perm_ctx.merge(to_be_merged)
return perm_ctx

@override
async def build_ctx_in_domain_scope(
self,
ctx: ClientContext,
scope: DomainScope,
) -> DomainPermissionContext:
permissions = await self.calculate_permission(ctx, scope)
return DomainPermissionContext(
object_id_to_additional_permission_map={scope.domain_name: permissions}
)

@override
async def build_ctx_in_project_scope(
self, ctx: ClientContext, scope: ProjectScope
) -> DomainPermissionContext:
return DomainPermissionContext()

@override
async def build_ctx_in_user_scope(
self, ctx: ClientContext, scope: UserScope
) -> DomainPermissionContext:
return DomainPermissionContext()

@override
@classmethod
async def _permission_for_owner(
cls,
) -> frozenset[DomainPermission]:
return OWNER_PERMISSIONS

@override
@classmethod
async def _permission_for_admin(
cls,
) -> frozenset[DomainPermission]:
return ADMIN_PERMISSIONS

@override
@classmethod
async def _permission_for_monitor(
cls,
) -> frozenset[DomainPermission]:
return MONITOR_PERMISSIONS

@override
@classmethod
async def _permission_for_privileged_member(
cls,
) -> frozenset[DomainPermission]:
return PRIVILEGED_MEMBER_PERMISSIONS

@override
@classmethod
async def _permission_for_member(
cls,
) -> frozenset[DomainPermission]:
return MEMBER_PERMISSIONS


class DomainWithPermissionSet(NamedTuple):
domain_row: DomainRow
permissions: frozenset[DomainPermission]


async def get_domains(
target_scope: ScopeType,
requested_permission: DomainPermission,
domain_name: Optional[str] = None,
*,
ctx: ClientContext,
db_conn: SAConnection,
) -> list[DomainWithPermissionSet]:
async with ctx.db.begin_readonly_session(db_conn) as db_session:
builder = DomainPermissionContextBuilder(db_session)
permission_ctx = await builder.build(ctx, target_scope, requested_permission)
query_stmt = await permission_ctx.build_query()
if query_stmt is None:
return []
if domain_name is not None:
query_stmt = query_stmt.where(DomainRow.name == domain_name)
result: list[DomainWithPermissionSet] = []
async for row in await db_session.stream_scalars(query_stmt):
permissions = await permission_ctx.calculate_final_permission(row)
result.append(DomainWithPermissionSet(row, permissions))
return result
Loading