Skip to content

Commit

Permalink
Introduces update_current_group method for better group management
Browse files Browse the repository at this point in the history
Ensures default group handling and validation in credential deletion
  • Loading branch information
arash77 committed Jan 15, 2025
1 parent 1bdacdb commit 1833a0e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 19 deletions.
34 changes: 18 additions & 16 deletions lib/galaxy/managers/credentials.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import (
Any,
Dict,
List,
Optional,
Tuple,
Expand Down Expand Up @@ -94,18 +92,13 @@ def create_or_update_credentials(
trans: ProvidesUserContext,
payload: CreateSourceCredentialsPayload,
db_user_credentials: List[Tuple[UserCredentials, CredentialsGroup]],
credentials_dict: Dict[int, Dict[str, Any]],
) -> None:
session = trans.sa_session
existing_groups = {
cred["reference"]: {group["name"]: group["id"] for group in cred["groups"].values()}
for cred in credentials_dict.values()
}
for service_payload in payload.credentials:
reference = service_payload.reference
current_group_name = service_payload.current_group
current_group_id = existing_groups.get(reference, {}).get(current_group_name)

if not current_group_name:
current_group_name = "default"
user_credentials = next((uc[0] for uc in db_user_credentials if uc[0].reference == reference), None)
if not user_credentials:
user_credentials = UserCredentials(
Expand All @@ -129,8 +122,6 @@ def create_or_update_credentials(
session.add(credentials_group)
session.flush()
user_credential_group_id = credentials_group.id
if current_group_name == group_name:
current_group_id = user_credential_group_id
variables, secrets = self.fetch_credentials(trans.sa_session, user_credential_group_id)
user_vault = UserVaultWrapper(self._app.vault, trans.user)
for variable_payload in group.variables:
Expand Down Expand Up @@ -169,14 +160,25 @@ def create_or_update_credentials(
session.add(secret)
vault_ref = f"{payload.source_type}|{payload.source_id}|{reference}|{group_name}|{secret_name}"
user_vault.write_secret(vault_ref, secret_value)
if not current_group_id:
raise RequestParameterInvalidException("No current group selected.")
user_credentials.current_group_id = current_group_id
session.add(user_credentials)

self.update_current_group(trans, user_credentials_id, current_group_name)
with transaction(session):
session.commit()

def update_current_group(
self,
trans: ProvidesUserContext,
user_credentials_id: DecodedDatabaseIdField,
group_name: str,
) -> None:
db_user_credentials = self.get_user_credentials(trans, trans.user.id, user_credentials_id=user_credentials_id)
for user_credentials, credentials_group in db_user_credentials:
if credentials_group.name == group_name:
user_credentials.current_group_id = credentials_group.id
trans.sa_session.add(user_credentials)
break
else:
raise RequestParameterInvalidException("Group not found to set as current.")

def delete_rows(
self,
session: galaxy_scoped_session,
Expand Down
13 changes: 10 additions & 3 deletions lib/galaxy/webapps/galaxy/services/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
Union,
)

from galaxy.exceptions import ObjectNotFound
from galaxy.exceptions import (
ObjectNotFound,
RequestParameterInvalidException,
)
from galaxy.managers.context import ProvidesUserContext
from galaxy.managers.credentials import CredentialsManager
from galaxy.model import (
Expand Down Expand Up @@ -54,8 +57,7 @@ def provide_credential(
"""Allows users to provide credentials for a group of secrets and variables."""
source_type, source_id = payload.source_type, payload.source_id
db_user_credentials = self._credentials_manager.get_user_credentials(trans, user_id, source_type, source_id)
credentials_dict = self._map_user_credentials(db_user_credentials)
self._credentials_manager.create_or_update_credentials(trans, payload, db_user_credentials, credentials_dict)
self._credentials_manager.create_or_update_credentials(trans, payload, db_user_credentials)
return self._list_user_credentials(trans, user_id, source_type, source_id)

def delete_credentials(
Expand All @@ -75,6 +77,11 @@ def delete_credentials(
for uc, credentials_group in db_user_credentials:
if not group_id:
rows_to_delete.append(uc)
else:
if credentials_group.name == "default":
raise RequestParameterInvalidException("Cannot delete the default group.")
if credentials_group.id == uc.current_group_id:
self._credentials_manager.update_current_group(trans, uc.id, "default")
variables, secrets = self._credentials_manager.fetch_credentials(trans.sa_session, credentials_group.id)
rows_to_delete.extend([credentials_group, *variables, *secrets])
self._credentials_manager.delete_rows(trans.sa_session, rows_to_delete)
Expand Down

0 comments on commit 1833a0e

Please sign in to comment.