From 6553fad1b0b0aad08a46594fb66a914adb5a7d7a Mon Sep 17 00:00:00 2001 From: heisner-tillman Date: Sun, 7 Apr 2024 16:17:07 +0200 Subject: [PATCH] Rename ToolResponse model and only use it in execute operation of the tool API --- lib/galaxy/schema/tools.py | 2 +- lib/galaxy/webapps/galaxy/api/tools.py | 10 ++++----- lib/galaxy/webapps/galaxy/services/tools.py | 25 ++++++++++++--------- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/lib/galaxy/schema/tools.py b/lib/galaxy/schema/tools.py index 31787df7fc56..353109b177eb 100644 --- a/lib/galaxy/schema/tools.py +++ b/lib/galaxy/schema/tools.py @@ -137,7 +137,7 @@ class HDCADetailedWithOutputName(HDCADetailed): AnyHDCAWithOutputName = Union[HDCADetailedWithOutputName, HDCASummaryWithOutputName] -class ToolResponse(Model): +class ExecuteToolResponse(Model): outputs: List[AnyHDAWithOutputName] = Field( default=[], title="Outputs", diff --git a/lib/galaxy/webapps/galaxy/api/tools.py b/lib/galaxy/webapps/galaxy/api/tools.py index f4535d19cc55..b743fa50e00c 100644 --- a/lib/galaxy/webapps/galaxy/api/tools.py +++ b/lib/galaxy/webapps/galaxy/api/tools.py @@ -33,7 +33,7 @@ ) from galaxy.schema.tools import ( ExecuteToolPayload, - ToolResponse, + ExecuteToolResponse, ) from galaxy.tools.evaluation import global_tool_errors from galaxy.util.zipstream import ZipstreamWrapper @@ -80,9 +80,7 @@ class FetchTools: service: ToolsService = depends(ToolsService) @router.post("/api/tools/fetch", summary="Upload files to Galaxy", route_class_override=JsonApiRoute) - async def fetch_json( - self, payload: FetchDataPayload = Body(...), trans: ProvidesHistoryContext = DependsOnTrans - ) -> ToolResponse: + async def fetch_json(self, payload: FetchDataPayload = Body(...), trans: ProvidesHistoryContext = DependsOnTrans): return self.service.create_fetch(trans, payload) @router.post( @@ -96,7 +94,7 @@ async def fetch_form( payload: FetchDataFormPayload = Depends(FetchDataForm.as_form), files: Optional[List[UploadFile]] = None, trans: ProvidesHistoryContext = DependsOnTrans, - ) -> ToolResponse: + ): files2: List[StarletteUploadFile] = cast(List[StarletteUploadFile], files or []) # FastAPI's UploadFile is a very light wrapper around starlette's UploadFile @@ -115,7 +113,7 @@ def execute( self, payload: ExecuteToolPayload = depend_on_either_json_or_form_data(ExecuteToolPayload), trans: ProvidesHistoryContext = DependsOnTrans, - ) -> ToolResponse: + ) -> ExecuteToolResponse: return self.service.execute(trans, payload) diff --git a/lib/galaxy/webapps/galaxy/services/tools.py b/lib/galaxy/webapps/galaxy/services/tools.py index 7b3ebb371a8c..f33932bcd743 100644 --- a/lib/galaxy/webapps/galaxy/services/tools.py +++ b/lib/galaxy/webapps/galaxy/services/tools.py @@ -36,7 +36,7 @@ ) from galaxy.schema.tools import ( ExecuteToolPayload, - ToolResponse, + ExecuteToolResponse, ) from galaxy.security.idencoding import IdEncodingHelper from galaxy.tools import Tool @@ -98,7 +98,7 @@ def create_fetch( trans: ProvidesHistoryContext, fetch_payload: Union[FetchDataFormPayload, FetchDataPayload], files: Optional[List[UploadFile]] = None, - ) -> ToolResponse: + ): payload = fetch_payload.model_dump(exclude_unset=True) request_version = "1" history_id = payload.pop("history_id") @@ -132,7 +132,7 @@ def execute( self, trans: ProvidesHistoryContext, payload: ExecuteToolPayload, - ) -> ToolResponse: + ) -> ExecuteToolResponse: tool_id = payload.tool_id tool_uuid = payload.tool_uuid if tool_id in PROTECTED_TOOLS: @@ -149,13 +149,14 @@ def execute( if key.startswith("files_") and isinstance(create_payload[key], UploadFile): files[key] = self.create_temp_file_execute(trans, create_payload.pop(key)) create_payload.update(files) - return self._create(trans, create_payload) + return self._create(trans, create_payload, encode_ids=False) def _create( self, trans: ProvidesHistoryContext, payload: Dict[str, Any], - ) -> ToolResponse: + encode_ids: bool = True, + ): if trans.user_is_bootstrap_admin: raise exceptions.RealUserRequiredException("Only real users can execute tools or run jobs.") action = payload.get("action") @@ -249,9 +250,9 @@ def _create( with transaction(trans.sa_session): trans.sa_session.commit() - return self._handle_inputs_output_to_api_response(trans, tool, target_history, vars) + return self._handle_inputs_output_to_api_response(trans, tool, target_history, vars, encode_ids) - def _handle_inputs_output_to_api_response(self, trans, tool, target_history, vars) -> ToolResponse: + def _handle_inputs_output_to_api_response(self, trans, tool, target_history, vars, encode_ids=True): # TODO: check for errors and ensure that output dataset(s) are available. output_datasets = vars.get("out_data", []) rval: Dict[str, Any] = {"outputs": [], "output_collections": [], "jobs": [], "implicit_collections": []} @@ -270,6 +271,9 @@ def _handle_inputs_output_to_api_response(self, trans, tool, target_history, var output_dict["output_name"] = output_name outputs.append(output_dict) + for job in vars.get("jobs", []): + rval["jobs"].append(job.to_dict(view="collection")) + for output_name, collection_instance in vars.get("output_collections", []): history = target_history or trans.history output_dict = dictify_dataset_collection_instance( @@ -292,11 +296,10 @@ def _handle_inputs_output_to_api_response(self, trans, tool, target_history, var output_dict["output_name"] = output_name rval["implicit_collections"].append(output_dict) - # Encoding the job ids is handled by the pydantic model - for job in vars.get("jobs", []): - rval["jobs"].append(job.to_dict(view="collection")) + if encode_ids: + rval = self.encode_all_ids(rval, recursive=True) - return ToolResponse(**rval) + return rval def _search(self, q, view): """