Skip to content

Commit

Permalink
Fix type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
mr-c committed Nov 13, 2021
1 parent 44740a9 commit 7bf67a7
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 26 deletions.
13 changes: 7 additions & 6 deletions lib/galaxy/datatypes/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,13 @@ def sniff_prefix(self, file_prefix):
('header_size', 'i4'),
]
np_dtype = np.dtype(header_def)
header = np.ndarray(
shape=(),
dtype=np_dtype,
buffer=header_raw)
if header['header_size'] == 1000 and b'TRACK' in header['magic'] and \
header['version'] == 2 and len(header['dim']) == 3:
header: np.ndarray = np.ndarray(shape=(), dtype=np_dtype, buffer=header_raw)
if (
header["header_size"] == 1000
and b"TRACK" in header["magic"]
and header["version"] == 2
and len(header["dim"]) == 3
):
return True
return False

Expand Down
20 changes: 15 additions & 5 deletions lib/galaxy/tools/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Any,
cast,
Dict,
ItemsView,
Iterable,
Iterator,
KeysView,
Expand Down Expand Up @@ -56,7 +57,7 @@ class ToolParameterValueWrapper:
Base class for object that Wraps a Tool Parameter and Value.
"""

value: Union[str, List[str]]
value: Union[None, str, List[str], Dict[str, str]]
input: "ToolParameter"

def __bool__(self) -> bool:
Expand Down Expand Up @@ -103,10 +104,12 @@ class InputValueWrapper(ToolParameterValueWrapper):
Wraps an input so that __str__ gives the "param_dict" representation.
"""

value: Optional[Dict[str, str]]

