From c97e489a18979f2db6cbd4960cfde31d0d63a1d6 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Wed, 29 May 2024 15:26:21 +0930 Subject: [PATCH 01/21] debugging running tests --- nipype2pydra/cli/convert.py | 5 +++- nipype2pydra/workflow.py | 50 ++++++++++++++++++++++++++++++------- 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/nipype2pydra/cli/convert.py b/nipype2pydra/cli/convert.py index a50c27e..fe1c295 100644 --- a/nipype2pydra/cli/convert.py +++ b/nipype2pydra/cli/convert.py @@ -69,7 +69,10 @@ def convert( shutil.rmtree(package_dir / "auto") else: for fspath in package_dir.iterdir(): - if fspath == package_dir / "__init__.py": + if fspath.parent == package_dir and fspath.name in ( + "_version.py", + "__init__.py", + ): continue if fspath.is_dir(): shutil.rmtree(fspath) diff --git a/nipype2pydra/workflow.py b/nipype2pydra/workflow.py index 8250537..bdd56a6 100644 --- a/nipype2pydra/workflow.py +++ b/nipype2pydra/workflow.py @@ -71,6 +71,7 @@ class WorkflowInterfaceField: }, ) node_name: ty.Optional[str] = attrs.field( + default=None, metadata={ "help": "The name of the node that the input/output is connected to", }, @@ -130,7 +131,7 @@ def type_repr_(t): elif issubclass(t, Field): return t.primitive.__name__ elif issubclass(t, FileSet): - return t.__name__ + return t.type_name elif t.__module__ == "builtins": return t.__name__ else: @@ -764,7 +765,9 @@ def write( ), converted_code=self.test_code, used=self.test_used, - additional_imports=self.input_output_imports, + additional_imports=( + self.input_output_imports + parse_imports("import pytest") + ), ) conftest_fspath = test_module_fspath.parent / "conftest.py" @@ -931,22 +934,51 @@ def parsed_statements(self): def test_code(self): args_str = ", ".join(f"{n}={v}" for n, v in self.test_inputs.items()) - return f""" + code_str = f""" -def test_{self.name}(): + +def test_{self.name}_build(): workflow = {self.name}({args_str}) assert isinstance(workflow, Workflow) """ + inputs_dict = {} + for inpt in self.inputs.values(): + if issubclass(inpt.type, FileSet): + inputs_dict[inpt.name] = inpt.type.type_name + ".sample()" + elif inpt.name in self.test_inputs: + inputs_dict[inpt.name] = self.test_inputs[inpt.name] + args_str = ", ".join(f"{n}={v}" for n, v in inputs_dict.items()) + + code_str += f""" + +@pytest.mark.skip(reason="Appropriate inputs for this workflow haven't been specified yet") +def test_{self.name}_run(): + workflow = {self.name}({args_str}) + result = workflow(plugin='serial') + print(result.out) +""" + return code_str + @property def test_used(self): + nonstd_types = [ + i.type for i in self.inputs.values() if issubclass(i.type, FileSet) + ] + nonstd_type_imports = [] + for tp in itertools.chain(*(unwrap_nested_type(t) for t in nonstd_types)): + nonstd_type_imports.append(ImportStatement.from_object(tp)) + return UsedSymbols( module_name=self.nipype_module.__name__, - imports=parse_imports( - [ - f"from {self.output_module} import {self.name}", - "from pydra.engine import Workflow", - ] + imports=( + nonstd_type_imports + + parse_imports( + [ + f"from {self.output_module} import {self.name}", + "from pydra.engine import Workflow", + ] + ) ), ) From caee1bc192013279713466aca8c0ba203f215842 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Wed, 29 May 2024 23:57:52 +0930 Subject: [PATCH 02/21] fixing defaults for function interfaces --- nipype2pydra/interface/base.py | 2 +- nipype2pydra/interface/function.py | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/nipype2pydra/interface/base.py b/nipype2pydra/interface/base.py index f9b7314..acc5574 100644 --- a/nipype2pydra/interface/base.py +++ b/nipype2pydra/interface/base.py @@ -575,7 +575,7 @@ def pydra_fld_input(self, field, nm): pos = pydra_metadata.get("position", None) - if pydra_default is not None and not pydra_metadata.get("mandatory", None): + if pydra_default is not None: # and not pydra_metadata.get("mandatory", None): return (pydra_type, pydra_default, pydra_metadata), pos else: return (pydra_type, pydra_metadata), pos diff --git a/nipype2pydra/interface/function.py b/nipype2pydra/interface/function.py index f2bc153..45dd3de 100644 --- a/nipype2pydra/interface/function.py +++ b/nipype2pydra/interface/function.py @@ -105,7 +105,14 @@ def types_to_names(spec_fields): spec_str += ", ".join(f"'{n}': {t}" for n, t, _ in output_fields_str) spec_str += "}})\n" spec_str += f"def {self.task_name}(" - spec_str += ", ".join(f"{i[0]}: {i[1]}" for i in input_fields_str) + spec_str += ", ".join( + ( + f"{i[0]}: {i[1]} = {i[2]!r}" + if len(i) == 4 + else f"{i[0]}: {i[1]} = attrs.NOTHING" + ) + for i in input_fields_str + ) spec_str += ")" if output_type_names: spec_str += "-> " From 5e8da87bfacd13a3be81d27b829b51c356171277 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Thu, 30 May 2024 00:43:34 +0930 Subject: [PATCH 03/21] set mandatory to False when there is a valid default --- nipype2pydra/interface/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nipype2pydra/interface/base.py b/nipype2pydra/interface/base.py index acc5574..8af64df 100644 --- a/nipype2pydra/interface/base.py +++ b/nipype2pydra/interface/base.py @@ -539,6 +539,8 @@ def pydra_fld_input(self, field, nm): if val is not None: if key == "argstr" and "%" in val: val = self.string_formats(argstr=val, name=nm) + elif key == "mandatory" and pydra_default is not None: + val = False # Overwrite mandatory to False if default is provided pydra_metadata[pydra_key_nm] = val if getattr(field, "name_template"): From 843a83a5dea26e6d1b8a5e75f62a72ca1f5925e3 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Fri, 31 May 2024 17:08:37 +0930 Subject: [PATCH 04/21] added support for _format_arg and _parse_inputs methods --- nipype2pydra/cli/convert.py | 4 +- nipype2pydra/interface/base.py | 299 ++++++++++++++++++++++++ nipype2pydra/interface/function.py | 280 +--------------------- nipype2pydra/interface/shell_command.py | 173 +++++++++++++- nipype2pydra/utils/__init__.py | 1 + nipype2pydra/utils/misc.py | 10 + 6 files changed, 488 insertions(+), 279 deletions(-) diff --git a/nipype2pydra/cli/convert.py b/nipype2pydra/cli/convert.py index fe1c295..9c3a8d6 100644 --- a/nipype2pydra/cli/convert.py +++ b/nipype2pydra/cli/convert.py @@ -66,7 +66,9 @@ def convert( # Clean previous version of output dir package_dir = converter.package_dir(package_root) if converter.interface_only: - shutil.rmtree(package_dir / "auto") + auto_dir = package_dir / "auto" + if auto_dir.exists(): + shutil.rmtree(auto_dir) else: for fspath in package_dir.iterdir(): if fspath.parent == package_dir and fspath.name in ( diff --git a/nipype2pydra/interface/base.py b/nipype2pydra/interface/base.py index 8af64df..57696fa 100644 --- a/nipype2pydra/interface/base.py +++ b/nipype2pydra/interface/base.py @@ -24,6 +24,12 @@ types_converter, from_dict_converter, unwrap_nested_type, + get_local_functions, + get_local_constants, + get_return_line, + cleanup_function_body, + insert_args_in_signature, + extract_args, ) from ..statements import ( ImportStatement, @@ -397,6 +403,10 @@ def nipype_output_spec(self) -> nipype.interfaces.base.BaseTraitedSpec: def input_fields(self): return self._convert_input_fields[0] + @property + def input_names(self): + return [f[0] for f in self.input_fields] + @cached_property def input_templates(self): return self._convert_input_fields[1] @@ -440,6 +450,65 @@ def _converted(self): self.input_fields, self.nonstd_types, self.output_fields ) + @property + def referenced_local_functions(self): + return self._referenced_funcs_and_methods[0] + + @property + def referenced_methods(self): + return self._referenced_funcs_and_methods[1] + + @property + def method_args(self): + return self._referenced_funcs_and_methods[2] + + @property + def method_returns(self): + return self._referenced_funcs_and_methods[3] + + @cached_property + def _referenced_funcs_and_methods(self): + referenced_funcs = set() + referenced_methods = set() + method_args = {} + method_returns = {} + already_processed = set( + getattr(self.nipype_interface, m) for m in self.INCLUDED_METHODS + ) + for method_name in self.INCLUDED_METHODS: + if method_name not in self.nipype_interface.__dict__: + continue # Don't include base methods + self._get_referenced( + getattr(self.nipype_interface, method_name), + referenced_funcs, + referenced_methods, + method_args, + method_returns, + already_processed=already_processed, + ) + return referenced_funcs, referenced_methods, method_args, method_returns + + @cached_property + def source_code(self): + with open(inspect.getsourcefile(self.nipype_interface)) as f: + return f.read() + + @cached_property + def methods(self): + """Get the methods defined in the interface""" + methods = [] + for attr_name in dir(self.nipype_interface): + if attr_name.startswith("__"): + continue + attr = getattr(self.nipype_interface, attr_name) + if inspect.isfunction(attr): + methods.append(attr) + return methods + + @cached_property + def local_function_names(self): + return [f.__name__ for f in self.local_functions] + def write( self, package_root: Path, @@ -653,6 +722,8 @@ def function_callables(self): for fun_nm in fun_names: fun = getattr(self.callables_module, fun_nm) fun_str += inspect.getsource(fun) + "\n" + list_outputs = getattr(self.callables_module, "_list_outputs") + fun_str += inspect.getsource(list_outputs) + "\n" return fun_str def pydra_type_converter(self, field, spec_type, name): @@ -904,6 +975,234 @@ def create_doctests(self, input_fields, nonstd_types): return " Examples\n -------\n\n" + doctest_str + def _get_referenced( + self, + method: ty.Callable, + referenced_funcs: ty.Set[ty.Callable], + referenced_methods: ty.Set[ty.Callable] = None, + method_args: ty.Dict[str, ty.List[str]] = None, + method_returns: ty.Dict[str, ty.List[str]] = None, + already_processed: ty.Set[ty.Callable] = None, + ) -> ty.Tuple[ty.Set, ty.Set]: + """Get the local functions referenced in the source code + + Parameters + ---------- + src: str + the source of the file to extract the import statements from + referenced_funcs: set[function] + the set of local functions that have been referenced so far + referenced_methods: set[function] + the set of methods that have been referenced so far + method_args: dict[str, list[str]] + a dictionary to hold additional arguments that need to be added to each method, + where the dictionary key is the names of the methods + method_returns: dict[str, list[str]] + a dictionary to hold the return values of each method, + where the dictionary key is the names of the methods + + Returns + ------- + referenced_inputs: set[str] + inputs that have been referenced + referenced_outputs: set[str] + outputs that have been referenced + """ + if already_processed: + already_processed.add(method) + else: + already_processed = {method} + method_body = inspect.getsource(method) + method_body = re.sub(r"\s*#.*", "", method_body) # Strip out comments + return_value = get_return_line(method_body) + ref_local_func_names = re.findall(r"(? str: + return_value = get_return_line(method_body) + method_body = method_body.replace("if self.output_spec:", "if True:") + # Replace self.inputs. with in the function body + input_re = re.compile(r"self\.inputs\.(\w+)\b(?!\()") + unrecognised_inputs = set( + m for m in input_re.findall(method_body) if m not in input_names + ) + if unrecognised_inputs: + logger.warning( + "Found the following unrecognised (potentially dynamic) inputs %s in " + "'%s' task", + unrecognised_inputs, + self.task_name, + ) + method_body = input_re.sub(r"\1", method_body) + + if return_value: + output_re = re.compile(return_value + r"\[(?:'|\")(\w+)(?:'|\")\]") + unrecognised_outputs = set( + m for m in output_re.findall(method_body) if m not in output_names + ) + if unrecognised_outputs: + logger.warning( + "Found the following unrecognised (potentially dynamic) outputs %s in " + "'%s' task", + unrecognised_outputs, + self.task_name, + ) + method_body = output_re.sub(r"\1", method_body) + # Strip initialisation of outputs + method_body = re.sub( + r"outputs = self.output_spec().*", r"outputs = {}", method_body + ) + return self.unwrap_nested_methods(method_body) + + def unwrap_nested_methods(self, method_body): + """ + Converts nested method calls into function calls + """ + # Add args to the function signature of method calls + method_re = re.compile(r"self\.(\w+)(?=\()", flags=re.MULTILINE | re.DOTALL) + method_names = [m.__name__ for m in self.referenced_methods] + unrecognised_methods = set( + m for m in method_re.findall(method_body) if m not in method_names + ) + assert ( + not unrecognised_methods + ), f"Found the following unrecognised methods {unrecognised_methods}" + splits = method_re.split(method_body) + new_body = splits[0] + for name, args in zip(splits[1::2], splits[2::2]): + # Assign additional return values (which were previously saved to member + # attributes) to new variables from the method call + if self.method_returns[name]: + last_line = new_body.splitlines()[-1] + match = re.match(r" *([a-zA-Z0-9\,\.\_ ]+ *=)? *$", last_line) + if match: + if match.group(1): + new_body_lines = new_body.splitlines() + new_body = "\n".join(new_body_lines[:-1]) + last_line = new_body_lines[-1] + new_body += "\n" + re.sub( + r"^( *)([a-zA-Z0-9\,\.\_ ]+) *= *$", + r"\1\2, " + ",".join(self.method_returns[name]) + " = ", + last_line, + flags=re.MULTILINE, + ) + else: + new_body += ",".join(self.method_returns[name]) + " = " + else: + raise NotImplementedError( + "Could not augment the return value of the method converted from " + "a function with the previously assigned attributes as it is used " + "directly. Need to replace the method call with a variable and " + "assign the return value to it on a previous line" + ) + # Insert additional arguments to the method call (which were previously + # accessed via member attributes) + new_body += name + insert_args_in_signature( + args, [f"{a}={a}" for a in self.method_args[name]] + ) + method_body = new_body + # Convert assignment to self attributes into method-scoped variables (hopefully + # there aren't any name clashes) + method_body = re.sub( + r"self\.(\w+ *)(?==)", r"\1", method_body, flags=re.MULTILINE | re.DOTALL + ) + return cleanup_function_body(method_body) + INPUT_KEYS = [ "allowed_values", "argstr", diff --git a/nipype2pydra/interface/function.py b/nipype2pydra/interface/function.py index 45dd3de..ed72651 100644 --- a/nipype2pydra/interface/function.py +++ b/nipype2pydra/interface/function.py @@ -8,14 +8,7 @@ import attrs from nipype.interfaces.base import BaseInterface, TraitedSpec from .base import BaseInterfaceConverter -from ..utils import ( - extract_args, - UsedSymbols, - get_local_functions, - get_local_constants, - cleanup_function_body, - insert_args_in_signature, -) +from ..utils import UsedSymbols, get_return_line logger = logging.getLogger("nipype2pydra") @@ -24,6 +17,8 @@ @attrs.define(slots=False) class FunctionInterfaceConverter(BaseInterfaceConverter): + INCLUDED_METHODS = ("_run_interface", "_list_outputs") + def generate_code(self, input_fields, nonstd_types, output_fields) -> ty.Tuple[ str, UsedSymbols, @@ -150,278 +145,17 @@ def types_to_names(spec_fields): return spec_str, used - def process_method( - self, - method: str, - input_names: ty.List[str], - output_names: ty.List[str], - method_args: ty.Dict[str, ty.List[str]] = None, - method_returns: ty.Dict[str, ty.List[str]] = None, - ): - src = inspect.getsource(method) - pre, args, post = extract_args(src) - args.remove("self") - if "runtime" in args: - args.remove("runtime") - if method.__name__ in self.method_args: - args += [f"{a}=None" for a in self.method_args[method.__name__]] - # Insert method args in signature if present - return_types, method_body = post.split(":", maxsplit=1) - method_body = method_body.split("\n", maxsplit=1)[1] - method_body = self.process_method_body(method_body, input_names, output_names) - if self.method_returns.get(method.__name__): - return_args = self.method_returns[method.__name__] - method_body = ( - " " + " = ".join(return_args) + " = attrs.NOTHING\n" + method_body - ) - method_lines = method_body.rstrip().splitlines() - method_body = "\n".join(method_lines[:-1]) - last_line = method_lines[-1] - if "return" in last_line: - method_body += "\n" + last_line + "," + ",".join(return_args) - else: - method_body += ( - "\n" + last_line + "\n return " + ",".join(return_args) - ) - return f"{pre.strip()}{', '.join(args)}{return_types}:\n{method_body}" - - def process_method_body( - self, method_body: str, input_names: ty.List[str], output_names: ty.List[str] - ) -> str: - # Replace self.inputs. with in the function body - input_re = re.compile(r"self\.inputs\.(\w+)\b(?!\()") - unrecognised_inputs = set( - m for m in input_re.findall(method_body) if m not in input_names - ) - if unrecognised_inputs: - logger.warning( - "Found the following unrecognised (potentially dynamic) inputs %s in " - "'%s' task", - unrecognised_inputs, - self.task_name, - ) - method_body = input_re.sub(r"\1", method_body) - - output_re = re.compile(self.return_value + r"\[(?:'|\")(\w+)(?:'|\")\]") - unrecognised_outputs = set( - m for m in output_re.findall(method_body) if m not in output_names - ) - if unrecognised_outputs: - logger.warning( - "Found the following unrecognised (potentially dynamic) outputs %s in " - "'%s' task", - unrecognised_outputs, - self.task_name, - ) - method_body = output_re.sub(r"\1", method_body) - # Strip initialisation of outputs - method_body = re.sub( - r"outputs = self.output_spec().*", r"outputs = {}", method_body - ) - # Add args to the function signature of method calls - method_re = re.compile(r"self\.(\w+)(?=\()", flags=re.MULTILINE | re.DOTALL) - method_names = [m.__name__ for m in self.referenced_methods] - unrecognised_methods = set( - m for m in method_re.findall(method_body) if m not in method_names - ) - assert ( - not unrecognised_methods - ), f"Found the following unrecognised methods {unrecognised_methods}" - splits = method_re.split(method_body) - new_body = splits[0] - for name, args in zip(splits[1::2], splits[2::2]): - # Assign additional return values (which were previously saved to member - # attributes) to new variables from the method call - if self.method_returns[name]: - last_line = new_body.splitlines()[-1] - match = re.match(r" *([a-zA-Z0-9\,\.\_ ]+ *=)? *$", last_line) - if match: - if match.group(1): - new_body_lines = new_body.splitlines() - new_body = "\n".join(new_body_lines[:-1]) - last_line = new_body_lines[-1] - new_body += "\n" + re.sub( - r"^( *)([a-zA-Z0-9\,\.\_ ]+) *= *$", - r"\1\2, " + ",".join(self.method_returns[name]) + " = ", - last_line, - flags=re.MULTILINE, - ) - else: - new_body += ",".join(self.method_returns[name]) + " = " - else: - raise NotImplementedError( - "Could not augment the return value of the method converted from " - "a function with the previously assigned attributes as it is used " - "directly. Need to replace the method call with a variable and " - "assign the return value to it on a previous line" - ) - # Insert additional arguments to the method call (which were previously - # accessed via member attributes) - new_body += name + insert_args_in_signature( - args, [f"{a}={a}" for a in self.method_args[name]] - ) - method_body = new_body - # Convert assignment to self attributes into method-scoped variables (hopefully - # there aren't any name clashes) - method_body = re.sub( - r"self\.(\w+ *)(?==)", r"\1", method_body, flags=re.MULTILINE | re.DOTALL - ) - return cleanup_function_body(method_body) - - @property - def referenced_local_functions(self): - return self._referenced_funcs_and_methods[0] - - @property - def referenced_methods(self): - return self._referenced_funcs_and_methods[1] - - @property - def method_args(self): - return self._referenced_funcs_and_methods[2] - - @property - def method_returns(self): - return self._referenced_funcs_and_methods[3] - - @cached_property - def _referenced_funcs_and_methods(self): - referenced_funcs = set() - referenced_methods = set() - method_args = {} - method_returns = {} - self._get_referenced( - self.nipype_interface._run_interface, - referenced_funcs, - referenced_methods, - method_args, - method_returns, - ) - self._get_referenced( - self.nipype_interface._list_outputs, - referenced_funcs, - referenced_methods, - method_args, - method_returns, - ) - return referenced_funcs, referenced_methods, method_args, method_returns - def replace_attributes(self, function_body: ty.Callable) -> str: """Replace self.inputs. with in the function body and add args to the function signature""" function_body = re.sub(r"self\.inputs\.(\w+)", r"\1", function_body) - def _get_referenced( - self, - method: ty.Callable, - referenced_funcs: ty.Set[ty.Callable], - referenced_methods: ty.Set[ty.Callable] = None, - method_args: ty.Dict[str, ty.List[str]] = None, - method_returns: ty.Dict[str, ty.List[str]] = None, - ): - """Get the local functions referenced in the source code - - Parameters - ---------- - src: str - the source of the file to extract the import statements from - referenced_funcs: set[function] - the set of local functions that have been referenced so far - referenced_methods: set[function] - the set of methods that have been referenced so far - method_args: dict[str, list[str]] - a dictionary to hold additional arguments that need to be added to each method, - where the dictionary key is the names of the methods - method_returns: dict[str, list[str]] - a dictionary to hold the return values of each method, - where the dictionary key is the names of the methods - """ - method_body = inspect.getsource(method) - method_body = re.sub(r"\s*#.*", "", method_body) # Strip out comments - ref_local_func_names = re.findall(r"(? ty.Tuple[ str, UsedSymbols, @@ -80,10 +92,16 @@ def types_to_names(spec_fields): spec_fields_str.append(tuple(el)) return spec_fields_str - input_fields_str = types_to_names(spec_fields=input_fields) - output_fields_str = types_to_names(spec_fields=output_fields) + input_names = [i[0] for i in input_fields] + output_names = [o[0] for o in output_fields] + input_fields_str = str(types_to_names(spec_fields=input_fields)) + input_fields_str = re.sub( + r"'formatter': '(\w+)'", r"'formatter': \1", input_fields_str + ) + output_fields_str = str(types_to_names(spec_fields=output_fields)) functions_str = self.function_callables() spec_str = functions_str + spec_str += self.format_arg_code + self.parse_inputs_code spec_str += f"input_fields = {input_fields_str}\n" spec_str += f"{self.task_name}_input_spec = specs.SpecInfo(name='Input', fields=input_fields, bases=(specs.ShellSpec,))\n\n" spec_str += f"output_fields = {output_fields_str}\n" @@ -101,6 +119,9 @@ def types_to_names(spec_fields): spec_str = re.sub(r"'#([^'#]+)#'", r"\1", spec_str) + for m in sorted(self.referenced_methods, key=attrgetter("__name__")): + spec_str += "\n\n" + self.process_method(m, input_names, output_names) + imports = self.construct_imports( nonstd_types, spec_str, @@ -109,6 +130,148 @@ def types_to_names(spec_fields): ) # spec_str = "\n".join(str(i) for i in imports) + "\n\n" + spec_str - return spec_str, UsedSymbols( - module_name=self.nipype_module.__name__, imports=imports + used = UsedSymbols.find( + self.nipype_module, + [self.format_arg_code, self.parse_inputs_code, self.function_callables()], + omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec], + omit_modules=self.package.omit_modules, + omit_functions=self.package.omit_functions, + omit_constants=self.package.omit_constants, + always_include=self.package.all_explicit, + translations=self.package.all_import_translations, + absolute_imports=True, + ) + used.imports.update(imports) + + return spec_str, used + + @cached_property + def _convert_input_fields(self): + pydra_fields_l, has_template = super()._convert_input_fields + for field in pydra_fields_l: + if field[0] in self.formatted_input_fields: + field[-1]["formatter"] = f"{field[0]}_formatter" + self._format_argstrs[field[0]] = field[-1].pop("argstr") + return pydra_fields_l, has_template + + @property + def formatted_input_fields(self): + return re.findall(r"name == \"(\w+)\"", self._format_arg_body) + + @cached_property + def _format_arg_body(self): + if "_format_arg" not in self.nipype_interface.__dict__: + return "" + return _strip_doc_string( + inspect.getsource(self.nipype_interface._format_arg).split("\n", 1)[-1] + ) + + @property + def format_arg_code(self): + if not self._format_arg_body: + return "" + body = self._format_arg_body + # Replace self.inputs. with in the function body + input_re = re.compile(r"self\.inputs\.(\w+)\b(?!\()") + unrecognised_inputs = set( + m for m in input_re.findall(body) if m not in self.input_names + ) + if unrecognised_inputs: + logger.warning( + "Found the following unrecognised (potentially dynamic) inputs %s in " + "'%s' task", + unrecognised_inputs, + self.task_name, + ) + + existing_args = list( + inspect.signature(self.nipype_interface._format_arg).parameters + )[1:] + name_arg, _, val_arg = existing_args + body = input_re.sub(r"inputs.\1", body) + body = re.sub(r"self\.(?!inputs)(\w+)", r"parsed_inputs['\1']", body) + body = re.sub( + r"trait_spec\.argstr % (.*)", + r"argstr.format(**{" + name_arg + r": \1})", + body, ) + + # Strip out return value + body = re.sub( + ( + r"\s*return super\((\w+,\s*self)?\)\._format_arg\(" + + ", ".join(existing_args) + + r"\)\n" + ), + "", + body, + ) + if not body: + return "" + body = self.unwrap_nested_methods(body) + + code_str = f"""def _format_arg({name_arg}, {val_arg}, inputs, parsed_inputs, argstr): +{body} + raise ValueError(f"Unrecognised field {{{name_arg}}}") + + +""" + for field_name in self.formatted_input_fields: + code_str += f"def {field_name}_formatter(field, inputs):\n" + if self.parse_inputs_code: + code_str += " parsed_inputs = _parse_inputs(inputs)\n" + else: + code_str += " parsed_inputs = {}\n" + + code_str += ( + f" return _format_arg({field_name!r}, field, inputs, " + f"parsed_inputs, argstr={self._format_argstrs[field_name]!r})\n\n\n" + ) + return code_str + + @property + def parse_inputs_code(self) -> str: + if "_parse_inputs" not in self.nipype_interface.__dict__: + return "" + body = _strip_doc_string( + inspect.getsource(self.nipype_interface._parse_inputs).split("\n", 1)[-1] + ) + # Replace self.inputs. with in the function body + input_re = re.compile(r"self\.inputs\.(\w+)\b(?!\()") + unrecognised_inputs = set( + m for m in input_re.findall(body) if m not in self.input_names + ) + if unrecognised_inputs: + logger.warning( + "Found the following unrecognised (potentially dynamic) inputs %s in " + "'%s' task", + unrecognised_inputs, + self.task_name, + ) + body = re.sub(r"self\.inputs", r"inputs", body) + body = re.sub(r"self\.(\w+)\b(?!\()", r"parsed_inputs['\1']", body) + body = re.sub( + r"self.\_format_arg\((\w+), (\w+), (\w+)\)", + r"_format_arg(\1, \3, inputs, parsed_inputs, argstrs.get(\1))", + body, + ) + + # Strip out return value + body = re.sub(r"\s*return .*\n", "", body) + body = self.unwrap_nested_methods(body) + + return f"""def _parse_inputs(inputs): + parsed_inputs = {{}} + argstrs = {self._format_argstrs!r} + skip = [] +{body} + return parsed_inputs + + +""" + + +def _strip_doc_string(body: str) -> str: + if re.match(r"\s*(\"|')", body): + body = "\n".join(split_source_into_statements(body)[1:]) + return body diff --git a/nipype2pydra/utils/__init__.py b/nipype2pydra/utils/__init__.py index 58d7df7..51bb486 100644 --- a/nipype2pydra/utils/__init__.py +++ b/nipype2pydra/utils/__init__.py @@ -19,6 +19,7 @@ str_to_type, types_converter, unwrap_nested_type, + get_return_line, INBUILT_NIPYPE_TRAIT_NAMES, ) from .symbols import ( # noqa: F401 diff --git a/nipype2pydra/utils/misc.py b/nipype2pydra/utils/misc.py index 56254d3..08525f6 100644 --- a/nipype2pydra/utils/misc.py +++ b/nipype2pydra/utils/misc.py @@ -528,3 +528,13 @@ def unwrap_nested_type(t: type) -> ty.List[type]: unwrapped.extend(unwrap_nested_type(c)) return unwrapped return [t] + + +def get_return_line(func: ty.Union[str, ty.Callable]) -> str: + if not isinstance(func, str): + func = inspect.getsource(func) + return_line = func.strip().split("\n")[-1] + match = re.match(r"\s*return(.*)", return_line) + if not match: + return None + return match.group(1).strip() From 07700f2db95fe28f866b029df6de0707d1e2158a Mon Sep 17 00:00:00 2001 From: Tom Close Date: Sat, 1 Jun 2024 01:30:49 +0930 Subject: [PATCH 05/21] debugging reworked _list_outputs handling --- nipype2pydra/interface/base.py | 197 +++++++++++++++++--- nipype2pydra/interface/function.py | 3 + nipype2pydra/interface/shell_command.py | 232 +++++++++++++++++++----- nipype2pydra/pkg_gen/__init__.py | 16 +- nipype2pydra/utils/misc.py | 3 +- 5 files changed, 379 insertions(+), 72 deletions(-) diff --git a/nipype2pydra/interface/base.py b/nipype2pydra/interface/base.py index 57696fa..e439771 100644 --- a/nipype2pydra/interface/base.py +++ b/nipype2pydra/interface/base.py @@ -5,6 +5,7 @@ from abc import ABCMeta, abstractmethod from importlib import import_module from types import ModuleType +from collections import defaultdict import itertools import inspect import traits.trait_types @@ -13,7 +14,7 @@ import attrs from attrs.converters import default_if_none import nipype.interfaces.base -from nipype.interfaces.base import traits_extension +from nipype.interfaces.base import traits_extension, CommandLine from pydra.engine import specs from pydra.engine.helpers import ensure_list from ..utils import ( @@ -459,34 +460,66 @@ def referenced_methods(self): return self._referenced_funcs_and_methods[1] @property - def method_args(self): + def referenced_supers(self): return self._referenced_funcs_and_methods[2] @property - def method_returns(self): + def method_args(self): return self._referenced_funcs_and_methods[3] + @property + def method_returns(self): + return self._referenced_funcs_and_methods[4] + + @property + def method_stacks(self): + return self._referenced_funcs_and_methods[5] + + @property + def method_supers(self): + return self._referenced_funcs_and_methods[6] + @cached_property def _referenced_funcs_and_methods(self): referenced_funcs = set() referenced_methods = set() + referenced_supers = {} method_args = {} method_returns = {} + method_stacks = {} + method_supers = defaultdict(dict) already_processed = set( getattr(self.nipype_interface, m) for m in self.INCLUDED_METHODS ) + for method_name in self.INCLUDED_METHODS: + method_args[method_name] = [] + method_returns[method_name] = [] + method_stacks[method_name] = () for method_name in self.INCLUDED_METHODS: if method_name not in self.nipype_interface.__dict__: continue # Don't include base methods + method = getattr(self.nipype_interface, method_name) + referenced_methods.add(method) self._get_referenced( - getattr(self.nipype_interface, method_name), - referenced_funcs, - referenced_methods, - method_args, - method_returns, + method, + referenced_funcs=referenced_funcs, + referenced_methods=referenced_methods, + referenced_supers=referenced_supers, + method_args=method_args, + method_returns=method_returns, + method_stacks=method_stacks, + method_supers=method_supers, already_processed=already_processed, ) - return referenced_funcs, referenced_methods, method_args, method_returns + return ( + referenced_funcs, + referenced_methods, + referenced_supers, + method_args, + method_returns, + method_stacks, + method_supers, + ) @cached_property def source_code(self): @@ -717,13 +750,14 @@ def function_callables(self): "callables module must be provided if output_callables are set in the spec file" ) fun_str = "" - fun_names = list(set(self.outputs.callables.values())) - fun_names.sort() - for fun_nm in fun_names: - fun = getattr(self.callables_module, fun_nm) - fun_str += inspect.getsource(fun) + "\n" - list_outputs = getattr(self.callables_module, "_list_outputs") - fun_str += inspect.getsource(list_outputs) + "\n" + if list(set(self.outputs.callables.values())): + fun_str = inspect.getsource(self.callables_module) + # fun_names.sort() + # for fun_nm in fun_names: + # fun = getattr(self.callables_module, fun_nm) + # fun_str += inspect.getsource(fun) + "\n" + # list_outputs = getattr(self.callables_module, "_list_outputs") + # fun_str += inspect.getsource(list_outputs) + "\n" return fun_str def pydra_type_converter(self, field, spec_type, name): @@ -975,14 +1009,33 @@ def create_doctests(self, input_fields, nonstd_types): return " Examples\n -------\n\n" + doctest_str + def _misc_cleanups(self, body: str) -> str: + if hasattr(self.nipype_interface, "_cmd"): + body = body.replace("self.cmd", f'"{self.nipype_interface._cmd}"') + + body = re.sub( + r"outputs = self\.(output_spec|_outputs)\(\).*$", + r"outputs = {}", + body, + flags=re.MULTILINE, + ) + body = re.sub(r"\w+runtime\.(stdout|stderr)", r"\1", body) + body = body.replace("os.getcwd()", "output_dir") + return body + def _get_referenced( self, method: ty.Callable, referenced_funcs: ty.Set[ty.Callable], - referenced_methods: ty.Set[ty.Callable] = None, + referenced_methods: ty.Set[ty.Callable], + referenced_supers: ty.Dict[str, ty.Tuple[ty.Callable, type]], method_args: ty.Dict[str, ty.List[str]] = None, method_returns: ty.Dict[str, ty.List[str]] = None, + method_stacks: ty.Dict[str, ty.Tuple[ty.Callable]] = None, + method_supers: ty.Dict[type, ty.Dict[str, str]] = None, already_processed: ty.Set[ty.Callable] = None, + method_stack: ty.Optional[ty.Tuple[ty.Callable]] = None, + super_base: ty.Optional[type] = None, ) -> ty.Tuple[ty.Set, ty.Set]: """Get the local functions referenced in the source code @@ -1012,6 +1065,12 @@ def _get_referenced( already_processed.add(method) else: already_processed = {method} + if method_stack is None: + method_stack = (method,) + else: + method_stack += (method,) + if super_base is None: + super_base = self.nipype_interface method_body = inspect.getsource(method) method_body = re.sub(r"\s*#.*", "", method_body) # Strip out comments return_value = get_return_line(method_body) @@ -1034,6 +1093,39 @@ def _get_referenced( referenced_outputs.update( re.findall(return_value + r"\[(?:'|\")(\w+)(?:'|\")\] *=", method_body) ) + for match in re.findall(r"super\([^\)]*\)\.(\w+)\(", method_body): + super_method = None + for base in self.nipype_interface.__mro__[1:]: + if match in base.__dict__: # Found the match + super_method = getattr(base, match) + break + assert super_method is not None, ( + f"Could not find super of '{match}' method in base classes of " + f"{self.nipype_interface}" + ) + func_name = self._common_parent_pkg_prefix(base) + match + if func_name not in referenced_supers: + referenced_supers[func_name] = (super_method, base) + method_supers[super_base][match] = func_name + method_stacks[func_name] = method_stack + rf_inputs, rf_outputs = self._get_referenced( + super_method, + referenced_funcs, + referenced_methods, + referenced_supers=referenced_supers, + method_args=method_args, + method_returns=method_returns, + method_stacks=method_stacks, + method_supers=method_supers, + already_processed=already_processed, + method_stack=method_stack, + super_base=base, + ) + referenced_inputs.update(rf_inputs) + referenced_outputs.update(rf_outputs) + method_args[func_name] = rf_inputs + method_returns[func_name] = rf_outputs + method_stacks[func_name] = method_stack for func in ref_local_funcs: if func in already_processed: continue @@ -1041,7 +1133,12 @@ def _get_referenced( func, referenced_funcs, referenced_methods, + referenced_supers=referenced_supers, + method_stacks=method_stacks, + method_supers=method_supers, already_processed=already_processed, + method_stack=method_stack, + super_base=super_base, ) referenced_inputs.update(rf_inputs) referenced_outputs.update(rf_outputs) @@ -1052,16 +1149,36 @@ def _get_referenced( meth, referenced_funcs, referenced_methods, + referenced_supers=referenced_supers, method_args=method_args, method_returns=method_returns, + method_stacks=method_stacks, + method_supers=method_supers, already_processed=already_processed, + method_stack=method_stack, + super_base=super_base, ) method_args[meth.__name__] = ref_inputs method_returns[meth.__name__] = ref_outputs + method_stacks[meth.__name__] = method_stack referenced_inputs.update(ref_inputs) referenced_outputs.update(ref_outputs) return referenced_inputs, sorted(referenced_outputs) + def _common_parent_pkg_prefix(self, base: type) -> str: + """Return the common part of two package names""" + ref_parts = self.nipype_interface.__module__.split(".") + mod_parts = base.__module__.split(".") + common = [] + for r_part, m_part in zip(ref_parts, mod_parts): + if r_part == m_part: + common.append(r_part) + else: + break + if not common: + return "" + return "_".join(common + [base.__name__]) + "__" + @cached_property def local_functions(self): """Get the functions defined in the same file as the interface""" @@ -1078,7 +1195,12 @@ def process_method( output_names: ty.List[str], method_args: ty.Dict[str, ty.List[str]] = None, method_returns: ty.Dict[str, ty.List[str]] = None, + additional_args: ty.Sequence[str] = (), + new_name: ty.Optional[str] = None, + super_base: ty.Optional[type] = None, ): + if super_base is None: + super_base = self.nipype_interface src = inspect.getsource(method) pre, args, post = extract_args(src) try: @@ -1088,11 +1210,16 @@ def process_method( if "runtime" in args: args.remove("runtime") if method.__name__ in self.method_args: - args += [f"{a}=None" for a in self.method_args[method.__name__]] + args += [ + f"{a}=None" + for a in (list(self.method_args[method.__name__]) + additional_args) + ] # Insert method args in signature if present return_types, method_body = post.split(":", maxsplit=1) method_body = method_body.split("\n", maxsplit=1)[1] - method_body = self.process_method_body(method_body, input_names, output_names) + method_body = self.process_method_body( + method_body, input_names, output_names, super_base + ) if self.method_returns.get(method.__name__): return_args = self.method_returns[method.__name__] method_body = ( @@ -1109,11 +1236,19 @@ def process_method( ) pre = re.sub(r"^\s*", "", pre, flags=re.MULTILINE) pre = pre.replace("@staticmethod\n", "") + if new_name: + pre = re.sub(r"^def (\w+)\(", f"def {new_name}(", pre, flags=re.MULTILINE) return f"{pre}{', '.join(args)}{return_types}:\n{method_body}" def process_method_body( - self, method_body: str, input_names: ty.List[str], output_names: ty.List[str] + self, + method_body: str, + input_names: ty.List[str], + output_names: ty.List[str], + super_base: ty.Optional[type] = None, ) -> str: + if super_base is None: + super_base = self.nipype_interface return_value = get_return_line(method_body) method_body = method_body.replace("if self.output_spec:", "if True:") # Replace self.inputs. with in the function body @@ -1129,6 +1264,7 @@ def process_method_body( self.task_name, ) method_body = input_re.sub(r"\1", method_body) + method_body = self.replace_supers(method_body, super_base) if return_value: output_re = re.compile(return_value + r"\[(?:'|\")(\w+)(?:'|\")\]") @@ -1147,9 +1283,20 @@ def process_method_body( method_body = re.sub( r"outputs = self.output_spec().*", r"outputs = {}", method_body ) + method_body = self._misc_cleanups(method_body) return self.unwrap_nested_methods(method_body) - def unwrap_nested_methods(self, method_body): + def replace_supers(self, method_body, super_base=None): + if super_base is None: + super_base = self.nipype_interface + super_name_map = self.method_supers[super_base] + return re.sub( + r"super\([^\)]*\)\.(\w+)\(", + lambda m: super_name_map[m.group(1)] + "(", + method_body, + ) + + def unwrap_nested_methods(self, method_body, additional_args=()): """ Converts nested method calls into function calls """ @@ -1193,7 +1340,11 @@ def unwrap_nested_methods(self, method_body): # Insert additional arguments to the method call (which were previously # accessed via member attributes) new_body += name + insert_args_in_signature( - args, [f"{a}={a}" for a in self.method_args[name]] + args, + [ + f"{a}={a}" + for a in (list(self.method_args[name]) + list(additional_args)) + ], ) method_body = new_body # Convert assignment to self attributes into method-scoped variables (hopefully @@ -1203,6 +1354,8 @@ def unwrap_nested_methods(self, method_body): ) return cleanup_function_body(method_body) + SUPER_MAPPINGS = {CommandLine: {"_list_outputs": "{}"}} + INPUT_KEYS = [ "allowed_values", "argstr", diff --git a/nipype2pydra/interface/function.py b/nipype2pydra/interface/function.py index ed72651..60cadb4 100644 --- a/nipype2pydra/interface/function.py +++ b/nipype2pydra/interface/function.py @@ -76,6 +76,9 @@ def types_to_names(spec_fields): lo_src = "\n".join(lo_lines) method_body += "\n" + lo_src method_body = self.process_method_body(method_body, input_names, output_names) + method_body = re.sub( + r"self\._results\[(?:'|\")(\w+)(?:'|\")\]", r"\1", method_body + ) used = UsedSymbols.find( self.nipype_module, diff --git a/nipype2pydra/interface/shell_command.py b/nipype2pydra/interface/shell_command.py index e74667e..a197742 100644 --- a/nipype2pydra/interface/shell_command.py +++ b/nipype2pydra/interface/shell_command.py @@ -5,23 +5,34 @@ import logging from functools import cached_property from copy import copy -from operator import attrgetter +from operator import attrgetter, itemgetter from nipype.interfaces.base import BaseInterface, TraitedSpec from .base import BaseInterfaceConverter -from ..utils import UsedSymbols, split_source_into_statements +from ..utils import ( + UsedSymbols, + split_source_into_statements, + INBUILT_NIPYPE_TRAIT_NAMES, +) from fileformats.core.mixin import WithClassifiers from fileformats.generic import File, Directory logger = logging.getLogger("nipype2pydra") +CALLABLES_ARGS = ["inputs", "stdout", "stderr", "output_dir"] + @attrs.define(slots=False) class ShellCommandInterfaceConverter(BaseInterfaceConverter): _format_argstrs: ty.Dict[str, str] = attrs.field(factory=dict) - INCLUDED_METHODS = ("_parse_inputs", "_format_arg", "_list_outputs") + INCLUDED_METHODS = ( + "_parse_inputs", + "_format_arg", + "_list_outputs", + "_gen_filename", + ) def generate_code(self, input_fields, nonstd_types, output_fields) -> ty.Tuple[ str, @@ -99,9 +110,15 @@ def types_to_names(spec_fields): r"'formatter': '(\w+)'", r"'formatter': \1", input_fields_str ) output_fields_str = str(types_to_names(spec_fields=output_fields)) - functions_str = self.function_callables() - spec_str = functions_str - spec_str += self.format_arg_code + self.parse_inputs_code + # functions_str = self.function_callables() + # functions_imports, functions_str = functions_str.split("\n\n", 1) + # spec_str = functions_str + spec_str = ( + self.format_arg_code + + self.parse_inputs_code + + self.callables_code + + self.defaults_code + ) spec_str += f"input_fields = {input_fields_str}\n" spec_str += f"{self.task_name}_input_spec = specs.SpecInfo(name='Input', fields=input_fields, bases=(specs.ShellSpec,))\n\n" spec_str += f"output_fields = {output_fields_str}\n" @@ -120,7 +137,30 @@ def types_to_names(spec_fields): spec_str = re.sub(r"'#([^'#]+)#'", r"\1", spec_str) for m in sorted(self.referenced_methods, key=attrgetter("__name__")): - spec_str += "\n\n" + self.process_method(m, input_names, output_names) + if m.__name__ in self.INCLUDED_METHODS: + continue + if self.method_stacks[m.__name__][0] == self.nipype_interface._list_outputs: + additional_args = CALLABLES_ARGS + else: + additional_args = [] + spec_str += "\n\n" + self.process_method( + m, input_names, output_names, additional_args=additional_args + ) + + for new_name, (m, _) in sorted( + self.referenced_supers.items(), key=itemgetter(0) + ): + if self.method_stacks[new_name][0] == self.nipype_interface._list_outputs: + additional_args = CALLABLES_ARGS + else: + additional_args = [] + spec_str += "\n\n" + self.process_method( + m, + input_names, + output_names, + additional_args=additional_args, + new_name=new_name, + ) imports = self.construct_imports( nonstd_types, @@ -132,7 +172,12 @@ def types_to_names(spec_fields): used = UsedSymbols.find( self.nipype_module, - [self.format_arg_code, self.parse_inputs_code, self.function_callables()], + [ + self.format_arg_code, + self.parse_inputs_code, + self.callables_code, + self.defaults_code, + ], omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec], omit_modules=self.package.omit_modules, omit_functions=self.package.omit_functions, @@ -141,6 +186,20 @@ def types_to_names(spec_fields): translations=self.package.all_import_translations, absolute_imports=True, ) + for super_method, base in self.referenced_supers.values(): + super_used = UsedSymbols.find( + base, + [super_method], + omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec], + omit_modules=self.package.omit_modules, + omit_functions=self.package.omit_functions, + omit_constants=self.package.omit_constants, + always_include=self.package.all_explicit, + translations=self.package.all_import_translations, + absolute_imports=True, + ) + used.update(super_used) + used.imports.update(imports) return spec_str, used @@ -158,6 +217,21 @@ def _convert_input_fields(self): def formatted_input_fields(self): return re.findall(r"name == \"(\w+)\"", self._format_arg_body) + @property + def callable_default_input_fields(self): + return re.findall(r"name == \"(\w+)\"", self._gen_filename_body) + + @property + def callable_output_fields(self): + return [ + f + for f in self.output_fields + if ( + "output_file_template" not in f[-1] + and f[0] not in INBUILT_NIPYPE_TRAIT_NAMES + ) + ] + @cached_property def _format_arg_body(self): if "_format_arg" not in self.nipype_interface.__dict__: @@ -166,30 +240,25 @@ def _format_arg_body(self): inspect.getsource(self.nipype_interface._format_arg).split("\n", 1)[-1] ) + @cached_property + def _gen_filename_body(self): + if "_gen_filename" not in self.nipype_interface.__dict__: + return "" + return _strip_doc_string( + inspect.getsource(self.nipype_interface._gen_filename).split("\n", 1)[-1] + ) + @property def format_arg_code(self): if not self._format_arg_body: return "" body = self._format_arg_body - # Replace self.inputs. with in the function body - input_re = re.compile(r"self\.inputs\.(\w+)\b(?!\()") - unrecognised_inputs = set( - m for m in input_re.findall(body) if m not in self.input_names - ) - if unrecognised_inputs: - logger.warning( - "Found the following unrecognised (potentially dynamic) inputs %s in " - "'%s' task", - unrecognised_inputs, - self.task_name, - ) - + body = self._process_inputs(body) + body = self._misc_cleanups(body) existing_args = list( inspect.signature(self.nipype_interface._format_arg).parameters )[1:] name_arg, _, val_arg = existing_args - body = input_re.sub(r"inputs.\1", body) - body = re.sub(r"self\.(?!inputs)(\w+)", r"parsed_inputs['\1']", body) body = re.sub( r"trait_spec\.argstr % (.*)", r"argstr.format(**{" + name_arg + r": \1})", @@ -209,23 +278,22 @@ def format_arg_code(self): if not body: return "" body = self.unwrap_nested_methods(body) + body = self.replace_supers(body) - code_str = f"""def _format_arg({name_arg}, {val_arg}, inputs, parsed_inputs, argstr): + code_str = f"""def _format_arg({name_arg}, {val_arg}, inputs, argstr): + parsed_inputs = _parse_inputs(inputs) if inputs else {{}} + if {val_arg} is None: + return "" {body} - raise ValueError(f"Unrecognised field {{{name_arg}}}") + return argstr.format(**inputs) """ for field_name in self.formatted_input_fields: - code_str += f"def {field_name}_formatter(field, inputs):\n" - if self.parse_inputs_code: - code_str += " parsed_inputs = _parse_inputs(inputs)\n" - else: - code_str += " parsed_inputs = {}\n" - code_str += ( + f"def {field_name}_formatter(field, inputs):\n" f" return _format_arg({field_name!r}, field, inputs, " - f"parsed_inputs, argstr={self._format_argstrs[field_name]!r})\n\n\n" + f"argstr={self._format_argstrs[field_name]!r})\n\n\n" ) return code_str @@ -236,20 +304,8 @@ def parse_inputs_code(self) -> str: body = _strip_doc_string( inspect.getsource(self.nipype_interface._parse_inputs).split("\n", 1)[-1] ) - # Replace self.inputs. with in the function body - input_re = re.compile(r"self\.inputs\.(\w+)\b(?!\()") - unrecognised_inputs = set( - m for m in input_re.findall(body) if m not in self.input_names - ) - if unrecognised_inputs: - logger.warning( - "Found the following unrecognised (potentially dynamic) inputs %s in " - "'%s' task", - unrecognised_inputs, - self.task_name, - ) - body = re.sub(r"self\.inputs", r"inputs", body) - body = re.sub(r"self\.(\w+)\b(?!\()", r"parsed_inputs['\1']", body) + body = self._process_inputs(body) + body = self._misc_cleanups(body) body = re.sub( r"self.\_format_arg\((\w+), (\w+), (\w+)\)", r"_format_arg(\1, \3, inputs, parsed_inputs, argstrs.get(\1))", @@ -259,6 +315,7 @@ def parse_inputs_code(self) -> str: # Strip out return value body = re.sub(r"\s*return .*\n", "", body) body = self.unwrap_nested_methods(body) + body = self.replace_supers(body) return f"""def _parse_inputs(inputs): parsed_inputs = {{}} @@ -270,6 +327,89 @@ def parse_inputs_code(self) -> str: """ + @cached_property + def defaults_code(self): + if not self.callable_default_input_fields: + return "" + + body = _strip_doc_string( + inspect.getsource(self.nipype_interface._gen_filename).split("\n", 1)[-1] + ) + body = self._process_inputs(body) + body = self._misc_cleanups(body) + + if not body: + return "" + body = self.unwrap_nested_methods(body) + body = self.replace_supers(body) + + code_str = f"""def _gen_filename(name, inputs): + parsed_inputs = _parse_inputs(inputs) if inputs else {{}} + {body} +""" + # Create separate default function for each input field with genfile, which + # reference the magic "_gen_filename" method + for inpt_name, inpt in sorted( + self.nipype_interface.input_spec().traits().items() + ): + if inpt.genfile: + code_str += ( + f"\n\n\ndef {inpt_name}_default(inputs):\n" + f' return _gen_filename("{inpt_name}", inputs=inputs)\n\n' + ) + return code_str + + @cached_property + def callables_code(self): + + if not self.callable_output_fields: + return "" + + body = _strip_doc_string( + inspect.getsource(self.nipype_interface._list_outputs).split("\n", 1)[-1] + ) + body = self._process_inputs(body) + body = self._misc_cleanups(body) + + if not body: + return "" + body = self.unwrap_nested_methods( + body, + additional_args=CALLABLES_ARGS, + ) + body = self.replace_supers(body) + + code_str = f"""def _list_outputs(inputs=None, stdout=None, stderr=None, output_dir=None): + parsed_inputs = _parse_inputs(inputs) if inputs else {{}} +{body} +""" + # Create separate function for each output field in the "callables" section + for output_field in self.callable_output_fields: + output_name = output_field[0] + code_str += ( + f"\n\n\ndef {output_name}_callable(output_dir, inputs, stdout, stderr):\n" + " outputs = _list_outputs(output_dir=output_dir, inputs=inputs, stdout=stdout, stderr=stderr)\n" + ' return outputs["' + output_name + '"]\n\n' + ) + return code_str + + def _process_inputs(self, body: str) -> str: + # Replace self.inputs. with in the function body + input_re = re.compile(r"self\.inputs\.(\w+)\b(?!\()") + unrecognised_inputs = set( + m for m in input_re.findall(body) if m not in self.input_names + ) + if unrecognised_inputs: + logger.warning( + "Found the following unrecognised (potentially dynamic) inputs %s in " + "'%s' task", + unrecognised_inputs, + self.task_name, + ) + body = input_re.sub(r"inputs['\1']", body) + body = re.sub(r"self\.(?!inputs)(\w+)", r"parsed_inputs['\1']", body) + return body + def _strip_doc_string(body: str) -> str: if re.match(r"\s*(\"|')", body): diff --git a/nipype2pydra/pkg_gen/__init__.py b/nipype2pydra/pkg_gen/__init__.py index bb12826..4f43640 100644 --- a/nipype2pydra/pkg_gen/__init__.py +++ b/nipype2pydra/pkg_gen/__init__.py @@ -403,7 +403,8 @@ def generate_callables(self, nipype_interface) -> str: if output_name not in INBUILT_NIPYPE_TRAIT_NAMES: callables_str += ( f"def {output_name}_callable(output_dir, inputs, stdout, stderr):\n" - " outputs = _list_outputs(output_dir=output_dir, inputs=inputs, stdout=stdout, stderr=stderr)\n" + " parsed_inputs = {}" + " outputs = _list_outputs(output_dir=output_dir, inputs=inputs, stdout=stdout, stderr=stderr, parsed_inputs=parsed_inputs)\n" ' return outputs["' + output_name + '"]\n\n' ) @@ -922,7 +923,7 @@ def test_generate_sample_{frmt.lower()}_data(): def get_callable_sources( - nipype_interface, + nipype_interface, attrs_as_parsed_inputs: bool = False ) -> ty.Tuple[ty.Set[str], ty.List[str], ty.Set[str], ty.Set[ty.Tuple[str, str]]]: """ Convert the _gen_filename method of a nipype interface into a function that can be @@ -1025,7 +1026,16 @@ def process_method( ) if hasattr(nipype_interface, "_cmd"): body = body.replace("self.cmd", f'"{nipype_interface._cmd}"') - body = body.replace("self.", "") + body = re.sub( + r"getattr\(self\.inputs, (\w+), None\)", r"inputs.get(\1)", body + ) + body = re.sub(r"getattr\(self\.inputs, (\w+)\)", r"inputs[\1]", body) + if attrs_as_parsed_inputs: + body = re.sub( + r"self\.(?!inputs)(\w+)\b(?!\()", r"parsed_inputs['\1']", body + ) + else: + body = body.replace("self.", "") body = re.sub( r"super\([^\)]*\)\.(\w+)\(", lambda m: name_map[m.group(1)] + "(", body ) diff --git a/nipype2pydra/utils/misc.py b/nipype2pydra/utils/misc.py index 08525f6..ed578e0 100644 --- a/nipype2pydra/utils/misc.py +++ b/nipype2pydra/utils/misc.py @@ -317,6 +317,7 @@ def cleanup_function_body(function_body: str) -> str: function_body = re.sub( r"^" + " " * indent_reduction, "", function_body, flags=re.MULTILINE ) + # Other misc replacements # function_body = function_body.replace("LOGGER.", "logger.") return replace_undefined(function_body) @@ -415,7 +416,7 @@ def split_source_into_statements(source_code: str) -> ty.List[str]: else: # Handle dictionary assignments where the first open-closing bracket is # before the assignment, e.g. outputs["out_file"] = [..." - if post and re.match(r"\s*=", post[1:]): + if post and re.match(r"\s*=|.*[\(\[\{\"'].*", post[1:]): try: extract_args(post[1:]) except (UnmatchedParensException, UnmatchedQuoteException): From 985ab46ab7ac247fd78260d4deb3c6c936906d96 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Sat, 1 Jun 2024 09:24:31 +0930 Subject: [PATCH 06/21] handled super methods properly --- nipype2pydra/interface/base.py | 44 ++++++++++++++++--------- nipype2pydra/interface/shell_command.py | 21 ++++++------ 2 files changed, 39 insertions(+), 26 deletions(-) diff --git a/nipype2pydra/interface/base.py b/nipype2pydra/interface/base.py index e439771..7d038eb 100644 --- a/nipype2pydra/interface/base.py +++ b/nipype2pydra/interface/base.py @@ -1094,15 +1094,11 @@ def _get_referenced( re.findall(return_value + r"\[(?:'|\")(\w+)(?:'|\")\] *=", method_body) ) for match in re.findall(r"super\([^\)]*\)\.(\w+)\(", method_body): - super_method = None - for base in self.nipype_interface.__mro__[1:]: - if match in base.__dict__: # Found the match - super_method = getattr(base, match) - break - assert super_method is not None, ( - f"Could not find super of '{match}' method in base classes of " - f"{self.nipype_interface}" - ) + super_method, base = find_super_method(super_base, match) + if any( + base.__module__.startswith(m) for m in UsedSymbols.ALWAYS_OMIT_MODULES + ): + continue func_name = self._common_parent_pkg_prefix(base) + match if func_name not in referenced_supers: referenced_supers[func_name] = (super_method, base) @@ -1289,12 +1285,16 @@ def process_method_body( def replace_supers(self, method_body, super_base=None): if super_base is None: super_base = self.nipype_interface - super_name_map = self.method_supers[super_base] - return re.sub( - r"super\([^\)]*\)\.(\w+)\(", - lambda m: super_name_map[m.group(1)] + "(", - method_body, - ) + name_map = self.method_supers[super_base] + + def replace_super(match): + super_method = find_super_method(super_base, match.group(1))[0] + try: + return self.SPECIAL_SUPER_MAPPINGS[super_method] + except KeyError: + return name_map[match.group(1)] + "(" + match.group(2) + ")" + + return re.sub(r"super\([^\)]*\)\.(\w+)\(([^\)]*)\)", replace_super, method_body) def unwrap_nested_methods(self, method_body, additional_args=()): """ @@ -1354,7 +1354,7 @@ def unwrap_nested_methods(self, method_body, additional_args=()): ) return cleanup_function_body(method_body) - SUPER_MAPPINGS = {CommandLine: {"_list_outputs": "{}"}} + SPECIAL_SUPER_MAPPINGS = {CommandLine._list_outputs: "{}"} INPUT_KEYS = [ "allowed_values", @@ -1407,3 +1407,15 @@ def pytest_configure(config): else: CATCH_CLI_EXCEPTIONS = True """ + + +def find_super_method( + super_base: type, method_name: str +) -> ty.Tuple[ty.Callable, type]: + for base in super_base.__mro__[1:]: + if method_name in base.__dict__: # Found the match + return getattr(base, method_name), base + raise RuntimeError( + f"Could not find super of '{method_name}' method in base classes of " + f"{super_base}" + ) diff --git a/nipype2pydra/interface/shell_command.py b/nipype2pydra/interface/shell_command.py index a197742..4f4cae2 100644 --- a/nipype2pydra/interface/shell_command.py +++ b/nipype2pydra/interface/shell_command.py @@ -6,6 +6,7 @@ from functools import cached_property from copy import copy from operator import attrgetter, itemgetter +from importlib import import_module from nipype.interfaces.base import BaseInterface, TraitedSpec from .base import BaseInterfaceConverter from ..utils import ( @@ -162,14 +163,6 @@ def types_to_names(spec_fields): new_name=new_name, ) - imports = self.construct_imports( - nonstd_types, - spec_str, - include_task=False, - base=base_imports, - ) - # spec_str = "\n".join(str(i) for i in imports) + "\n\n" + spec_str - used = UsedSymbols.find( self.nipype_module, [ @@ -188,7 +181,7 @@ def types_to_names(spec_fields): ) for super_method, base in self.referenced_supers.values(): super_used = UsedSymbols.find( - base, + import_module(base.__module__), [super_method], omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec], omit_modules=self.package.omit_modules, @@ -197,10 +190,18 @@ def types_to_names(spec_fields): always_include=self.package.all_explicit, translations=self.package.all_import_translations, absolute_imports=True, + collapse_intra_pkg=True, ) used.update(super_used) - used.imports.update(imports) + used.imports.update( + self.construct_imports( + nonstd_types, + spec_str, + include_task=False, + base=base_imports, + ) + ) return spec_str, used From 1bcc2ec57f9950f1b88ff89bb5ed6f67665ce5c8 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Sat, 1 Jun 2024 09:53:19 +0930 Subject: [PATCH 07/21] added callables to outputs --- nipype2pydra/interface/base.py | 19 ++++++++++--- nipype2pydra/interface/shell_command.py | 38 ++++++++++++++++--------- 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/nipype2pydra/interface/base.py b/nipype2pydra/interface/base.py index 7d038eb..7aab577 100644 --- a/nipype2pydra/interface/base.py +++ b/nipype2pydra/interface/base.py @@ -414,7 +414,7 @@ def input_templates(self): @cached_property def output_fields(self): - return self.convert_output_spec(fields_from_template=self.input_templates) + return self._convert_output_fields(fields_from_template=self.input_templates) @cached_property def nonstd_types(self): @@ -684,7 +684,7 @@ def pydra_fld_input(self, field, nm): else: return (pydra_type, pydra_metadata), pos - def convert_output_spec(self, fields_from_template): + def _convert_output_fields(self, fields_from_template): """creating fields list for pydra input spec""" pydra_fields_l = [] if not self.nipype_output_spec: @@ -1288,11 +1288,22 @@ def replace_supers(self, method_body, super_base=None): name_map = self.method_supers[super_base] def replace_super(match): - super_method = find_super_method(super_base, match.group(1))[0] + super_method, base = find_super_method(super_base, match.group(1)) try: return self.SPECIAL_SUPER_MAPPINGS[super_method] except KeyError: - return name_map[match.group(1)] + "(" + match.group(2) + ")" + try: + return name_map[match.group(1)] + "(" + match.group(2) + ")" + except KeyError: + if any( + base.__module__.startswith(m) + for m in UsedSymbols.ALWAYS_OMIT_MODULES + ): + raise KeyError( + f"Require special mapping for {match.group(1)} in {base} class " + "as methods in that module are being omitted from the conversion" + ) from None + raise return re.sub(r"super\([^\)]*\)\.(\w+)\(([^\)]*)\)", replace_super, method_body) diff --git a/nipype2pydra/interface/shell_command.py b/nipype2pydra/interface/shell_command.py index 4f4cae2..16c7851 100644 --- a/nipype2pydra/interface/shell_command.py +++ b/nipype2pydra/interface/shell_command.py @@ -111,6 +111,9 @@ def types_to_names(spec_fields): r"'formatter': '(\w+)'", r"'formatter': \1", input_fields_str ) output_fields_str = str(types_to_names(spec_fields=output_fields)) + output_fields_str = re.sub( + r"'callable': '(\w+)'", r"'callable': \1", output_fields_str + ) # functions_str = self.function_callables() # functions_imports, functions_str = functions_str.split("\n\n", 1) # spec_str = functions_str @@ -206,27 +209,34 @@ def types_to_names(spec_fields): return spec_str, used @cached_property - def _convert_input_fields(self): - pydra_fields_l, has_template = super()._convert_input_fields - for field in pydra_fields_l: - if field[0] in self.formatted_input_fields: + def input_fields(self): + input_fields = super().input_fields + for field in input_fields: + if field[0] in self.formatted_input_field_names: field[-1]["formatter"] = f"{field[0]}_formatter" self._format_argstrs[field[0]] = field[-1].pop("argstr") - return pydra_fields_l, has_template + return input_fields + + @cached_property + def output_fields(self): + output_fields = super().output_fields + for field in self.callable_output_fields: + field[-1]["callable"] = f"{field[0]}_callable" + return output_fields @property - def formatted_input_fields(self): + def formatted_input_field_names(self): return re.findall(r"name == \"(\w+)\"", self._format_arg_body) @property - def callable_default_input_fields(self): + def callable_default_input_field_names(self): return re.findall(r"name == \"(\w+)\"", self._gen_filename_body) @property def callable_output_fields(self): return [ f - for f in self.output_fields + for f in super().output_fields if ( "output_file_template" not in f[-1] and f[0] not in INBUILT_NIPYPE_TRAIT_NAMES @@ -290,7 +300,7 @@ def format_arg_code(self): """ - for field_name in self.formatted_input_fields: + for field_name in self.formatted_input_field_names: code_str += ( f"def {field_name}_formatter(field, inputs):\n" f" return _format_arg({field_name!r}, field, inputs, " @@ -318,19 +328,21 @@ def parse_inputs_code(self) -> str: body = self.unwrap_nested_methods(body) body = self.replace_supers(body) - return f"""def _parse_inputs(inputs): - parsed_inputs = {{}} - argstrs = {self._format_argstrs!r} + code_str = "def _parse_inputs(inputs):\n parsed_inputs = {{}}" + if re.findall(r"\bargstrs\b", body): + code_str += f"\n argstrs = {self._format_argstrs!r}" + code_str += f""" skip = [] {body} return parsed_inputs """ + return code_str @cached_property def defaults_code(self): - if not self.callable_default_input_fields: + if not self.callable_default_input_field_names: return "" body = _strip_doc_string( From 68e09c7b23322d0d52bb00ef2b76917c8c3d29eb Mon Sep 17 00:00:00 2001 From: Tom Close Date: Sat, 1 Jun 2024 10:07:22 +0930 Subject: [PATCH 08/21] touching up format_arg super handling --- nipype2pydra/interface/base.py | 7 +++++-- nipype2pydra/interface/shell_command.py | 3 +-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/nipype2pydra/interface/base.py b/nipype2pydra/interface/base.py index 7aab577..700b232 100644 --- a/nipype2pydra/interface/base.py +++ b/nipype2pydra/interface/base.py @@ -1300,7 +1300,7 @@ def replace_super(match): for m in UsedSymbols.ALWAYS_OMIT_MODULES ): raise KeyError( - f"Require special mapping for {match.group(1)} in {base} class " + f"Require special mapping for '{match.group(1)}' in {base} class " "as methods in that module are being omitted from the conversion" ) from None raise @@ -1365,7 +1365,10 @@ def unwrap_nested_methods(self, method_body, additional_args=()): ) return cleanup_function_body(method_body) - SPECIAL_SUPER_MAPPINGS = {CommandLine._list_outputs: "{}"} + SPECIAL_SUPER_MAPPINGS = { + CommandLine._list_outputs: "{}", + CommandLine._format_arg: "argstr.format(**inputs)", + } INPUT_KEYS = [ "allowed_values", diff --git a/nipype2pydra/interface/shell_command.py b/nipype2pydra/interface/shell_command.py index 16c7851..a55f856 100644 --- a/nipype2pydra/interface/shell_command.py +++ b/nipype2pydra/interface/shell_command.py @@ -296,7 +296,6 @@ def format_arg_code(self): if {val_arg} is None: return "" {body} - return argstr.format(**inputs) """ @@ -358,7 +357,7 @@ def defaults_code(self): code_str = f"""def _gen_filename(name, inputs): parsed_inputs = _parse_inputs(inputs) if inputs else {{}} - {body} +{body} """ # Create separate default function for each input field with genfile, which # reference the magic "_gen_filename" method From b550966171dedeff5157c324a30730b64fae9cd8 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Sat, 1 Jun 2024 11:27:52 +0930 Subject: [PATCH 09/21] handling special super methods --- nipype2pydra/interface/base.py | 40 ++++++++++++++------- nipype2pydra/interface/shell_command.py | 46 ++++++++++++++++--------- 2 files changed, 56 insertions(+), 30 deletions(-) diff --git a/nipype2pydra/interface/base.py b/nipype2pydra/interface/base.py index 700b232..116e6b9 100644 --- a/nipype2pydra/interface/base.py +++ b/nipype2pydra/interface/base.py @@ -14,7 +14,7 @@ import attrs from attrs.converters import default_if_none import nipype.interfaces.base -from nipype.interfaces.base import traits_extension, CommandLine +from nipype.interfaces.base import traits_extension, CommandLine, BaseInterface from pydra.engine import specs from pydra.engine.helpers import ensure_list from ..utils import ( @@ -1019,7 +1019,8 @@ def _misc_cleanups(self, body: str) -> str: body, flags=re.MULTILINE, ) - body = re.sub(r"\w+runtime\.(stdout|stderr)", r"\1", body) + body = re.sub(r"\bruntime\.(stdout|stderr)", r"\1", body) + body = re.sub(r"\boutputs\.(\w+)", r"outputs['\1']", body) body = body.replace("os.getcwd()", "output_dir") return body @@ -1307,22 +1308,30 @@ def replace_super(match): return re.sub(r"super\([^\)]*\)\.(\w+)\(([^\)]*)\)", replace_super, method_body) - def unwrap_nested_methods(self, method_body, additional_args=()): + def unwrap_nested_methods( + self, method_body, additional_args=(), inputs_as_dict: bool = False + ): """ Converts nested method calls into function calls """ # Add args to the function signature of method calls method_re = re.compile(r"self\.(\w+)(?=\()", flags=re.MULTILINE | re.DOTALL) method_names = [m.__name__ for m in self.referenced_methods] - unrecognised_methods = set( + method_body = strip_comments(method_body) + omitted_methods = {} + for method_name in set( m for m in method_re.findall(method_body) if m not in method_names - ) - assert ( - not unrecognised_methods - ), f"Found the following unrecognised methods {unrecognised_methods}" + ): + omitted_methods[method_name] = find_super_method( + self.nipype_interface, method_name + )[0] splits = method_re.split(method_body) new_body = splits[0] for name, args in zip(splits[1::2], splits[2::2]): + if name in omitted_methods: + new_body += self.SPECIAL_SUPER_MAPPINGS[omitted_methods[name]] + new_body += extract_args(args)[-1][1:] + continue # Assign additional return values (which were previously saved to member # attributes) to new variables from the method call if self.method_returns[name]: @@ -1342,18 +1351,18 @@ def unwrap_nested_methods(self, method_body, additional_args=()): else: new_body += ",".join(self.method_returns[name]) + " = " else: - raise NotImplementedError( + logger.warning( "Could not augment the return value of the method converted from " - "a function with the previously assigned attributes as it is used " - "directly. Need to replace the method call with a variable and " - "assign the return value to it on a previous line" + f"a function '{name}' with the previously assigned attributes " + f"{self.method_returns[name]} as the method doesn't have a " + f"singular return statement at the end of the method" ) # Insert additional arguments to the method call (which were previously # accessed via member attributes) new_body += name + insert_args_in_signature( args, [ - f"{a}={a}" + f"{a}=inputs['{a}']" if inputs_as_dict else f"{a}={a}" for a in (list(self.method_args[name]) + list(additional_args)) ], ) @@ -1368,6 +1377,7 @@ def unwrap_nested_methods(self, method_body, additional_args=()): SPECIAL_SUPER_MAPPINGS = { CommandLine._list_outputs: "{}", CommandLine._format_arg: "argstr.format(**inputs)", + BaseInterface._check_version_requirements: "[]", } INPUT_KEYS = [ @@ -1433,3 +1443,7 @@ def find_super_method( f"Could not find super of '{method_name}' method in base classes of " f"{super_base}" ) + + +def strip_comments(src: str) -> str: + return re.sub(r"\s*#.*", "", src) diff --git a/nipype2pydra/interface/shell_command.py b/nipype2pydra/interface/shell_command.py index a55f856..59c50e0 100644 --- a/nipype2pydra/interface/shell_command.py +++ b/nipype2pydra/interface/shell_command.py @@ -288,11 +288,10 @@ def format_arg_code(self): ) if not body: return "" - body = self.unwrap_nested_methods(body) + body = self.unwrap_nested_methods(body, inputs_as_dict=True) body = self.replace_supers(body) - code_str = f"""def _format_arg({name_arg}, {val_arg}, inputs, argstr): - parsed_inputs = _parse_inputs(inputs) if inputs else {{}} + code_str = f"""def _format_arg({name_arg}, {val_arg}, inputs, argstr):{self.parse_inputs_call} if {val_arg} is None: return "" {body} @@ -324,7 +323,7 @@ def parse_inputs_code(self) -> str: # Strip out return value body = re.sub(r"\s*return .*\n", "", body) - body = self.unwrap_nested_methods(body) + body = self.unwrap_nested_methods(body, inputs_as_dict=True) body = self.replace_supers(body) code_str = "def _parse_inputs(inputs):\n parsed_inputs = {{}}" @@ -352,11 +351,10 @@ def defaults_code(self): if not body: return "" - body = self.unwrap_nested_methods(body) + body = self.unwrap_nested_methods(body, inputs_as_dict=True) body = self.replace_supers(body) - code_str = f"""def _gen_filename(name, inputs): - parsed_inputs = _parse_inputs(inputs) if inputs else {{}} + code_str = f"""def _gen_filename(name, inputs):{self.parse_inputs_call} {body} """ # Create separate default function for each input field with genfile, which @@ -376,23 +374,31 @@ def callables_code(self): if not self.callable_output_fields: return "" - - body = _strip_doc_string( - inspect.getsource(self.nipype_interface._list_outputs).split("\n", 1)[-1] - ) + if hasattr(self.nipype_interface, "aggregate_outputs"): + func_name = "aggregate_outputs" + body = _strip_doc_string( + inspect.getsource(self.nipype_interface.aggregate_outputs).split( + "\n", 1 + )[-1] + ) + else: + func_name = "_list_outputs" + body = _strip_doc_string( + inspect.getsource(self.nipype_interface._list_outputs).split("\n", 1)[ + -1 + ] + ) body = self._process_inputs(body) body = self._misc_cleanups(body) if not body: return "" body = self.unwrap_nested_methods( - body, - additional_args=CALLABLES_ARGS, + body, additional_args=CALLABLES_ARGS, inputs_as_dict=True ) body = self.replace_supers(body) - code_str = f"""def _list_outputs(inputs=None, stdout=None, stderr=None, output_dir=None): - parsed_inputs = _parse_inputs(inputs) if inputs else {{}} + code_str = f"""def {func_name}(inputs=None, stdout=None, stderr=None, output_dir=None):{self.parse_inputs_call} {body} """ # Create separate function for each output field in the "callables" section @@ -400,7 +406,7 @@ def callables_code(self): output_name = output_field[0] code_str += ( f"\n\n\ndef {output_name}_callable(output_dir, inputs, stdout, stderr):\n" - " outputs = _list_outputs(output_dir=output_dir, inputs=inputs, stdout=stdout, stderr=stderr)\n" + f" outputs = {func_name}(output_dir=output_dir, inputs=inputs, stdout=stdout, stderr=stderr)\n" ' return outputs["' + output_name + '"]\n\n' ) return code_str @@ -419,9 +425,15 @@ def _process_inputs(self, body: str) -> str: self.task_name, ) body = input_re.sub(r"inputs['\1']", body) - body = re.sub(r"self\.(?!inputs)(\w+)", r"parsed_inputs['\1']", body) + body = re.sub(r"self\.(?!inputs)(\w+)\b(?!\()", r"parsed_inputs['\1']", body) return body + @property + def parse_inputs_call(self): + if not self.parse_inputs_code: + return "" + return "\n _parse_inputs(inputs) if inputs else {}" + def _strip_doc_string(body: str) -> str: if re.match(r"\s*(\"|')", body): From 28659bc7a61339261adc2744097d4c4a791dc5c7 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Sat, 1 Jun 2024 14:13:54 +0930 Subject: [PATCH 10/21] fixed handling of super methods, added supported for aggregate_outputs and fix bug with inputs handling in _list_outputs --- nipype2pydra/interface/base.py | 45 +++++++++--------- nipype2pydra/interface/shell_command.py | 63 +++++++++++++++++++------ nipype2pydra/utils/__init__.py | 2 + nipype2pydra/utils/misc.py | 16 +++++++ 4 files changed, 90 insertions(+), 36 deletions(-) diff --git a/nipype2pydra/interface/base.py b/nipype2pydra/interface/base.py index 116e6b9..35af106 100644 --- a/nipype2pydra/interface/base.py +++ b/nipype2pydra/interface/base.py @@ -31,6 +31,8 @@ cleanup_function_body, insert_args_in_signature, extract_args, + strip_comments, + find_super_method, ) from ..statements import ( ImportStatement, @@ -1021,6 +1023,11 @@ def _misc_cleanups(self, body: str) -> str: ) body = re.sub(r"\bruntime\.(stdout|stderr)", r"\1", body) body = re.sub(r"\boutputs\.(\w+)", r"outputs['\1']", body) + body = re.sub(r"getattr\(inputs, ([^)]+)\)", r"inputs[\1]", body) + body = re.sub( + r"setattr\(outputs, ([^,]+), ([^)]+)\)", r"outputs[\1] = \2", body + ) + body = body.replace("TraitError", "KeyError") body = body.replace("os.getcwd()", "output_dir") return body @@ -1209,7 +1216,9 @@ def process_method( if method.__name__ in self.method_args: args += [ f"{a}=None" - for a in (list(self.method_args[method.__name__]) + additional_args) + for a in ( + list(self.method_args[method.__name__]) + list(additional_args) + ) ] # Insert method args in signature if present return_types, method_body = post.split(":", maxsplit=1) @@ -1291,7 +1300,9 @@ def replace_supers(self, method_body, super_base=None): def replace_super(match): super_method, base = find_super_method(super_base, match.group(1)) try: - return self.SPECIAL_SUPER_MAPPINGS[super_method] + return self.SPECIAL_SUPER_MAPPINGS[super_method].format( + args=match.group(2) + ) except KeyError: try: return name_map[match.group(1)] + "(" + match.group(2) + ")" @@ -1316,7 +1327,9 @@ def unwrap_nested_methods( """ # Add args to the function signature of method calls method_re = re.compile(r"self\.(\w+)(?=\()", flags=re.MULTILINE | re.DOTALL) - method_names = [m.__name__ for m in self.referenced_methods] + method_names = [m.__name__ for m in self.referenced_methods] + list( + self.INCLUDED_METHODS + ) method_body = strip_comments(method_body) omitted_methods = {} for method_name in set( @@ -1329,8 +1342,11 @@ def unwrap_nested_methods( new_body = splits[0] for name, args in zip(splits[1::2], splits[2::2]): if name in omitted_methods: - new_body += self.SPECIAL_SUPER_MAPPINGS[omitted_methods[name]] - new_body += extract_args(args)[-1][1:] + args, post = extract_args(args)[1:] + new_body += self.SPECIAL_SUPER_MAPPINGS[omitted_methods[name]].format( + args=", ".join(args) + ) + new_body += post[1:] # drop the leading parenthesis continue # Assign additional return values (which were previously saved to member # attributes) to new variables from the method call @@ -1375,8 +1391,9 @@ def unwrap_nested_methods( return cleanup_function_body(method_body) SPECIAL_SUPER_MAPPINGS = { - CommandLine._list_outputs: "{}", + CommandLine._list_outputs: "{{}}", CommandLine._format_arg: "argstr.format(**inputs)", + CommandLine._filename_from_source: "{args} + '_generated'", BaseInterface._check_version_requirements: "[]", } @@ -1431,19 +1448,3 @@ def pytest_configure(config): else: CATCH_CLI_EXCEPTIONS = True """ - - -def find_super_method( - super_base: type, method_name: str -) -> ty.Tuple[ty.Callable, type]: - for base in super_base.__mro__[1:]: - if method_name in base.__dict__: # Found the match - return getattr(base, method_name), base - raise RuntimeError( - f"Could not find super of '{method_name}' method in base classes of " - f"{super_base}" - ) - - -def strip_comments(src: str) -> str: - return re.sub(r"\s*#.*", "", src) diff --git a/nipype2pydra/interface/shell_command.py b/nipype2pydra/interface/shell_command.py index 59c50e0..d1091b4 100644 --- a/nipype2pydra/interface/shell_command.py +++ b/nipype2pydra/interface/shell_command.py @@ -13,6 +13,7 @@ UsedSymbols, split_source_into_statements, INBUILT_NIPYPE_TRAIT_NAMES, + find_super_method, ) from fileformats.core.mixin import WithClassifiers from fileformats.generic import File, Directory @@ -243,6 +244,10 @@ def callable_output_fields(self): ) ] + @property + def callable_output_field_names(self): + return [f[0] for f in self.callable_output_fields] + @cached_property def _format_arg_body(self): if "_format_arg" not in self.nipype_interface.__dict__: @@ -295,7 +300,7 @@ def format_arg_code(self): if {val_arg} is None: return "" {body} - + return argstr.format(**inputs) """ for field_name in self.formatted_input_field_names: @@ -326,7 +331,7 @@ def parse_inputs_code(self) -> str: body = self.unwrap_nested_methods(body, inputs_as_dict=True) body = self.replace_supers(body) - code_str = "def _parse_inputs(inputs):\n parsed_inputs = {{}}" + code_str = "def _parse_inputs(inputs):\n parsed_inputs = {}" if re.findall(r"\bargstrs\b", body): code_str += f"\n argstrs = {self._format_argstrs!r}" code_str += f""" @@ -374,32 +379,62 @@ def callables_code(self): if not self.callable_output_fields: return "" - if hasattr(self.nipype_interface, "aggregate_outputs"): + code_str = "" + if ( + find_super_method(self.nipype_interface, "aggregate_outputs")[1] + is not BaseInterface + ): func_name = "aggregate_outputs" body = _strip_doc_string( inspect.getsource(self.nipype_interface.aggregate_outputs).split( "\n", 1 )[-1] ) + need_list_outputs = bool(re.findall(r"\b_list_outputs\b", body)) + body = self._process_inputs(body) + body = self._misc_cleanups(body) + + if not body: + return "" + body = self.unwrap_nested_methods( + body, additional_args=CALLABLES_ARGS, inputs_as_dict=True + ) + body = self.replace_supers(body) + + code_str += f"""def aggregate_outputs(inputs=None, stdout=None, stderr=None, output_dir=None): + inputs = attrs.asdict(inputs){self.parse_inputs_call} + needed_outputs = {self.callable_output_field_names!r} +{body} + + +""" + inputs_as_dict_call = "" + else: func_name = "_list_outputs" + inputs_as_dict_call = "\n inputs = attrs.asdict(inputs)" + need_list_outputs = True + + if need_list_outputs: body = _strip_doc_string( inspect.getsource(self.nipype_interface._list_outputs).split("\n", 1)[ -1 ] ) - body = self._process_inputs(body) - body = self._misc_cleanups(body) + body = self._process_inputs(body) + body = self._misc_cleanups(body) - if not body: - return "" - body = self.unwrap_nested_methods( - body, additional_args=CALLABLES_ARGS, inputs_as_dict=True - ) - body = self.replace_supers(body) + if not body: + return "" + body = self.unwrap_nested_methods( + body, additional_args=CALLABLES_ARGS, inputs_as_dict=True + ) + body = self.replace_supers(body) - code_str = f"""def {func_name}(inputs=None, stdout=None, stderr=None, output_dir=None):{self.parse_inputs_call} + code_str += f"""def _list_outputs(inputs=None, stdout=None, stderr=None, output_dir=None):{inputs_as_dict_call}{self.parse_inputs_call} {body} + + """ # Create separate function for each output field in the "callables" section for output_field in self.callable_output_fields: @@ -407,7 +442,7 @@ def callables_code(self): code_str += ( f"\n\n\ndef {output_name}_callable(output_dir, inputs, stdout, stderr):\n" f" outputs = {func_name}(output_dir=output_dir, inputs=inputs, stdout=stdout, stderr=stderr)\n" - ' return outputs["' + output_name + '"]\n\n' + ' return outputs.get("' + output_name + '", attrs.NOTHING)\n\n' ) return code_str @@ -432,7 +467,7 @@ def _process_inputs(self, body: str) -> str: def parse_inputs_call(self): if not self.parse_inputs_code: return "" - return "\n _parse_inputs(inputs) if inputs else {}" + return "\n parsed_inputs = _parse_inputs(inputs) if inputs else {}" def _strip_doc_string(body: str) -> str: diff --git a/nipype2pydra/utils/__init__.py b/nipype2pydra/utils/__init__.py index 51bb486..000f78a 100644 --- a/nipype2pydra/utils/__init__.py +++ b/nipype2pydra/utils/__init__.py @@ -20,6 +20,8 @@ types_converter, unwrap_nested_type, get_return_line, + find_super_method, + strip_comments, INBUILT_NIPYPE_TRAIT_NAMES, ) from .symbols import ( # noqa: F401 diff --git a/nipype2pydra/utils/misc.py b/nipype2pydra/utils/misc.py index ed578e0..ee6b013 100644 --- a/nipype2pydra/utils/misc.py +++ b/nipype2pydra/utils/misc.py @@ -539,3 +539,19 @@ def get_return_line(func: ty.Union[str, ty.Callable]) -> str: if not match: return None return match.group(1).strip() + + +def find_super_method( + super_base: type, method_name: str +) -> ty.Tuple[ty.Callable, type]: + for base in super_base.__mro__[1:]: + if method_name in base.__dict__: # Found the match + return getattr(base, method_name), base + raise RuntimeError( + f"Could not find super of '{method_name}' method in base classes of " + f"{super_base}" + ) + + +def strip_comments(src: str) -> str: + return re.sub(r"\s+#.*", "", src) From 1e92a403501f25f7e2db255315e309bbdd1dad85 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Sat, 1 Jun 2024 16:58:32 +0930 Subject: [PATCH 11/21] debugged changes so that all packages build successfully again --- nipype2pydra/interface/base.py | 60 +++++++++++++++---------- nipype2pydra/interface/shell_command.py | 21 ++++++--- nipype2pydra/package.py | 14 +++++- nipype2pydra/utils/misc.py | 9 ++-- 4 files changed, 69 insertions(+), 35 deletions(-) diff --git a/nipype2pydra/interface/base.py b/nipype2pydra/interface/base.py index 35af106..e4460c8 100644 --- a/nipype2pydra/interface/base.py +++ b/nipype2pydra/interface/base.py @@ -42,6 +42,7 @@ ) from fileformats.generic import File import nipype2pydra.package +from nipype2pydra.exceptions import UnmatchedParensException logger = logging.getLogger("nipype2pydra") @@ -498,7 +499,10 @@ def _referenced_funcs_and_methods(self): method_returns[method_name] = [] method_stacks[method_name] = () for method_name in self.INCLUDED_METHODS: - if method_name not in self.nipype_interface.__dict__: + base = find_super_method( + self.nipype_interface, method_name, include_class=True + )[1] + if self.package.is_omitted(base): continue # Don't include base methods method = getattr(self.nipype_interface, method_name) referenced_methods.add(method) @@ -1103,9 +1107,7 @@ def _get_referenced( ) for match in re.findall(r"super\([^\)]*\)\.(\w+)\(", method_body): super_method, base = find_super_method(super_base, match) - if any( - base.__module__.startswith(m) for m in UsedSymbols.ALWAYS_OMIT_MODULES - ): + if self.package.is_omitted(super_method): continue func_name = self._common_parent_pkg_prefix(base) + match if func_name not in referenced_supers: @@ -1296,28 +1298,28 @@ def replace_supers(self, method_body, super_base=None): if super_base is None: super_base = self.nipype_interface name_map = self.method_supers[super_base] - - def replace_super(match): - super_method, base = find_super_method(super_base, match.group(1)) + splits = re.split(r"super\([^\)]*\)\.(\w+)(?=\()", method_body) + new_body = splits[0] + for name, block in zip(splits[1::2], splits[2::2]): + super_method, base = find_super_method(super_base, name) + _, args, post = extract_args(block) + arg_str = ", ".join(args) try: - return self.SPECIAL_SUPER_MAPPINGS[super_method].format( - args=match.group(2) + new_body += self.SPECIAL_SUPER_MAPPINGS[super_method].format( + args=arg_str ) except KeyError: try: - return name_map[match.group(1)] + "(" + match.group(2) + ")" + new_body += name_map[name] + "(" + arg_str + ")" except KeyError: - if any( - base.__module__.startswith(m) - for m in UsedSymbols.ALWAYS_OMIT_MODULES - ): + if self.package.is_omitted(base): raise KeyError( - f"Require special mapping for '{match.group(1)}' in {base} class " + f"Require special mapping for '{name}' in {base} class " "as methods in that module are being omitted from the conversion" ) from None raise - - return re.sub(r"super\([^\)]*\)\.(\w+)\(([^\)]*)\)", replace_super, method_body) + new_body += post[1:] + return new_body def unwrap_nested_methods( self, method_body, additional_args=(), inputs_as_dict: bool = False @@ -1375,13 +1377,22 @@ def unwrap_nested_methods( ) # Insert additional arguments to the method call (which were previously # accessed via member attributes) - new_body += name + insert_args_in_signature( - args, - [ - f"{a}=inputs['{a}']" if inputs_as_dict else f"{a}={a}" - for a in (list(self.method_args[name]) + list(additional_args)) - ], - ) + args_to_be_inserted = list(self.method_args[name]) + list(additional_args) + try: + new_body += name + insert_args_in_signature( + args, + [ + f"{a}=inputs['{a}']" if inputs_as_dict else f"{a}={a}" + for a in args_to_be_inserted + ], + ) + except UnmatchedParensException: + logger.warning( + f"Nested method call inside '{name}' in {self.full_address}, " + "the following args will need to be manually inserted up after the " + f"conversion: {args_to_be_inserted}" + ) + new_body += name + args method_body = new_body # Convert assignment to self attributes into method-scoped variables (hopefully # there aren't any name clashes) @@ -1395,6 +1406,7 @@ def unwrap_nested_methods( CommandLine._format_arg: "argstr.format(**inputs)", CommandLine._filename_from_source: "{args} + '_generated'", BaseInterface._check_version_requirements: "[]", + CommandLine._parse_inputs: "{{}}", } INPUT_KEYS = [ diff --git a/nipype2pydra/interface/shell_command.py b/nipype2pydra/interface/shell_command.py index d1091b4..87d70b0 100644 --- a/nipype2pydra/interface/shell_command.py +++ b/nipype2pydra/interface/shell_command.py @@ -174,7 +174,8 @@ def types_to_names(spec_fields): self.parse_inputs_code, self.callables_code, self.defaults_code, - ], + ] + + list(self.referenced_methods), omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec], omit_modules=self.package.omit_modules, omit_functions=self.package.omit_functions, @@ -215,7 +216,7 @@ def input_fields(self): for field in input_fields: if field[0] in self.formatted_input_field_names: field[-1]["formatter"] = f"{field[0]}_formatter" - self._format_argstrs[field[0]] = field[-1].pop("argstr") + self._format_argstrs[field[0]] = field[-1].pop("argstr", "") return input_fields @cached_property @@ -284,12 +285,13 @@ def format_arg_code(self): # Strip out return value body = re.sub( ( - r"\s*return super\((\w+,\s*self)?\)\._format_arg\(" + r"^ return super\((\w+,\s*self)?\)\._format_arg\(" + ", ".join(existing_args) + r"\)\n" ), - "", + "return argstr.format(**inputs)", body, + flags=re.MULTILINE, ) if not body: return "" @@ -299,10 +301,13 @@ def format_arg_code(self): code_str = f"""def _format_arg({name_arg}, {val_arg}, inputs, argstr):{self.parse_inputs_call} if {val_arg} is None: return "" -{body} - return argstr.format(**inputs) +{body}""" + + if not code_str.rstrip().endswith("return argstr.format(**inputs)"): + code_str += "\n return argstr.format(**inputs)" + + code_str += "\n\n" -""" for field_name in self.formatted_input_field_names: code_str += ( f"def {field_name}_formatter(field, inputs):\n" @@ -328,6 +333,8 @@ def parse_inputs_code(self) -> str: # Strip out return value body = re.sub(r"\s*return .*\n", "", body) + if not body: + return "" body = self.unwrap_nested_methods(body, inputs_as_dict=True) body = self.replace_supers(body) diff --git a/nipype2pydra/package.py b/nipype2pydra/package.py index 176065e..ca33065 100644 --- a/nipype2pydra/package.py +++ b/nipype2pydra/package.py @@ -310,7 +310,7 @@ def all_import_translations(self) -> ty.List[ty.Tuple[str, str]]: @property def all_omit_modules(self) -> ty.List[str]: - return self.omit_modules + ["nipype.interfaces.utility"] + return self.omit_modules + UsedSymbols.ALWAYS_OMIT_MODULES @property def all_explicit(self): @@ -344,6 +344,17 @@ def config_defaults(self) -> ty.Dict[str, ty.Dict[str, str]]: all_defaults[name] = defaults return all_defaults + def is_omitted(self, obj: ty.Any) -> bool: + if full_address(obj) in self.omit_classes + self.omit_functions: + return True + if inspect.ismodule(obj): + mod_name = obj.__name__ + else: + mod_name = obj.__module__ + if any(re.match(m + r"\b", mod_name) for m in self.all_omit_modules): + return True + return False + def write(self, package_root: Path, to_include: ty.List[str] = None): """Writes the package to the specified package root""" @@ -886,6 +897,7 @@ def write_to_module( # Write to file for debugging debug_file = "~/unparsable-nipype2pydra-output.py" with open(Path(debug_file).expanduser(), "w") as f: + f.write(f"# Attemping to convert {self.nipype_name}\n") f.write(converted_code) raise RuntimeError( f"Black could not parse generated code (written to {debug_file}): " diff --git a/nipype2pydra/utils/misc.py b/nipype2pydra/utils/misc.py index ee6b013..0360923 100644 --- a/nipype2pydra/utils/misc.py +++ b/nipype2pydra/utils/misc.py @@ -542,9 +542,12 @@ def get_return_line(func: ty.Union[str, ty.Callable]) -> str: def find_super_method( - super_base: type, method_name: str + super_base: type, method_name: str, include_class: bool = False ) -> ty.Tuple[ty.Callable, type]: - for base in super_base.__mro__[1:]: + mro = super_base.__mro__ + if not include_class: + mro = mro[1:] + for base in mro: if method_name in base.__dict__: # Found the match return getattr(base, method_name), base raise RuntimeError( @@ -554,4 +557,4 @@ def find_super_method( def strip_comments(src: str) -> str: - return re.sub(r"\s+#.*", "", src) + return re.sub(r"^\s+#.*", "", src, flags=re.MULTILINE) From 63a407707f339b107652a056613544ece827f321 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Sat, 1 Jun 2024 22:28:38 +0930 Subject: [PATCH 12/21] finally got all packages to convert again! --- nipype2pydra/interface/base.py | 65 +++++++++++++-------- nipype2pydra/interface/function.py | 23 +++++++- nipype2pydra/interface/shell_command.py | 78 ++++++++++++++++++------- nipype2pydra/utils/misc.py | 6 +- 4 files changed, 124 insertions(+), 48 deletions(-) diff --git a/nipype2pydra/interface/base.py b/nipype2pydra/interface/base.py index e4460c8..1468737 100644 --- a/nipype2pydra/interface/base.py +++ b/nipype2pydra/interface/base.py @@ -492,20 +492,21 @@ def _referenced_funcs_and_methods(self): method_stacks = {} method_supers = defaultdict(dict) already_processed = set( - getattr(self.nipype_interface, m) for m in self.INCLUDED_METHODS + getattr(self.nipype_interface, m) for m in self.included_methods ) - for method_name in self.INCLUDED_METHODS: + for method_name in self.included_methods: method_args[method_name] = [] method_returns[method_name] = [] method_stacks[method_name] = () - for method_name in self.INCLUDED_METHODS: - base = find_super_method( + for method_name in self.included_methods: + method = getattr(self.nipype_interface, method_name) + super_base = find_super_method( self.nipype_interface, method_name, include_class=True )[1] - if self.package.is_omitted(base): - continue # Don't include base methods - method = getattr(self.nipype_interface, method_name) - referenced_methods.add(method) + # if super_base is not self.nipype_interface: + # method_supers[self.nipype_interface][method_name] = ( + # self._common_parent_pkg_prefix(super_base) + method_name + # ) self._get_referenced( method, referenced_funcs=referenced_funcs, @@ -516,6 +517,7 @@ def _referenced_funcs_and_methods(self): method_stacks=method_stacks, method_supers=method_supers, already_processed=already_processed, + super_base=super_base, ) return ( referenced_funcs, @@ -1095,7 +1097,14 @@ def _get_referenced( ref_method_names = re.findall(r"(?<=self\.)(\w+)\(", method_body) ref_methods = set(m for m in self.methods if m.__name__ in ref_method_names) - + # Filter methods in omitted common base-classes like BaseInterface & CommandLine + ref_methods = [ + m + for m in ref_methods + if not self.package.is_omitted( + find_super_method(super_base, m.__name__, include_class=True)[1] + ) + ] referenced_funcs.update(ref_local_funcs) referenced_methods.update(ref_methods) @@ -1107,7 +1116,7 @@ def _get_referenced( ) for match in re.findall(r"super\([^\)]*\)\.(\w+)\(", method_body): super_method, base = find_super_method(super_base, match) - if self.package.is_omitted(super_method): + if self.package.is_omitted(base): continue func_name = self._common_parent_pkg_prefix(base) + match if func_name not in referenced_supers: @@ -1144,7 +1153,6 @@ def _get_referenced( method_supers=method_supers, already_processed=already_processed, method_stack=method_stack, - super_base=super_base, ) referenced_inputs.update(rf_inputs) referenced_outputs.update(rf_outputs) @@ -1162,7 +1170,6 @@ def _get_referenced( method_supers=method_supers, already_processed=already_processed, method_stack=method_stack, - super_base=super_base, ) method_args[meth.__name__] = ref_inputs method_returns[meth.__name__] = ref_outputs @@ -1215,13 +1222,12 @@ def process_method( pass if "runtime" in args: args.remove("runtime") - if method.__name__ in self.method_args: - args += [ - f"{a}=None" - for a in ( - list(self.method_args[method.__name__]) + list(additional_args) - ) - ] + args_to_add = list(self.method_args.get(method.__name__, [])) + list( + additional_args + ) + if args_to_add: + kwargs = [args.pop()] if args and args[-1].startswith("**") else [] + args += [f"{a}=None" for a in args_to_add] + kwargs # Insert method args in signature if present return_types, method_body = post.split(":", maxsplit=1) method_body = method_body.split("\n", maxsplit=1)[1] @@ -1255,6 +1261,8 @@ def process_method_body( output_names: ty.List[str], super_base: ty.Optional[type] = None, ) -> str: + if not method_body: + return "" if super_base is None: super_base = self.nipype_interface return_value = get_return_line(method_body) @@ -1330,7 +1338,7 @@ def unwrap_nested_methods( # Add args to the function signature of method calls method_re = re.compile(r"self\.(\w+)(?=\()", flags=re.MULTILINE | re.DOTALL) method_names = [m.__name__ for m in self.referenced_methods] + list( - self.INCLUDED_METHODS + self.included_methods ) method_body = strip_comments(method_body) omitted_methods = {} @@ -1345,9 +1353,16 @@ def unwrap_nested_methods( for name, args in zip(splits[1::2], splits[2::2]): if name in omitted_methods: args, post = extract_args(args)[1:] - new_body += self.SPECIAL_SUPER_MAPPINGS[omitted_methods[name]].format( - args=", ".join(args) - ) + omitted_method = omitted_methods[name] + try: + new_body += self.SPECIAL_SUPER_MAPPINGS[omitted_method].format( + args=", ".join(args) + ) + except KeyError: + raise KeyError( + f"Require special mapping for {omitted_methods[name]} method " + "as methods in that module are being omitted from the conversion" + ) from None new_body += post[1:] # drop the leading parenthesis continue # Assign additional return values (which were previously saved to member @@ -1407,6 +1422,10 @@ def unwrap_nested_methods( CommandLine._filename_from_source: "{args} + '_generated'", BaseInterface._check_version_requirements: "[]", CommandLine._parse_inputs: "{{}}", + CommandLine._gen_filename: "", + BaseInterface.aggregate_outputs: "{{}}", + BaseInterface.run: "None", + BaseInterface._list_outputs: "{{}}", } INPUT_KEYS = [ diff --git a/nipype2pydra/interface/function.py b/nipype2pydra/interface/function.py index 60cadb4..c2f9f98 100644 --- a/nipype2pydra/interface/function.py +++ b/nipype2pydra/interface/function.py @@ -8,7 +8,7 @@ import attrs from nipype.interfaces.base import BaseInterface, TraitedSpec from .base import BaseInterfaceConverter -from ..utils import UsedSymbols, get_return_line +from ..utils import UsedSymbols, get_return_line, find_super_method logger = logging.getLogger("nipype2pydra") @@ -17,7 +17,9 @@ @attrs.define(slots=False) class FunctionInterfaceConverter(BaseInterfaceConverter): - INCLUDED_METHODS = ("_run_interface", "_list_outputs") + @property + def included_methods(self) -> ty.Tuple[str, ...]: + return ("_run_interface", "_list_outputs") def generate_code(self, input_fields, nonstd_types, output_fields) -> ty.Tuple[ str, @@ -68,14 +70,29 @@ def types_to_names(spec_fields): if re.match(r"\s*return", method_lines[-1]): method_lines = method_lines[:-1] method_body = "\n".join(method_lines) + method_body = self.process_method_body( + method_body, + input_names, + output_names, + super_base=find_super_method( + self.nipype_interface, "_run_interface", include_class=True + )[1], + ) lo_src = inspect.getsource(self.nipype_interface._list_outputs).strip() # Strip out method def and return statement lo_lines = lo_src.strip().split("\n")[1:] if re.match(r"\s*(return|raise NotImplementedError)", lo_lines[-1]): lo_lines = lo_lines[:-1] lo_src = "\n".join(lo_lines) + lo_src = self.process_method_body( + lo_src, + input_names, + output_names, + super_base=find_super_method( + self.nipype_interface, "_list_outputs", include_class=True + )[1], + ) method_body += "\n" + lo_src - method_body = self.process_method_body(method_body, input_names, output_names) method_body = re.sub( r"self\._results\[(?:'|\")(\w+)(?:'|\")\]", r"\1", method_body ) diff --git a/nipype2pydra/interface/shell_command.py b/nipype2pydra/interface/shell_command.py index 87d70b0..f309c5c 100644 --- a/nipype2pydra/interface/shell_command.py +++ b/nipype2pydra/interface/shell_command.py @@ -29,12 +29,21 @@ class ShellCommandInterfaceConverter(BaseInterfaceConverter): _format_argstrs: ty.Dict[str, str] = attrs.field(factory=dict) - INCLUDED_METHODS = ( - "_parse_inputs", - "_format_arg", - "_list_outputs", - "_gen_filename", - ) + @cached_property + def included_methods(self) -> ty.Tuple[str, ...]: + included = [] + if not self.method_omitted("_parse_inputs"): + included.append("_parse_inputs"), + if not self.method_omitted("_format_arg"): + included.append("_format_arg") + if not self.method_omitted("_gen_filename"): + included.append("_gen_filename") + if self.callable_output_fields: + if not self.method_omitted("aggregate_outputs"): + included.append("aggregate_outputs") + if not self.method_omitted("_list_outputs"): + included.append("_list_outputs") + return tuple(included) def generate_code(self, input_fields, nonstd_types, output_fields) -> ty.Tuple[ str, @@ -142,7 +151,7 @@ def types_to_names(spec_fields): spec_str = re.sub(r"'#([^'#]+)#'", r"\1", spec_str) for m in sorted(self.referenced_methods, key=attrgetter("__name__")): - if m.__name__ in self.INCLUDED_METHODS: + if m.__name__ in self.included_methods: continue if self.method_stacks[m.__name__][0] == self.nipype_interface._list_outputs: additional_args = CALLABLES_ARGS @@ -251,7 +260,7 @@ def callable_output_field_names(self): @cached_property def _format_arg_body(self): - if "_format_arg" not in self.nipype_interface.__dict__: + if self.method_omitted("_format_arg"): return "" return _strip_doc_string( inspect.getsource(self.nipype_interface._format_arg).split("\n", 1)[-1] @@ -259,7 +268,7 @@ def _format_arg_body(self): @cached_property def _gen_filename_body(self): - if "_gen_filename" not in self.nipype_interface.__dict__: + if self.method_omitted("_gen_filename"): return "" return _strip_doc_string( inspect.getsource(self.nipype_interface._gen_filename).split("\n", 1)[-1] @@ -267,7 +276,7 @@ def _gen_filename_body(self): @property def format_arg_code(self): - if not self._format_arg_body: + if "_format_arg" not in self.included_methods: return "" body = self._format_arg_body body = self._process_inputs(body) @@ -296,7 +305,12 @@ def format_arg_code(self): if not body: return "" body = self.unwrap_nested_methods(body, inputs_as_dict=True) - body = self.replace_supers(body) + body = self.replace_supers( + body, + super_base=find_super_method( + self.nipype_interface, "_format_arg", include_class=True + )[1], + ) code_str = f"""def _format_arg({name_arg}, {val_arg}, inputs, argstr):{self.parse_inputs_call} if {val_arg} is None: @@ -318,7 +332,7 @@ def format_arg_code(self): @property def parse_inputs_code(self) -> str: - if "_parse_inputs" not in self.nipype_interface.__dict__: + if "_parse_inputs" not in self.included_methods: return "" body = _strip_doc_string( inspect.getsource(self.nipype_interface._parse_inputs).split("\n", 1)[-1] @@ -336,7 +350,12 @@ def parse_inputs_code(self) -> str: if not body: return "" body = self.unwrap_nested_methods(body, inputs_as_dict=True) - body = self.replace_supers(body) + body = self.replace_supers( + body, + super_base=find_super_method( + self.nipype_interface, "_parse_inputs", include_class=True + )[1], + ) code_str = "def _parse_inputs(inputs):\n parsed_inputs = {}" if re.findall(r"\bargstrs\b", body): @@ -352,7 +371,7 @@ def parse_inputs_code(self) -> str: @cached_property def defaults_code(self): - if not self.callable_default_input_field_names: + if "_gen_filename" not in self.included_methods: return "" body = _strip_doc_string( @@ -364,7 +383,12 @@ def defaults_code(self): if not body: return "" body = self.unwrap_nested_methods(body, inputs_as_dict=True) - body = self.replace_supers(body) + body = self.replace_supers( + body, + super_base=find_super_method( + self.nipype_interface, "_gen_filename", include_class=True + )[1], + ) code_str = f"""def _gen_filename(name, inputs):{self.parse_inputs_call} {body} @@ -387,10 +411,7 @@ def callables_code(self): if not self.callable_output_fields: return "" code_str = "" - if ( - find_super_method(self.nipype_interface, "aggregate_outputs")[1] - is not BaseInterface - ): + if "aggregate_outputs" in self.included_methods: func_name = "aggregate_outputs" body = _strip_doc_string( inspect.getsource(self.nipype_interface.aggregate_outputs).split( @@ -406,7 +427,12 @@ def callables_code(self): body = self.unwrap_nested_methods( body, additional_args=CALLABLES_ARGS, inputs_as_dict=True ) - body = self.replace_supers(body) + body = self.replace_supers( + body, + super_base=find_super_method( + self.nipype_interface, "aggregate_outputs", include_class=True + )[1], + ) code_str += f"""def aggregate_outputs(inputs=None, stdout=None, stderr=None, output_dir=None): inputs = attrs.asdict(inputs){self.parse_inputs_call} @@ -436,7 +462,12 @@ def callables_code(self): body = self.unwrap_nested_methods( body, additional_args=CALLABLES_ARGS, inputs_as_dict=True ) - body = self.replace_supers(body) + body = self.replace_supers( + body, + super_base=find_super_method( + self.nipype_interface, "_list_outputs", include_class=True + )[1], + ) code_str += f"""def _list_outputs(inputs=None, stdout=None, stderr=None, output_dir=None):{inputs_as_dict_call}{self.parse_inputs_call} {body} @@ -476,6 +507,11 @@ def parse_inputs_call(self): return "" return "\n parsed_inputs = _parse_inputs(inputs) if inputs else {}" + def method_omitted(self, method_name: str) -> bool: + return self.package.is_omitted( + find_super_method(self.nipype_interface, method_name, include_class=True)[1] + ) + def _strip_doc_string(body: str) -> str: if re.match(r"\s*(\"|')", body): diff --git a/nipype2pydra/utils/misc.py b/nipype2pydra/utils/misc.py index 0360923..38b709b 100644 --- a/nipype2pydra/utils/misc.py +++ b/nipype2pydra/utils/misc.py @@ -361,7 +361,11 @@ def insert_args_in_signature(snippet: str, new_args: ty.Iterable[str]) -> str: pre, args, post = extract_args(snippet) if "runtime" in args: args.remove("runtime") - return pre + ", ".join(args + new_args) + post + if args and args[-1].startswith("**"): + kwargs = [args.pop()] + else: + kwargs = [] + return pre + ", ".join(args + new_args + kwargs) + post def get_source_code(func_or_klass: ty.Union[ty.Callable, ty.Type]) -> str: From 0aa80aaa2bfec19e88930d5dfddce3fea8f3ba7e Mon Sep 17 00:00:00 2001 From: Tom Close Date: Sun, 2 Jun 2024 23:48:50 +1000 Subject: [PATCH 13/21] debugging updates to conversions to handle function task super methods --- nipype2pydra/helpers.py | 2 + nipype2pydra/interface/base.py | 50 +++++-- nipype2pydra/interface/function.py | 176 +++++++++++++++++------- nipype2pydra/interface/shell_command.py | 94 ++++++++----- nipype2pydra/package.py | 3 + nipype2pydra/utils/__init__.py | 1 + nipype2pydra/utils/misc.py | 10 +- nipype2pydra/utils/symbols.py | 29 ++-- nipype2pydra/workflow.py | 1 + 9 files changed, 259 insertions(+), 107 deletions(-) diff --git a/nipype2pydra/helpers.py b/nipype2pydra/helpers.py index 1a3910d..3892cd3 100644 --- a/nipype2pydra/helpers.py +++ b/nipype2pydra/helpers.py @@ -350,6 +350,7 @@ def _converted_code(self) -> ty.Tuple[str, ty.List[str]]: # Write to file for debugging debug_file = "~/unparsable-nipype2pydra-output.py" with open(Path(debug_file).expanduser(), "w") as f: + f.write(f"# Attemping to convert {self.full_name}\n") f.write(code_str) raise RuntimeError( f"Black could not parse generated code (written to {debug_file}): " @@ -413,6 +414,7 @@ def _converted_code(self) -> ty.Tuple[str, ty.List[str]]: # Write to file for debugging debug_file = "~/unparsable-nipype2pydra-output.py" with open(Path(debug_file).expanduser(), "w") as f: + f.write(f"# Attemping to convert {self.full_name}\n") f.write(code_str) raise RuntimeError( f"Black could not parse generated code (written to {debug_file}): " diff --git a/nipype2pydra/interface/base.py b/nipype2pydra/interface/base.py index 1468737..6ba5a58 100644 --- a/nipype2pydra/interface/base.py +++ b/nipype2pydra/interface/base.py @@ -14,7 +14,12 @@ import attrs from attrs.converters import default_if_none import nipype.interfaces.base -from nipype.interfaces.base import traits_extension, CommandLine, BaseInterface +from nipype.interfaces.base import ( + traits_extension, + CommandLine, + BaseInterface, +) +from nipype.interfaces.base.core import SimpleInterface from pydra.engine import specs from pydra.engine.helpers import ensure_list from ..utils import ( @@ -33,6 +38,7 @@ extract_args, strip_comments, find_super_method, + min_indentation, ) from ..statements import ( ImportStatement, @@ -364,6 +370,8 @@ class BaseInterfaceConverter(metaclass=ABCMeta): }, ) + _output_name_mappings: ty.Dict[str, str] = attrs.field(factory=dict) + def __attrs_post_init__(self): if self.output_module is None: if self.nipype_module.__name__.startswith("nipype.interfaces."): @@ -682,6 +690,7 @@ def pydra_fld_input(self, field, nm): f"the filed {nm} has genfile=True, but no template or " "`callables_default` function in the callables_module provided" ) + self._output_name_mappings[getattr(field, "output_name")] = nm pydra_metadata.update(metadata_extra_spec) @@ -1021,18 +1030,26 @@ def _misc_cleanups(self, body: str) -> str: if hasattr(self.nipype_interface, "_cmd"): body = body.replace("self.cmd", f'"{self.nipype_interface._cmd}"') - body = re.sub( - r"outputs = self\.(output_spec|_outputs)\(\).*$", - r"outputs = {}", - body, - flags=re.MULTILINE, - ) + body = body.replace("self.output_spec().get()", "{}") + body = body.replace("self._outputs()", "{}") + # body = re.sub( + # r"outputs = self\.(output_spec|_outputs)\(\).*$", + # r"outputs = {}", + # body, + # flags=re.MULTILINE, + # ) body = re.sub(r"\bruntime\.(stdout|stderr)", r"\1", body) body = re.sub(r"\boutputs\.(\w+)", r"outputs['\1']", body) body = re.sub(r"getattr\(inputs, ([^)]+)\)", r"inputs[\1]", body) body = re.sub( r"setattr\(outputs, ([^,]+), ([^)]+)\)", r"outputs[\1] = \2", body ) + body = re.sub(r"self\._results\[(?:'|\")(\w+)(?:'|\")\]", r"\1", body) + body = re.sub(r"\s+runtime.returncode = (.*)", "", body) + new_body = re.sub(r"self\.(\w+)\b(?!\()", r"self_dict['\1']", body) + if new_body != body: + body = " " * min_indentation(body) + "self_dict = {}\n" + new_body + body = body.replace("return runtime", "") body = body.replace("TraitError", "KeyError") body = body.replace("os.getcwd()", "output_dir") return body @@ -1237,7 +1254,10 @@ def process_method( if self.method_returns.get(method.__name__): return_args = self.method_returns[method.__name__] method_body = ( - " " + " = ".join(return_args) + " = attrs.NOTHING\n" + method_body + " " * min_indentation(method_body) + + " = ".join(return_args) + + " = attrs.NOTHING\n" + + method_body ) method_lines = method_body.rstrip().splitlines() method_body = "\n".join(method_lines[:-1]) @@ -1295,12 +1315,9 @@ def process_method_body( self.task_name, ) method_body = output_re.sub(r"\1", method_body) - # Strip initialisation of outputs - method_body = re.sub( - r"outputs = self.output_spec().*", r"outputs = {}", method_body - ) - method_body = self._misc_cleanups(method_body) - return self.unwrap_nested_methods(method_body) + method_body = self.unwrap_nested_methods(method_body) + # method_body = self._misc_cleanups(method_body) + return method_body def replace_supers(self, method_body, super_base=None): if super_base is None: @@ -1335,6 +1352,7 @@ def unwrap_nested_methods( """ Converts nested method calls into function calls """ + method_body = self._misc_cleanups(method_body) # Add args to the function signature of method calls method_re = re.compile(r"self\.(\w+)(?=\()", flags=re.MULTILINE | re.DOTALL) method_names = [m.__name__ for m in self.referenced_methods] + list( @@ -1426,6 +1444,10 @@ def unwrap_nested_methods( BaseInterface.aggregate_outputs: "{{}}", BaseInterface.run: "None", BaseInterface._list_outputs: "{{}}", + BaseInterface.__init__: "", + SimpleInterface.__init__: "", + BaseInterface._outputs: "{{}}", + None: "", } INPUT_KEYS = [ diff --git a/nipype2pydra/interface/function.py b/nipype2pydra/interface/function.py index c2f9f98..ad519cf 100644 --- a/nipype2pydra/interface/function.py +++ b/nipype2pydra/interface/function.py @@ -63,49 +63,9 @@ def types_to_names(spec_fields): output_names = [o[0] for o in output_fields] output_type_names = [o[1] for o in output_fields_str] - # Combined src of run_interface and list_outputs - method_body = inspect.getsource(self.nipype_interface._run_interface).strip() - # Strip out method def and return statement - method_lines = method_body.strip().split("\n")[1:] - if re.match(r"\s*return", method_lines[-1]): - method_lines = method_lines[:-1] - method_body = "\n".join(method_lines) - method_body = self.process_method_body( - method_body, - input_names, - output_names, - super_base=find_super_method( - self.nipype_interface, "_run_interface", include_class=True - )[1], - ) - lo_src = inspect.getsource(self.nipype_interface._list_outputs).strip() - # Strip out method def and return statement - lo_lines = lo_src.strip().split("\n")[1:] - if re.match(r"\s*(return|raise NotImplementedError)", lo_lines[-1]): - lo_lines = lo_lines[:-1] - lo_src = "\n".join(lo_lines) - lo_src = self.process_method_body( - lo_src, - input_names, - output_names, - super_base=find_super_method( - self.nipype_interface, "_list_outputs", include_class=True - )[1], - ) - method_body += "\n" + lo_src - method_body = re.sub( - r"self\._results\[(?:'|\")(\w+)(?:'|\")\]", r"\1", method_body - ) - used = UsedSymbols.find( self.nipype_module, - [method_body] - + [ - inspect.getsource(f) - for f in itertools.chain( - self.referenced_local_functions, self.referenced_methods - ) - ], + self.referenced_local_functions, omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec], omit_modules=self.package.omit_modules, omit_functions=self.package.omit_functions, @@ -115,6 +75,128 @@ def types_to_names(spec_fields): absolute_imports=True, ) + for ref_method in self.referenced_methods: + method_module = find_super_method( + self.nipype_interface, ref_method.__name__, include_class=True + )[1].__module__ + method_used = UsedSymbols.find( + method_module, + [ref_method], + omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec], + omit_modules=self.package.omit_modules, + omit_functions=self.package.omit_functions, + omit_constants=self.package.omit_constants, + always_include=self.package.all_explicit, + translations=self.package.all_import_translations, + absolute_imports=True, + ) + used.update(method_used, from_other_module=False) + + method_body = "" + for field in output_fields: + method_body += f" {field[0]} = attrs.NOTHING\n" + + # Combined src of init and list_outputs + init_code = inspect.getsource(self.nipype_interface.__init__).strip() + init_class = find_super_method( + self.nipype_interface, "__init__", include_class=True + )[1] + if not self.package.is_omitted(init_class): + # Strip out method def and return statement + method_lines = init_code.strip().split("\n")[1:] + if re.match(r"\s*return", method_lines[-1]): + method_lines = method_lines[:-1] + init_code = "\n".join(method_lines) + init_code = self.process_method_body( + init_code, + input_names, + output_names, + super_base=init_class, + ) + + init_used = UsedSymbols.find( + init_class.__module__, + [init_code], + omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec], + omit_modules=self.package.omit_modules, + omit_functions=self.package.omit_functions, + omit_constants=self.package.omit_constants, + always_include=self.package.all_explicit, + translations=self.package.all_import_translations, + absolute_imports=True, + ) + used.update(init_used, from_other_module=False) + method_body += init_code + "\n" + + # Combined src of run_interface and list_outputs + run_interface_code = inspect.getsource( + self.nipype_interface._run_interface + ).strip() + run_interface_class = find_super_method( + self.nipype_interface, "_run_interface", include_class=True + )[1] + if not self.package.is_omitted(run_interface_class): + # Strip out method def and return statement + method_lines = run_interface_code.strip().split("\n")[1:] + if re.match(r"\s*return", method_lines[-1]): + method_lines = method_lines[:-1] + run_interface_code = "\n".join(method_lines) + run_interface_code = self.process_method_body( + run_interface_code, + input_names, + output_names, + super_base=run_interface_class, + ) + + run_interface_used = UsedSymbols.find( + run_interface_class.__module__, + [run_interface_code], + omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec], + omit_modules=self.package.omit_modules, + omit_functions=self.package.omit_functions, + omit_constants=self.package.omit_constants, + always_include=self.package.all_explicit, + translations=self.package.all_import_translations, + absolute_imports=True, + ) + used.update(run_interface_used, from_other_module=False) + method_body += run_interface_code + "\n" + + list_outputs_code = inspect.getsource( + self.nipype_interface._list_outputs + ).strip() + list_outputs_class = find_super_method( + self.nipype_interface, "_list_outputs", include_class=True + )[1] + if not self.package.is_omitted(list_outputs_class): + # Strip out method def and return statement + lo_lines = list_outputs_code.strip().split("\n")[1:] + if re.match(r"\s*(return|raise NotImplementedError)", lo_lines[-1]): + lo_lines = lo_lines[:-1] + list_outputs_code = "\n".join(lo_lines) + list_outputs_code = self.process_method_body( + list_outputs_code, + input_names, + output_names, + super_base=list_outputs_class, + ) + + list_outputs_used = UsedSymbols.find( + list_outputs_class.__module__, + [list_outputs_code], + omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec], + omit_modules=self.package.omit_modules, + omit_functions=self.package.omit_functions, + omit_constants=self.package.omit_constants, + always_include=self.package.all_explicit, + translations=self.package.all_import_translations, + absolute_imports=True, + ) + used.update(list_outputs_used, from_other_module=False) + method_body += list_outputs_code + "\n" + + assert method_body, "Neither `run_interface` and `list_outputs` are defined" + spec_str = "@pydra.mark.task\n" spec_str += "@pydra.mark.annotate({'return': {" spec_str += ", ".join(f"'{n}': {t}" for n, t, _ in output_fields_str) @@ -156,11 +238,13 @@ def types_to_names(spec_fields): additional_imports.add(imprt) spec_str = repl_spec_str - used.imports = self.construct_imports( - nonstd_types, - spec_str, - include_task=False, - base=base_imports + list(used.imports) + list(additional_imports), + used.imports.update( + self.construct_imports( + nonstd_types, + spec_str, + include_task=False, + base=base_imports + list(used.imports) + list(additional_imports), + ) ) return spec_str, used diff --git a/nipype2pydra/interface/shell_command.py b/nipype2pydra/interface/shell_command.py index f309c5c..49cb85a 100644 --- a/nipype2pydra/interface/shell_command.py +++ b/nipype2pydra/interface/shell_command.py @@ -280,7 +280,6 @@ def format_arg_code(self): return "" body = self._format_arg_body body = self._process_inputs(body) - body = self._misc_cleanups(body) existing_args = list( inspect.signature(self.nipype_interface._format_arg).parameters )[1:] @@ -311,7 +310,7 @@ def format_arg_code(self): self.nipype_interface, "_format_arg", include_class=True )[1], ) - + # body = self._misc_cleanups(body) code_str = f"""def _format_arg({name_arg}, {val_arg}, inputs, argstr):{self.parse_inputs_call} if {val_arg} is None: return "" @@ -338,7 +337,6 @@ def parse_inputs_code(self) -> str: inspect.getsource(self.nipype_interface._parse_inputs).split("\n", 1)[-1] ) body = self._process_inputs(body) - body = self._misc_cleanups(body) body = re.sub( r"self.\_format_arg\((\w+), (\w+), (\w+)\)", r"_format_arg(\1, \3, inputs, parsed_inputs, argstrs.get(\1))", @@ -356,6 +354,7 @@ def parse_inputs_code(self) -> str: self.nipype_interface, "_parse_inputs", include_class=True )[1], ) + # body = self._misc_cleanups(body) code_str = "def _parse_inputs(inputs):\n parsed_inputs = {}" if re.findall(r"\bargstrs\b", body): @@ -378,7 +377,6 @@ def defaults_code(self): inspect.getsource(self.nipype_interface._gen_filename).split("\n", 1)[-1] ) body = self._process_inputs(body) - body = self._misc_cleanups(body) if not body: return "" @@ -389,6 +387,7 @@ def defaults_code(self): self.nipype_interface, "_gen_filename", include_class=True )[1], ) + # body = self._misc_cleanups(body) code_str = f"""def _gen_filename(name, inputs):{self.parse_inputs_call} {body} @@ -413,22 +412,21 @@ def callables_code(self): code_str = "" if "aggregate_outputs" in self.included_methods: func_name = "aggregate_outputs" - body = _strip_doc_string( + agg_body = _strip_doc_string( inspect.getsource(self.nipype_interface.aggregate_outputs).split( "\n", 1 )[-1] ) - need_list_outputs = bool(re.findall(r"\b_list_outputs\b", body)) - body = self._process_inputs(body) - body = self._misc_cleanups(body) + need_list_outputs = bool(re.findall(r"\b_list_outputs\b", agg_body)) + agg_body = self._process_inputs(agg_body) - if not body: + if not agg_body: return "" - body = self.unwrap_nested_methods( - body, additional_args=CALLABLES_ARGS, inputs_as_dict=True + agg_body = self.unwrap_nested_methods( + agg_body, additional_args=CALLABLES_ARGS, inputs_as_dict=True ) - body = self.replace_supers( - body, + agg_body = self.replace_supers( + agg_body, super_base=find_super_method( self.nipype_interface, "aggregate_outputs", include_class=True )[1], @@ -437,7 +435,7 @@ def callables_code(self): code_str += f"""def aggregate_outputs(inputs=None, stdout=None, stderr=None, output_dir=None): inputs = attrs.asdict(inputs){self.parse_inputs_call} needed_outputs = {self.callable_output_field_names!r} -{body} +{agg_body} """ @@ -449,28 +447,56 @@ def callables_code(self): need_list_outputs = True if need_list_outputs: - body = _strip_doc_string( - inspect.getsource(self.nipype_interface._list_outputs).split("\n", 1)[ - -1 - ] - ) - body = self._process_inputs(body) - body = self._misc_cleanups(body) + if "_list_outputs" not in self.included_methods: + assert self.callable_output_fields + # Need to reimplemt the base _list_outputs method in Pydra, which maps + # inputs with 'output_name' to outputs + for f in self.callable_output_fields: + output_name = f[0] + code_str += f"\n\n\ndef {output_name}_callable(output_dir, inputs, stdout, stderr):\n" + try: + input_name = self._output_name_mappings[output_name] + except KeyError: + logger.warning( + "Could not find input name with 'output_name' for " + "%s output, attempting to create something that can be worked " + "with", + output_name, + ) + if "_parse_inputs" in self.included_methods: + code_str += ( + f" parsed_inputs = _parse_inputs(inputs)\n" + f" return parsed_inputs.get('{output_name}', attrs.NOTHING)\n" + ) + else: + code_str += " raise NotImplementedError\n" - if not body: - return "" - body = self.unwrap_nested_methods( - body, additional_args=CALLABLES_ARGS, inputs_as_dict=True - ) - body = self.replace_supers( - body, - super_base=find_super_method( - self.nipype_interface, "_list_outputs", include_class=True - )[1], - ) + else: + code_str += f" return inputs.{input_name}\n" - code_str += f"""def _list_outputs(inputs=None, stdout=None, stderr=None, output_dir=None):{inputs_as_dict_call}{self.parse_inputs_call} -{body} + return code_str + else: + lo_body = _strip_doc_string( + inspect.getsource(self.nipype_interface._list_outputs).split( + "\n", 1 + )[-1] + ) + lo_body = self._process_inputs(lo_body) + + if not lo_body: + return "" + lo_body = self.unwrap_nested_methods( + lo_body, additional_args=CALLABLES_ARGS, inputs_as_dict=True + ) + lo_body = self.replace_supers( + lo_body, + super_base=find_super_method( + self.nipype_interface, "_list_outputs", include_class=True + )[1], + ) + + code_str += f"""def _list_outputs(inputs=None, stdout=None, stderr=None, output_dir=None):{inputs_as_dict_call}{self.parse_inputs_call} +{lo_body} """ diff --git a/nipype2pydra/package.py b/nipype2pydra/package.py index ca33065..6f8b454 100644 --- a/nipype2pydra/package.py +++ b/nipype2pydra/package.py @@ -961,6 +961,7 @@ def write_to_module( # Write to file for debugging debug_file = "~/unparsable-nipype2pydra-output.py" with open(Path(debug_file).expanduser(), "w") as f: + f.write(f"# Attemping to convert {self.nipype_name}\n") f.write(code_str) raise RuntimeError( f"Black could not parse generated code (written to {debug_file}): {e}\n\n{code_str}" @@ -1080,6 +1081,7 @@ def write_pkg_inits( # Write to file for debugging debug_file = "~/unparsable-nipype2pydra-output.py" with open(Path(debug_file).expanduser(), "w") as f: + f.write(f"# Attemping to convert {self.nipype_name}\n") f.write(code_str) raise RuntimeError( f"Black could not parse generated code (written to {debug_file}): " @@ -1104,6 +1106,7 @@ def write_pkg_inits( # Write to file for debugging debug_file = "~/unparsable-nipype2pydra-output.py" with open(Path(debug_file).expanduser(), "w") as f: + f.write(f"# Attemping to convert {self.nipype_name}\n") f.write(code_str) raise RuntimeError( f"Black could not parse generated code (written to {debug_file}): " diff --git a/nipype2pydra/utils/__init__.py b/nipype2pydra/utils/__init__.py index 000f78a..0aa572a 100644 --- a/nipype2pydra/utils/__init__.py +++ b/nipype2pydra/utils/__init__.py @@ -22,6 +22,7 @@ get_return_line, find_super_method, strip_comments, + min_indentation, INBUILT_NIPYPE_TRAIT_NAMES, ) from .symbols import ( # noqa: F401 diff --git a/nipype2pydra/utils/misc.py b/nipype2pydra/utils/misc.py index 38b709b..4ea71cd 100644 --- a/nipype2pydra/utils/misc.py +++ b/nipype2pydra/utils/misc.py @@ -305,9 +305,7 @@ def cleanup_function_body(function_body: str) -> str: with_signature = True else: with_signature = False - # Detect the indentation of the source code in src and reduce it to 4 spaces - non_empty_lines = [ln for ln in function_body.splitlines() if ln] - indent_size = len(re.match(r"^( *)", non_empty_lines[0]).group(1)) + indent_size = min_indentation(function_body) indent_reduction = indent_size - (0 if with_signature else 4) assert indent_reduction >= 0, ( "Indentation reduction cannot be negative, probably didn't detect signature of " @@ -323,6 +321,12 @@ def cleanup_function_body(function_body: str) -> str: return replace_undefined(function_body) +def min_indentation(function_body: str) -> int: + # Detect the indentation of the source code in src and reduce it to 4 spaces + non_empty_lines = [ln for ln in function_body.splitlines() if ln] + return len(re.match(r"^( *)", non_empty_lines[0]).group(1)) + + def replace_undefined(function_body: str) -> str: parts = re.split(r"not isdefined\b", function_body, flags=re.MULTILINE) new_function_body = parts[0] diff --git a/nipype2pydra/utils/symbols.py b/nipype2pydra/utils/symbols.py index 1140163..c2a2a90 100644 --- a/nipype2pydra/utils/symbols.py +++ b/nipype2pydra/utils/symbols.py @@ -74,25 +74,34 @@ def update( other: "UsedSymbols", absolute_imports: bool = False, to_be_inlined: bool = False, + from_other_module: bool = True, ): - if to_be_inlined: + if to_be_inlined or not from_other_module: self.imports.update( i.absolute() if absolute_imports else i for i in other.imports ) self.intra_pkg_funcs.update(other.intra_pkg_funcs) - self.intra_pkg_funcs.update((None, f) for f in other.local_functions) self.intra_pkg_classes.extend( c for c in other.intra_pkg_classes if c not in self.intra_pkg_classes ) - self.intra_pkg_classes.extend( - (None, c) - for c in other.local_classes - if (None, c) not in self.intra_pkg_classes - ) - self.intra_pkg_constants.update( - (other.module_name, None, c[0]) for c in other.constants - ) self.intra_pkg_constants.update(other.intra_pkg_constants) + if from_other_module: + self.intra_pkg_funcs.update((None, f) for f in other.local_functions) + self.intra_pkg_classes.extend( + (None, c) + for c in other.local_classes + if (None, c) not in self.intra_pkg_classes + ) + self.intra_pkg_constants.update( + (other.module_name, None, c[0]) for c in other.constants + ) + else: + self.local_functions.update(other.local_functions) + self.intra_pkg_classes.extend( + c for c in other.local_classes if c not in self.local_classes + ) + + self.constants.update(other.constants) DEFAULT_FILTERED_CONSTANTS = ( Undefined, diff --git a/nipype2pydra/workflow.py b/nipype2pydra/workflow.py index bdd56a6..cbec9e9 100644 --- a/nipype2pydra/workflow.py +++ b/nipype2pydra/workflow.py @@ -914,6 +914,7 @@ def add_nonstd_types(tp): # Write to file for debugging debug_file = "~/unparsable-nipype2pydra-output.py" with open(Path(debug_file).expanduser(), "w") as f: + f.write(f"# Attemping to convert {self.full_address}\n") f.write(code_str) raise RuntimeError( f"Black could not parse generated code (written to {debug_file}): " From 3591c7ba49bee2bf664e76764450f956dc46a213 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Mon, 3 Jun 2024 00:06:37 +1000 Subject: [PATCH 14/21] all packages build again after refactor --- nipype2pydra/interface/function.py | 3 +-- nipype2pydra/utils/misc.py | 11 ++++++----- nipype2pydra/utils/symbols.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/nipype2pydra/interface/function.py b/nipype2pydra/interface/function.py index ad519cf..075a038 100644 --- a/nipype2pydra/interface/function.py +++ b/nipype2pydra/interface/function.py @@ -3,7 +3,6 @@ import inspect from operator import attrgetter from functools import cached_property -import itertools import logging import attrs from nipype.interfaces.base import BaseInterface, TraitedSpec @@ -19,7 +18,7 @@ class FunctionInterfaceConverter(BaseInterfaceConverter): @property def included_methods(self) -> ty.Tuple[str, ...]: - return ("_run_interface", "_list_outputs") + return ("__init__", "_run_interface", "_list_outputs") def generate_code(self, input_fields, nonstd_types, output_fields) -> ty.Tuple[ str, diff --git a/nipype2pydra/utils/misc.py b/nipype2pydra/utils/misc.py index 4ea71cd..b91300f 100644 --- a/nipype2pydra/utils/misc.py +++ b/nipype2pydra/utils/misc.py @@ -551,17 +551,18 @@ def get_return_line(func: ty.Union[str, ty.Callable]) -> str: def find_super_method( super_base: type, method_name: str, include_class: bool = False -) -> ty.Tuple[ty.Callable, type]: +) -> ty.Optional[ty.Tuple[ty.Callable, type]]: mro = super_base.__mro__ if not include_class: mro = mro[1:] for base in mro: if method_name in base.__dict__: # Found the match return getattr(base, method_name), base - raise RuntimeError( - f"Could not find super of '{method_name}' method in base classes of " - f"{super_base}" - ) + return None + # raise RuntimeError( + # f"Could not find super of '{method_name}' method in base classes of " + # f"{super_base}" + # ) def strip_comments(src: str) -> str: diff --git a/nipype2pydra/utils/symbols.py b/nipype2pydra/utils/symbols.py index c2a2a90..c4524eb 100644 --- a/nipype2pydra/utils/symbols.py +++ b/nipype2pydra/utils/symbols.py @@ -97,7 +97,7 @@ def update( ) else: self.local_functions.update(other.local_functions) - self.intra_pkg_classes.extend( + self.local_classes.extend( c for c in other.local_classes if c not in self.local_classes ) From 23ebaa4c315c042675fda2c53f597e683bfd135b Mon Sep 17 00:00:00 2001 From: Tom Close Date: Mon, 3 Jun 2024 00:41:09 +1000 Subject: [PATCH 15/21] moved constants to bottom of file just in case they rely on a local function --- nipype2pydra/package.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nipype2pydra/package.py b/nipype2pydra/package.py index 6f8b454..ec21cef 100644 --- a/nipype2pydra/package.py +++ b/nipype2pydra/package.py @@ -871,10 +871,6 @@ def write_to_module( existing_imports = parse_imports(existing_import_strs, relative_to=module_name) converter_imports = [] - for const_name, const_val in sorted(used.constants): - if f"\n{const_name} = " not in code_str: - code_str += f"\n{const_name} = {const_val}\n" - for klass in used.local_classes: if f"\nclass {klass.__name__}(" not in code_str: try: @@ -949,6 +945,10 @@ def write_to_module( code_str += "\n\n" + cleanup_function_body(klass_src) inlined_symbols.append(klass_name) + for const_name, const_val in sorted(used.constants): + if f"\n{const_name} = " not in code_str: + code_str += f"\n{const_name} = {const_val}\n" + # We run the formatter before the find/replace so that the find/replace can be more # predictable try: From dc3214f745da896f94b3073fc17fc8e3c231416e Mon Sep 17 00:00:00 2001 From: Tom Close Date: Mon, 3 Jun 2024 10:40:29 +1000 Subject: [PATCH 16/21] disabled return dict unwrapping (used in _list_outputs) by default. Added copyfile to python functions --- nipype2pydra/interface/base.py | 7 +++++-- nipype2pydra/interface/function.py | 12 ++++++++---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/nipype2pydra/interface/base.py b/nipype2pydra/interface/base.py index 6ba5a58..375c4be 100644 --- a/nipype2pydra/interface/base.py +++ b/nipype2pydra/interface/base.py @@ -1280,12 +1280,12 @@ def process_method_body( input_names: ty.List[str], output_names: ty.List[str], super_base: ty.Optional[type] = None, + unwrap_return_dict: bool = False, ) -> str: if not method_body: return "" if super_base is None: super_base = self.nipype_interface - return_value = get_return_line(method_body) method_body = method_body.replace("if self.output_spec:", "if True:") # Replace self.inputs. with in the function body input_re = re.compile(r"self\.inputs\.(\w+)\b(?!\()") @@ -1302,7 +1302,10 @@ def process_method_body( method_body = input_re.sub(r"\1", method_body) method_body = self.replace_supers(method_body, super_base) - if return_value: + if unwrap_return_dict: + return_value = get_return_line(method_body) + if return_value is None: + return_value = "outputs" output_re = re.compile(return_value + r"\[(?:'|\")(\w+)(?:'|\")\]") unrecognised_outputs = set( m for m in output_re.findall(method_body) if m not in output_names diff --git a/nipype2pydra/interface/function.py b/nipype2pydra/interface/function.py index 075a038..56edf32 100644 --- a/nipype2pydra/interface/function.py +++ b/nipype2pydra/interface/function.py @@ -65,7 +65,7 @@ def types_to_names(spec_fields): used = UsedSymbols.find( self.nipype_module, self.referenced_local_functions, - omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec], + omit_classes=self.package.omit_classes, # + [BaseInterface, TraitedSpec], omit_modules=self.package.omit_modules, omit_functions=self.package.omit_functions, omit_constants=self.package.omit_constants, @@ -81,7 +81,7 @@ def types_to_names(spec_fields): method_used = UsedSymbols.find( method_module, [ref_method], - omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec], + omit_classes=self.package.omit_classes, # + [BaseInterface, TraitedSpec], omit_modules=self.package.omit_modules, omit_functions=self.package.omit_functions, omit_constants=self.package.omit_constants, @@ -92,6 +92,9 @@ def types_to_names(spec_fields): used.update(method_used, from_other_module=False) method_body = "" + for field in input_fields: + if field[-1].get("copyfile"): + method_body += f" {field[0]} = {field[0]}.copy(Path.cwd())\n" for field in output_fields: method_body += f" {field[0]} = attrs.NOTHING\n" @@ -150,7 +153,7 @@ def types_to_names(spec_fields): run_interface_used = UsedSymbols.find( run_interface_class.__module__, [run_interface_code], - omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec], + omit_classes=self.package.omit_classes, # + [BaseInterface, TraitedSpec], omit_modules=self.package.omit_modules, omit_functions=self.package.omit_functions, omit_constants=self.package.omit_constants, @@ -178,12 +181,13 @@ def types_to_names(spec_fields): input_names, output_names, super_base=list_outputs_class, + unwrap_return_dict=True, ) list_outputs_used = UsedSymbols.find( list_outputs_class.__module__, [list_outputs_code], - omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec], + omit_classes=self.package.omit_classes, # + [BaseInterface, TraitedSpec], omit_modules=self.package.omit_modules, omit_functions=self.package.omit_functions, omit_constants=self.package.omit_constants, From 0cac4cd2d1feadf4b23c13b7c7005910a40568d8 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Mon, 3 Jun 2024 16:13:59 +1000 Subject: [PATCH 17/21] added option to explicitly route connections from/to explicit inputs and outputs --- example-specs/pkg-gen/niworkflows.yaml | 1 + nipype2pydra/package.py | 4 +- nipype2pydra/pkg_gen/__init__.py | 8 +-- nipype2pydra/statements/workflow_build.py | 2 +- nipype2pydra/workflow.py | 69 +++++++++++++++++++++-- 5 files changed, 71 insertions(+), 13 deletions(-) diff --git a/example-specs/pkg-gen/niworkflows.yaml b/example-specs/pkg-gen/niworkflows.yaml index 55651ea..7da8e02 100644 --- a/example-specs/pkg-gen/niworkflows.yaml +++ b/example-specs/pkg-gen/niworkflows.yaml @@ -5,6 +5,7 @@ niworkflows: - niworkflows.interfaces.bids.ReadSidecarJSON - niworkflows.interfaces.fixes.FixHeaderApplyTransforms - niworkflows.interfaces.fixes.FixN4BiasFieldCorrection + - niworkflows.interfaces.fixes.FixHeaderRegistration - niworkflows.interfaces.header.SanitizeImage - niworkflows.interfaces.images.RobustAverage - niworkflows.interfaces.morphology.BinaryDilation diff --git a/nipype2pydra/package.py b/nipype2pydra/package.py index ec21cef..dfbc01f 100644 --- a/nipype2pydra/package.py +++ b/nipype2pydra/package.py @@ -422,10 +422,10 @@ def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True): intra_pkg_modules[conv.nipype_module_name].add(conv.nipype_object) collect_intra_pkg_objects(conv.used_symbols) - for converter in tqdm( + for workflow in tqdm( workflows_to_include, "converting workflows from Nipype to Pydra syntax" ): - all_used = converter.write( + all_used = workflow.write( package_root, already_converted=already_converted, ) diff --git a/nipype2pydra/pkg_gen/__init__.py b/nipype2pydra/pkg_gen/__init__.py index 4f43640..9ab8151 100644 --- a/nipype2pydra/pkg_gen/__init__.py +++ b/nipype2pydra/pkg_gen/__init__.py @@ -403,7 +403,7 @@ def generate_callables(self, nipype_interface) -> str: if output_name not in INBUILT_NIPYPE_TRAIT_NAMES: callables_str += ( f"def {output_name}_callable(output_dir, inputs, stdout, stderr):\n" - " parsed_inputs = {}" + " parsed_inputs = {}\n" " outputs = _list_outputs(output_dir=output_dir, inputs=inputs, stdout=stdout, stderr=stderr, parsed_inputs=parsed_inputs)\n" ' return outputs["' + output_name + '"]\n\n' ) @@ -422,7 +422,7 @@ def generate_callables(self, nipype_interface) -> str: callables_str, fast=False, mode=black.FileMode() ) except black.parsing.InvalidInput as e: - with open(Path("~/Desktop/gen-code.py").expanduser(), "w") as f: + with open(Path("~/unparsable-gen-code.py").expanduser(), "w") as f: f.write(callables_str) raise RuntimeError( f"Black could not parse generated code: {e}\n\n{callables_str}" @@ -1026,9 +1026,7 @@ def process_method( ) if hasattr(nipype_interface, "_cmd"): body = body.replace("self.cmd", f'"{nipype_interface._cmd}"') - body = re.sub( - r"getattr\(self\.inputs, (\w+), None\)", r"inputs.get(\1)", body - ) + body = re.sub(r"getattr\(self\.inputs, (\w+), None\)", r"inputs.get(\1)", body) body = re.sub(r"getattr\(self\.inputs, (\w+)\)", r"inputs[\1]", body) if attrs_as_parsed_inputs: body = re.sub( diff --git a/nipype2pydra/statements/workflow_build.py b/nipype2pydra/statements/workflow_build.py index 05c0549..ac2d5d1 100644 --- a/nipype2pydra/statements/workflow_build.py +++ b/nipype2pydra/statements/workflow_build.py @@ -106,7 +106,7 @@ class DynamicField(VarField): callable: ty.Callable = attrs.field() def __repr__(self): - return f"DelayedVarField({self.varname}, callable={self.callable})" + return f"DynamicField({self.varname}, callable={self.callable})" @attrs.define diff --git a/nipype2pydra/workflow.py b/nipype2pydra/workflow.py index cbec9e9..f32544c 100644 --- a/nipype2pydra/workflow.py +++ b/nipype2pydra/workflow.py @@ -36,6 +36,7 @@ WorkflowInitStatement, AssignmentStatement, OtherStatement, + DynamicField, ) import nipype2pydra.package @@ -94,8 +95,7 @@ class WorkflowInterfaceField: factory=list, metadata={ "help": ( - "node-name/field-name pairs of other fields that are to be routed to " - "from other node fields to this input/output", + "node-name/field-name pairs of additional fields that this input/output replaces", ) }, ) @@ -159,6 +159,11 @@ def __hash__(self): @attrs.define class WorkflowInput(WorkflowInterfaceField): + connections: ty.Tuple[ty.Tuple[str, str]] = attrs.field( + converter=lambda lst: tuple(sorted(tuple(t) for t in lst)), + factory=list, + metadata={"help": ("Explicit connections to be made from this input field",)}, + ) out_conns: ty.List[ConnectionStatement] = attrs.field( factory=list, eq=False, @@ -170,9 +175,7 @@ class WorkflowInput(WorkflowInterfaceField): ) }, ) - include: bool = attrs.field( - default=False, eq=False, hash=False, metadata={ @@ -183,6 +186,10 @@ class WorkflowInput(WorkflowInterfaceField): }, ) + @include.default + def _include_default(self) -> bool: + return bool(self.connections) + def __hash__(self): return super().__hash__() @@ -190,6 +197,11 @@ def __hash__(self): @attrs.define class WorkflowOutput(WorkflowInterfaceField): + connection: ty.Tuple[str, str] = attrs.field( + converter=tuple, + factory=list, + metadata={"help": ("Explicit connection to be made to this output field",)}, + ) in_conns: ty.List[ConnectionStatement] = attrs.field( factory=list, eq=False, @@ -413,6 +425,12 @@ def get_input_from_conn(self, conn: ConnectionStatement) -> WorkflowInput: """ Returns the name of the input field in the workflow for the given node and field escaped by the prefix of the node if present""" + if isinstance(conn.source_out, DynamicField): + logger.warning( + f"Not able to connect inputs from {conn.source_name}:{conn.source_out}->" + f"{conn.target_name}:{conn.target_in} properly due to adynamic-field " + "just connecting to source input for now" + ) try: return self.make_input( field_name=conn.source_out, @@ -1036,7 +1054,7 @@ def prepare_connections(self): # append to parsed statements so set_output can be set self.parsed_statements.append(conn_stmt) while self._unprocessed_connections: - conn = self._unprocessed_connections.pop() + conn = self._unprocessed_connections.pop(0) try: inpt = self.get_input_from_conn(conn) except KeyError: @@ -1056,6 +1074,47 @@ def prepare_connections(self): conn.target_in = outpt.name outpt.in_conns.append(conn) + # Overwrite connections with explict connections + for inpt in list(self.inputs.values()): + for target_name, target_in in inpt.connections: + conn = ConnectionStatement( + indent=" ", + source_name=None, + source_out=inpt.name, + target_name=target_name, + target_in=target_in, + workflow_converter=self, + ) + for tgt_node in self.nodes[conn.target_name]: + try: + existing_conn = next( + c for c in tgt_node.in_conns if c.target_in == target_in + ) + except StopIteration: + pass + else: + tgt_node.in_conns.remove(existing_conn) + self.inputs[existing_conn.source_out].out_conns.remove( + existing_conn + ) + inpt.out_conns.append(conn) + tgt_node.add_input_connection(conn) + + for outpt in list(self.outputs.values()): + if outpt.connection: + source_name, source_out = outpt.connection + conn = ConnectionStatement( + indent=" ", + source_name=source_name, + source_out=source_out, + target_name=None, + target_in=outpt.name, + workflow_converter=self, + ) + for src_node in self.nodes[conn.source_name]: + src_node.add_output_connection(conn) + outpt.in_conns.append(conn) + def _parse_statements(self, func_body: str) -> ty.Tuple[ ty.List[ ty.Union[ From 955e6c8ad00ba99a14119b3bc8249424ad637088 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Mon, 3 Jun 2024 22:38:09 +1000 Subject: [PATCH 18/21] debugged workflows and spatial normalisation monster interface --- nipype2pydra/interface/base.py | 6 ++++ nipype2pydra/interface/function.py | 19 ++++++++--- nipype2pydra/statements/__init__.py | 1 + nipype2pydra/statements/imports.py | 49 ++++++++++++++++++++--------- 4 files changed, 56 insertions(+), 19 deletions(-) diff --git a/nipype2pydra/interface/base.py b/nipype2pydra/interface/base.py index 375c4be..ec424d5 100644 --- a/nipype2pydra/interface/base.py +++ b/nipype2pydra/interface/base.py @@ -45,6 +45,7 @@ parse_imports, ExplicitImport, from_list_to_imports, + make_imports_absolute, ) from fileformats.generic import File import nipype2pydra.package @@ -1319,6 +1320,11 @@ def process_method_body( ) method_body = output_re.sub(r"\1", method_body) method_body = self.unwrap_nested_methods(method_body) + method_body = make_imports_absolute( + method_body, + super_base.__module__, + translations=self.package.all_import_translations, + ) # method_body = self._misc_cleanups(method_body) return method_body diff --git a/nipype2pydra/interface/function.py b/nipype2pydra/interface/function.py index 56edf32..1f8f6f4 100644 --- a/nipype2pydra/interface/function.py +++ b/nipype2pydra/interface/function.py @@ -16,6 +16,8 @@ @attrs.define(slots=False) class FunctionInterfaceConverter(BaseInterfaceConverter): + converter_type = "function" + @property def included_methods(self) -> ty.Tuple[str, ...]: return ("__init__", "_run_interface", "_list_outputs") @@ -65,7 +67,7 @@ def types_to_names(spec_fields): used = UsedSymbols.find( self.nipype_module, self.referenced_local_functions, - omit_classes=self.package.omit_classes, # + [BaseInterface, TraitedSpec], + omit_classes=self.package.omit_classes, omit_modules=self.package.omit_modules, omit_functions=self.package.omit_functions, omit_constants=self.package.omit_constants, @@ -81,7 +83,7 @@ def types_to_names(spec_fields): method_used = UsedSymbols.find( method_module, [ref_method], - omit_classes=self.package.omit_classes, # + [BaseInterface, TraitedSpec], + omit_classes=self.package.omit_classes, omit_modules=self.package.omit_modules, omit_functions=self.package.omit_functions, omit_constants=self.package.omit_constants, @@ -153,7 +155,7 @@ def types_to_names(spec_fields): run_interface_used = UsedSymbols.find( run_interface_class.__module__, [run_interface_code], - omit_classes=self.package.omit_classes, # + [BaseInterface, TraitedSpec], + omit_classes=self.package.omit_classes, omit_modules=self.package.omit_modules, omit_functions=self.package.omit_functions, omit_constants=self.package.omit_constants, @@ -187,7 +189,7 @@ def types_to_names(spec_fields): list_outputs_used = UsedSymbols.find( list_outputs_class.__module__, [list_outputs_code], - omit_classes=self.package.omit_classes, # + [BaseInterface, TraitedSpec], + omit_classes=self.package.omit_classes, omit_modules=self.package.omit_modules, omit_functions=self.package.omit_functions, omit_constants=self.package.omit_constants, @@ -231,7 +233,14 @@ def types_to_names(spec_fields): spec_str += "\n\n# Nipype methods converted into functions\n\n" for m in sorted(self.referenced_methods, key=attrgetter("__name__")): - spec_str += "\n\n" + self.process_method(m, input_names, output_names) + spec_str += "\n\n" + self.process_method( + m, + input_names, + output_names, + super_base=find_super_method( + self.nipype_interface, m.__name__, include_class=True + )[1], + ) # Replace runtime attributes additional_imports = set() diff --git a/nipype2pydra/statements/__init__.py b/nipype2pydra/statements/__init__.py index 9300d5d..9a29c72 100644 --- a/nipype2pydra/statements/__init__.py +++ b/nipype2pydra/statements/__init__.py @@ -5,6 +5,7 @@ GENERIC_PYDRA_IMPORTS, ExplicitImport, from_list_to_imports, + make_imports_absolute, ) from .workflow_build import ( # noqa: F401 AddNestedWorkflowStatement, diff --git a/nipype2pydra/statements/imports.py b/nipype2pydra/statements/imports.py index 9639d3c..0e9887a 100644 --- a/nipype2pydra/statements/imports.py +++ b/nipype2pydra/statements/imports.py @@ -465,6 +465,39 @@ def collate( ) +def translate( + module_name: str, translations: ty.Sequence[ty.Tuple[str, str]] +) -> ty.Optional[str]: + for from_pkg, to_pkg in translations: + if re.match(from_pkg, module_name): + return re.sub( + from_pkg, + to_pkg, + module_name, + count=1, + flags=re.MULTILINE | re.DOTALL, + ) + return None + + +def make_imports_absolute( + src: str, modulepath: str, translations: ty.Sequence[ty.Tuple[str, str]] = () +) -> str: + parts = modulepath.split(".") + + def replacer(match) -> str: + levels = len(match.group(2)) + assert levels < len(parts) + abs_modulepath = ".".join(parts[:-levels]) + "." + match.group(3) + if translations: + newpath = translate(abs_modulepath, translations) + if newpath: + abs_modulepath = newpath + return f"{match.group(1)}from {abs_modulepath}" + + return re.sub(r"(\s*)from\s+(\.+)(\w+)", replacer, src) + + def parse_imports( stmts: ty.Union[str, ty.Sequence[str]], relative_to: ty.Union[str, ModuleType, None] = None, @@ -497,18 +530,6 @@ def parse_imports( ".__init__" if relative_to.__file__.endswith("__init__.py") else "" ) - def translate(module_name: str) -> ty.Optional[str]: - for from_pkg, to_pkg in translations: - if re.match(from_pkg, module_name): - return re.sub( - from_pkg, - to_pkg, - module_name, - count=1, - flags=re.MULTILINE | re.DOTALL, - ) - return None - parsed = [] for stmt in stmts: if isinstance(stmt, ImportStatement): @@ -543,7 +564,7 @@ def translate(module_name: str) -> ty.Optional[str]: ) if absolute: import_stmt = import_stmt.absolute() - import_stmt.translation = translate(import_stmt.module_name) + import_stmt.translation = translate(import_stmt.module_name, translations) parsed.append(import_stmt) else: @@ -554,7 +575,7 @@ def translate(module_name: str) -> ty.Optional[str]: ImportStatement( indent=match.group(1), imported={imp.local_name: imp}, - translation=translate(imp.name), + translation=translate(imp.name, translations), ) ) return parsed From abf355140128fcaff353e7da312999b283364fc5 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Mon, 3 Jun 2024 22:54:21 +1000 Subject: [PATCH 19/21] unwrapped super methods in shell helper methods --- nipype2pydra/interface/shell_command.py | 64 +++++++++++++++++-------- nipype2pydra/utils/misc.py | 4 +- 2 files changed, 45 insertions(+), 23 deletions(-) diff --git a/nipype2pydra/interface/shell_command.py b/nipype2pydra/interface/shell_command.py index 49cb85a..4a551a8 100644 --- a/nipype2pydra/interface/shell_command.py +++ b/nipype2pydra/interface/shell_command.py @@ -13,6 +13,7 @@ UsedSymbols, split_source_into_statements, INBUILT_NIPYPE_TRAIT_NAMES, + extract_args, find_super_method, ) from fileformats.core.mixin import WithClassifiers @@ -27,6 +28,7 @@ @attrs.define(slots=False) class ShellCommandInterfaceConverter(BaseInterfaceConverter): + converter_type = "shell_command" _format_argstrs: ty.Dict[str, str] = attrs.field(factory=dict) @cached_property @@ -237,11 +239,19 @@ def output_fields(self): @property def formatted_input_field_names(self): - return re.findall(r"name == \"(\w+)\"", self._format_arg_body) + if not self._format_arg_body: + return [] + sig = inspect.getsource(self.nipype_interface._format_arg).split("\n", 1)[0] + name_arg = re.match(r"\s*def _format_arg\(self, (\w+),", sig).group(1) + return re.findall(name_arg + r" == \"(\w+)\"", self._format_arg_body) @property def callable_default_input_field_names(self): - return re.findall(r"name == \"(\w+)\"", self._gen_filename_body) + if not self._gen_filename_body: + return [] + sig = inspect.getsource(self.nipype_interface._format_arg).split("\n", 1)[0] + name_arg = re.match(r"\s*def _gen_filename\((\w+),", sig).group(1) + return re.findall(name_arg + r" == \"(\w+)\"", self._gen_filename_body) @property def callable_output_fields(self): @@ -262,17 +272,13 @@ def callable_output_field_names(self): def _format_arg_body(self): if self.method_omitted("_format_arg"): return "" - return _strip_doc_string( - inspect.getsource(self.nipype_interface._format_arg).split("\n", 1)[-1] - ) + return self._unwrap_supers(self.nipype_interface._format_arg) @cached_property def _gen_filename_body(self): if self.method_omitted("_gen_filename"): return "" - return _strip_doc_string( - inspect.getsource(self.nipype_interface._gen_filename).split("\n", 1)[-1] - ) + return self._unwrap_supers(self.nipype_interface._gen_filename) @property def format_arg_code(self): @@ -333,9 +339,7 @@ def format_arg_code(self): def parse_inputs_code(self) -> str: if "_parse_inputs" not in self.included_methods: return "" - body = _strip_doc_string( - inspect.getsource(self.nipype_interface._parse_inputs).split("\n", 1)[-1] - ) + body = self._unwrap_supers(self.nipype_interface._parse_inputs) body = self._process_inputs(body) body = re.sub( r"self.\_format_arg\((\w+), (\w+), (\w+)\)", @@ -412,11 +416,7 @@ def callables_code(self): code_str = "" if "aggregate_outputs" in self.included_methods: func_name = "aggregate_outputs" - agg_body = _strip_doc_string( - inspect.getsource(self.nipype_interface.aggregate_outputs).split( - "\n", 1 - )[-1] - ) + agg_body = self._unwrap_supers(self.nipype_interface.aggregate_outputs) need_list_outputs = bool(re.findall(r"\b_list_outputs\b", agg_body)) agg_body = self._process_inputs(agg_body) @@ -476,11 +476,7 @@ def callables_code(self): return code_str else: - lo_body = _strip_doc_string( - inspect.getsource(self.nipype_interface._list_outputs).split( - "\n", 1 - )[-1] - ) + lo_body = self._unwrap_supers(self.nipype_interface._list_outputs) lo_body = self._process_inputs(lo_body) if not lo_body: @@ -538,6 +534,32 @@ def method_omitted(self, method_name: str) -> bool: find_super_method(self.nipype_interface, method_name, include_class=True)[1] ) + def _unwrap_supers( + self, method: ty.Callable, base=None, base_replacement="", arg_names=None + ) -> str: + if base is None: + base = self.nipype_interface + if self.package.is_omitted(base): + return base_replacement + method_name = method.__name__ + sig, body = inspect.getsource(method).split("\n", 1) + body = _strip_doc_string(body) + args = extract_args(sig)[1][1:] + if arg_names: + for new, old in zip(args, arg_names): + if new != old: + body = re.sub(r"\b" + old + r"\b", new, body) + super_re = re.compile( + r"\n\s*(return )?super\([^\)]*\)\." + method_name + r"\([^\)]+\)" + ) + if super_re.search(body): + super_method, base = find_super_method(base, method_name) + super_body = self._unwrap_supers( + super_method, base, base_replacement, arg_names=args + ) + body = super_re.sub("\n" + super_body, body) + return body + def _strip_doc_string(body: str) -> str: if re.match(r"\s*(\"|')", body): diff --git a/nipype2pydra/utils/misc.py b/nipype2pydra/utils/misc.py index b91300f..78c4f99 100644 --- a/nipype2pydra/utils/misc.py +++ b/nipype2pydra/utils/misc.py @@ -551,14 +551,14 @@ def get_return_line(func: ty.Union[str, ty.Callable]) -> str: def find_super_method( super_base: type, method_name: str, include_class: bool = False -) -> ty.Optional[ty.Tuple[ty.Callable, type]]: +) -> ty.Tuple[ty.Optional[ty.Callable], ty.Optional[type]]: mro = super_base.__mro__ if not include_class: mro = mro[1:] for base in mro: if method_name in base.__dict__: # Found the match return getattr(base, method_name), base - return None + return None, None # raise RuntimeError( # f"Could not find super of '{method_name}' method in base classes of " # f"{super_base}" From 8c696fd451b1418d7fadd5c9a51b40b2f99db2e9 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Fri, 7 Jun 2024 15:42:51 +1000 Subject: [PATCH 20/21] debugging mriqc/niworkflows conversions --- nipype2pydra/helpers.py | 8 +- nipype2pydra/interface/base.py | 19 ++- nipype2pydra/interface/function.py | 12 +- nipype2pydra/interface/shell_command.py | 2 +- nipype2pydra/package.py | 40 +++---- nipype2pydra/pkg_gen/__init__.py | 10 +- nipype2pydra/statements/imports.py | 2 + nipype2pydra/utils/misc.py | 9 ++ nipype2pydra/utils/symbols.py | 113 +++++++++++------- .../utils/tests/test_utils_imports.py | 28 +++-- nipype2pydra/workflow.py | 8 +- 11 files changed, 155 insertions(+), 96 deletions(-) diff --git a/nipype2pydra/helpers.py b/nipype2pydra/helpers.py index 3892cd3..d344d9d 100644 --- a/nipype2pydra/helpers.py +++ b/nipype2pydra/helpers.py @@ -133,7 +133,7 @@ def used_symbols(self) -> UsedSymbols: always_include=self.package.all_explicit, translations=self.package.all_import_translations, ) - used.imports.update(i.to_statement() for i in self.imports) + used.import_stmts.update(i.to_statement() for i in self.imports) return used @cached_property @@ -147,12 +147,10 @@ def converted_code(self) -> ty.List[str]: @cached_property def nested_interfaces(self): potential_classes = { - full_address(c[1]): c[0] - for c in self.used_symbols.intra_pkg_classes - if c[0] + full_address(c[1]): c[0] for c in self.used_symbols.imported_classes if c[0] } potential_classes.update( - (full_address(c), c.__name__) for c in self.used_symbols.local_classes + (full_address(c), c.__name__) for c in self.used_symbols.classes ) return { potential_classes[address]: workflow diff --git a/nipype2pydra/interface/base.py b/nipype2pydra/interface/base.py index ec424d5..d7a2c1e 100644 --- a/nipype2pydra/interface/base.py +++ b/nipype2pydra/interface/base.py @@ -657,7 +657,9 @@ def pydra_fld_input(self, field, nm): val = getattr(field, key) if val is not None: if key == "argstr" and "%" in val: - val = self.string_formats(argstr=val, name=nm) + val = self.string_formats( + argstr=val, name=nm, type_=field.trait_type + ) elif key == "mandatory" and pydra_default is not None: val = False # Overwrite mandatory to False if default is provided pydra_metadata[pydra_key_nm] = val @@ -666,7 +668,9 @@ def pydra_fld_input(self, field, nm): template = getattr(field, "name_template") name_source = ensure_list(getattr(field, "name_source")) if name_source: - tmpl = self.string_formats(argstr=template, name=name_source[0]) + tmpl = self.string_formats( + argstr=template, name=name_source[0], type_=field.trait_type + ) else: tmpl = template if nm in self.nipype_interface.output_spec().class_trait_names(): @@ -829,11 +833,14 @@ def pydra_type_converter(self, field, spec_type, name): pydra_type = ty.Any return pydra_type - def string_formats(self, argstr, name): + def string_formats(self, argstr, name, type_): keys = re.findall(r"(%[0-9\.]*(?:s|d|i|g|f))", argstr) new_argstr = argstr for i, key in enumerate(keys): - repl = f"{name}" if len(keys) == 1 else f"{name}[{i}]" + if isinstance(type_, traits.trait_types.Bool): + repl = f"{name}:d" + else: + repl = f"{name}" if len(keys) == 1 else f"{name}[{i}]" match = re.match(r"%([0-9\.]+)f", key) if match: repl += ":" + match.group(1) @@ -972,7 +979,7 @@ def _converted_test(self): ) return spec_str, UsedSymbols( - module_name=self.nipype_module.__name__, imports=imports + module_name=self.nipype_module.__name__, import_stmts=imports ) def create_doctests(self, input_fields, nonstd_types): @@ -1032,7 +1039,7 @@ def _misc_cleanups(self, body: str) -> str: body = body.replace("self.cmd", f'"{self.nipype_interface._cmd}"') body = body.replace("self.output_spec().get()", "{}") - body = body.replace("self._outputs()", "{}") + body = body.replace("self._outputs().get()", "{}") # body = re.sub( # r"outputs = self\.(output_spec|_outputs)\(\).*$", # r"outputs = {}", diff --git a/nipype2pydra/interface/function.py b/nipype2pydra/interface/function.py index 1f8f6f4..27c6f8e 100644 --- a/nipype2pydra/interface/function.py +++ b/nipype2pydra/interface/function.py @@ -91,7 +91,7 @@ def types_to_names(spec_fields): translations=self.package.all_import_translations, absolute_imports=True, ) - used.update(method_used, from_other_module=False) + used.update(method_used) method_body = "" for field in input_fields: @@ -129,7 +129,7 @@ def types_to_names(spec_fields): translations=self.package.all_import_translations, absolute_imports=True, ) - used.update(init_used, from_other_module=False) + used.update(init_used) method_body += init_code + "\n" # Combined src of run_interface and list_outputs @@ -163,7 +163,7 @@ def types_to_names(spec_fields): translations=self.package.all_import_translations, absolute_imports=True, ) - used.update(run_interface_used, from_other_module=False) + used.update(run_interface_used) method_body += run_interface_code + "\n" list_outputs_code = inspect.getsource( @@ -197,7 +197,7 @@ def types_to_names(spec_fields): translations=self.package.all_import_translations, absolute_imports=True, ) - used.update(list_outputs_used, from_other_module=False) + used.update(list_outputs_used) method_body += list_outputs_code + "\n" assert method_body, "Neither `run_interface` and `list_outputs` are defined" @@ -250,12 +250,12 @@ def types_to_names(spec_fields): additional_imports.add(imprt) spec_str = repl_spec_str - used.imports.update( + used.import_stmts.update( self.construct_imports( nonstd_types, spec_str, include_task=False, - base=base_imports + list(used.imports) + list(additional_imports), + base=base_imports + list(used.import_stmts) + list(additional_imports), ) ) diff --git a/nipype2pydra/interface/shell_command.py b/nipype2pydra/interface/shell_command.py index 4a551a8..7d2dcb6 100644 --- a/nipype2pydra/interface/shell_command.py +++ b/nipype2pydra/interface/shell_command.py @@ -210,7 +210,7 @@ def types_to_names(spec_fields): ) used.update(super_used) - used.imports.update( + used.import_stmts.update( self.construct_imports( nonstd_types, spec_str, diff --git a/nipype2pydra/package.py b/nipype2pydra/package.py index dfbc01f..4e2741d 100644 --- a/nipype2pydra/package.py +++ b/nipype2pydra/package.py @@ -400,7 +400,7 @@ def write(self, package_root: Path, to_include: ty.List[str] = None): workflow.prepare_connections() def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True): - for _, klass in used.intra_pkg_classes: + for _, klass in used.imported_classes: address = full_address(klass) if address in self.nipype_port_converters: if port_nipype: @@ -412,10 +412,10 @@ def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True): ) elif full_address(klass) not in self.interfaces: intra_pkg_modules[klass.__module__].add(klass) - for _, func in used.intra_pkg_funcs: + for _, func in used.imported_funcs: if full_address(func) not in list(self.workflows): intra_pkg_modules[func.__module__].add(func) - for const_mod_address, _, const_name in used.intra_pkg_constants: + for const_mod_address, _, const_name in used.imported_constants: intra_pkg_modules[const_mod_address].add(const_name) for conv in list(self.functions.values()) + list(self.classes.values()): @@ -429,7 +429,7 @@ def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True): package_root, already_converted=already_converted, ) - class_addrs = [full_address(c) for _, c in all_used.intra_pkg_classes] + class_addrs = [full_address(c) for _, c in all_used.imported_classes] included_addrs = [c.full_address for c in interfaces_to_include] interfaces_to_include.extend( self.interfaces[a] @@ -555,14 +555,12 @@ def write_intra_pkg_modules( always_include=self.all_explicit, ) - classes = used.local_classes + [ - o for o in objs if inspect.isclass(o) and o not in used.local_classes + classes = used.classes + [ + o for o in objs if inspect.isclass(o) and o not in used.classes ] - functions = list(used.local_functions) + [ - o - for o in objs - if inspect.isfunction(o) and o not in used.local_functions + functions = list(used.functions) + [ + o for o in objs if inspect.isfunction(o) and o not in used.functions ] self.write_to_module( @@ -570,10 +568,10 @@ def write_intra_pkg_modules( module_name=out_mod_name, used=UsedSymbols( module_name=mod_name, - imports=used.imports, + import_stmts=used.import_stmts, constants=used.constants, - local_classes=classes, - local_functions=functions, + classes=classes, + functions=functions, ), find_replace=self.find_replace, inline_intra_pkg=False, @@ -871,11 +869,11 @@ def write_to_module( existing_imports = parse_imports(existing_import_strs, relative_to=module_name) converter_imports = [] - for klass in used.local_classes: + for klass in used.classes: if f"\nclass {klass.__name__}(" not in code_str: try: class_converter = self.classes[full_address(klass)] - converter_imports.extend(class_converter.used_symbols.imports) + converter_imports.extend(class_converter.used_symbols.import_stmts) except KeyError: class_converter = ClassConverter.from_object(klass, self) code_str += "\n" + class_converter.converted_code + "\n" @@ -903,11 +901,13 @@ def write_to_module( if converted_code.strip() not in code_str: code_str += "\n" + converted_code + "\n" - for func in sorted(used.local_functions, key=attrgetter("__name__")): + for func in sorted(used.functions, key=attrgetter("__name__")): if f"\ndef {func.__name__}(" not in code_str: if func.__name__ in self.functions: function_converter = self.functions[full_address(func)] - converter_imports.extend(function_converter.used_symbols.imports) + converter_imports.extend( + function_converter.used_symbols.import_stmts + ) else: function_converter = FunctionConverter.from_object(func, self) code_str += "\n" + function_converter.converted_code + "\n" @@ -923,7 +923,7 @@ def write_to_module( code_str += ( "\n\n# Intra-package imports that have been inlined in this module\n\n" ) - for func_name, func in sorted(used.intra_pkg_funcs, key=itemgetter(0)): + for func_name, func in sorted(used.imported_funcs, key=itemgetter(0)): func_src = get_source_code(func) func_src = re.sub( r"^(#[^\n]+\ndef) (\w+)(?=\()", @@ -934,7 +934,7 @@ def write_to_module( code_str += "\n\n" + cleanup_function_body(func_src) inlined_symbols.append(func_name) - for klass_name, klass in sorted(used.intra_pkg_classes, key=itemgetter(0)): + for klass_name, klass in sorted(used.imported_classes, key=itemgetter(0)): klass_src = get_source_code(klass) klass_src = re.sub( r"^(#[^\n]+\nclass) (\w+)(?=\()", @@ -973,7 +973,7 @@ def write_to_module( imports = ImportStatement.collate( existing_imports + converter_imports - + [i for i in used.imports if not i.indent] + + [i for i in used.import_stmts if not i.indent] + GENERIC_PYDRA_IMPORTS + additional_imports ) diff --git a/nipype2pydra/pkg_gen/__init__.py b/nipype2pydra/pkg_gen/__init__.py index 9ab8151..170a6d5 100644 --- a/nipype2pydra/pkg_gen/__init__.py +++ b/nipype2pydra/pkg_gen/__init__.py @@ -1123,13 +1123,13 @@ def insert_args_in_method_calls( mod = import_module(mod_name) used = UsedSymbols.find(mod, methods, omit_classes=(BaseInterface, TraitedSpec)) all_funcs.update(methods) - for func in used.local_functions: + for func in used.functions: all_funcs.add(cleanup_function_body(get_source_code(func))) - for klass in used.local_classes: + for klass in used.classes: klass_src = cleanup_function_body(get_source_code(klass)) if klass_src not in all_classes: all_classes.append(klass_src) - for new_func_name, func in used.intra_pkg_funcs: + for new_func_name, func in used.imported_funcs: if new_func_name is None: continue # Not referenced directly in this module func_src = get_source_code(func) @@ -1148,7 +1148,7 @@ def insert_args_in_method_calls( + match.group(2) ) all_funcs.add(cleanup_function_body(func_src)) - for new_klass_name, klass in used.intra_pkg_classes: + for new_klass_name, klass in used.imported_classes: if new_klass_name is None: continue # Not referenced directly in this module klass_src = get_source_code(klass) @@ -1169,7 +1169,7 @@ def insert_args_in_method_calls( klass_src = cleanup_function_body(klass_src) if klass_src not in all_classes: all_classes.append(klass_src) - all_imports.update(used.imports) + all_imports.update(used.import_stmts) all_constants.update(used.constants) return ( sorted( diff --git a/nipype2pydra/statements/imports.py b/nipype2pydra/statements/imports.py index 0e9887a..a1b993a 100644 --- a/nipype2pydra/statements/imports.py +++ b/nipype2pydra/statements/imports.py @@ -587,6 +587,8 @@ def parse_imports( "from fileformats.generic import File, Directory", "from pydra.engine.specs import MultiInputObj", "from pathlib import Path", + "import json", + "import yaml", "import logging", "import pydra.mark", "import typing as ty", diff --git a/nipype2pydra/utils/misc.py b/nipype2pydra/utils/misc.py index 78c4f99..c91c58c 100644 --- a/nipype2pydra/utils/misc.py +++ b/nipype2pydra/utils/misc.py @@ -22,6 +22,7 @@ from importlib import import_module from logging import getLogger +from pydra.engine.specs import MultiInputObj logger = getLogger("nipype2pydra") @@ -482,12 +483,20 @@ def from_named_dicts_converter( def str_to_type(type_str: str) -> type: """Resolve a string representation of a type into a valid type""" if "/" in type_str: + if type_str.startswith("multi["): + assert type_str.endswith("]"), f"Invalid multi type: {type_str}" + type_str = type_str[6:-1] + multi = True + else: + multi = False tp = from_mime(type_str) try: # If datatype is a field, use its primitive instead tp = tp.primitive # type: ignore except AttributeError: pass + if multi: + tp = MultiInputObj[tp] else: def resolve_type(type_str: str) -> type: diff --git a/nipype2pydra/utils/symbols.py b/nipype2pydra/utils/symbols.py index c4524eb..0d169c8 100644 --- a/nipype2pydra/utils/symbols.py +++ b/nipype2pydra/utils/symbols.py @@ -24,7 +24,9 @@ class UsedSymbols: A class to hold the used symbols in a module Parameters - ------- + ---------- + module_name: str + the name of the module containing the functions to be converted imports : list[str] the import statements that need to be included in the converted file local_functions: set[callable] @@ -45,16 +47,31 @@ class UsedSymbols: set of all the constants defined within the package that are referenced by the function, (, , ), where the local alias and the definition of the constant + methods: set[callable] + the names of the methods that are referenced, by default None is a function not + a method + class_constants: set[tuple[str, str]] + the names of the class attributes that are referenced by the method + + class_name: str, optional + the name of the class that the methods originate from """ module_name: str - imports: ty.Set[str] = attrs.field(factory=set) - local_functions: ty.Set[ty.Callable] = attrs.field(factory=set) - local_classes: ty.List[type] = attrs.field(factory=list) + import_stmts: ty.Set[str] = attrs.field(factory=set) + functions: ty.Set[ty.Callable] = attrs.field(factory=set) + classes: ty.List[type] = attrs.field(factory=list) constants: ty.Set[ty.Tuple[str, str]] = attrs.field(factory=set) - intra_pkg_funcs: ty.Set[ty.Tuple[str, ty.Callable]] = attrs.field(factory=set) - intra_pkg_classes: ty.List[ty.Tuple[str, ty.Callable]] = attrs.field(factory=list) - intra_pkg_constants: ty.Set[ty.Tuple[str, str, str]] = attrs.field(factory=set) + methods: ty.Set[ty.Callable] = attrs.field(factory=set) + class_attrs: ty.Set[ty.Tuple[str, str]] = attrs.field(factory=set) + imported_funcs: ty.Set[ty.Tuple[str, ty.Callable]] = attrs.field(factory=set) + imported_classes: ty.List[ty.Tuple[str, ty.Callable]] = attrs.field(factory=list) + imported_constants: ty.Set[ty.Tuple[str, str, str]] = attrs.field(factory=set) + super_methoods: ty.Set[ty.Tuple[type, ty.Callable]] = attrs.field(factory=set) + super_class_attrs: ty.Set[ty.Tuple[type, ty.Tuple[str, str]]] = attrs.field( + factory=set + ) + klass: ty.Optional[type] = None ALWAYS_OMIT_MODULES = [ "traits.trait_handlers", # Old traits module, pre v6.0 @@ -74,34 +91,46 @@ def update( other: "UsedSymbols", absolute_imports: bool = False, to_be_inlined: bool = False, - from_other_module: bool = True, ): - if to_be_inlined or not from_other_module: - self.imports.update( - i.absolute() if absolute_imports else i for i in other.imports + if (self.module_name == other.module_name) or to_be_inlined: + self.import_stmts.update( + i.absolute() if absolute_imports else i for i in other.import_stmts ) - self.intra_pkg_funcs.update(other.intra_pkg_funcs) - self.intra_pkg_classes.extend( - c for c in other.intra_pkg_classes if c not in self.intra_pkg_classes + self.imported_funcs.update(other.imported_funcs) + self.imported_classes.extend( + c for c in other.imported_classes if c not in self.imported_classes ) - self.intra_pkg_constants.update(other.intra_pkg_constants) - if from_other_module: - self.intra_pkg_funcs.update((None, f) for f in other.local_functions) - self.intra_pkg_classes.extend( + self.imported_constants.update(other.imported_constants) + if self.module_name != other.module_name: + self.imported_funcs.update((None, f) for f in other.functions) + self.imported_classes.extend( (None, c) - for c in other.local_classes - if (None, c) not in self.intra_pkg_classes + for c in other.classes + if (None, c) not in self.imported_classes ) - self.intra_pkg_constants.update( + self.imported_constants.update( (other.module_name, None, c[0]) for c in other.constants ) else: - self.local_functions.update(other.local_functions) - self.local_classes.extend( - c for c in other.local_classes if c not in self.local_classes - ) - + self.functions.update(other.functions) + self.classes.extend(c for c in other.classes if c not in self.classes) self.constants.update(other.constants) + if other.klass: + if not self.klass: + raise ValueError( + f"Attempting to merge class symbols for {other.klass} with module " + f"symbols ({self.module_name}) with different names" + ) + if self.klass is other.klass: + self.methods.update(other.methods) + self.constants.update(other.constants) + else: + self.super_methoods.update( + (other.klass, m) for m in other.super_methoods + ) + self.super_class_attrs.update( + (other.klass, a) for a in other.super_class_attrs + ) DEFAULT_FILTERED_CONSTANTS = ( Undefined, @@ -243,19 +272,19 @@ def find( for local_func in local_functions: if ( local_func.__name__ in used_symbols - and local_func not in used.local_functions + and local_func not in used.functions ): - used.local_functions.add(local_func) + used.functions.add(local_func) cls._get_symbols(local_func, used_symbols) all_src += "\n\n" + inspect.getsource(local_func) for local_class in local_classes: if ( local_class.__name__ in used_symbols - and local_class not in used.local_classes + and local_class not in used.classes ): if issubclass(local_class, (BaseInterface, TraitedSpec)): continue - used.local_classes.append(local_class) + used.classes.append(local_class) class_body = inspect.getsource(local_class) bases = extract_args(class_body)[1] used_symbols.update(bases) @@ -334,12 +363,12 @@ def find( ) or inspect.isbuiltin(imported.object): # Case where an object is a nested import from a different package # which is imported in a chain from a neighbouring module - used.imports.add( + used.import_stmts.add( imported.as_independent_statement(resolve=True) ) stmt.drop(imported) elif inspect.isfunction(imported.object): - used.intra_pkg_funcs.add((imported.local_name, imported.object)) + used.imported_funcs.add((imported.local_name, imported.object)) # Recursively include objects imported in the module intra_pkg_objs[import_module(imported.object.__module__)].add( imported.object @@ -353,8 +382,8 @@ def find( # like we did for functions here because we need to preserve the # order the classes are defined in the module in case one inherits # from the other - if class_def not in used.intra_pkg_classes: - used.intra_pkg_classes.append(class_def) + if class_def not in used.imported_classes: + used.imported_classes.append(class_def) # Recursively include objects imported in the module intra_pkg_objs[import_module(imported.object.__module__)].add( imported.object, @@ -375,15 +404,15 @@ def find( obj = getattr(imported.object, attr_name) if inspect.isfunction(obj): - used.intra_pkg_funcs.add((obj.__name__, obj)) + used.imported_funcs.add((obj.__name__, obj)) intra_pkg_objs[imported.object.__name__].add(obj) elif inspect.isclass(obj): class_def = (obj.__name__, obj) - if class_def not in used.intra_pkg_classes: - used.intra_pkg_classes.append(class_def) + if class_def not in used.imported_classes: + used.imported_classes.append(class_def) intra_pkg_objs[imported.object.__name__].add(obj) else: - used.intra_pkg_constants.add( + used.imported_constants.add( ( imported.object.__name__, attr_name, @@ -397,7 +426,7 @@ def find( f"Cannot inline imported module in statement '{stmt}'" ) else: - used.intra_pkg_constants.add( + used.imported_constants.add( ( stmt.module_name, imported.local_name, @@ -423,7 +452,7 @@ def find( ) used.update(used_in_mod, to_be_inlined=collapse_intra_pkg) if stmt: - used.imports.add(stmt) + used.import_stmts.add(stmt) return used @classmethod @@ -495,7 +524,7 @@ def get_imported_object(self, name: str) -> ty.Any: # if not i.from_ # } all_imported = {} - for stmt in self.imports: + for stmt in self.import_stmts: all_imported.update(stmt.imported) try: return all_imported[name].object @@ -514,7 +543,7 @@ def get_imported_object(self, name: str) -> ty.Any: if imported_obj is None: raise ImportError( f"Could not find object named {name} in any of the imported modules:\n" - + "\n".join(str(i) for i in self.imports) + + "\n".join(str(i) for i in self.import_stmts) ) for part in parts[-i:]: imported_obj = getattr(imported_obj, part) diff --git a/nipype2pydra/utils/tests/test_utils_imports.py b/nipype2pydra/utils/tests/test_utils_imports.py index 483ebf0..ecb12c4 100644 --- a/nipype2pydra/utils/tests/test_utils_imports.py +++ b/nipype2pydra/utils/tests/test_utils_imports.py @@ -49,7 +49,9 @@ def test_get_imported_object1(): import_stmts = [ "import nipype.interfaces.utility as niu", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) assert ( used.get_imported_object("niu.IdentityInterface") is nipype.interfaces.utility.IdentityInterface @@ -60,7 +62,9 @@ def test_get_imported_object2(): import_stmts = [ "import nipype.interfaces.utility", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) assert ( used.get_imported_object("nipype.interfaces.utility") is nipype.interfaces.utility @@ -71,7 +75,9 @@ def test_get_imported_object3(): import_stmts = [ "from nipype.interfaces.utility import IdentityInterface", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) assert ( used.get_imported_object("IdentityInterface") is nipype.interfaces.utility.IdentityInterface @@ -82,7 +88,9 @@ def test_get_imported_object4(): import_stmts = [ "from nipype.interfaces.utility import IdentityInterface", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) assert ( used.get_imported_object("IdentityInterface.input_spec") is nipype.interfaces.utility.IdentityInterface.input_spec @@ -93,7 +101,9 @@ def test_get_imported_object5(): import_stmts = [ "import nipype.interfaces.utility", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) assert ( used.get_imported_object( "nipype.interfaces.utility.IdentityInterface.input_spec" @@ -106,7 +116,9 @@ def test_get_imported_object_fail1(): import_stmts = [ "import nipype.interfaces.utility", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) with pytest.raises(ImportError, match="Could not find object named"): used.get_imported_object("nipype.interfaces.utilityboo") @@ -115,6 +127,8 @@ def test_get_imported_object_fail2(): import_stmts = [ "from nipype.interfaces.utility import IdentityInterface", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) with pytest.raises(ImportError, match="Could not find object named"): used.get_imported_object("IdentityBoo") diff --git a/nipype2pydra/workflow.py b/nipype2pydra/workflow.py index f32544c..033a5b7 100644 --- a/nipype2pydra/workflow.py +++ b/nipype2pydra/workflow.py @@ -666,10 +666,10 @@ def func_body(self): @cached_property def nested_workflows(self): potential_funcs = { - full_address(f[1]): f[0] for f in self.used_symbols.intra_pkg_funcs if f[0] + full_address(f[1]): f[0] for f in self.used_symbols.imported_funcs if f[0] } potential_funcs.update( - (full_address(f), f.__name__) for f in self.used_symbols.local_functions + (full_address(f), f.__name__) for f in self.used_symbols.functions ) return { potential_funcs[address]: workflow @@ -731,7 +731,7 @@ def write( # main workflow code_str = self.converted_code - local_func_names = {f.__name__ for f in used.local_functions} + local_func_names = {f.__name__ for f in used.functions} # Convert any nested workflows for name, conv in self.nested_workflows.items(): if conv.address in already_converted: @@ -990,7 +990,7 @@ def test_used(self): return UsedSymbols( module_name=self.nipype_module.__name__, - imports=( + import_stmts=( nonstd_type_imports + parse_imports( [ From 3d0e218f95cc54ac7b6593ff6be69794aa17877b Mon Sep 17 00:00:00 2001 From: Tom Close Date: Fri, 5 Jul 2024 11:27:06 +1000 Subject: [PATCH 21/21] debugging switch to class-symbols --- nipype2pydra/helpers.py | 39 +- nipype2pydra/interface/base.py | 456 +++++-------- nipype2pydra/interface/function.py | 141 ++-- nipype2pydra/interface/shell_command.py | 214 +++--- nipype2pydra/package.py | 25 +- nipype2pydra/pkg_gen/__init__.py | 2 +- nipype2pydra/statements/workflow_build.py | 4 +- nipype2pydra/{utils => }/symbols.py | 635 ++++++++++++++---- nipype2pydra/utils/__init__.py | 6 - nipype2pydra/utils/misc.py | 10 +- .../utils/tests/test_utils_imports.py | 2 +- nipype2pydra/workflow.py | 23 +- 12 files changed, 890 insertions(+), 667 deletions(-) rename nipype2pydra/{utils => }/symbols.py (52%) diff --git a/nipype2pydra/helpers.py b/nipype2pydra/helpers.py index d344d9d..10e164f 100644 --- a/nipype2pydra/helpers.py +++ b/nipype2pydra/helpers.py @@ -9,8 +9,8 @@ from types import ModuleType import black.report import yaml +from .symbols import UsedSymbols from .utils import ( - UsedSymbols, extract_args, full_address, multiline_comment, @@ -121,17 +121,13 @@ def nipype_object(self): return getattr(self.nipype_module, self.nipype_name) @cached_property - def used_symbols(self) -> UsedSymbols: + def used(self) -> UsedSymbols: used = UsedSymbols.find( self.nipype_module, [self.src], + package=self.package, collapse_intra_pkg=False, - omit_classes=self.package.omit_classes, - omit_modules=self.package.omit_modules, - omit_functions=self.package.omit_functions, - omit_constants=self.package.omit_constants, always_include=self.package.all_explicit, - translations=self.package.all_import_translations, ) used.import_stmts.update(i.to_statement() for i in self.imports) return used @@ -147,10 +143,10 @@ def converted_code(self) -> ty.List[str]: @cached_property def nested_interfaces(self): potential_classes = { - full_address(c[1]): c[0] for c in self.used_symbols.imported_classes if c[0] + full_address(c[1]): c[0] for c in self.used.imported_classes if c[0] } potential_classes.update( - (full_address(c), c.__name__) for c in self.used_symbols.classes + (full_address(c), c.__name__) for c in self.used.classes ) return { potential_classes[address]: workflow @@ -377,8 +373,7 @@ class ClassConverter(BaseHelperConverter): @cached_property def _converted_code(self) -> ty.Tuple[str, ty.List[str]]: - """Convert the Nipype workflow function to a Pydra workflow function and determine - the configuration parameters that are used + """Convert a class into Pydra- Returns ------- @@ -389,18 +384,30 @@ def _converted_code(self) -> ty.Tuple[str, ty.List[str]]: """ used_configs = set() - parts = re.split( - r"\n (?!\s|\))", replace_undefined(self.src), flags=re.MULTILINE - ) + + src = replace_undefined(self.src)[len("class ") :] + name, bases, class_body = extract_args(src, drop_parens=True) + bases = [ + b + for b in bases + if not self.package.is_omitted(getattr(self.nipype_module, b)) + ] + + parts = re.split(r"\n (?!\s|\))", class_body, flags=re.MULTILINE) converted_parts = [] - for part in parts: + for part in parts[1:]: if part.startswith("def"): converted_func, func_used_configs = self._convert_function(part) converted_parts.append(converted_func) used_configs.update(func_used_configs) else: converted_parts.append(part) - code_str = "\n ".join(converted_parts) + code_str = ( + f"class {name}(" + + ", ".join(bases) + + "):\n " + + "\n ".join(converted_parts) + ) # Format the the code before the find and replace so it is more predictable try: code_str = black.format_file_contents( diff --git a/nipype2pydra/interface/base.py b/nipype2pydra/interface/base.py index d7a2c1e..71baa58 100644 --- a/nipype2pydra/interface/base.py +++ b/nipype2pydra/interface/base.py @@ -5,7 +5,6 @@ from abc import ABCMeta, abstractmethod from importlib import import_module from types import ModuleType -from collections import defaultdict import itertools import inspect import traits.trait_types @@ -22,16 +21,14 @@ from nipype.interfaces.base.core import SimpleInterface from pydra.engine import specs from pydra.engine.helpers import ensure_list +from .. import symbols from ..utils import ( import_module_from_path, is_fileset, to_snake_case, - UsedSymbols, types_converter, from_dict_converter, unwrap_nested_type, - get_local_functions, - get_local_constants, get_return_line, cleanup_function_body, insert_args_in_signature, @@ -447,95 +444,21 @@ def add_nonstd_types(tp): add_nonstd_types(f[1]) return nonstd_types - @property - def converted_code(self): - return self._converted[0] - - @property - def used_symbols(self): - return self._converted[1] - @cached_property - def _converted(self): - """writing pydra task to the dile based on the input and output spec""" - + def converted_code(self): return self.generate_code( self.input_fields, self.nonstd_types, self.output_fields ) - @property - def referenced_local_functions(self): - return self._referenced_funcs_and_methods[0] - - @property - def referenced_methods(self): - return self._referenced_funcs_and_methods[1] - - @property - def referenced_supers(self): - return self._referenced_funcs_and_methods[2] - - @property - def method_args(self): - return self._referenced_funcs_and_methods[3] - - @property - def method_returns(self): - return self._referenced_funcs_and_methods[4] - - @property - def method_stacks(self): - return self._referenced_funcs_and_methods[5] - - @property - def method_supers(self): - return self._referenced_funcs_and_methods[6] - @cached_property - def _referenced_funcs_and_methods(self): - referenced_funcs = set() - referenced_methods = set() - referenced_supers = {} - method_args = {} - method_returns = {} - method_stacks = {} - method_supers = defaultdict(dict) - already_processed = set( - getattr(self.nipype_interface, m) for m in self.included_methods - ) - for method_name in self.included_methods: - method_args[method_name] = [] - method_returns[method_name] = [] - method_stacks[method_name] = () - for method_name in self.included_methods: - method = getattr(self.nipype_interface, method_name) - super_base = find_super_method( - self.nipype_interface, method_name, include_class=True - )[1] - # if super_base is not self.nipype_interface: - # method_supers[self.nipype_interface][method_name] = ( - # self._common_parent_pkg_prefix(super_base) + method_name - # ) - self._get_referenced( - method, - referenced_funcs=referenced_funcs, - referenced_methods=referenced_methods, - referenced_supers=referenced_supers, - method_args=method_args, - method_returns=method_returns, - method_stacks=method_stacks, - method_supers=method_supers, - already_processed=already_processed, - super_base=super_base, - ) - return ( - referenced_funcs, - referenced_methods, - referenced_supers, - method_args, - method_returns, - method_stacks, - method_supers, + def used(self) -> "symbols.UsedClassSymbols": + return symbols.UsedClassSymbols.find( + klass=self.nipype_interface, + method_names=self.included_methods, + package=self.package, + collapse_intra_pkg=False, + pull_out_inline_imports=True, + absolute_imports=True, ) @cached_property @@ -579,8 +502,7 @@ def write( package_root=package_root, module_name=self.output_module, converted_code=self.converted_code, - used=self.used_symbols, - # inline_intra_pkg=True, + used=self.used, find_replace=self.find_replace + self.package.find_replace, ) @@ -591,8 +513,6 @@ def write( depth=self.package.init_depth, auto_import_depth=self.package.auto_import_init_depth, import_find_replace=self.package.import_find_replace, - # + [f.__name__ for f in self.used_symbols.local_functions] - # + [c.__name__ for c in self.used_symbols.local_classes], ) test_module_fspath = self.package.write_to_module( @@ -601,7 +521,7 @@ def write( self.output_module, f".tests.test_{self.task_name.lower()}" ), converted_code=self.converted_test_code, - used=self.used_symbols_test, + used=self.used_test, inline_intra_pkg=False, find_replace=self.find_replace, ) @@ -850,14 +770,14 @@ def string_formats(self, argstr, name, type_): @abstractmethod def generate_code(self, input_fields, nonstd_types, output_fields) -> ty.Tuple[ str, - UsedSymbols, + "symbols.UsedSymbols", ]: """ Returns ------- converted_code : str the core converted code for the task - used_symbols: UsedSymbols + used: UsedSymbols symbols used in the code """ @@ -901,7 +821,7 @@ def converted_test_code(self): return self._converted_test[0] @property - def used_symbols_test(self): + def used_test(self): return self._converted_test[1] @cached_property @@ -978,7 +898,7 @@ def _converted_test(self): }, ) - return spec_str, UsedSymbols( + return spec_str, symbols.UsedSymbols( module_name=self.nipype_module.__name__, import_stmts=imports ) @@ -1059,172 +979,162 @@ def _misc_cleanups(self, body: str) -> str: body = " " * min_indentation(body) + "self_dict = {}\n" + new_body body = body.replace("return runtime", "") body = body.replace("TraitError", "KeyError") - body = body.replace("os.getcwd()", "output_dir") return body - def _get_referenced( - self, - method: ty.Callable, - referenced_funcs: ty.Set[ty.Callable], - referenced_methods: ty.Set[ty.Callable], - referenced_supers: ty.Dict[str, ty.Tuple[ty.Callable, type]], - method_args: ty.Dict[str, ty.List[str]] = None, - method_returns: ty.Dict[str, ty.List[str]] = None, - method_stacks: ty.Dict[str, ty.Tuple[ty.Callable]] = None, - method_supers: ty.Dict[type, ty.Dict[str, str]] = None, - already_processed: ty.Set[ty.Callable] = None, - method_stack: ty.Optional[ty.Tuple[ty.Callable]] = None, - super_base: ty.Optional[type] = None, - ) -> ty.Tuple[ty.Set, ty.Set]: - """Get the local functions referenced in the source code - - Parameters - ---------- - src: str - the source of the file to extract the import statements from - referenced_funcs: set[function] - the set of local functions that have been referenced so far - referenced_methods: set[function] - the set of methods that have been referenced so far - method_args: dict[str, list[str]] - a dictionary to hold additional arguments that need to be added to each method, - where the dictionary key is the names of the methods - method_returns: dict[str, list[str]] - a dictionary to hold the return values of each method, - where the dictionary key is the names of the methods - - Returns - ------- - referenced_inputs: set[str] - inputs that have been referenced - referenced_outputs: set[str] - outputs that have been referenced - """ - if already_processed: - already_processed.add(method) - else: - already_processed = {method} - if method_stack is None: - method_stack = (method,) - else: - method_stack += (method,) - if super_base is None: - super_base = self.nipype_interface - method_body = inspect.getsource(method) - method_body = re.sub(r"\s*#.*", "", method_body) # Strip out comments - return_value = get_return_line(method_body) - ref_local_func_names = re.findall(r"(? str: - """Return the common part of two package names""" - ref_parts = self.nipype_interface.__module__.split(".") - mod_parts = base.__module__.split(".") - common = [] - for r_part, m_part in zip(ref_parts, mod_parts): - if r_part == m_part: - common.append(r_part) - else: - break - if not common: - return "" - return "_".join(common + [base.__name__]) + "__" - - @cached_property - def local_functions(self): - """Get the functions defined in the same file as the interface""" - return get_local_functions(self.nipype_module) - - @cached_property - def local_constants(self): - return get_local_constants(self.nipype_module) + # def _get_referenced( + # self, + # method: ty.Callable, + # referenced_funcs: ty.Set[ty.Callable], + # referenced_methods: ty.Set[ty.Callable], + # referenced_supers: ty.Dict[str, ty.Tuple[ty.Callable, type]], + # method_args: ty.Dict[str, ty.List[str]] = None, + # method_returns: ty.Dict[str, ty.List[str]] = None, + # method_stacks: ty.Dict[str, ty.Tuple[ty.Callable]] = None, + # method_supers: ty.Dict[type, ty.Dict[str, str]] = None, + # already_processed: ty.Set[ty.Callable] = None, + # method_stack: ty.Optional[ty.Tuple[ty.Callable]] = None, + # super_base: ty.Optional[type] = None, + # ) -> ty.Tuple[ty.Set, ty.Set]: + # """Get the local functions referenced in the source code + + # Parameters + # ---------- + # src: str + # the source of the file to extract the import statements from + # referenced_funcs: set[function] + # the set of local functions that have been referenced so far + # referenced_methods: set[function] + # the set of methods that have been referenced so far + # method_args: dict[str, list[str]] + # a dictionary to hold additional arguments that need to be added to each method, + # where the dictionary key is the names of the methods + # method_returns: dict[str, list[str]] + # a dictionary to hold the return values of each method, + # where the dictionary key is the names of the methods + + # Returns + # ------- + # referenced_inputs: set[str] + # inputs that have been referenced + # referenced_outputs: set[str] + # outputs that have been referenced + # """ + # if already_processed: + # already_processed.add(method) + # else: + # already_processed = {method} + # if method_stack is None: + # method_stack = (method,) + # else: + # method_stack += (method,) + # if super_base is None: + # super_base = self.nipype_interface + # method_body = inspect.getsource(method) + # method_body = re.sub(r"\s*#.*", "", method_body) # Strip out comments + # return_value = get_return_line(method_body) + # ref_local_func_names = re.findall(r"(? str: + # """Return the common part of two package names""" + # ref_parts = self.nipype_interface.__module__.split(".") + # mod_parts = base.__module__.split(".") + # common = [] + # for r_part, m_part in zip(ref_parts, mod_parts): + # if r_part == m_part: + # common.append(r_part) + # else: + # break + # if not common: + # return "" + # return "_".join(common + [base.__name__]) + "__" def process_method( self, @@ -1247,7 +1157,7 @@ def process_method( pass if "runtime" in args: args.remove("runtime") - args_to_add = list(self.method_args.get(method.__name__, [])) + list( + args_to_add = list(self.used.method_args.get(method.__name__, [])) + list( additional_args ) if args_to_add: @@ -1259,8 +1169,8 @@ def process_method( method_body = self.process_method_body( method_body, input_names, output_names, super_base ) - if self.method_returns.get(method.__name__): - return_args = self.method_returns[method.__name__] + if self.used.method_returns.get(method.__name__): + return_args = self.used.method_returns[method.__name__] method_body = ( " " * min_indentation(method_body) + " = ".join(return_args) @@ -1338,7 +1248,7 @@ def process_method_body( def replace_supers(self, method_body, super_base=None): if super_base is None: super_base = self.nipype_interface - name_map = self.method_supers[super_base] + name_map = self.used.super_func_names.get(super_base) splits = re.split(r"super\([^\)]*\)\.(\w+)(?=\()", method_body) new_body = splits[0] for name, block in zip(splits[1::2], splits[2::2]): @@ -1371,9 +1281,7 @@ def unwrap_nested_methods( method_body = self._misc_cleanups(method_body) # Add args to the function signature of method calls method_re = re.compile(r"self\.(\w+)(?=\()", flags=re.MULTILINE | re.DOTALL) - method_names = [m.__name__ for m in self.referenced_methods] + list( - self.included_methods - ) + method_names = [m.__name__ for m in self.used.methods] method_body = strip_comments(method_body) omitted_methods = {} for method_name in set( @@ -1401,7 +1309,7 @@ def unwrap_nested_methods( continue # Assign additional return values (which were previously saved to member # attributes) to new variables from the method call - if self.method_returns[name]: + if self.used.method_returns[name]: last_line = new_body.splitlines()[-1] match = re.match(r" *([a-zA-Z0-9\,\.\_ ]+ *=)? *$", last_line) if match: @@ -1411,22 +1319,26 @@ def unwrap_nested_methods( last_line = new_body_lines[-1] new_body += "\n" + re.sub( r"^( *)([a-zA-Z0-9\,\.\_ ]+) *= *$", - r"\1\2, " + ",".join(self.method_returns[name]) + " = ", + r"\1\2, " + + ",".join(self.used.method_returns[name]) + + " = ", last_line, flags=re.MULTILINE, ) else: - new_body += ",".join(self.method_returns[name]) + " = " + new_body += ",".join(self.used.method_returns[name]) + " = " else: logger.warning( "Could not augment the return value of the method converted from " f"a function '{name}' with the previously assigned attributes " - f"{self.method_returns[name]} as the method doesn't have a " + f"{self.used.method_returns[name]} as the method doesn't have a " f"singular return statement at the end of the method" ) # Insert additional arguments to the method call (which were previously # accessed via member attributes) - args_to_be_inserted = list(self.method_args[name]) + list(additional_args) + args_to_be_inserted = list(self.used.method_args[name]) + list( + additional_args + ) try: new_body += name + insert_args_in_signature( args, diff --git a/nipype2pydra/interface/function.py b/nipype2pydra/interface/function.py index 27c6f8e..1994e7f 100644 --- a/nipype2pydra/interface/function.py +++ b/nipype2pydra/interface/function.py @@ -7,7 +7,7 @@ import attrs from nipype.interfaces.base import BaseInterface, TraitedSpec from .base import BaseInterfaceConverter -from ..utils import UsedSymbols, get_return_line, find_super_method +from ..symbols import UsedSymbols, get_return_line, find_super_method logger = logging.getLogger("nipype2pydra") @@ -31,7 +31,7 @@ def generate_code(self, input_fields, nonstd_types, output_fields) -> ty.Tuple[ ------- converted_code : str the core converted code for the task - used_symbols: UsedSymbols + used: UsedSymbols symbols used in the code """ @@ -64,35 +64,6 @@ def types_to_names(spec_fields): output_names = [o[0] for o in output_fields] output_type_names = [o[1] for o in output_fields_str] - used = UsedSymbols.find( - self.nipype_module, - self.referenced_local_functions, - omit_classes=self.package.omit_classes, - omit_modules=self.package.omit_modules, - omit_functions=self.package.omit_functions, - omit_constants=self.package.omit_constants, - always_include=self.package.all_explicit, - translations=self.package.all_import_translations, - absolute_imports=True, - ) - - for ref_method in self.referenced_methods: - method_module = find_super_method( - self.nipype_interface, ref_method.__name__, include_class=True - )[1].__module__ - method_used = UsedSymbols.find( - method_module, - [ref_method], - omit_classes=self.package.omit_classes, - omit_modules=self.package.omit_modules, - omit_functions=self.package.omit_functions, - omit_constants=self.package.omit_constants, - always_include=self.package.all_explicit, - translations=self.package.all_import_translations, - absolute_imports=True, - ) - used.update(method_used) - method_body = "" for field in input_fields: if field[-1].get("copyfile"): @@ -100,12 +71,14 @@ def types_to_names(spec_fields): for field in output_fields: method_body += f" {field[0]} = attrs.NOTHING\n" + used_method_names = [m.__name__ for m in self.used.methods] # Combined src of init and list_outputs - init_code = inspect.getsource(self.nipype_interface.__init__).strip() - init_class = find_super_method( - self.nipype_interface, "__init__", include_class=True - )[1] - if not self.package.is_omitted(init_class): + if "__init__" in used_method_names: + init_code = inspect.getsource(self.nipype_interface.__init__).strip() + init_class = find_super_method( + self.nipype_interface, "__init__", include_class=True + )[1] + assert not self.package.is_omitted(init_class) # Strip out method def and return statement method_lines = init_code.strip().split("\n")[1:] if re.match(r"\s*return", method_lines[-1]): @@ -117,29 +90,17 @@ def types_to_names(spec_fields): output_names, super_base=init_class, ) - - init_used = UsedSymbols.find( - init_class.__module__, - [init_code], - omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec], - omit_modules=self.package.omit_modules, - omit_functions=self.package.omit_functions, - omit_constants=self.package.omit_constants, - always_include=self.package.all_explicit, - translations=self.package.all_import_translations, - absolute_imports=True, - ) - used.update(init_used) method_body += init_code + "\n" # Combined src of run_interface and list_outputs - run_interface_code = inspect.getsource( - self.nipype_interface._run_interface - ).strip() - run_interface_class = find_super_method( - self.nipype_interface, "_run_interface", include_class=True - )[1] - if not self.package.is_omitted(run_interface_class): + if "_run_interface" in used_method_names: + run_interface_code = inspect.getsource( + self.nipype_interface._run_interface + ).strip() + run_interface_class = find_super_method( + self.nipype_interface, "_run_interface", include_class=True + )[1] + assert not self.package.is_omitted(run_interface_class) # Strip out method def and return statement method_lines = run_interface_code.strip().split("\n")[1:] if re.match(r"\s*return", method_lines[-1]): @@ -151,28 +112,16 @@ def types_to_names(spec_fields): output_names, super_base=run_interface_class, ) - - run_interface_used = UsedSymbols.find( - run_interface_class.__module__, - [run_interface_code], - omit_classes=self.package.omit_classes, - omit_modules=self.package.omit_modules, - omit_functions=self.package.omit_functions, - omit_constants=self.package.omit_constants, - always_include=self.package.all_explicit, - translations=self.package.all_import_translations, - absolute_imports=True, - ) - used.update(run_interface_used) method_body += run_interface_code + "\n" - list_outputs_code = inspect.getsource( - self.nipype_interface._list_outputs - ).strip() - list_outputs_class = find_super_method( - self.nipype_interface, "_list_outputs", include_class=True - )[1] - if not self.package.is_omitted(list_outputs_class): + if "_list_outputs" in used_method_names: + list_outputs_code = inspect.getsource( + self.nipype_interface._list_outputs + ).strip() + list_outputs_class = find_super_method( + self.nipype_interface, "_list_outputs", include_class=True + )[1] + assert not self.package.is_omitted(list_outputs_class) # Strip out method def and return statement lo_lines = list_outputs_code.strip().split("\n")[1:] if re.match(r"\s*(return|raise NotImplementedError)", lo_lines[-1]): @@ -185,19 +134,6 @@ def types_to_names(spec_fields): super_base=list_outputs_class, unwrap_return_dict=True, ) - - list_outputs_used = UsedSymbols.find( - list_outputs_class.__module__, - [list_outputs_code], - omit_classes=self.package.omit_classes, - omit_modules=self.package.omit_modules, - omit_functions=self.package.omit_functions, - omit_constants=self.package.omit_constants, - always_include=self.package.all_explicit, - translations=self.package.all_import_translations, - absolute_imports=True, - ) - used.update(list_outputs_used) method_body += list_outputs_code + "\n" assert method_body, "Neither `run_interface` and `list_outputs` are defined" @@ -232,15 +168,16 @@ def types_to_names(spec_fields): spec_str += "\n\n# Nipype methods converted into functions\n\n" - for m in sorted(self.referenced_methods, key=attrgetter("__name__")): - spec_str += "\n\n" + self.process_method( - m, - input_names, - output_names, - super_base=find_super_method( - self.nipype_interface, m.__name__, include_class=True - )[1], - ) + for m in sorted(self.used.methods, key=attrgetter("__name__")): + if m.__name__ not in self.included_methods: + spec_str += "\n\n" + self.process_method( + m, + input_names, + output_names, + super_base=find_super_method( + self.nipype_interface, m.__name__, include_class=True + )[1], + ) # Replace runtime attributes additional_imports = set() @@ -250,16 +187,18 @@ def types_to_names(spec_fields): additional_imports.add(imprt) spec_str = repl_spec_str - used.import_stmts.update( + self.used.import_stmts.update( self.construct_imports( nonstd_types, spec_str, include_task=False, - base=base_imports + list(used.import_stmts) + list(additional_imports), + base=base_imports + + list(self.used.import_stmts) + + list(additional_imports), ) ) - return spec_str, used + return spec_str def replace_attributes(self, function_body: ty.Callable) -> str: """Replace self.inputs. with in the function body and add args to the diff --git a/nipype2pydra/interface/shell_command.py b/nipype2pydra/interface/shell_command.py index 7d2dcb6..0e59e60 100644 --- a/nipype2pydra/interface/shell_command.py +++ b/nipype2pydra/interface/shell_command.py @@ -5,16 +5,14 @@ import logging from functools import cached_property from copy import copy -from operator import attrgetter, itemgetter -from importlib import import_module -from nipype.interfaces.base import BaseInterface, TraitedSpec +from operator import attrgetter from .base import BaseInterfaceConverter from ..utils import ( - UsedSymbols, split_source_into_statements, INBUILT_NIPYPE_TRAIT_NAMES, extract_args, find_super_method, + cleanup_function_body, ) from fileformats.core.mixin import WithClassifiers from fileformats.generic import File, Directory @@ -34,6 +32,8 @@ class ShellCommandInterfaceConverter(BaseInterfaceConverter): @cached_property def included_methods(self) -> ty.Tuple[str, ...]: included = [] + # if not self.method_omitted("__init__"): + # included.append("__init__"), if not self.method_omitted("_parse_inputs"): included.append("_parse_inputs"), if not self.method_omitted("_format_arg"): @@ -47,21 +47,19 @@ def included_methods(self) -> ty.Tuple[str, ...]: included.append("_list_outputs") return tuple(included) - def generate_code(self, input_fields, nonstd_types, output_fields) -> ty.Tuple[ - str, - UsedSymbols, - ]: + def generate_code(self, input_fields, nonstd_types, output_fields) -> str: """ Returns ------- converted_code : str the core converted code for the task - used_symbols: UsedSymbols + used: UsedSymbols symbols used in the code """ base_imports = [ "from pydra.engine import specs", + "import os", ] task_base = "ShellCommandTask" @@ -130,7 +128,8 @@ def types_to_names(spec_fields): # functions_imports, functions_str = functions_str.split("\n\n", 1) # spec_str = functions_str spec_str = ( - self.format_arg_code + self.init_code + + self.format_arg_code + self.parse_inputs_code + self.callables_code + self.defaults_code @@ -152,65 +151,23 @@ def types_to_names(spec_fields): spec_str = re.sub(r"'#([^'#]+)#'", r"\1", spec_str) - for m in sorted(self.referenced_methods, key=attrgetter("__name__")): + for m in sorted(self.used.methods, key=attrgetter("__name__")): if m.__name__ in self.included_methods: continue - if self.method_stacks[m.__name__][0] == self.nipype_interface._list_outputs: + if any( + s[0] == self.nipype_interface._list_outputs + for s in self.used.method_stacks[m.__name__] + ): additional_args = CALLABLES_ARGS else: additional_args = [] - spec_str += "\n\n" + self.process_method( + method_str = self.process_method( m, input_names, output_names, additional_args=additional_args ) + method_str = method_str.replace("os.getcwd()", "output_dir") + spec_str += "\n\n" + method_str - for new_name, (m, _) in sorted( - self.referenced_supers.items(), key=itemgetter(0) - ): - if self.method_stacks[new_name][0] == self.nipype_interface._list_outputs: - additional_args = CALLABLES_ARGS - else: - additional_args = [] - spec_str += "\n\n" + self.process_method( - m, - input_names, - output_names, - additional_args=additional_args, - new_name=new_name, - ) - - used = UsedSymbols.find( - self.nipype_module, - [ - self.format_arg_code, - self.parse_inputs_code, - self.callables_code, - self.defaults_code, - ] - + list(self.referenced_methods), - omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec], - omit_modules=self.package.omit_modules, - omit_functions=self.package.omit_functions, - omit_constants=self.package.omit_constants, - always_include=self.package.all_explicit, - translations=self.package.all_import_translations, - absolute_imports=True, - ) - for super_method, base in self.referenced_supers.values(): - super_used = UsedSymbols.find( - import_module(base.__module__), - [super_method], - omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec], - omit_modules=self.package.omit_modules, - omit_functions=self.package.omit_functions, - omit_constants=self.package.omit_constants, - always_include=self.package.all_explicit, - translations=self.package.all_import_translations, - absolute_imports=True, - collapse_intra_pkg=True, - ) - used.update(super_used) - - used.import_stmts.update( + self.used.import_stmts.update( self.construct_imports( nonstd_types, spec_str, @@ -219,7 +176,7 @@ def types_to_names(spec_fields): ) ) - return spec_str, used + return spec_str @cached_property def input_fields(self): @@ -272,7 +229,10 @@ def callable_output_field_names(self): def _format_arg_body(self): if self.method_omitted("_format_arg"): return "" - return self._unwrap_supers(self.nipype_interface._format_arg) + return self._unwrap_supers( + self.nipype_interface._format_arg, + base_replacement="return argstr.format(**inputs)", + ) @cached_property def _gen_filename_body(self): @@ -280,6 +240,17 @@ def _gen_filename_body(self): return "" return self._unwrap_supers(self.nipype_interface._gen_filename) + @property + def init_code(self): + if "__init__" not in self.included_methods: + return "" + body = self._unwrap_supers( + self.nipype_interface.__init__, + base_replacement="", + ) + code_str = f"def _init():\n {body}\n" + return code_str + @property def format_arg_code(self): if "_format_arg" not in self.included_methods: @@ -289,12 +260,15 @@ def format_arg_code(self): existing_args = list( inspect.signature(self.nipype_interface._format_arg).parameters )[1:] - name_arg, _, val_arg = existing_args + name_arg, spec_arg, val_arg = existing_args + + # Single-line replacement args body = re.sub( - r"trait_spec\.argstr % (.*)", + spec_arg + r"\.argstr % +([^\( ].+)", r"argstr.format(**{" + name_arg + r": \1})", body, ) + body = body.replace(f"{spec_arg}.argstr", "argstr") # Strip out return value body = re.sub( @@ -307,16 +281,10 @@ def format_arg_code(self): body, flags=re.MULTILINE, ) - if not body: + if not body.strip(): return "" body = self.unwrap_nested_methods(body, inputs_as_dict=True) - body = self.replace_supers( - body, - super_base=find_super_method( - self.nipype_interface, "_format_arg", include_class=True - )[1], - ) - # body = self._misc_cleanups(body) + code_str = f"""def _format_arg({name_arg}, {val_arg}, inputs, argstr):{self.parse_inputs_call} if {val_arg} is None: return "" @@ -339,7 +307,9 @@ def format_arg_code(self): def parse_inputs_code(self) -> str: if "_parse_inputs" not in self.included_methods: return "" - body = self._unwrap_supers(self.nipype_interface._parse_inputs) + body = self._unwrap_supers( + self.nipype_interface._parse_inputs, base_replacement="return {}" + ) body = self._process_inputs(body) body = re.sub( r"self.\_format_arg\((\w+), (\w+), (\w+)\)", @@ -349,18 +319,19 @@ def parse_inputs_code(self) -> str: # Strip out return value body = re.sub(r"\s*return .*\n", "", body) - if not body: + if not body.strip(): return "" body = self.unwrap_nested_methods(body, inputs_as_dict=True) - body = self.replace_supers( - body, - super_base=find_super_method( - self.nipype_interface, "_parse_inputs", include_class=True - )[1], - ) + # Supers are already unwrapped so this isn't necessary + # body = self.replace_supers( + # body, + # super_base=find_super_method( + # self.nipype_interface, "_parse_inputs", include_class=True + # )[1], + # ) # body = self._misc_cleanups(body) - code_str = "def _parse_inputs(inputs):\n parsed_inputs = {}" + code_str = "def _parse_inputs(inputs, output_dir=None):\n if not output_dir:\n output_dir = os.getcwd()\n parsed_inputs = {}" if re.findall(r"\bargstrs\b", body): code_str += f"\n argstrs = {self._format_argstrs!r}" code_str += f""" @@ -382,7 +353,7 @@ def defaults_code(self): ) body = self._process_inputs(body) - if not body: + if not body.strip(): return "" body = self.unwrap_nested_methods(body, inputs_as_dict=True) body = self.replace_supers( @@ -416,11 +387,14 @@ def callables_code(self): code_str = "" if "aggregate_outputs" in self.included_methods: func_name = "aggregate_outputs" - agg_body = self._unwrap_supers(self.nipype_interface.aggregate_outputs) + agg_body = self._unwrap_supers( + self.nipype_interface.aggregate_outputs, + base_replacement=" return {}", + ) need_list_outputs = bool(re.findall(r"\b_list_outputs\b", agg_body)) agg_body = self._process_inputs(agg_body) - if not agg_body: + if not agg_body.strip(): return "" agg_body = self.unwrap_nested_methods( agg_body, additional_args=CALLABLES_ARGS, inputs_as_dict=True @@ -476,10 +450,16 @@ def callables_code(self): return code_str else: - lo_body = self._unwrap_supers(self.nipype_interface._list_outputs) + lo_body = self._unwrap_supers( + self.nipype_interface._list_outputs, + base_replacement=" return {}", + ) lo_body = self._process_inputs(lo_body) + lo_body = re.sub( + r"(\w+) = self\.output_spec\(\).get\(\)", r"\1 = {}", lo_body + ) - if not lo_body: + if not lo_body.strip(): return "" lo_body = self.unwrap_nested_methods( lo_body, additional_args=CALLABLES_ARGS, inputs_as_dict=True @@ -491,7 +471,13 @@ def callables_code(self): )[1], ) - code_str += f"""def _list_outputs(inputs=None, stdout=None, stderr=None, output_dir=None):{inputs_as_dict_call}{self.parse_inputs_call} + parse_inputs_call = ( + "\n parsed_inputs = _parse_inputs(inputs, output_dir=output_dir)" + if self.parse_inputs_code + else "" + ) + + code_str += f"""def _list_outputs(inputs=None, stdout=None, stderr=None, output_dir=None):{inputs_as_dict_call}{parse_inputs_call} {lo_body} @@ -538,26 +524,66 @@ def _unwrap_supers( self, method: ty.Callable, base=None, base_replacement="", arg_names=None ) -> str: if base is None: - base = self.nipype_interface + base = find_super_method( + self.nipype_interface, method.__name__, include_class=True + )[1] if self.package.is_omitted(base): return base_replacement method_name = method.__name__ - sig, body = inspect.getsource(method).split("\n", 1) - body = _strip_doc_string(body) - args = extract_args(sig)[1][1:] + body = inspect.getsource(method).split("\n", 1)[1] + body = "\n" + _strip_doc_string(body) + body = cleanup_function_body(body) + args = list(inspect.signature(method).parameters.keys())[1:] if arg_names: for new, old in zip(args, arg_names): if new != old: body = re.sub(r"\b" + old + r"\b", new, body) super_re = re.compile( - r"\n\s*(return )?super\([^\)]*\)\." + method_name + r"\([^\)]+\)" + r"\n( *(?:return|\w+\s*=)?\s*super\([^\)]*\)\." + method_name + ")" ) if super_re.search(body): super_method, base = find_super_method(base, method_name) super_body = self._unwrap_supers( super_method, base, base_replacement, arg_names=args ) - body = super_re.sub("\n" + super_body, body) + return_indent = return_val = None + if super_body: + super_args = list(inspect.signature(super_method).parameters.keys())[1:] + lines = super_body.splitlines() + match = re.match(r"(\s*)return\s+(.*)", lines[-1]) + if match: + return_indent, return_val = match.groups() + super_body = "\n".join(lines[:-1]) + else: + super_args = [] + + splits = super_re.split(body) + new_body = splits[0] + for call, block in zip(splits[1::2], splits[2::2]): + _, args, post = extract_args(block) + indent = re.match(r"^(\s*)", call).group(1) + arg_str = ", ".join(args) + if "=" in call: + assert return_val + assigned_to_varname = call.split("=")[0].strip() + if return_val == assigned_to_varname: + replacement = super_body + else: + replacement = ( + super_body + + f"\n{indent}{assigned_to_varname} = {return_val}" + ) + elif super_body: + replacement = super_body + else: + if len(indent) > 4: + new_body += f"\n{indent}pass" + new_body += post[1:] + continue + for o, n in zip(args, super_args): + replacement = re.sub(r"\b" + o + r"\b", n, replacement) + new_body += replacement + "(" + arg_str + post + return new_body return body diff --git a/nipype2pydra/package.py b/nipype2pydra/package.py index 4e2741d..9052ad5 100644 --- a/nipype2pydra/package.py +++ b/nipype2pydra/package.py @@ -15,9 +15,10 @@ import black.report from tqdm import tqdm import yaml +import nipype.utils.logger from . import interface +from .symbols import UsedSymbols from .utils import ( - UsedSymbols, full_address, to_snake_case, cleanup_function_body, @@ -270,6 +271,10 @@ class PackageConverter: }, ) + def __attrs_post_init__(self): + # Adds in some default omissions + self.omit_constants.append("nipype.logging") + @init_depth.default def _init_depth_default(self) -> int: if self.name.startswith("pydra.tasks."): @@ -420,7 +425,7 @@ def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True): for conv in list(self.functions.values()) + list(self.classes.values()): intra_pkg_modules[conv.nipype_module_name].add(conv.nipype_object) - collect_intra_pkg_objects(conv.used_symbols) + collect_intra_pkg_objects(conv.used) for workflow in tqdm( workflows_to_include, "converting workflows from Nipype to Pydra syntax" @@ -447,7 +452,7 @@ def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True): package_root, already_converted=already_converted, ) - collect_intra_pkg_objects(converter.used_symbols) + collect_intra_pkg_objects(converter.used) for converter in tqdm( nipype_ports, "Porting interfaces from the core nipype package" @@ -456,7 +461,7 @@ def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True): package_root, already_converted=already_converted, ) - collect_intra_pkg_objects(converter.used_symbols, port_nipype=False) + collect_intra_pkg_objects(converter.used, port_nipype=False) # Write any additional functions in other modules in the package self.write_intra_pkg_modules(package_root, intra_pkg_modules) @@ -547,12 +552,8 @@ def write_intra_pkg_modules( mod, objs, pull_out_inline_imports=False, - translations=self.all_import_translations, - omit_classes=self.omit_classes, - omit_modules=self.omit_modules, - omit_functions=self.omit_functions, - omit_constants=self.omit_constants, always_include=self.all_explicit, + package=self, ) classes = used.classes + [ @@ -873,7 +874,7 @@ def write_to_module( if f"\nclass {klass.__name__}(" not in code_str: try: class_converter = self.classes[full_address(klass)] - converter_imports.extend(class_converter.used_symbols.import_stmts) + converter_imports.extend(class_converter.used.import_stmts) except KeyError: class_converter = ClassConverter.from_object(klass, self) code_str += "\n" + class_converter.converted_code + "\n" @@ -905,9 +906,7 @@ def write_to_module( if f"\ndef {func.__name__}(" not in code_str: if func.__name__ in self.functions: function_converter = self.functions[full_address(func)] - converter_imports.extend( - function_converter.used_symbols.import_stmts - ) + converter_imports.extend(function_converter.used.import_stmts) else: function_converter = FunctionConverter.from_object(func, self) code_str += "\n" + function_converter.converted_code + "\n" diff --git a/nipype2pydra/pkg_gen/__init__.py b/nipype2pydra/pkg_gen/__init__.py index 170a6d5..26bb23b 100644 --- a/nipype2pydra/pkg_gen/__init__.py +++ b/nipype2pydra/pkg_gen/__init__.py @@ -33,8 +33,8 @@ TestGenerator, DocTestGenerator, ) +from nipype2pydra.symbols import UsedSymbols from nipype2pydra.utils import ( - UsedSymbols, extract_args, get_source_code, cleanup_function_body, diff --git a/nipype2pydra/statements/workflow_build.py b/nipype2pydra/statements/workflow_build.py index ac2d5d1..5cf0307 100644 --- a/nipype2pydra/statements/workflow_build.py +++ b/nipype2pydra/statements/workflow_build.py @@ -539,9 +539,7 @@ def parse( if intf_name.endswith("("): # strip trailing parenthesis intf_name = intf_name[:-1] try: - imported_obj = workflow_converter.used_symbols.get_imported_object( - intf_name - ) + imported_obj = workflow_converter.used.get_imported_object(intf_name) except ImportError: imported_obj = None is_factory = "already-initialised" diff --git a/nipype2pydra/utils/symbols.py b/nipype2pydra/symbols.py similarity index 52% rename from nipype2pydra/utils/symbols.py rename to nipype2pydra/symbols.py index 0d169c8..74bd8ac 100644 --- a/nipype2pydra/utils/symbols.py +++ b/nipype2pydra/symbols.py @@ -8,11 +8,21 @@ from collections import defaultdict from logging import getLogger from importlib import import_module +import itertools +from functools import cached_property import attrs from nipype.interfaces.base import BaseInterface, TraitedSpec, isdefined, Undefined from nipype.interfaces.base import traits_extension -from .misc import split_source_into_statements, extract_args -from ..statements.imports import ImportStatement, parse_imports +from .utils.misc import ( + split_source_into_statements, + extract_args, + find_super_method, + get_return_line, +) +from .statements.imports import ImportStatement, parse_imports + +if ty.TYPE_CHECKING: + from .package import PackageConverter logger = getLogger("nipype2pydra") @@ -62,16 +72,10 @@ class UsedSymbols: functions: ty.Set[ty.Callable] = attrs.field(factory=set) classes: ty.List[type] = attrs.field(factory=list) constants: ty.Set[ty.Tuple[str, str]] = attrs.field(factory=set) - methods: ty.Set[ty.Callable] = attrs.field(factory=set) - class_attrs: ty.Set[ty.Tuple[str, str]] = attrs.field(factory=set) imported_funcs: ty.Set[ty.Tuple[str, ty.Callable]] = attrs.field(factory=set) imported_classes: ty.List[ty.Tuple[str, ty.Callable]] = attrs.field(factory=list) imported_constants: ty.Set[ty.Tuple[str, str, str]] = attrs.field(factory=set) - super_methoods: ty.Set[ty.Tuple[type, ty.Callable]] = attrs.field(factory=set) - super_class_attrs: ty.Set[ty.Tuple[type, ty.Tuple[str, str]]] = attrs.field( - factory=set - ) - klass: ty.Optional[type] = None + package: "PackageConverter" = attrs.field(default=None) ALWAYS_OMIT_MODULES = [ "traits.trait_handlers", # Old traits module, pre v6.0 @@ -83,6 +87,11 @@ class UsedSymbols: ] _cache = {} + _stmts_cache = {} + _imports_cache = {} + _funcs_cache = {} + _classes_cache = {} + _constants_cache = {} symbols_re = re.compile(r"(? "UsedSymbols": """Get the imports and local functions/classes/constants referenced in the @@ -197,10 +190,6 @@ def find( UsedSymbols a class containing the used symbols in the module """ - if omit_classes is None: - omit_classes = [] - if omit_modules is None: - omit_modules = [] if always_include is None: always_include = [] if isinstance(module, str): @@ -210,28 +199,45 @@ def find( tuple(f.__name__ if not isinstance(f, str) else f for f in function_bodies), collapse_intra_pkg, pull_out_inline_imports, - tuple(omit_constants) if omit_constants else None, - tuple(omit_functions) if omit_functions else None, - tuple(omit_classes) if omit_classes else None, - tuple(omit_modules) if omit_modules else None, - tuple(always_include) if always_include else None, - tuple(translations) if translations else None, ) try: return cls._cache[cache_key] except KeyError: pass - used = cls(module_name=module.__name__) + used = cls(module_name=module.__name__, package=package) cls._cache[cache_key] = used + used._find_referenced( + module, + function_bodies, + pull_out_inline_imports, + absolute_imports, + always_include, + collapse_intra_pkg, + ) + return used + + @classmethod + def _module_statements(cls, module) -> list: + try: + return cls._stmts_cache[module.__name__] + except KeyError: + pass source_code = inspect.getsource(module) - # Sort local func/classes/consts so they are iterated in a consistent order to - # remove stochastic element of traversal and make debugging easier - local_functions = sorted( - get_local_functions(module), key=attrgetter("__name__") + cls._stmts_cache[module.__name__] = stmts = split_source_into_statements( + source_code ) - local_constants = sorted(get_local_constants(module)) - local_classes = sorted(get_local_classes(module), key=attrgetter("__name__")) - module_statements = split_source_into_statements(source_code) + return stmts + + @classmethod + def _global_imports( + cls, module, package, absolute_imports, pull_out_inline_imports + ): + """Get the global imports in the module""" + try: + return cls._imports_cache[module.__name__] + except KeyError: + pass + module_statements = cls._module_statements(module) imports: ty.List[ImportStatement] = [] global_scope = True for stmt in module_statements: @@ -249,62 +255,39 @@ def find( parse_imports( stmt, relative_to=module, - translations=translations, + translations=package.all_import_translations, absolute=absolute_imports, ) ) imports = sorted(imports) + cls._imports_cache[module.__name__] = imports + return imports - all_src = "" # All the source code that is searched for symbols + def _find_referenced( + self, + module, + function_bodies, + pull_out_inline_imports, + absolute_imports, + always_include, + collapse_intra_pkg, + ): - used_symbols = set() - for function_body in function_bodies: - if not isinstance(function_body, str): - function_body = inspect.getsource(function_body) - all_src += "\n\n" + function_body - cls._get_symbols(function_body, used_symbols) + imports = self._global_imports( + module, self.package, absolute_imports, pull_out_inline_imports + ) + # Sort local func/classes/consts so they are iterated in a consistent order to + # remove stochastic element of traversal and make debugging easier - # Keep stepping into nested referenced local function/class sources until all local - # functions and constants that are referenced are added to the used symbols - prev_num_symbols = -1 - while len(used_symbols) > prev_num_symbols: - prev_num_symbols = len(used_symbols) - for local_func in local_functions: - if ( - local_func.__name__ in used_symbols - and local_func not in used.functions - ): - used.functions.add(local_func) - cls._get_symbols(local_func, used_symbols) - all_src += "\n\n" + inspect.getsource(local_func) - for local_class in local_classes: - if ( - local_class.__name__ in used_symbols - and local_class not in used.classes - ): - if issubclass(local_class, (BaseInterface, TraitedSpec)): - continue - used.classes.append(local_class) - class_body = inspect.getsource(local_class) - bases = extract_args(class_body)[1] - used_symbols.update(bases) - cls._get_symbols(class_body, used_symbols) - all_src += "\n\n" + class_body - for const_name, const_def in local_constants: - if ( - const_name in used_symbols - and (const_name, const_def) not in used.constants - ): - used.constants.add((const_name, const_def)) - cls._get_symbols(const_def, used_symbols) - all_src += "\n\n" + const_def - used_symbols -= set(cls.SYMBOLS_TO_IGNORE) + used_symbols, all_src = self._get_used_symbols(function_bodies, module) base_pkg = module.__name__.split(".")[0] module_omit_re = re.compile( r"^\b(" - + "|".join(cls.ALWAYS_OMIT_MODULES + [module.__name__] + omit_modules) + + "|".join( + self.ALWAYS_OMIT_MODULES + [module.__name__] + self.package.omit_modules + ) + r")\b", ) @@ -316,7 +299,12 @@ def find( continue # Filter out Nipype-specific objects that aren't relevant in Pydra module_omit = bool(module_omit_re.match(stmt.module_name)) - if module_omit or omit_classes or omit_functions or omit_constants: + if ( + module_omit + or self.package.omit_classes + or self.package.omit_functions + or self.package.omit_constants + ): to_include = [] for imported in stmt.values(): if imported.address in always_include: @@ -335,18 +323,23 @@ def find( ), imported.name, imported.statement.module_name, - omit_classes, - omit_functions, + self.package.omit_classes, + self.package.omit_functions, ) to_include.append(imported.local_name) continue if inspect.isclass(obj): - if omit_classes and issubclass(obj, tuple(omit_classes)): + if self.package.omit_classes and issubclass( + obj, tuple(self.package.omit_classes) + ): continue elif inspect.isfunction(obj): - if omit_functions and obj in omit_functions: + if ( + self.package.omit_functions + and obj in self.package.omit_functions + ): continue - elif imported.address in omit_constants: + elif imported.address in self.package.omit_constants: continue to_include.append(imported.local_name) if not to_include: @@ -363,12 +356,12 @@ def find( ) or inspect.isbuiltin(imported.object): # Case where an object is a nested import from a different package # which is imported in a chain from a neighbouring module - used.import_stmts.add( + self.import_stmts.add( imported.as_independent_statement(resolve=True) ) stmt.drop(imported) elif inspect.isfunction(imported.object): - used.imported_funcs.add((imported.local_name, imported.object)) + self.imported_funcs.add((imported.local_name, imported.object)) # Recursively include objects imported in the module intra_pkg_objs[import_module(imported.object.__module__)].add( imported.object @@ -382,8 +375,8 @@ def find( # like we did for functions here because we need to preserve the # order the classes are defined in the module in case one inherits # from the other - if class_def not in used.imported_classes: - used.imported_classes.append(class_def) + if class_def not in self.imported_classes: + self.imported_classes.append(class_def) # Recursively include objects imported in the module intra_pkg_objs[import_module(imported.object.__module__)].add( imported.object, @@ -404,15 +397,15 @@ def find( obj = getattr(imported.object, attr_name) if inspect.isfunction(obj): - used.imported_funcs.add((obj.__name__, obj)) + self.imported_funcs.add((obj.__name__, obj)) intra_pkg_objs[imported.object.__name__].add(obj) elif inspect.isclass(obj): class_def = (obj.__name__, obj) - if class_def not in used.imported_classes: - used.imported_classes.append(class_def) + if class_def not in self.imported_classes: + self.imported_classes.append(class_def) intra_pkg_objs[imported.object.__name__].add(obj) else: - used.imported_constants.add( + self.imported_constants.add( ( imported.object.__name__, attr_name, @@ -426,7 +419,7 @@ def find( f"Cannot inline imported module in statement '{stmt}'" ) else: - used.imported_constants.add( + self.imported_constants.add( ( stmt.module_name, imported.local_name, @@ -439,21 +432,64 @@ def find( # Recursively include neighbouring objects imported in the module for from_mod, inlined_objs in intra_pkg_objs.items(): - used_in_mod = cls.find( + used_in_mod = UsedSymbols.find( from_mod, function_bodies=inlined_objs, collapse_intra_pkg=collapse_intra_pkg, - translations=translations, - omit_modules=omit_modules, - omit_classes=omit_classes, - omit_functions=omit_functions, - omit_constants=omit_constants, + package=self.package, always_include=always_include, ) - used.update(used_in_mod, to_be_inlined=collapse_intra_pkg) + self.update(used_in_mod, to_be_inlined=collapse_intra_pkg) if stmt: - used.import_stmts.add(stmt) - return used + self.import_stmts.add(stmt) + + def _get_used_symbols(self, function_bodies, module): + """Search the given source code for any symbols that are used in the function bodies""" + + all_src = "" + used_symbols = set() + for function_body in function_bodies: + if not isinstance(function_body, str): + function_body = inspect.getsource(function_body) + all_src += "\n\n" + function_body + self._get_symbols(function_body, used_symbols) + + # Keep stepping into nested referenced local function/class sources until all local + # functions and constants that are referenced are added to the used symbols + prev_num_symbols = -1 + while len(used_symbols) > prev_num_symbols: + prev_num_symbols = len(used_symbols) + for local_func in self.local_functions(module): + if ( + local_func.__name__ in used_symbols + and local_func not in self.functions + ): + self.functions.add(local_func) + self._get_symbols(local_func, used_symbols) + all_src += "\n\n" + inspect.getsource(local_func) + for local_class in self.local_classes(module): + if ( + local_class.__name__ in used_symbols + and local_class not in self.classes + ): + if issubclass(local_class, (BaseInterface, TraitedSpec)): + continue + self.classes.append(local_class) + class_body = inspect.getsource(local_class) + bases = extract_args(class_body)[1] + used_symbols.update(bases) + self._get_symbols(class_body, used_symbols) + all_src += "\n\n" + class_body + for const_name, const_def in self.local_constants(module): + if ( + const_name in used_symbols + and (const_name, const_def) not in self.constants + ): + self.constants.add((const_name, const_def)) + self._get_symbols(const_def, used_symbols) + all_src += "\n\n" + const_def + used_symbols -= set(self.SYMBOLS_TO_IGNORE) + return used_symbols, all_src @classmethod def filter_imports( @@ -549,36 +585,347 @@ def get_imported_object(self, name: str) -> ty.Any: imported_obj = getattr(imported_obj, part) return imported_obj + @classmethod + def local_functions(cls, mod) -> ty.List[ty.Callable]: + """Get the functions defined in the module""" + try: + return cls._funcs_cache[mod.__name__] + except KeyError: + pass + functions = [] + for attr_name in dir(mod): + attr = getattr(mod, attr_name) + if inspect.isfunction(attr) and attr.__module__ == mod.__name__: + functions.append(attr) + functions = sorted(functions, key=attrgetter("__name__")) + cls._funcs_cache[mod.__name__] = functions + return functions + + @classmethod + def local_classes(cls, mod) -> ty.List[type]: + """Get the functions defined in the module""" + try: + return cls._classes_cache[mod.__name__] + except KeyError: + pass + classes = [] + for attr_name in dir(mod): + attr = getattr(mod, attr_name) + if inspect.isclass(attr) and attr.__module__ == mod.__name__: + classes.append(attr) + classes = sorted(classes, key=attrgetter("__name__")) + cls._classes_cache[mod.__name__] = classes + return classes + + @classmethod + def local_constants(cls, mod) -> ty.List[ty.Tuple[str, str]]: + """ + Get the constants defined in the module + """ + try: + return cls._constants_cache[mod.__name__] + except KeyError: + pass + source_code = inspect.getsource(mod) + source_code = source_code.replace("\\\n", " ") + constants = [] + for stmt in split_source_into_statements(source_code): + match = re.match(r"^(\w+) *= *(.*)", stmt, flags=re.MULTILINE | re.DOTALL) + if match: + constants.append(tuple(match.groups())) + constants = sorted(constants) + cls._constants_cache[mod.__name__] = constants + return constants + + +@attrs.define(kw_only=True) +class UsedClassSymbols(UsedSymbols): + """Class to detect/hold symbols that are used in class methods""" + + klass: type + methods: ty.Set[ty.Callable] = attrs.field(factory=set) + class_attrs: ty.Set[ty.Tuple[str, str]] = attrs.field(factory=set) + supers: ty.Dict[str, ty.Tuple[ty.Callable, type]] = attrs.field(factory=dict) + method_args: ty.Dict[str, ty.Set[str]] = attrs.field( + factory=lambda: defaultdict(set) + ) + method_returns: ty.Dict[str, ty.Set[str]] = attrs.field( + factory=lambda: defaultdict(set) + ) + method_stacks: ty.Dict[str, ty.Set[ty.Tuple[str, ...]]] = attrs.field( + factory=lambda: defaultdict(set) + ) + super_func_names: ty.Dict[type, ty.Dict[str, str]] = attrs.field( + factory=lambda: defaultdict(dict) + ) + outputs: ty.Set[str] = attrs.field(factory=set) + inputs: ty.Set[str] = attrs.field(factory=set) + + _class_attrs_cache = {} + + @classmethod + def find( + cls, + klass: type, + method_names: ty.List[str], + package: "PackageConverter", + collapse_intra_pkg: bool = False, + pull_out_inline_imports: bool = True, + always_include: ty.Optional[ty.List[str]] = None, + absolute_imports: bool = False, + ) -> "UsedSymbols": + """Get the imports and local functions/classes/constants referenced in the + provided function bodies, and those nested within them + + Parameters + ---------- + klass: type + the klass the methods belong to + method_names: list[str] + the source of all functions/classes (or the functions/classes themselves) + that need to be checked for used imports + collapse_intra_pkg : bool + whether functions and classes defined within the same package, but not the + same module, are to be included in the output module or not, i.e. whether + the local funcs/classes/constants they referenced need to be included also + pull_out_inline_imports : bool, optional + whether to pull out imports that are inline in the function bodies + or not, by default True + always_include : list[str], optional + a list of module objects (e.g. functions, classes, etc...) to always include + in list of used imports, even if they would be normally filtered out by + one of the `omit` clauses, by default None + translations : list[tuple[str, str]], optional + a list of tuples where the first element is the name of the symbol to be + replaced and the second element is the name of the symbol to replace it with, + regex supported, by default None + absolute_imports : bool, optional + whether to convert relative imports to absolute imports, by default False + + Returns + ------- + UsedSymbols + a class containing the used symbols in the module + """ + if always_include is None: + always_include = [] + cache_key = ( + klass.__name__, + klass.__module__, + tuple(method_names), + collapse_intra_pkg, + pull_out_inline_imports, + absolute_imports, + tuple(always_include), + ) + try: + return cls._cache[cache_key] + except KeyError: + pass + used = cls(klass=klass, module_name=klass.__module__, package=package) + cls._cache[cache_key] = used -def get_local_functions(mod) -> ty.List[ty.Callable]: - """Get the functions defined in the module""" - functions = [] - for attr_name in dir(mod): - attr = getattr(mod, attr_name) - if inspect.isfunction(attr) and attr.__module__ == mod.__name__: - functions.append(attr) - return functions + for method_name in method_names: + used._find_referenced_by_method( + method_name=method_name, + pull_out_inline_imports=pull_out_inline_imports, + absolute_imports=absolute_imports, + always_include=always_include, + collapse_intra_pkg=collapse_intra_pkg, + ) + return used -def get_local_classes(mod) -> ty.List[type]: - """Get the functions defined in the module""" - classes = [] - for attr_name in dir(mod): - attr = getattr(mod, attr_name) - if inspect.isclass(attr) and attr.__module__ == mod.__name__: - classes.append(attr) - return classes + def _find_referenced_by_method( + self, + method_name, + pull_out_inline_imports, + absolute_imports, + always_include, + collapse_intra_pkg, + already_processed=None, + method_stack=(), + super_base=None, + ): + if super_base: + method = getattr(super_base, method_name) + else: + method, super_base = find_super_method( + self.klass, method_name, include_class=True + ) + module = import_module(super_base.__module__) + if already_processed: + already_processed.add(method) + else: + already_processed = {method} + + if self.package.is_omitted(super_base): + return + self.methods.add(method) + method_stack += (method,) + method_body = inspect.getsource(method) + method_body = re.sub(r"\s*#.*", "", method_body) # Strip out comments + + meth_ref_inputs = set(re.findall(r"(?<=self\.inputs\.)(\w+)", method_body)) + meth_ref_outputs = set(re.findall(r"self\.(\w+) *=", method_body)) + + self._find_referenced( + module, + [method], + pull_out_inline_imports=pull_out_inline_imports, + absolute_imports=absolute_imports, + always_include=always_include, + collapse_intra_pkg=collapse_intra_pkg, + ) + # Find all referenced methods + ref_method_names = re.findall(r"(?<=self\.)(\w+)\(", method_body) + # Filter methods in omitted common base-classes like BaseInterface & CommandLine + ref_method_names = [ + m + for m in ref_method_names + if ( + m != "output_spec" + and not self.package.is_omitted( + find_super_method(super_base, m, include_class=True)[1] + ) + ) + ] + ref_methods = set(getattr(self.klass, m) for m in ref_method_names) + for meth in ref_methods: + if meth in already_processed: + continue + if inspect.isclass(meth): + logger.warning( + "Found %s type, that is instantiated as a method, " + "should be treated as a nested type, skipping for now", + meth, + ) + continue + ref_inputs, ref_outputs = self._find_referenced_by_method( + meth.__name__, + pull_out_inline_imports=pull_out_inline_imports, + absolute_imports=absolute_imports, + always_include=always_include, + collapse_intra_pkg=collapse_intra_pkg, + already_processed=already_processed, + method_stack=method_stack, + ) + self.method_args[meth.__name__].update(ref_inputs) + self.method_returns[meth.__name__].update(ref_outputs) + self.method_stacks[meth.__name__].add(method_stack) + self.inputs.update(ref_inputs) + self.outputs.update(ref_outputs) + meth_ref_inputs.update(ref_inputs) + meth_ref_outputs.update(ref_outputs) + + # Find all referenced supers + for match in re.findall(r"super\([^\)]*\)\.(\w+)\(", method_body): + super_method, base = find_super_method(super_base, match) + if self.package.is_omitted(base): + continue + func_name = self._different_parent_pkg_prefix(base) + match + if func_name not in self.supers: + self.supers[func_name] = (super_method, base) + self.super_func_names[super_base][match] = func_name + self.method_stacks[func_name].add(method_stack) + ref_inputs, ref_outputs = self._find_referenced_by_method( + super_method.__name__, + pull_out_inline_imports=pull_out_inline_imports, + absolute_imports=absolute_imports, + always_include=always_include, + collapse_intra_pkg=collapse_intra_pkg, + already_processed=already_processed, + method_stack=method_stack, + super_base=base, + ) + self.inputs.update(ref_inputs) + self.outputs.update(ref_outputs) + self.method_args[func_name].update(ref_inputs) + self.method_returns[func_name].update(ref_outputs) + meth_ref_inputs.update(ref_inputs) + meth_ref_outputs.update(ref_outputs) + + # Find all referenced constants/class attributes + local_class_attrs = self.local_class_attrs(super_base) + for match in re.findall(r"self\.(\w+)\b(?! =|\(.)", method_body): + try: + value = local_class_attrs[match] + except KeyError: + continue + base = find_super_method(super_base, match, include_class=True)[1] + if self.package.is_omitted(base): + base = self.klass + self.class_attrs.add((match, value)) + + return_value = get_return_line(method_body) + if return_value and return_value.startswith("self."): + self.outputs.update( + re.findall( + return_value + r"\[(?:'|\")(\w+)(?:'|\")\] *=", + method_body, + ) + ) + return sorted(meth_ref_inputs), sorted(meth_ref_outputs) + + # def _find_referenced( + # self, + # klass: type, + # method_names: ty.List[str], + # package: "PackageConverter", + # collapse_intra_pkg: bool = False, + # pull_out_inline_imports: bool = True, + # always_include: ty.Optional[ty.List[str]] = None, + # absolute_imports: bool = False, + # method_stack: ty.Tuple[ty.Callable] = (), + # already_processed: ty.Optional[ty.Set[ty.Callable]] = None, + # ) -> "UsedClassSymbols": + + # module = import_module(klass.__module__) + + # for method_name in method_names: + # self._find_referenced_by_method( + # method_name=method_name, + # already_processed=already_processed, + # method_stack=method_stack, + # ) + + def _different_parent_pkg_prefix(self, base: type) -> str: + """Return the common part of two package names""" + ref_parts = self.klass.__module__.split(".") + mod_parts = base.__module__.split(".") + different = [] + is_common = True + for r_part, m_part in zip( + itertools.chain(ref_parts, itertools.repeat(None)), mod_parts + ): + if r_part != m_part: + is_common = False + if not is_common: + different.append(m_part) + if not different: + return "" + return "_".join(different) + "__" + base.__name__ + "__" -def get_local_constants(mod) -> ty.List[ty.Tuple[str, str]]: - """ - Get the constants defined in the module - """ - source_code = inspect.getsource(mod) - source_code = source_code.replace("\\\n", " ") - local_vars = [] - for stmt in split_source_into_statements(source_code): - match = re.match(r"^(\w+) *= *(.*)", stmt, flags=re.MULTILINE | re.DOTALL) - if match: - local_vars.append(tuple(match.groups())) - return local_vars + @classmethod + def local_class_attrs(cls, klass) -> ty.Dict[str, str]: + """ + Get the constant attrs defined in the klass + """ + cache_key = klass.__module__ + "__" + klass.__name__ + try: + return cls._class_attrs_cache[cache_key] + except KeyError: + pass + source_code = inspect.getsource(klass) + source_code = source_code.replace("\\\n", " ") + class_attrs = [] + for stmt in split_source_into_statements(source_code): + match = re.match( + r"^ (\w+) *= *(.*)", stmt, flags=re.MULTILINE | re.DOTALL + ) + if match: + class_attrs.append(tuple(match.groups())) + class_attrs = dict(class_attrs) + cls._constants_cache[cache_key] = class_attrs + return class_attrs diff --git a/nipype2pydra/utils/__init__.py b/nipype2pydra/utils/__init__.py index 0aa572a..c6bb9b9 100644 --- a/nipype2pydra/utils/__init__.py +++ b/nipype2pydra/utils/__init__.py @@ -25,9 +25,3 @@ min_indentation, INBUILT_NIPYPE_TRAIT_NAMES, ) -from .symbols import ( # noqa: F401 - UsedSymbols, - get_local_functions, - get_local_classes, - get_local_constants, -) diff --git a/nipype2pydra/utils/misc.py b/nipype2pydra/utils/misc.py index c91c58c..a49d631 100644 --- a/nipype2pydra/utils/misc.py +++ b/nipype2pydra/utils/misc.py @@ -157,7 +157,9 @@ def add_exc_note(e, note): return e -def extract_args(snippet) -> ty.Tuple[str, ty.List[str], str]: +def extract_args( + snippet, drop_parens: bool = False +) -> ty.Tuple[str, ty.List[str], str]: """Splits the code snippet at the first opening brackets into a 3-tuple consisting of the preceding text + opening bracket, the arguments/items within the parenthesis/bracket pair, and the closing paren/bracket + trailing text. @@ -256,7 +258,11 @@ def extract_args(snippet) -> ty.Tuple[str, ty.List[str], str]: if matching_open == first and depth[matching_open] == 0: if next_item: contents.append(next_item) - return pre, contents, "".join(splits[i:]) + post = "".join(splits[i:]) + if drop_parens: + pre = pre[:-1] + post = post[1:] + return pre, contents, post if ( first and depth[first] == 1 diff --git a/nipype2pydra/utils/tests/test_utils_imports.py b/nipype2pydra/utils/tests/test_utils_imports.py index ecb12c4..b0ce1cc 100644 --- a/nipype2pydra/utils/tests/test_utils_imports.py +++ b/nipype2pydra/utils/tests/test_utils_imports.py @@ -1,5 +1,5 @@ import pytest -from nipype2pydra.utils.symbols import UsedSymbols +from nipype2pydra.symbols import UsedSymbols from nipype2pydra.statements.imports import ImportStatement, parse_imports import nipype.interfaces.utility diff --git a/nipype2pydra/workflow.py b/nipype2pydra/workflow.py index 033a5b7..a2bf908 100644 --- a/nipype2pydra/workflow.py +++ b/nipype2pydra/workflow.py @@ -14,8 +14,8 @@ import yaml from fileformats.core import from_mime, FileSet, Field from fileformats.core.exceptions import FormatRecognitionError +from .symbols import UsedSymbols from .utils import ( - UsedSymbols, split_source_into_statements, extract_args, full_address, @@ -622,17 +622,12 @@ def add_connection_from_output(self, out_conn: ConnectionStatement): self._add_output_conn(out_conn, "from") @cached_property - def used_symbols(self) -> UsedSymbols: + def used(self) -> UsedSymbols: return UsedSymbols.find( self.nipype_module, [self.func_body], collapse_intra_pkg=False, - omit_classes=self.package.omit_classes, - omit_modules=self.package.omit_modules, - omit_functions=self.package.omit_functions, - omit_constants=self.package.omit_constants, - always_include=self.package.all_explicit, - translations=self.package.all_import_translations, + package=self.package, ) @property @@ -666,10 +661,10 @@ def func_body(self): @cached_property def nested_workflows(self): potential_funcs = { - full_address(f[1]): f[0] for f in self.used_symbols.imported_funcs if f[0] + full_address(f[1]): f[0] for f in self.used.imported_funcs if f[0] } potential_funcs.update( - (full_address(f), f.__name__) for f in self.used_symbols.functions + (full_address(f), f.__name__) for f in self.used.functions ) return { potential_funcs[address]: workflow @@ -724,8 +719,8 @@ def write( if additional_funcs is None: additional_funcs = [] - used = self.used_symbols.copy() - all_used = self.used_symbols.copy() + used = self.used.copy() + all_used = self.used.copy() # Start writing output module with used imports and converted function body of # main workflow @@ -737,10 +732,10 @@ def write( if conv.address in already_converted: continue already_converted.add(conv.address) - all_used.update(conv.used_symbols) + all_used.update(conv.used) if name in local_func_names: code_str += "\n\n\n" + conv.converted_code - used.update(conv.used_symbols) + used.update(conv.used) else: conv_all_used = conv.write( package_root,