From 1833a0e4759b082962932b648ea2272af7cdb5fd Mon Sep 17 00:00:00 2001 From: Arash Date: Wed, 15 Jan 2025 17:16:28 +0100 Subject: [PATCH] Introduces update_current_group method for better group management Ensures default group handling and validation in credential deletion --- lib/galaxy/managers/credentials.py | 34 ++++++++++--------- .../webapps/galaxy/services/credentials.py | 13 +++++-- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/lib/galaxy/managers/credentials.py b/lib/galaxy/managers/credentials.py index 782b20cf519b..d550d15904f2 100644 --- a/lib/galaxy/managers/credentials.py +++ b/lib/galaxy/managers/credentials.py @@ -1,6 +1,4 @@ from typing import ( - Any, - Dict, List, Optional, Tuple, @@ -94,18 +92,13 @@ def create_or_update_credentials( trans: ProvidesUserContext, payload: CreateSourceCredentialsPayload, db_user_credentials: List[Tuple[UserCredentials, CredentialsGroup]], - credentials_dict: Dict[int, Dict[str, Any]], ) -> None: session = trans.sa_session - existing_groups = { - cred["reference"]: {group["name"]: group["id"] for group in cred["groups"].values()} - for cred in credentials_dict.values() - } for service_payload in payload.credentials: reference = service_payload.reference current_group_name = service_payload.current_group - current_group_id = existing_groups.get(reference, {}).get(current_group_name) - + if not current_group_name: + current_group_name = "default" user_credentials = next((uc[0] for uc in db_user_credentials if uc[0].reference == reference), None) if not user_credentials: user_credentials = UserCredentials( @@ -129,8 +122,6 @@ def create_or_update_credentials( session.add(credentials_group) session.flush() user_credential_group_id = credentials_group.id - if current_group_name == group_name: - current_group_id = user_credential_group_id variables, secrets = self.fetch_credentials(trans.sa_session, user_credential_group_id) user_vault = UserVaultWrapper(self._app.vault, trans.user) for variable_payload in group.variables: @@ -169,14 +160,25 @@ def create_or_update_credentials( session.add(secret) vault_ref = f"{payload.source_type}|{payload.source_id}|{reference}|{group_name}|{secret_name}" user_vault.write_secret(vault_ref, secret_value) - if not current_group_id: - raise RequestParameterInvalidException("No current group selected.") - user_credentials.current_group_id = current_group_id - session.add(user_credentials) - + self.update_current_group(trans, user_credentials_id, current_group_name) with transaction(session): session.commit() + def update_current_group( + self, + trans: ProvidesUserContext, + user_credentials_id: DecodedDatabaseIdField, + group_name: str, + ) -> None: + db_user_credentials = self.get_user_credentials(trans, trans.user.id, user_credentials_id=user_credentials_id) + for user_credentials, credentials_group in db_user_credentials: + if credentials_group.name == group_name: + user_credentials.current_group_id = credentials_group.id + trans.sa_session.add(user_credentials) + break + else: + raise RequestParameterInvalidException("Group not found to set as current.") + def delete_rows( self, session: galaxy_scoped_session, diff --git a/lib/galaxy/webapps/galaxy/services/credentials.py b/lib/galaxy/webapps/galaxy/services/credentials.py index 52afaaa0de82..cd59be68ae83 100644 --- a/lib/galaxy/webapps/galaxy/services/credentials.py +++ b/lib/galaxy/webapps/galaxy/services/credentials.py @@ -7,7 +7,10 @@ Union, ) -from galaxy.exceptions import ObjectNotFound +from galaxy.exceptions import ( + ObjectNotFound, + RequestParameterInvalidException, +) from galaxy.managers.context import ProvidesUserContext from galaxy.managers.credentials import CredentialsManager from galaxy.model import ( @@ -54,8 +57,7 @@ def provide_credential( """Allows users to provide credentials for a group of secrets and variables.""" source_type, source_id = payload.source_type, payload.source_id db_user_credentials = self._credentials_manager.get_user_credentials(trans, user_id, source_type, source_id) - credentials_dict = self._map_user_credentials(db_user_credentials) - self._credentials_manager.create_or_update_credentials(trans, payload, db_user_credentials, credentials_dict) + self._credentials_manager.create_or_update_credentials(trans, payload, db_user_credentials) return self._list_user_credentials(trans, user_id, source_type, source_id) def delete_credentials( @@ -75,6 +77,11 @@ def delete_credentials( for uc, credentials_group in db_user_credentials: if not group_id: rows_to_delete.append(uc) + else: + if credentials_group.name == "default": + raise RequestParameterInvalidException("Cannot delete the default group.") + if credentials_group.id == uc.current_group_id: + self._credentials_manager.update_current_group(trans, uc.id, "default") variables, secrets = self._credentials_manager.fetch_credentials(trans.sa_session, credentials_group.id) rows_to_delete.extend([credentials_group, *variables, *secrets]) self._credentials_manager.delete_rows(trans.sa_session, rows_to_delete)