def __init__(
self,
input: "ToolParameter",
value: str,
value: Dict[str, str],
other_values: Optional[Dict[str, str]] = None,
) -> None:
self.input = input
Expand Down Expand Up @@ -172,6 +175,7 @@ class SelectToolParameterWrapper(ToolParameterValueWrapper):
"""

input: "SelectToolParameter"
value: Union[str, List[str]]

class SelectToolParameterFieldWrapper:
"""
Expand Down Expand Up @@ -625,9 +629,13 @@ def __init__(
self.collection = collection

elements = collection.elements
element_instances = {}
element_instances: Dict[
str, Union[DatasetCollectionWrapper, DatasetFilenameWrapper]
] = {}

element_instance_list = []
element_instance_list: List[
Union[DatasetCollectionWrapper, DatasetFilenameWrapper]
] = []
for dataset_collection_element in elements:
element_object = dataset_collection_element.element_object
element_identifier = dataset_collection_element.element_identifier
Expand Down Expand Up @@ -662,7 +670,9 @@ def keys(self) -> Union[List[str], KeysView[Any]]:
return []
return self.__element_instances.keys()

def items(self):
def items(
self,
) -> ItemsView[str, Union["DatasetCollectionWrapper", DatasetFilenameWrapper]]:
return self.__element_instances.items()

@property
Expand Down
28 changes: 19 additions & 9 deletions lib/galaxy/workflow/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def get_inputs(self):
cases = []

for param_type in ["text", "integer", "float", "boolean", "color", "field"]:
default_source: Dict[str, Union[int, float, bool, str]] = dict(
default_source: Dict[str, Union[None, int, float, bool, str]] = dict(
name="default", label="Default Value", type=param_type
)
if param_type == "text":
Expand All @@ -885,6 +885,7 @@ def get_inputs(self):
input_default_value: Union[
TextToolParameter,
IntegerToolParameter,
FieldTypeToolParameter,
FloatToolParameter,
BooleanToolParameter,
ColorToolParameter,
Expand Down Expand Up @@ -1722,7 +1723,7 @@ def decode_runtime_state(self, runtime_state):

def evaluate_value_from_expressions(self, progress, step, execution_state, extra_step_state):
value_from_expressions = {}
replacements = {}
replacements: Dict[str, str] = {}

for key in execution_state.inputs.keys():
step_input = step.inputs_by_name.get(key)
Expand Down Expand Up @@ -1876,8 +1877,17 @@ def callback(input, prefixed_name, **kwargs):
replacement = json.load(f)
found_replacement_keys.add(prefixed_name)

is_data = isinstance(input, DataToolParameter) or isinstance(input, DataCollectionToolParameter) or isinstance(input, FieldTypeToolParameter)
if not is_data and getattr(replacement, "history_content_type", None) == "dataset" and getattr(replacement, "ext", None) == "expression.json":
is_data = (
isinstance(input, DataToolParameter)
or isinstance(input, DataCollectionToolParameter)
or isinstance(input, FieldTypeToolParameter)
)
if (
not is_data
and not isinstance(replacement, NoReplacement)
and getattr(replacement, "history_content_type", None) == "dataset"
and getattr(replacement, "ext", None) == "expression.json"
):
if isinstance(replacement, model.HistoryDatasetAssociation):
if not replacement.dataset.in_ready_state():
why = "dataset [%s] is needed for non-data connection and is non-ready" % replacement.id
Expand All @@ -1891,11 +1901,11 @@ def callback(input, prefixed_name, **kwargs):

if isinstance(input, FieldTypeToolParameter):
if isinstance(replacement, model.HistoryDatasetAssociation):
replacement = {"src": "hda", "value": replacement}
return {"src": "hda", "value": replacement}
elif isinstance(replacement, model.HistoryDatasetCollectionAssociation):
replacement = {"src": "hdca", "value": replacement}
return {"src": "hdca", "value": replacement}
elif replacement is not NO_REPLACEMENT:
replacement = {"src": "json", "value": replacement}
return {"src": "json", "value": replacement}

return replacement

Expand Down Expand Up @@ -1938,9 +1948,9 @@ def expression_callback(input, prefixed_name, **kwargs):
if prefixed_name in expression_replacements:
expression_replacement = expression_replacements[prefixed_name]
if isinstance(input, FieldTypeToolParameter):
replacement = {"src": "json", "value": expression_replacement}
return {"src": "json", "value": expression_replacement}
else:
replacement = expression_replacement
return expression_replacement

return replacement

Expand Down
3 changes: 1 addition & 2 deletions lib/galaxy/workflow/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,10 @@ def replacement_for_input_connections(self, step, input_dict, connections):
else:
raise NotImplementedError()

ephemeral_collection = modules.EphemeralCollection(
return modules.EphemeralCollection(
collection=collection,
history=self.workflow_invocation.history,
)
replacement = ephemeral_collection

return replacement

Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy_test/api/test_tools_cwl.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def test_any1_file(self):
test_data_directory="test/functional/tools/cwl_tools/v1.0/v1.0/",
)
output1_content = self.dataset_populator.get_history_dataset_content(run_object.history_id)
self.dataset_populator._summarize_history_errors(run_object.history_id)
self.dataset_populator._summarize_history(run_object.history_id)
assert output1_content == '"File"', "[%s]" % output1_content

@skip_without_tool("any1")
Expand Down
7 changes: 7 additions & 0 deletions lib/galaxy_test/api/test_workflows_cwl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_simplest_wf(self):
if "tool_representation" in step:
del step["tool_representation"]

assert self.history_id
hda1 = self.dataset_populator.new_dataset(self.history_id, content="hello world\nhello all\nhello all in world\nhello")
inputs_map = {
"file1": {"src": "hda", "id": hda1["id"]}
Expand Down Expand Up @@ -80,6 +81,7 @@ def test_count_lines3_v1(self):
}
invocation_id = self._invoke(inputs_map, workflow_id)
self.wait_for_invocation_and_jobs(self.history_id, workflow_id, invocation_id)
assert self.history_id
hdca = self.dataset_populator.get_history_collection_details(self.history_id, hid=5)
assert hdca["collection_type"] == "list"
elements = hdca["elements"]
Expand All @@ -92,6 +94,7 @@ def test_count_lines3_v1(self):

def test_count_lines4_v1(self):
workflow_id = self._load_workflow("v1.0/v1.0/count-lines4-wf.cwl")
assert self.history_id
hda1 = self.dataset_populator.new_dataset(self.history_id, content="hello world\nhello all\nhello all in world\nhello")
hda2 = self.dataset_populator.new_dataset(self.history_id, content="moo\ncow\nthat\nis\nall")
inputs_map = {
Expand All @@ -104,14 +107,17 @@ def test_count_lines4_v1(self):

def test_count_lines4_json(self):
self.cwl_populator.run_workflow_job("v1.0/v1.0/count-lines4-wf.cwl", "v1.0/v1.0/count-lines4-job.json", history_id=self.history_id)
assert self.history_id
self.dataset_populator.get_history_collection_details(self.history_id, hid=4)

def test_scatter_wf1_v1(self):
self.cwl_populator.run_workflow_job("v1.0/v1.0/scatter-wf1.cwl", "v1.0/v1.0/scatter-job1.json", history_id=self.history_id)
assert self.history_id
self.dataset_populator.get_history_collection_details(self.history_id, hid=5)

def _run_count_lines_wf(self, wf_path):
workflow_id = self._load_workflow(wf_path)
assert self.history_id
hda1 = self.dataset_populator.new_dataset(self.history_id, content="hello world\nhello all\nhello all in world\nhello")
inputs_map = {
"file1": {"src": "hda", "id": hda1["id"]}
Expand All @@ -121,6 +127,7 @@ def _run_count_lines_wf(self, wf_path):

def _check_countlines_wf(self, invocation_id, workflow_id, expected_count=4):
self.wait_for_invocation_and_jobs(self.history_id, workflow_id, invocation_id)
assert self.history_id
output = self.dataset_populator.get_history_dataset_content(self.history_id, hid=3)
assert str(expected_count) == output, output

Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy_test/base/populators.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def run_conformance_test(self, version, doc):
directory = os.path.join(CWL_TOOL_DIRECTORY, version)
tool = os.path.join(directory, test["tool"])
job_path = test.get("job")
job = None
job: Optional[Dict[str, str]] = None
if job_path is not None:
job_path = os.path.join(directory, job_path)
else:
Expand Down
5 changes: 3 additions & 2 deletions test/unit/app/tools/test_cwl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import shutil
import tempfile
from typing import Dict
from unittest import TestCase
from uuid import uuid4

Expand Down Expand Up @@ -546,12 +547,12 @@ def tearDown(self):
def test_default_data_inputs(self):
self._init_tool(tool_path=_cwl_tool_path("v1.0/v1.0/default_path.cwl"))
hda = self._new_hda()
errors = {}
errors: Dict[str, str] = {}
cwl_inputs = {
"file1": {"src": "hda", "id": self.app.security.encode_id(hda.id)}
}
inputs = self.tool.inputs_from_dict({"inputs": cwl_inputs, "inputs_representation": "cwl"})
populated_state = {}
populated_state: Dict[str, str] = {}
populate_state(self.trans, self.tool.inputs, inputs, populated_state, errors)
wrapped_params = WrappedParameters(self.trans, self.tool, populated_state)
input_json = to_cwl_job(self.tool, wrapped_params.params, self.test_directory)
Expand Down

0 comments on commit 7bf67a7

Please sign in to comment.