Skip to content

Commit

Permalink
Rename ToolResponse model and only use it in execute operation of the…
Browse files Browse the repository at this point in the history
… tool API
  • Loading branch information
heisner-tillman committed Apr 7, 2024
1 parent c7738d8 commit 6553fad
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 18 deletions.
2 changes: 1 addition & 1 deletion lib/galaxy/schema/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class HDCADetailedWithOutputName(HDCADetailed):
AnyHDCAWithOutputName = Union[HDCADetailedWithOutputName, HDCASummaryWithOutputName]


class ToolResponse(Model):
class ExecuteToolResponse(Model):
outputs: List[AnyHDAWithOutputName] = Field(
default=[],
title="Outputs",
Expand Down
10 changes: 4 additions & 6 deletions lib/galaxy/webapps/galaxy/api/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from galaxy.schema.tools import (
ExecuteToolPayload,
ToolResponse,
ExecuteToolResponse,
)
from galaxy.tools.evaluation import global_tool_errors
from galaxy.util.zipstream import ZipstreamWrapper
Expand Down Expand Up @@ -80,9 +80,7 @@ class FetchTools:
service: ToolsService = depends(ToolsService)

@router.post("/api/tools/fetch", summary="Upload files to Galaxy", route_class_override=JsonApiRoute)
async def fetch_json(
self, payload: FetchDataPayload = Body(...), trans: ProvidesHistoryContext = DependsOnTrans
) -> ToolResponse:
async def fetch_json(self, payload: FetchDataPayload = Body(...), trans: ProvidesHistoryContext = DependsOnTrans):
return self.service.create_fetch(trans, payload)

@router.post(
Expand All @@ -96,7 +94,7 @@ async def fetch_form(
payload: FetchDataFormPayload = Depends(FetchDataForm.as_form),
files: Optional[List[UploadFile]] = None,
trans: ProvidesHistoryContext = DependsOnTrans,
) -> ToolResponse:
):
files2: List[StarletteUploadFile] = cast(List[StarletteUploadFile], files or [])

# FastAPI's UploadFile is a very light wrapper around starlette's UploadFile
Expand All @@ -115,7 +113,7 @@ def execute(
self,
payload: ExecuteToolPayload = depend_on_either_json_or_form_data(ExecuteToolPayload),
trans: ProvidesHistoryContext = DependsOnTrans,
) -> ToolResponse:
) -> ExecuteToolResponse:
return self.service.execute(trans, payload)


Expand Down
25 changes: 14 additions & 11 deletions lib/galaxy/webapps/galaxy/services/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
from galaxy.schema.tools import (
ExecuteToolPayload,
ToolResponse,
ExecuteToolResponse,
)
from galaxy.security.idencoding import IdEncodingHelper
from galaxy.tools import Tool
Expand Down Expand Up @@ -98,7 +98,7 @@ def create_fetch(
trans: ProvidesHistoryContext,
fetch_payload: Union[FetchDataFormPayload, FetchDataPayload],
files: Optional[List[UploadFile]] = None,
) -> ToolResponse:
):
payload = fetch_payload.model_dump(exclude_unset=True)
request_version = "1"
history_id = payload.pop("history_id")
Expand Down Expand Up @@ -132,7 +132,7 @@ def execute(
self,
trans: ProvidesHistoryContext,
payload: ExecuteToolPayload,
) -> ToolResponse:
) -> ExecuteToolResponse:
tool_id = payload.tool_id
tool_uuid = payload.tool_uuid
if tool_id in PROTECTED_TOOLS:
Expand All @@ -149,13 +149,14 @@ def execute(
if key.startswith("files_") and isinstance(create_payload[key], UploadFile):
files[key] = self.create_temp_file_execute(trans, create_payload.pop(key))
create_payload.update(files)
return self._create(trans, create_payload)
return self._create(trans, create_payload, encode_ids=False)

def _create(
self,
trans: ProvidesHistoryContext,
payload: Dict[str, Any],
) -> ToolResponse:
encode_ids: bool = True,
):
if trans.user_is_bootstrap_admin:
raise exceptions.RealUserRequiredException("Only real users can execute tools or run jobs.")
action = payload.get("action")
Expand Down Expand Up @@ -249,9 +250,9 @@ def _create(
with transaction(trans.sa_session):
trans.sa_session.commit()

return self._handle_inputs_output_to_api_response(trans, tool, target_history, vars)
return self._handle_inputs_output_to_api_response(trans, tool, target_history, vars, encode_ids)

def _handle_inputs_output_to_api_response(self, trans, tool, target_history, vars) -> ToolResponse:
def _handle_inputs_output_to_api_response(self, trans, tool, target_history, vars, encode_ids=True):
# TODO: check for errors and ensure that output dataset(s) are available.
output_datasets = vars.get("out_data", [])
rval: Dict[str, Any] = {"outputs": [], "output_collections": [], "jobs": [], "implicit_collections": []}
Expand All @@ -270,6 +271,9 @@ def _handle_inputs_output_to_api_response(self, trans, tool, target_history, var
output_dict["output_name"] = output_name
outputs.append(output_dict)

for job in vars.get("jobs", []):
rval["jobs"].append(job.to_dict(view="collection"))

for output_name, collection_instance in vars.get("output_collections", []):
history = target_history or trans.history
output_dict = dictify_dataset_collection_instance(
Expand All @@ -292,11 +296,10 @@ def _handle_inputs_output_to_api_response(self, trans, tool, target_history, var
output_dict["output_name"] = output_name
rval["implicit_collections"].append(output_dict)

# Encoding the job ids is handled by the pydantic model
for job in vars.get("jobs", []):
rval["jobs"].append(job.to_dict(view="collection"))
if encode_ids:
rval = self.encode_all_ids(rval, recursive=True)

return ToolResponse(**rval)
return rval

def _search(self, q, view):
"""
Expand Down

0 comments on commit 6553fad

Please sign in to comment.