diff --git a/example-specs/pkg-gen/niworkflows.yaml b/example-specs/pkg-gen/niworkflows.yaml index 55651ea5..7da8e024 100644 --- a/example-specs/pkg-gen/niworkflows.yaml +++ b/example-specs/pkg-gen/niworkflows.yaml @@ -5,6 +5,7 @@ niworkflows: - niworkflows.interfaces.bids.ReadSidecarJSON - niworkflows.interfaces.fixes.FixHeaderApplyTransforms - niworkflows.interfaces.fixes.FixN4BiasFieldCorrection + - niworkflows.interfaces.fixes.FixHeaderRegistration - niworkflows.interfaces.header.SanitizeImage - niworkflows.interfaces.images.RobustAverage - niworkflows.interfaces.morphology.BinaryDilation diff --git a/nipype2pydra/cli/convert.py b/nipype2pydra/cli/convert.py index a50c27e4..9c3a8d65 100644 --- a/nipype2pydra/cli/convert.py +++ b/nipype2pydra/cli/convert.py @@ -66,10 +66,15 @@ def convert( # Clean previous version of output dir package_dir = converter.package_dir(package_root) if converter.interface_only: - shutil.rmtree(package_dir / "auto") + auto_dir = package_dir / "auto" + if auto_dir.exists(): + shutil.rmtree(auto_dir) else: for fspath in package_dir.iterdir(): - if fspath == package_dir / "__init__.py": + if fspath.parent == package_dir and fspath.name in ( + "_version.py", + "__init__.py", + ): continue if fspath.is_dir(): shutil.rmtree(fspath) diff --git a/nipype2pydra/helpers.py b/nipype2pydra/helpers.py index 1a3910d7..10e164f6 100644 --- a/nipype2pydra/helpers.py +++ b/nipype2pydra/helpers.py @@ -9,8 +9,8 @@ from types import ModuleType import black.report import yaml +from .symbols import UsedSymbols from .utils import ( - UsedSymbols, extract_args, full_address, multiline_comment, @@ -121,19 +121,15 @@ def nipype_object(self): return getattr(self.nipype_module, self.nipype_name) @cached_property - def used_symbols(self) -> UsedSymbols: + def used(self) -> UsedSymbols: used = UsedSymbols.find( self.nipype_module, [self.src], + package=self.package, collapse_intra_pkg=False, - omit_classes=self.package.omit_classes, - omit_modules=self.package.omit_modules, - omit_functions=self.package.omit_functions, - omit_constants=self.package.omit_constants, always_include=self.package.all_explicit, - translations=self.package.all_import_translations, ) - used.imports.update(i.to_statement() for i in self.imports) + used.import_stmts.update(i.to_statement() for i in self.imports) return used @cached_property @@ -147,12 +143,10 @@ def converted_code(self) -> ty.List[str]: @cached_property def nested_interfaces(self): potential_classes = { - full_address(c[1]): c[0] - for c in self.used_symbols.intra_pkg_classes - if c[0] + full_address(c[1]): c[0] for c in self.used.imported_classes if c[0] } potential_classes.update( - (full_address(c), c.__name__) for c in self.used_symbols.local_classes + (full_address(c), c.__name__) for c in self.used.classes ) return { potential_classes[address]: workflow @@ -350,6 +344,7 @@ def _converted_code(self) -> ty.Tuple[str, ty.List[str]]: # Write to file for debugging debug_file = "~/unparsable-nipype2pydra-output.py" with open(Path(debug_file).expanduser(), "w") as f: + f.write(f"# Attemping to convert {self.full_name}\n") f.write(code_str) raise RuntimeError( f"Black could not parse generated code (written to {debug_file}): " @@ -378,8 +373,7 @@ class ClassConverter(BaseHelperConverter): @cached_property def _converted_code(self) -> ty.Tuple[str, ty.List[str]]: - """Convert the Nipype workflow function to a Pydra workflow function and determine - the configuration parameters that are used + """Convert a class into Pydra- Returns ------- @@ -390,18 +384,30 @@ def _converted_code(self) -> ty.Tuple[str, ty.List[str]]: """ used_configs = set() - parts = re.split( - r"\n (?!\s|\))", replace_undefined(self.src), flags=re.MULTILINE - ) + + src = replace_undefined(self.src)[len("class ") :] + name, bases, class_body = extract_args(src, drop_parens=True) + bases = [ + b + for b in bases + if not self.package.is_omitted(getattr(self.nipype_module, b)) + ] + + parts = re.split(r"\n (?!\s|\))", class_body, flags=re.MULTILINE) converted_parts = [] - for part in parts: + for part in parts[1:]: if part.startswith("def"): converted_func, func_used_configs = self._convert_function(part) converted_parts.append(converted_func) used_configs.update(func_used_configs) else: converted_parts.append(part) - code_str = "\n ".join(converted_parts) + code_str = ( + f"class {name}(" + + ", ".join(bases) + + "):\n " + + "\n ".join(converted_parts) + ) # Format the the code before the find and replace so it is more predictable try: code_str = black.format_file_contents( @@ -413,6 +419,7 @@ def _converted_code(self) -> ty.Tuple[str, ty.List[str]]: # Write to file for debugging debug_file = "~/unparsable-nipype2pydra-output.py" with open(Path(debug_file).expanduser(), "w") as f: + f.write(f"# Attemping to convert {self.full_name}\n") f.write(code_str) raise RuntimeError( f"Black could not parse generated code (written to {debug_file}): " diff --git a/nipype2pydra/interface/base.py b/nipype2pydra/interface/base.py index f9b73149..71baa583 100644 --- a/nipype2pydra/interface/base.py +++ b/nipype2pydra/interface/base.py @@ -13,26 +13,40 @@ import attrs from attrs.converters import default_if_none import nipype.interfaces.base -from nipype.interfaces.base import traits_extension +from nipype.interfaces.base import ( + traits_extension, + CommandLine, + BaseInterface, +) +from nipype.interfaces.base.core import SimpleInterface from pydra.engine import specs from pydra.engine.helpers import ensure_list +from .. import symbols from ..utils import ( import_module_from_path, is_fileset, to_snake_case, - UsedSymbols, types_converter, from_dict_converter, unwrap_nested_type, + get_return_line, + cleanup_function_body, + insert_args_in_signature, + extract_args, + strip_comments, + find_super_method, + min_indentation, ) from ..statements import ( ImportStatement, parse_imports, ExplicitImport, from_list_to_imports, + make_imports_absolute, ) from fileformats.generic import File import nipype2pydra.package +from nipype2pydra.exceptions import UnmatchedParensException logger = logging.getLogger("nipype2pydra") @@ -354,6 +368,8 @@ class BaseInterfaceConverter(metaclass=ABCMeta): }, ) + _output_name_mappings: ty.Dict[str, str] = attrs.field(factory=dict) + def __attrs_post_init__(self): if self.output_module is None: if self.nipype_module.__name__.startswith("nipype.interfaces."): @@ -397,13 +413,17 @@ def nipype_output_spec(self) -> nipype.interfaces.base.BaseTraitedSpec: def input_fields(self): return self._convert_input_fields[0] + @property + def input_names(self): + return [f[0] for f in self.input_fields] + @cached_property def input_templates(self): return self._convert_input_fields[1] @cached_property def output_fields(self): - return self.convert_output_spec(fields_from_template=self.input_templates) + return self._convert_output_fields(fields_from_template=self.input_templates) @cached_property def nonstd_types(self): @@ -424,21 +444,43 @@ def add_nonstd_types(tp): add_nonstd_types(f[1]) return nonstd_types - @property + @cached_property def converted_code(self): - return self._converted[0] + return self.generate_code( + self.input_fields, self.nonstd_types, self.output_fields + ) - @property - def used_symbols(self): - return self._converted[1] + @cached_property + def used(self) -> "symbols.UsedClassSymbols": + return symbols.UsedClassSymbols.find( + klass=self.nipype_interface, + method_names=self.included_methods, + package=self.package, + collapse_intra_pkg=False, + pull_out_inline_imports=True, + absolute_imports=True, + ) @cached_property - def _converted(self): - """writing pydra task to the dile based on the input and output spec""" + def source_code(self): + with open(inspect.getsourcefile(self.nipype_interface)) as f: + return f.read() - return self.generate_code( - self.input_fields, self.nonstd_types, self.output_fields - ) + @cached_property + def methods(self): + """Get the methods defined in the interface""" + methods = [] + for attr_name in dir(self.nipype_interface): + if attr_name.startswith("__"): + continue + attr = getattr(self.nipype_interface, attr_name) + if inspect.isfunction(attr): + methods.append(attr) + return methods + + @cached_property + def local_function_names(self): + return [f.__name__ for f in self.local_functions] def write( self, @@ -460,8 +502,7 @@ def write( package_root=package_root, module_name=self.output_module, converted_code=self.converted_code, - used=self.used_symbols, - # inline_intra_pkg=True, + used=self.used, find_replace=self.find_replace + self.package.find_replace, ) @@ -472,8 +513,6 @@ def write( depth=self.package.init_depth, auto_import_depth=self.package.auto_import_init_depth, import_find_replace=self.package.import_find_replace, - # + [f.__name__ for f in self.used_symbols.local_functions] - # + [c.__name__ for c in self.used_symbols.local_classes], ) test_module_fspath = self.package.write_to_module( @@ -482,7 +521,7 @@ def write( self.output_module, f".tests.test_{self.task_name.lower()}" ), converted_code=self.converted_test_code, - used=self.used_symbols_test, + used=self.used_test, inline_intra_pkg=False, find_replace=self.find_replace, ) @@ -538,14 +577,20 @@ def pydra_fld_input(self, field, nm): val = getattr(field, key) if val is not None: if key == "argstr" and "%" in val: - val = self.string_formats(argstr=val, name=nm) + val = self.string_formats( + argstr=val, name=nm, type_=field.trait_type + ) + elif key == "mandatory" and pydra_default is not None: + val = False # Overwrite mandatory to False if default is provided pydra_metadata[pydra_key_nm] = val if getattr(field, "name_template"): template = getattr(field, "name_template") name_source = ensure_list(getattr(field, "name_source")) if name_source: - tmpl = self.string_formats(argstr=template, name=name_source[0]) + tmpl = self.string_formats( + argstr=template, name=name_source[0], type_=field.trait_type + ) else: tmpl = template if nm in self.nipype_interface.output_spec().class_trait_names(): @@ -570,17 +615,18 @@ def pydra_fld_input(self, field, nm): f"the filed {nm} has genfile=True, but no template or " "`callables_default` function in the callables_module provided" ) + self._output_name_mappings[getattr(field, "output_name")] = nm pydra_metadata.update(metadata_extra_spec) pos = pydra_metadata.get("position", None) - if pydra_default is not None and not pydra_metadata.get("mandatory", None): + if pydra_default is not None: # and not pydra_metadata.get("mandatory", None): return (pydra_type, pydra_default, pydra_metadata), pos else: return (pydra_type, pydra_metadata), pos - def convert_output_spec(self, fields_from_template): + def _convert_output_fields(self, fields_from_template): """creating fields list for pydra input spec""" pydra_fields_l = [] if not self.nipype_output_spec: @@ -646,11 +692,14 @@ def function_callables(self): "callables module must be provided if output_callables are set in the spec file" ) fun_str = "" - fun_names = list(set(self.outputs.callables.values())) - fun_names.sort() - for fun_nm in fun_names: - fun = getattr(self.callables_module, fun_nm) - fun_str += inspect.getsource(fun) + "\n" + if list(set(self.outputs.callables.values())): + fun_str = inspect.getsource(self.callables_module) + # fun_names.sort() + # for fun_nm in fun_names: + # fun = getattr(self.callables_module, fun_nm) + # fun_str += inspect.getsource(fun) + "\n" + # list_outputs = getattr(self.callables_module, "_list_outputs") + # fun_str += inspect.getsource(list_outputs) + "\n" return fun_str def pydra_type_converter(self, field, spec_type, name): @@ -704,11 +753,14 @@ def pydra_type_converter(self, field, spec_type, name): pydra_type = ty.Any return pydra_type - def string_formats(self, argstr, name): + def string_formats(self, argstr, name, type_): keys = re.findall(r"(%[0-9\.]*(?:s|d|i|g|f))", argstr) new_argstr = argstr for i, key in enumerate(keys): - repl = f"{name}" if len(keys) == 1 else f"{name}[{i}]" + if isinstance(type_, traits.trait_types.Bool): + repl = f"{name}:d" + else: + repl = f"{name}" if len(keys) == 1 else f"{name}[{i}]" match = re.match(r"%([0-9\.]+)f", key) if match: repl += ":" + match.group(1) @@ -718,14 +770,14 @@ def string_formats(self, argstr, name): @abstractmethod def generate_code(self, input_fields, nonstd_types, output_fields) -> ty.Tuple[ str, - UsedSymbols, + "symbols.UsedSymbols", ]: """ Returns ------- converted_code : str the core converted code for the task - used_symbols: UsedSymbols + used: UsedSymbols symbols used in the code """ @@ -769,7 +821,7 @@ def converted_test_code(self): return self._converted_test[0] @property - def used_symbols_test(self): + def used_test(self): return self._converted_test[1] @cached_property @@ -846,8 +898,8 @@ def _converted_test(self): }, ) - return spec_str, UsedSymbols( - module_name=self.nipype_module.__name__, imports=imports + return spec_str, symbols.UsedSymbols( + module_name=self.nipype_module.__name__, import_stmts=imports ) def create_doctests(self, input_fields, nonstd_types): @@ -902,6 +954,430 @@ def create_doctests(self, input_fields, nonstd_types): return " Examples\n -------\n\n" + doctest_str + def _misc_cleanups(self, body: str) -> str: + if hasattr(self.nipype_interface, "_cmd"): + body = body.replace("self.cmd", f'"{self.nipype_interface._cmd}"') + + body = body.replace("self.output_spec().get()", "{}") + body = body.replace("self._outputs().get()", "{}") + # body = re.sub( + # r"outputs = self\.(output_spec|_outputs)\(\).*$", + # r"outputs = {}", + # body, + # flags=re.MULTILINE, + # ) + body = re.sub(r"\bruntime\.(stdout|stderr)", r"\1", body) + body = re.sub(r"\boutputs\.(\w+)", r"outputs['\1']", body) + body = re.sub(r"getattr\(inputs, ([^)]+)\)", r"inputs[\1]", body) + body = re.sub( + r"setattr\(outputs, ([^,]+), ([^)]+)\)", r"outputs[\1] = \2", body + ) + body = re.sub(r"self\._results\[(?:'|\")(\w+)(?:'|\")\]", r"\1", body) + body = re.sub(r"\s+runtime.returncode = (.*)", "", body) + new_body = re.sub(r"self\.(\w+)\b(?!\()", r"self_dict['\1']", body) + if new_body != body: + body = " " * min_indentation(body) + "self_dict = {}\n" + new_body + body = body.replace("return runtime", "") + body = body.replace("TraitError", "KeyError") + return body + + # def _get_referenced( + # self, + # method: ty.Callable, + # referenced_funcs: ty.Set[ty.Callable], + # referenced_methods: ty.Set[ty.Callable], + # referenced_supers: ty.Dict[str, ty.Tuple[ty.Callable, type]], + # method_args: ty.Dict[str, ty.List[str]] = None, + # method_returns: ty.Dict[str, ty.List[str]] = None, + # method_stacks: ty.Dict[str, ty.Tuple[ty.Callable]] = None, + # method_supers: ty.Dict[type, ty.Dict[str, str]] = None, + # already_processed: ty.Set[ty.Callable] = None, + # method_stack: ty.Optional[ty.Tuple[ty.Callable]] = None, + # super_base: ty.Optional[type] = None, + # ) -> ty.Tuple[ty.Set, ty.Set]: + # """Get the local functions referenced in the source code + + # Parameters + # ---------- + # src: str + # the source of the file to extract the import statements from + # referenced_funcs: set[function] + # the set of local functions that have been referenced so far + # referenced_methods: set[function] + # the set of methods that have been referenced so far + # method_args: dict[str, list[str]] + # a dictionary to hold additional arguments that need to be added to each method, + # where the dictionary key is the names of the methods + # method_returns: dict[str, list[str]] + # a dictionary to hold the return values of each method, + # where the dictionary key is the names of the methods + + # Returns + # ------- + # referenced_inputs: set[str] + # inputs that have been referenced + # referenced_outputs: set[str] + # outputs that have been referenced + # """ + # if already_processed: + # already_processed.add(method) + # else: + # already_processed = {method} + # if method_stack is None: + # method_stack = (method,) + # else: + # method_stack += (method,) + # if super_base is None: + # super_base = self.nipype_interface + # method_body = inspect.getsource(method) + # method_body = re.sub(r"\s*#.*", "", method_body) # Strip out comments + # return_value = get_return_line(method_body) + # ref_local_func_names = re.findall(r"(? str: + # """Return the common part of two package names""" + # ref_parts = self.nipype_interface.__module__.split(".") + # mod_parts = base.__module__.split(".") + # common = [] + # for r_part, m_part in zip(ref_parts, mod_parts): + # if r_part == m_part: + # common.append(r_part) + # else: + # break + # if not common: + # return "" + # return "_".join(common + [base.__name__]) + "__" + + def process_method( + self, + method: str, + input_names: ty.List[str], + output_names: ty.List[str], + method_args: ty.Dict[str, ty.List[str]] = None, + method_returns: ty.Dict[str, ty.List[str]] = None, + additional_args: ty.Sequence[str] = (), + new_name: ty.Optional[str] = None, + super_base: ty.Optional[type] = None, + ): + if super_base is None: + super_base = self.nipype_interface + src = inspect.getsource(method) + pre, args, post = extract_args(src) + try: + args.remove("self") + except ValueError: + pass + if "runtime" in args: + args.remove("runtime") + args_to_add = list(self.used.method_args.get(method.__name__, [])) + list( + additional_args + ) + if args_to_add: + kwargs = [args.pop()] if args and args[-1].startswith("**") else [] + args += [f"{a}=None" for a in args_to_add] + kwargs + # Insert method args in signature if present + return_types, method_body = post.split(":", maxsplit=1) + method_body = method_body.split("\n", maxsplit=1)[1] + method_body = self.process_method_body( + method_body, input_names, output_names, super_base + ) + if self.used.method_returns.get(method.__name__): + return_args = self.used.method_returns[method.__name__] + method_body = ( + " " * min_indentation(method_body) + + " = ".join(return_args) + + " = attrs.NOTHING\n" + + method_body + ) + method_lines = method_body.rstrip().splitlines() + method_body = "\n".join(method_lines[:-1]) + last_line = method_lines[-1] + if "return" in last_line: + method_body += "\n" + last_line + "," + ",".join(return_args) + else: + method_body += ( + "\n" + last_line + "\n return " + ",".join(return_args) + ) + pre = re.sub(r"^\s*", "", pre, flags=re.MULTILINE) + pre = pre.replace("@staticmethod\n", "") + if new_name: + pre = re.sub(r"^def (\w+)\(", f"def {new_name}(", pre, flags=re.MULTILINE) + return f"{pre}{', '.join(args)}{return_types}:\n{method_body}" + + def process_method_body( + self, + method_body: str, + input_names: ty.List[str], + output_names: ty.List[str], + super_base: ty.Optional[type] = None, + unwrap_return_dict: bool = False, + ) -> str: + if not method_body: + return "" + if super_base is None: + super_base = self.nipype_interface + method_body = method_body.replace("if self.output_spec:", "if True:") + # Replace self.inputs. with in the function body + input_re = re.compile(r"self\.inputs\.(\w+)\b(?!\()") + unrecognised_inputs = set( + m for m in input_re.findall(method_body) if m not in input_names + ) + if unrecognised_inputs: + logger.warning( + "Found the following unrecognised (potentially dynamic) inputs %s in " + "'%s' task", + unrecognised_inputs, + self.task_name, + ) + method_body = input_re.sub(r"\1", method_body) + method_body = self.replace_supers(method_body, super_base) + + if unwrap_return_dict: + return_value = get_return_line(method_body) + if return_value is None: + return_value = "outputs" + output_re = re.compile(return_value + r"\[(?:'|\")(\w+)(?:'|\")\]") + unrecognised_outputs = set( + m for m in output_re.findall(method_body) if m not in output_names + ) + if unrecognised_outputs: + logger.warning( + "Found the following unrecognised (potentially dynamic) outputs %s in " + "'%s' task", + unrecognised_outputs, + self.task_name, + ) + method_body = output_re.sub(r"\1", method_body) + method_body = self.unwrap_nested_methods(method_body) + method_body = make_imports_absolute( + method_body, + super_base.__module__, + translations=self.package.all_import_translations, + ) + # method_body = self._misc_cleanups(method_body) + return method_body + + def replace_supers(self, method_body, super_base=None): + if super_base is None: + super_base = self.nipype_interface + name_map = self.used.super_func_names.get(super_base) + splits = re.split(r"super\([^\)]*\)\.(\w+)(?=\()", method_body) + new_body = splits[0] + for name, block in zip(splits[1::2], splits[2::2]): + super_method, base = find_super_method(super_base, name) + _, args, post = extract_args(block) + arg_str = ", ".join(args) + try: + new_body += self.SPECIAL_SUPER_MAPPINGS[super_method].format( + args=arg_str + ) + except KeyError: + try: + new_body += name_map[name] + "(" + arg_str + ")" + except KeyError: + if self.package.is_omitted(base): + raise KeyError( + f"Require special mapping for '{name}' in {base} class " + "as methods in that module are being omitted from the conversion" + ) from None + raise + new_body += post[1:] + return new_body + + def unwrap_nested_methods( + self, method_body, additional_args=(), inputs_as_dict: bool = False + ): + """ + Converts nested method calls into function calls + """ + method_body = self._misc_cleanups(method_body) + # Add args to the function signature of method calls + method_re = re.compile(r"self\.(\w+)(?=\()", flags=re.MULTILINE | re.DOTALL) + method_names = [m.__name__ for m in self.used.methods] + method_body = strip_comments(method_body) + omitted_methods = {} + for method_name in set( + m for m in method_re.findall(method_body) if m not in method_names + ): + omitted_methods[method_name] = find_super_method( + self.nipype_interface, method_name + )[0] + splits = method_re.split(method_body) + new_body = splits[0] + for name, args in zip(splits[1::2], splits[2::2]): + if name in omitted_methods: + args, post = extract_args(args)[1:] + omitted_method = omitted_methods[name] + try: + new_body += self.SPECIAL_SUPER_MAPPINGS[omitted_method].format( + args=", ".join(args) + ) + except KeyError: + raise KeyError( + f"Require special mapping for {omitted_methods[name]} method " + "as methods in that module are being omitted from the conversion" + ) from None + new_body += post[1:] # drop the leading parenthesis + continue + # Assign additional return values (which were previously saved to member + # attributes) to new variables from the method call + if self.used.method_returns[name]: + last_line = new_body.splitlines()[-1] + match = re.match(r" *([a-zA-Z0-9\,\.\_ ]+ *=)? *$", last_line) + if match: + if match.group(1): + new_body_lines = new_body.splitlines() + new_body = "\n".join(new_body_lines[:-1]) + last_line = new_body_lines[-1] + new_body += "\n" + re.sub( + r"^( *)([a-zA-Z0-9\,\.\_ ]+) *= *$", + r"\1\2, " + + ",".join(self.used.method_returns[name]) + + " = ", + last_line, + flags=re.MULTILINE, + ) + else: + new_body += ",".join(self.used.method_returns[name]) + " = " + else: + logger.warning( + "Could not augment the return value of the method converted from " + f"a function '{name}' with the previously assigned attributes " + f"{self.used.method_returns[name]} as the method doesn't have a " + f"singular return statement at the end of the method" + ) + # Insert additional arguments to the method call (which were previously + # accessed via member attributes) + args_to_be_inserted = list(self.used.method_args[name]) + list( + additional_args + ) + try: + new_body += name + insert_args_in_signature( + args, + [ + f"{a}=inputs['{a}']" if inputs_as_dict else f"{a}={a}" + for a in args_to_be_inserted + ], + ) + except UnmatchedParensException: + logger.warning( + f"Nested method call inside '{name}' in {self.full_address}, " + "the following args will need to be manually inserted up after the " + f"conversion: {args_to_be_inserted}" + ) + new_body += name + args + method_body = new_body + # Convert assignment to self attributes into method-scoped variables (hopefully + # there aren't any name clashes) + method_body = re.sub( + r"self\.(\w+ *)(?==)", r"\1", method_body, flags=re.MULTILINE | re.DOTALL + ) + return cleanup_function_body(method_body) + + SPECIAL_SUPER_MAPPINGS = { + CommandLine._list_outputs: "{{}}", + CommandLine._format_arg: "argstr.format(**inputs)", + CommandLine._filename_from_source: "{args} + '_generated'", + BaseInterface._check_version_requirements: "[]", + CommandLine._parse_inputs: "{{}}", + CommandLine._gen_filename: "", + BaseInterface.aggregate_outputs: "{{}}", + BaseInterface.run: "None", + BaseInterface._list_outputs: "{{}}", + BaseInterface.__init__: "", + SimpleInterface.__init__: "", + BaseInterface._outputs: "{{}}", + None: "", + } + INPUT_KEYS = [ "allowed_values", "argstr", diff --git a/nipype2pydra/interface/function.py b/nipype2pydra/interface/function.py index f2bc1533..1994e7fb 100644 --- a/nipype2pydra/interface/function.py +++ b/nipype2pydra/interface/function.py @@ -3,19 +3,11 @@ import inspect from operator import attrgetter from functools import cached_property -import itertools import logging import attrs from nipype.interfaces.base import BaseInterface, TraitedSpec from .base import BaseInterfaceConverter -from ..utils import ( - extract_args, - UsedSymbols, - get_local_functions, - get_local_constants, - cleanup_function_body, - insert_args_in_signature, -) +from ..symbols import UsedSymbols, get_return_line, find_super_method logger = logging.getLogger("nipype2pydra") @@ -24,6 +16,12 @@ @attrs.define(slots=False) class FunctionInterfaceConverter(BaseInterfaceConverter): + converter_type = "function" + + @property + def included_methods(self) -> ty.Tuple[str, ...]: + return ("__init__", "_run_interface", "_list_outputs") + def generate_code(self, input_fields, nonstd_types, output_fields) -> ty.Tuple[ str, UsedSymbols, @@ -33,7 +31,7 @@ def generate_code(self, input_fields, nonstd_types, output_fields) -> ty.Tuple[ ------- converted_code : str the core converted code for the task - used_symbols: UsedSymbols + used: UsedSymbols symbols used in the code """ @@ -66,46 +64,93 @@ def types_to_names(spec_fields): output_names = [o[0] for o in output_fields] output_type_names = [o[1] for o in output_fields_str] + method_body = "" + for field in input_fields: + if field[-1].get("copyfile"): + method_body += f" {field[0]} = {field[0]}.copy(Path.cwd())\n" + for field in output_fields: + method_body += f" {field[0]} = attrs.NOTHING\n" + + used_method_names = [m.__name__ for m in self.used.methods] + # Combined src of init and list_outputs + if "__init__" in used_method_names: + init_code = inspect.getsource(self.nipype_interface.__init__).strip() + init_class = find_super_method( + self.nipype_interface, "__init__", include_class=True + )[1] + assert not self.package.is_omitted(init_class) + # Strip out method def and return statement + method_lines = init_code.strip().split("\n")[1:] + if re.match(r"\s*return", method_lines[-1]): + method_lines = method_lines[:-1] + init_code = "\n".join(method_lines) + init_code = self.process_method_body( + init_code, + input_names, + output_names, + super_base=init_class, + ) + method_body += init_code + "\n" + # Combined src of run_interface and list_outputs - method_body = inspect.getsource(self.nipype_interface._run_interface).strip() - # Strip out method def and return statement - method_lines = method_body.strip().split("\n")[1:] - if re.match(r"\s*return", method_lines[-1]): - method_lines = method_lines[:-1] - method_body = "\n".join(method_lines) - lo_src = inspect.getsource(self.nipype_interface._list_outputs).strip() - # Strip out method def and return statement - lo_lines = lo_src.strip().split("\n")[1:] - if re.match(r"\s*(return|raise NotImplementedError)", lo_lines[-1]): - lo_lines = lo_lines[:-1] - lo_src = "\n".join(lo_lines) - method_body += "\n" + lo_src - method_body = self.process_method_body(method_body, input_names, output_names) + if "_run_interface" in used_method_names: + run_interface_code = inspect.getsource( + self.nipype_interface._run_interface + ).strip() + run_interface_class = find_super_method( + self.nipype_interface, "_run_interface", include_class=True + )[1] + assert not self.package.is_omitted(run_interface_class) + # Strip out method def and return statement + method_lines = run_interface_code.strip().split("\n")[1:] + if re.match(r"\s*return", method_lines[-1]): + method_lines = method_lines[:-1] + run_interface_code = "\n".join(method_lines) + run_interface_code = self.process_method_body( + run_interface_code, + input_names, + output_names, + super_base=run_interface_class, + ) + method_body += run_interface_code + "\n" + + if "_list_outputs" in used_method_names: + list_outputs_code = inspect.getsource( + self.nipype_interface._list_outputs + ).strip() + list_outputs_class = find_super_method( + self.nipype_interface, "_list_outputs", include_class=True + )[1] + assert not self.package.is_omitted(list_outputs_class) + # Strip out method def and return statement + lo_lines = list_outputs_code.strip().split("\n")[1:] + if re.match(r"\s*(return|raise NotImplementedError)", lo_lines[-1]): + lo_lines = lo_lines[:-1] + list_outputs_code = "\n".join(lo_lines) + list_outputs_code = self.process_method_body( + list_outputs_code, + input_names, + output_names, + super_base=list_outputs_class, + unwrap_return_dict=True, + ) + method_body += list_outputs_code + "\n" - used = UsedSymbols.find( - self.nipype_module, - [method_body] - + [ - inspect.getsource(f) - for f in itertools.chain( - self.referenced_local_functions, self.referenced_methods - ) - ], - omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec], - omit_modules=self.package.omit_modules, - omit_functions=self.package.omit_functions, - omit_constants=self.package.omit_constants, - always_include=self.package.all_explicit, - translations=self.package.all_import_translations, - absolute_imports=True, - ) + assert method_body, "Neither `run_interface` and `list_outputs` are defined" spec_str = "@pydra.mark.task\n" spec_str += "@pydra.mark.annotate({'return': {" spec_str += ", ".join(f"'{n}': {t}" for n, t, _ in output_fields_str) spec_str += "}})\n" spec_str += f"def {self.task_name}(" - spec_str += ", ".join(f"{i[0]}: {i[1]}" for i in input_fields_str) + spec_str += ", ".join( + ( + f"{i[0]}: {i[1]} = {i[2]!r}" + if len(i) == 4 + else f"{i[0]}: {i[1]} = attrs.NOTHING" + ) + for i in input_fields_str + ) spec_str += ")" if output_type_names: spec_str += "-> " @@ -123,8 +168,16 @@ def types_to_names(spec_fields): spec_str += "\n\n# Nipype methods converted into functions\n\n" - for m in sorted(self.referenced_methods, key=attrgetter("__name__")): - spec_str += "\n\n" + self.process_method(m, input_names, output_names) + for m in sorted(self.used.methods, key=attrgetter("__name__")): + if m.__name__ not in self.included_methods: + spec_str += "\n\n" + self.process_method( + m, + input_names, + output_names, + super_base=find_super_method( + self.nipype_interface, m.__name__, include_class=True + )[1], + ) # Replace runtime attributes additional_imports = set() @@ -134,287 +187,30 @@ def types_to_names(spec_fields): additional_imports.add(imprt) spec_str = repl_spec_str - used.imports = self.construct_imports( - nonstd_types, - spec_str, - include_task=False, - base=base_imports + list(used.imports) + list(additional_imports), - ) - - return spec_str, used - - def process_method( - self, - method: str, - input_names: ty.List[str], - output_names: ty.List[str], - method_args: ty.Dict[str, ty.List[str]] = None, - method_returns: ty.Dict[str, ty.List[str]] = None, - ): - src = inspect.getsource(method) - pre, args, post = extract_args(src) - args.remove("self") - if "runtime" in args: - args.remove("runtime") - if method.__name__ in self.method_args: - args += [f"{a}=None" for a in self.method_args[method.__name__]] - # Insert method args in signature if present - return_types, method_body = post.split(":", maxsplit=1) - method_body = method_body.split("\n", maxsplit=1)[1] - method_body = self.process_method_body(method_body, input_names, output_names) - if self.method_returns.get(method.__name__): - return_args = self.method_returns[method.__name__] - method_body = ( - " " + " = ".join(return_args) + " = attrs.NOTHING\n" + method_body + self.used.import_stmts.update( + self.construct_imports( + nonstd_types, + spec_str, + include_task=False, + base=base_imports + + list(self.used.import_stmts) + + list(additional_imports), ) - method_lines = method_body.rstrip().splitlines() - method_body = "\n".join(method_lines[:-1]) - last_line = method_lines[-1] - if "return" in last_line: - method_body += "\n" + last_line + "," + ",".join(return_args) - else: - method_body += ( - "\n" + last_line + "\n return " + ",".join(return_args) - ) - return f"{pre.strip()}{', '.join(args)}{return_types}:\n{method_body}" - - def process_method_body( - self, method_body: str, input_names: ty.List[str], output_names: ty.List[str] - ) -> str: - # Replace self.inputs. with in the function body - input_re = re.compile(r"self\.inputs\.(\w+)\b(?!\()") - unrecognised_inputs = set( - m for m in input_re.findall(method_body) if m not in input_names - ) - if unrecognised_inputs: - logger.warning( - "Found the following unrecognised (potentially dynamic) inputs %s in " - "'%s' task", - unrecognised_inputs, - self.task_name, - ) - method_body = input_re.sub(r"\1", method_body) - - output_re = re.compile(self.return_value + r"\[(?:'|\")(\w+)(?:'|\")\]") - unrecognised_outputs = set( - m for m in output_re.findall(method_body) if m not in output_names - ) - if unrecognised_outputs: - logger.warning( - "Found the following unrecognised (potentially dynamic) outputs %s in " - "'%s' task", - unrecognised_outputs, - self.task_name, - ) - method_body = output_re.sub(r"\1", method_body) - # Strip initialisation of outputs - method_body = re.sub( - r"outputs = self.output_spec().*", r"outputs = {}", method_body ) - # Add args to the function signature of method calls - method_re = re.compile(r"self\.(\w+)(?=\()", flags=re.MULTILINE | re.DOTALL) - method_names = [m.__name__ for m in self.referenced_methods] - unrecognised_methods = set( - m for m in method_re.findall(method_body) if m not in method_names - ) - assert ( - not unrecognised_methods - ), f"Found the following unrecognised methods {unrecognised_methods}" - splits = method_re.split(method_body) - new_body = splits[0] - for name, args in zip(splits[1::2], splits[2::2]): - # Assign additional return values (which were previously saved to member - # attributes) to new variables from the method call - if self.method_returns[name]: - last_line = new_body.splitlines()[-1] - match = re.match(r" *([a-zA-Z0-9\,\.\_ ]+ *=)? *$", last_line) - if match: - if match.group(1): - new_body_lines = new_body.splitlines() - new_body = "\n".join(new_body_lines[:-1]) - last_line = new_body_lines[-1] - new_body += "\n" + re.sub( - r"^( *)([a-zA-Z0-9\,\.\_ ]+) *= *$", - r"\1\2, " + ",".join(self.method_returns[name]) + " = ", - last_line, - flags=re.MULTILINE, - ) - else: - new_body += ",".join(self.method_returns[name]) + " = " - else: - raise NotImplementedError( - "Could not augment the return value of the method converted from " - "a function with the previously assigned attributes as it is used " - "directly. Need to replace the method call with a variable and " - "assign the return value to it on a previous line" - ) - # Insert additional arguments to the method call (which were previously - # accessed via member attributes) - new_body += name + insert_args_in_signature( - args, [f"{a}={a}" for a in self.method_args[name]] - ) - method_body = new_body - # Convert assignment to self attributes into method-scoped variables (hopefully - # there aren't any name clashes) - method_body = re.sub( - r"self\.(\w+ *)(?==)", r"\1", method_body, flags=re.MULTILINE | re.DOTALL - ) - return cleanup_function_body(method_body) - - @property - def referenced_local_functions(self): - return self._referenced_funcs_and_methods[0] - - @property - def referenced_methods(self): - return self._referenced_funcs_and_methods[1] - - @property - def method_args(self): - return self._referenced_funcs_and_methods[2] - - @property - def method_returns(self): - return self._referenced_funcs_and_methods[3] - @cached_property - def _referenced_funcs_and_methods(self): - referenced_funcs = set() - referenced_methods = set() - method_args = {} - method_returns = {} - self._get_referenced( - self.nipype_interface._run_interface, - referenced_funcs, - referenced_methods, - method_args, - method_returns, - ) - self._get_referenced( - self.nipype_interface._list_outputs, - referenced_funcs, - referenced_methods, - method_args, - method_returns, - ) - return referenced_funcs, referenced_methods, method_args, method_returns + return spec_str def replace_attributes(self, function_body: ty.Callable) -> str: """Replace self.inputs. with in the function body and add args to the function signature""" function_body = re.sub(r"self\.inputs\.(\w+)", r"\1", function_body) - def _get_referenced( - self, - method: ty.Callable, - referenced_funcs: ty.Set[ty.Callable], - referenced_methods: ty.Set[ty.Callable] = None, - method_args: ty.Dict[str, ty.List[str]] = None, - method_returns: ty.Dict[str, ty.List[str]] = None, - ): - """Get the local functions referenced in the source code - - Parameters - ---------- - src: str - the source of the file to extract the import statements from - referenced_funcs: set[function] - the set of local functions that have been referenced so far - referenced_methods: set[function] - the set of methods that have been referenced so far - method_args: dict[str, list[str]] - a dictionary to hold additional arguments that need to be added to each method, - where the dictionary key is the names of the methods - method_returns: dict[str, list[str]] - a dictionary to hold the return values of each method, - where the dictionary key is the names of the methods - """ - method_body = inspect.getsource(method) - method_body = re.sub(r"\s*#.*", "", method_body) # Strip out comments - ref_local_func_names = re.findall(r"(? ty.Tuple[ - str, - UsedSymbols, - ]: + + converter_type = "shell_command" + _format_argstrs: ty.Dict[str, str] = attrs.field(factory=dict) + + @cached_property + def included_methods(self) -> ty.Tuple[str, ...]: + included = [] + # if not self.method_omitted("__init__"): + # included.append("__init__"), + if not self.method_omitted("_parse_inputs"): + included.append("_parse_inputs"), + if not self.method_omitted("_format_arg"): + included.append("_format_arg") + if not self.method_omitted("_gen_filename"): + included.append("_gen_filename") + if self.callable_output_fields: + if not self.method_omitted("aggregate_outputs"): + included.append("aggregate_outputs") + if not self.method_omitted("_list_outputs"): + included.append("_list_outputs") + return tuple(included) + + def generate_code(self, input_fields, nonstd_types, output_fields) -> str: """ Returns ------- converted_code : str the core converted code for the task - used_symbols: UsedSymbols + used: UsedSymbols symbols used in the code """ base_imports = [ "from pydra.engine import specs", + "import os", ] task_base = "ShellCommandTask" @@ -80,10 +114,26 @@ def types_to_names(spec_fields): spec_fields_str.append(tuple(el)) return spec_fields_str - input_fields_str = types_to_names(spec_fields=input_fields) - output_fields_str = types_to_names(spec_fields=output_fields) - functions_str = self.function_callables() - spec_str = functions_str + input_names = [i[0] for i in input_fields] + output_names = [o[0] for o in output_fields] + input_fields_str = str(types_to_names(spec_fields=input_fields)) + input_fields_str = re.sub( + r"'formatter': '(\w+)'", r"'formatter': \1", input_fields_str + ) + output_fields_str = str(types_to_names(spec_fields=output_fields)) + output_fields_str = re.sub( + r"'callable': '(\w+)'", r"'callable': \1", output_fields_str + ) + # functions_str = self.function_callables() + # functions_imports, functions_str = functions_str.split("\n\n", 1) + # spec_str = functions_str + spec_str = ( + self.init_code + + self.format_arg_code + + self.parse_inputs_code + + self.callables_code + + self.defaults_code + ) spec_str += f"input_fields = {input_fields_str}\n" spec_str += f"{self.task_name}_input_spec = specs.SpecInfo(name='Input', fields=input_fields, bases=(specs.ShellSpec,))\n\n" spec_str += f"output_fields = {output_fields_str}\n" @@ -101,14 +151,443 @@ def types_to_names(spec_fields): spec_str = re.sub(r"'#([^'#]+)#'", r"\1", spec_str) - imports = self.construct_imports( - nonstd_types, - spec_str, - include_task=False, - base=base_imports, + for m in sorted(self.used.methods, key=attrgetter("__name__")): + if m.__name__ in self.included_methods: + continue + if any( + s[0] == self.nipype_interface._list_outputs + for s in self.used.method_stacks[m.__name__] + ): + additional_args = CALLABLES_ARGS + else: + additional_args = [] + method_str = self.process_method( + m, input_names, output_names, additional_args=additional_args + ) + method_str = method_str.replace("os.getcwd()", "output_dir") + spec_str += "\n\n" + method_str + + self.used.import_stmts.update( + self.construct_imports( + nonstd_types, + spec_str, + include_task=False, + base=base_imports, + ) + ) + + return spec_str + + @cached_property + def input_fields(self): + input_fields = super().input_fields + for field in input_fields: + if field[0] in self.formatted_input_field_names: + field[-1]["formatter"] = f"{field[0]}_formatter" + self._format_argstrs[field[0]] = field[-1].pop("argstr", "") + return input_fields + + @cached_property + def output_fields(self): + output_fields = super().output_fields + for field in self.callable_output_fields: + field[-1]["callable"] = f"{field[0]}_callable" + return output_fields + + @property + def formatted_input_field_names(self): + if not self._format_arg_body: + return [] + sig = inspect.getsource(self.nipype_interface._format_arg).split("\n", 1)[0] + name_arg = re.match(r"\s*def _format_arg\(self, (\w+),", sig).group(1) + return re.findall(name_arg + r" == \"(\w+)\"", self._format_arg_body) + + @property + def callable_default_input_field_names(self): + if not self._gen_filename_body: + return [] + sig = inspect.getsource(self.nipype_interface._format_arg).split("\n", 1)[0] + name_arg = re.match(r"\s*def _gen_filename\((\w+),", sig).group(1) + return re.findall(name_arg + r" == \"(\w+)\"", self._gen_filename_body) + + @property + def callable_output_fields(self): + return [ + f + for f in super().output_fields + if ( + "output_file_template" not in f[-1] + and f[0] not in INBUILT_NIPYPE_TRAIT_NAMES + ) + ] + + @property + def callable_output_field_names(self): + return [f[0] for f in self.callable_output_fields] + + @cached_property + def _format_arg_body(self): + if self.method_omitted("_format_arg"): + return "" + return self._unwrap_supers( + self.nipype_interface._format_arg, + base_replacement="return argstr.format(**inputs)", ) - # spec_str = "\n".join(str(i) for i in imports) + "\n\n" + spec_str - return spec_str, UsedSymbols( - module_name=self.nipype_module.__name__, imports=imports + @cached_property + def _gen_filename_body(self): + if self.method_omitted("_gen_filename"): + return "" + return self._unwrap_supers(self.nipype_interface._gen_filename) + + @property + def init_code(self): + if "__init__" not in self.included_methods: + return "" + body = self._unwrap_supers( + self.nipype_interface.__init__, + base_replacement="", + ) + code_str = f"def _init():\n {body}\n" + return code_str + + @property + def format_arg_code(self): + if "_format_arg" not in self.included_methods: + return "" + body = self._format_arg_body + body = self._process_inputs(body) + existing_args = list( + inspect.signature(self.nipype_interface._format_arg).parameters + )[1:] + name_arg, spec_arg, val_arg = existing_args + + # Single-line replacement args + body = re.sub( + spec_arg + r"\.argstr % +([^\( ].+)", + r"argstr.format(**{" + name_arg + r": \1})", + body, + ) + body = body.replace(f"{spec_arg}.argstr", "argstr") + + # Strip out return value + body = re.sub( + ( + r"^ return super\((\w+,\s*self)?\)\._format_arg\(" + + ", ".join(existing_args) + + r"\)\n" + ), + "return argstr.format(**inputs)", + body, + flags=re.MULTILINE, + ) + if not body.strip(): + return "" + body = self.unwrap_nested_methods(body, inputs_as_dict=True) + + code_str = f"""def _format_arg({name_arg}, {val_arg}, inputs, argstr):{self.parse_inputs_call} + if {val_arg} is None: + return "" +{body}""" + + if not code_str.rstrip().endswith("return argstr.format(**inputs)"): + code_str += "\n return argstr.format(**inputs)" + + code_str += "\n\n" + + for field_name in self.formatted_input_field_names: + code_str += ( + f"def {field_name}_formatter(field, inputs):\n" + f" return _format_arg({field_name!r}, field, inputs, " + f"argstr={self._format_argstrs[field_name]!r})\n\n\n" + ) + return code_str + + @property + def parse_inputs_code(self) -> str: + if "_parse_inputs" not in self.included_methods: + return "" + body = self._unwrap_supers( + self.nipype_interface._parse_inputs, base_replacement="return {}" + ) + body = self._process_inputs(body) + body = re.sub( + r"self.\_format_arg\((\w+), (\w+), (\w+)\)", + r"_format_arg(\1, \3, inputs, parsed_inputs, argstrs.get(\1))", + body, + ) + + # Strip out return value + body = re.sub(r"\s*return .*\n", "", body) + if not body.strip(): + return "" + body = self.unwrap_nested_methods(body, inputs_as_dict=True) + # Supers are already unwrapped so this isn't necessary + # body = self.replace_supers( + # body, + # super_base=find_super_method( + # self.nipype_interface, "_parse_inputs", include_class=True + # )[1], + # ) + # body = self._misc_cleanups(body) + + code_str = "def _parse_inputs(inputs, output_dir=None):\n if not output_dir:\n output_dir = os.getcwd()\n parsed_inputs = {}" + if re.findall(r"\bargstrs\b", body): + code_str += f"\n argstrs = {self._format_argstrs!r}" + code_str += f""" + skip = [] +{body} + return parsed_inputs + + +""" + return code_str + + @cached_property + def defaults_code(self): + if "_gen_filename" not in self.included_methods: + return "" + + body = _strip_doc_string( + inspect.getsource(self.nipype_interface._gen_filename).split("\n", 1)[-1] ) + body = self._process_inputs(body) + + if not body.strip(): + return "" + body = self.unwrap_nested_methods(body, inputs_as_dict=True) + body = self.replace_supers( + body, + super_base=find_super_method( + self.nipype_interface, "_gen_filename", include_class=True + )[1], + ) + # body = self._misc_cleanups(body) + + code_str = f"""def _gen_filename(name, inputs):{self.parse_inputs_call} +{body} +""" + # Create separate default function for each input field with genfile, which + # reference the magic "_gen_filename" method + for inpt_name, inpt in sorted( + self.nipype_interface.input_spec().traits().items() + ): + if inpt.genfile: + code_str += ( + f"\n\n\ndef {inpt_name}_default(inputs):\n" + f' return _gen_filename("{inpt_name}", inputs=inputs)\n\n' + ) + return code_str + + @cached_property + def callables_code(self): + + if not self.callable_output_fields: + return "" + code_str = "" + if "aggregate_outputs" in self.included_methods: + func_name = "aggregate_outputs" + agg_body = self._unwrap_supers( + self.nipype_interface.aggregate_outputs, + base_replacement=" return {}", + ) + need_list_outputs = bool(re.findall(r"\b_list_outputs\b", agg_body)) + agg_body = self._process_inputs(agg_body) + + if not agg_body.strip(): + return "" + agg_body = self.unwrap_nested_methods( + agg_body, additional_args=CALLABLES_ARGS, inputs_as_dict=True + ) + agg_body = self.replace_supers( + agg_body, + super_base=find_super_method( + self.nipype_interface, "aggregate_outputs", include_class=True + )[1], + ) + + code_str += f"""def aggregate_outputs(inputs=None, stdout=None, stderr=None, output_dir=None): + inputs = attrs.asdict(inputs){self.parse_inputs_call} + needed_outputs = {self.callable_output_field_names!r} +{agg_body} + + +""" + inputs_as_dict_call = "" + + else: + func_name = "_list_outputs" + inputs_as_dict_call = "\n inputs = attrs.asdict(inputs)" + need_list_outputs = True + + if need_list_outputs: + if "_list_outputs" not in self.included_methods: + assert self.callable_output_fields + # Need to reimplemt the base _list_outputs method in Pydra, which maps + # inputs with 'output_name' to outputs + for f in self.callable_output_fields: + output_name = f[0] + code_str += f"\n\n\ndef {output_name}_callable(output_dir, inputs, stdout, stderr):\n" + try: + input_name = self._output_name_mappings[output_name] + except KeyError: + logger.warning( + "Could not find input name with 'output_name' for " + "%s output, attempting to create something that can be worked " + "with", + output_name, + ) + if "_parse_inputs" in self.included_methods: + code_str += ( + f" parsed_inputs = _parse_inputs(inputs)\n" + f" return parsed_inputs.get('{output_name}', attrs.NOTHING)\n" + ) + else: + code_str += " raise NotImplementedError\n" + + else: + code_str += f" return inputs.{input_name}\n" + + return code_str + else: + lo_body = self._unwrap_supers( + self.nipype_interface._list_outputs, + base_replacement=" return {}", + ) + lo_body = self._process_inputs(lo_body) + lo_body = re.sub( + r"(\w+) = self\.output_spec\(\).get\(\)", r"\1 = {}", lo_body + ) + + if not lo_body.strip(): + return "" + lo_body = self.unwrap_nested_methods( + lo_body, additional_args=CALLABLES_ARGS, inputs_as_dict=True + ) + lo_body = self.replace_supers( + lo_body, + super_base=find_super_method( + self.nipype_interface, "_list_outputs", include_class=True + )[1], + ) + + parse_inputs_call = ( + "\n parsed_inputs = _parse_inputs(inputs, output_dir=output_dir)" + if self.parse_inputs_code + else "" + ) + + code_str += f"""def _list_outputs(inputs=None, stdout=None, stderr=None, output_dir=None):{inputs_as_dict_call}{parse_inputs_call} +{lo_body} + + +""" + # Create separate function for each output field in the "callables" section + for output_field in self.callable_output_fields: + output_name = output_field[0] + code_str += ( + f"\n\n\ndef {output_name}_callable(output_dir, inputs, stdout, stderr):\n" + f" outputs = {func_name}(output_dir=output_dir, inputs=inputs, stdout=stdout, stderr=stderr)\n" + ' return outputs.get("' + output_name + '", attrs.NOTHING)\n\n' + ) + return code_str + + def _process_inputs(self, body: str) -> str: + # Replace self.inputs. with in the function body + input_re = re.compile(r"self\.inputs\.(\w+)\b(?!\()") + unrecognised_inputs = set( + m for m in input_re.findall(body) if m not in self.input_names + ) + if unrecognised_inputs: + logger.warning( + "Found the following unrecognised (potentially dynamic) inputs %s in " + "'%s' task", + unrecognised_inputs, + self.task_name, + ) + body = input_re.sub(r"inputs['\1']", body) + body = re.sub(r"self\.(?!inputs)(\w+)\b(?!\()", r"parsed_inputs['\1']", body) + return body + + @property + def parse_inputs_call(self): + if not self.parse_inputs_code: + return "" + return "\n parsed_inputs = _parse_inputs(inputs) if inputs else {}" + + def method_omitted(self, method_name: str) -> bool: + return self.package.is_omitted( + find_super_method(self.nipype_interface, method_name, include_class=True)[1] + ) + + def _unwrap_supers( + self, method: ty.Callable, base=None, base_replacement="", arg_names=None + ) -> str: + if base is None: + base = find_super_method( + self.nipype_interface, method.__name__, include_class=True + )[1] + if self.package.is_omitted(base): + return base_replacement + method_name = method.__name__ + body = inspect.getsource(method).split("\n", 1)[1] + body = "\n" + _strip_doc_string(body) + body = cleanup_function_body(body) + args = list(inspect.signature(method).parameters.keys())[1:] + if arg_names: + for new, old in zip(args, arg_names): + if new != old: + body = re.sub(r"\b" + old + r"\b", new, body) + super_re = re.compile( + r"\n( *(?:return|\w+\s*=)?\s*super\([^\)]*\)\." + method_name + ")" + ) + if super_re.search(body): + super_method, base = find_super_method(base, method_name) + super_body = self._unwrap_supers( + super_method, base, base_replacement, arg_names=args + ) + return_indent = return_val = None + if super_body: + super_args = list(inspect.signature(super_method).parameters.keys())[1:] + lines = super_body.splitlines() + match = re.match(r"(\s*)return\s+(.*)", lines[-1]) + if match: + return_indent, return_val = match.groups() + super_body = "\n".join(lines[:-1]) + else: + super_args = [] + + splits = super_re.split(body) + new_body = splits[0] + for call, block in zip(splits[1::2], splits[2::2]): + _, args, post = extract_args(block) + indent = re.match(r"^(\s*)", call).group(1) + arg_str = ", ".join(args) + if "=" in call: + assert return_val + assigned_to_varname = call.split("=")[0].strip() + if return_val == assigned_to_varname: + replacement = super_body + else: + replacement = ( + super_body + + f"\n{indent}{assigned_to_varname} = {return_val}" + ) + elif super_body: + replacement = super_body + else: + if len(indent) > 4: + new_body += f"\n{indent}pass" + new_body += post[1:] + continue + for o, n in zip(args, super_args): + replacement = re.sub(r"\b" + o + r"\b", n, replacement) + new_body += replacement + "(" + arg_str + post + return new_body + return body + + +def _strip_doc_string(body: str) -> str: + if re.match(r"\s*(\"|')", body): + body = "\n".join(split_source_into_statements(body)[1:]) + return body diff --git a/nipype2pydra/package.py b/nipype2pydra/package.py index 176065ed..9052ad59 100644 --- a/nipype2pydra/package.py +++ b/nipype2pydra/package.py @@ -15,9 +15,10 @@ import black.report from tqdm import tqdm import yaml +import nipype.utils.logger from . import interface +from .symbols import UsedSymbols from .utils import ( - UsedSymbols, full_address, to_snake_case, cleanup_function_body, @@ -270,6 +271,10 @@ class PackageConverter: }, ) + def __attrs_post_init__(self): + # Adds in some default omissions + self.omit_constants.append("nipype.logging") + @init_depth.default def _init_depth_default(self) -> int: if self.name.startswith("pydra.tasks."): @@ -310,7 +315,7 @@ def all_import_translations(self) -> ty.List[ty.Tuple[str, str]]: @property def all_omit_modules(self) -> ty.List[str]: - return self.omit_modules + ["nipype.interfaces.utility"] + return self.omit_modules + UsedSymbols.ALWAYS_OMIT_MODULES @property def all_explicit(self): @@ -344,6 +349,17 @@ def config_defaults(self) -> ty.Dict[str, ty.Dict[str, str]]: all_defaults[name] = defaults return all_defaults + def is_omitted(self, obj: ty.Any) -> bool: + if full_address(obj) in self.omit_classes + self.omit_functions: + return True + if inspect.ismodule(obj): + mod_name = obj.__name__ + else: + mod_name = obj.__module__ + if any(re.match(m + r"\b", mod_name) for m in self.all_omit_modules): + return True + return False + def write(self, package_root: Path, to_include: ty.List[str] = None): """Writes the package to the specified package root""" @@ -389,7 +405,7 @@ def write(self, package_root: Path, to_include: ty.List[str] = None): workflow.prepare_connections() def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True): - for _, klass in used.intra_pkg_classes: + for _, klass in used.imported_classes: address = full_address(klass) if address in self.nipype_port_converters: if port_nipype: @@ -401,24 +417,24 @@ def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True): ) elif full_address(klass) not in self.interfaces: intra_pkg_modules[klass.__module__].add(klass) - for _, func in used.intra_pkg_funcs: + for _, func in used.imported_funcs: if full_address(func) not in list(self.workflows): intra_pkg_modules[func.__module__].add(func) - for const_mod_address, _, const_name in used.intra_pkg_constants: + for const_mod_address, _, const_name in used.imported_constants: intra_pkg_modules[const_mod_address].add(const_name) for conv in list(self.functions.values()) + list(self.classes.values()): intra_pkg_modules[conv.nipype_module_name].add(conv.nipype_object) - collect_intra_pkg_objects(conv.used_symbols) + collect_intra_pkg_objects(conv.used) - for converter in tqdm( + for workflow in tqdm( workflows_to_include, "converting workflows from Nipype to Pydra syntax" ): - all_used = converter.write( + all_used = workflow.write( package_root, already_converted=already_converted, ) - class_addrs = [full_address(c) for _, c in all_used.intra_pkg_classes] + class_addrs = [full_address(c) for _, c in all_used.imported_classes] included_addrs = [c.full_address for c in interfaces_to_include] interfaces_to_include.extend( self.interfaces[a] @@ -436,7 +452,7 @@ def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True): package_root, already_converted=already_converted, ) - collect_intra_pkg_objects(converter.used_symbols) + collect_intra_pkg_objects(converter.used) for converter in tqdm( nipype_ports, "Porting interfaces from the core nipype package" @@ -445,7 +461,7 @@ def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True): package_root, already_converted=already_converted, ) - collect_intra_pkg_objects(converter.used_symbols, port_nipype=False) + collect_intra_pkg_objects(converter.used, port_nipype=False) # Write any additional functions in other modules in the package self.write_intra_pkg_modules(package_root, intra_pkg_modules) @@ -536,22 +552,16 @@ def write_intra_pkg_modules( mod, objs, pull_out_inline_imports=False, - translations=self.all_import_translations, - omit_classes=self.omit_classes, - omit_modules=self.omit_modules, - omit_functions=self.omit_functions, - omit_constants=self.omit_constants, always_include=self.all_explicit, + package=self, ) - classes = used.local_classes + [ - o for o in objs if inspect.isclass(o) and o not in used.local_classes + classes = used.classes + [ + o for o in objs if inspect.isclass(o) and o not in used.classes ] - functions = list(used.local_functions) + [ - o - for o in objs - if inspect.isfunction(o) and o not in used.local_functions + functions = list(used.functions) + [ + o for o in objs if inspect.isfunction(o) and o not in used.functions ] self.write_to_module( @@ -559,10 +569,10 @@ def write_intra_pkg_modules( module_name=out_mod_name, used=UsedSymbols( module_name=mod_name, - imports=used.imports, + import_stmts=used.import_stmts, constants=used.constants, - local_classes=classes, - local_functions=functions, + classes=classes, + functions=functions, ), find_replace=self.find_replace, inline_intra_pkg=False, @@ -860,15 +870,11 @@ def write_to_module( existing_imports = parse_imports(existing_import_strs, relative_to=module_name) converter_imports = [] - for const_name, const_val in sorted(used.constants): - if f"\n{const_name} = " not in code_str: - code_str += f"\n{const_name} = {const_val}\n" - - for klass in used.local_classes: + for klass in used.classes: if f"\nclass {klass.__name__}(" not in code_str: try: class_converter = self.classes[full_address(klass)] - converter_imports.extend(class_converter.used_symbols.imports) + converter_imports.extend(class_converter.used.import_stmts) except KeyError: class_converter = ClassConverter.from_object(klass, self) code_str += "\n" + class_converter.converted_code + "\n" @@ -886,6 +892,7 @@ def write_to_module( # Write to file for debugging debug_file = "~/unparsable-nipype2pydra-output.py" with open(Path(debug_file).expanduser(), "w") as f: + f.write(f"# Attemping to convert {self.nipype_name}\n") f.write(converted_code) raise RuntimeError( f"Black could not parse generated code (written to {debug_file}): " @@ -895,11 +902,11 @@ def write_to_module( if converted_code.strip() not in code_str: code_str += "\n" + converted_code + "\n" - for func in sorted(used.local_functions, key=attrgetter("__name__")): + for func in sorted(used.functions, key=attrgetter("__name__")): if f"\ndef {func.__name__}(" not in code_str: if func.__name__ in self.functions: function_converter = self.functions[full_address(func)] - converter_imports.extend(function_converter.used_symbols.imports) + converter_imports.extend(function_converter.used.import_stmts) else: function_converter = FunctionConverter.from_object(func, self) code_str += "\n" + function_converter.converted_code + "\n" @@ -915,7 +922,7 @@ def write_to_module( code_str += ( "\n\n# Intra-package imports that have been inlined in this module\n\n" ) - for func_name, func in sorted(used.intra_pkg_funcs, key=itemgetter(0)): + for func_name, func in sorted(used.imported_funcs, key=itemgetter(0)): func_src = get_source_code(func) func_src = re.sub( r"^(#[^\n]+\ndef) (\w+)(?=\()", @@ -926,7 +933,7 @@ def write_to_module( code_str += "\n\n" + cleanup_function_body(func_src) inlined_symbols.append(func_name) - for klass_name, klass in sorted(used.intra_pkg_classes, key=itemgetter(0)): + for klass_name, klass in sorted(used.imported_classes, key=itemgetter(0)): klass_src = get_source_code(klass) klass_src = re.sub( r"^(#[^\n]+\nclass) (\w+)(?=\()", @@ -937,6 +944,10 @@ def write_to_module( code_str += "\n\n" + cleanup_function_body(klass_src) inlined_symbols.append(klass_name) + for const_name, const_val in sorted(used.constants): + if f"\n{const_name} = " not in code_str: + code_str += f"\n{const_name} = {const_val}\n" + # We run the formatter before the find/replace so that the find/replace can be more # predictable try: @@ -949,6 +960,7 @@ def write_to_module( # Write to file for debugging debug_file = "~/unparsable-nipype2pydra-output.py" with open(Path(debug_file).expanduser(), "w") as f: + f.write(f"# Attemping to convert {self.nipype_name}\n") f.write(code_str) raise RuntimeError( f"Black could not parse generated code (written to {debug_file}): {e}\n\n{code_str}" @@ -960,7 +972,7 @@ def write_to_module( imports = ImportStatement.collate( existing_imports + converter_imports - + [i for i in used.imports if not i.indent] + + [i for i in used.import_stmts if not i.indent] + GENERIC_PYDRA_IMPORTS + additional_imports ) @@ -1068,6 +1080,7 @@ def write_pkg_inits( # Write to file for debugging debug_file = "~/unparsable-nipype2pydra-output.py" with open(Path(debug_file).expanduser(), "w") as f: + f.write(f"# Attemping to convert {self.nipype_name}\n") f.write(code_str) raise RuntimeError( f"Black could not parse generated code (written to {debug_file}): " @@ -1092,6 +1105,7 @@ def write_pkg_inits( # Write to file for debugging debug_file = "~/unparsable-nipype2pydra-output.py" with open(Path(debug_file).expanduser(), "w") as f: + f.write(f"# Attemping to convert {self.nipype_name}\n") f.write(code_str) raise RuntimeError( f"Black could not parse generated code (written to {debug_file}): " diff --git a/nipype2pydra/pkg_gen/__init__.py b/nipype2pydra/pkg_gen/__init__.py index bb12826d..26bb23b1 100644 --- a/nipype2pydra/pkg_gen/__init__.py +++ b/nipype2pydra/pkg_gen/__init__.py @@ -33,8 +33,8 @@ TestGenerator, DocTestGenerator, ) +from nipype2pydra.symbols import UsedSymbols from nipype2pydra.utils import ( - UsedSymbols, extract_args, get_source_code, cleanup_function_body, @@ -403,7 +403,8 @@ def generate_callables(self, nipype_interface) -> str: if output_name not in INBUILT_NIPYPE_TRAIT_NAMES: callables_str += ( f"def {output_name}_callable(output_dir, inputs, stdout, stderr):\n" - " outputs = _list_outputs(output_dir=output_dir, inputs=inputs, stdout=stdout, stderr=stderr)\n" + " parsed_inputs = {}\n" + " outputs = _list_outputs(output_dir=output_dir, inputs=inputs, stdout=stdout, stderr=stderr, parsed_inputs=parsed_inputs)\n" ' return outputs["' + output_name + '"]\n\n' ) @@ -421,7 +422,7 @@ def generate_callables(self, nipype_interface) -> str: callables_str, fast=False, mode=black.FileMode() ) except black.parsing.InvalidInput as e: - with open(Path("~/Desktop/gen-code.py").expanduser(), "w") as f: + with open(Path("~/unparsable-gen-code.py").expanduser(), "w") as f: f.write(callables_str) raise RuntimeError( f"Black could not parse generated code: {e}\n\n{callables_str}" @@ -922,7 +923,7 @@ def test_generate_sample_{frmt.lower()}_data(): def get_callable_sources( - nipype_interface, + nipype_interface, attrs_as_parsed_inputs: bool = False ) -> ty.Tuple[ty.Set[str], ty.List[str], ty.Set[str], ty.Set[ty.Tuple[str, str]]]: """ Convert the _gen_filename method of a nipype interface into a function that can be @@ -1025,7 +1026,14 @@ def process_method( ) if hasattr(nipype_interface, "_cmd"): body = body.replace("self.cmd", f'"{nipype_interface._cmd}"') - body = body.replace("self.", "") + body = re.sub(r"getattr\(self\.inputs, (\w+), None\)", r"inputs.get(\1)", body) + body = re.sub(r"getattr\(self\.inputs, (\w+)\)", r"inputs[\1]", body) + if attrs_as_parsed_inputs: + body = re.sub( + r"self\.(?!inputs)(\w+)\b(?!\()", r"parsed_inputs['\1']", body + ) + else: + body = body.replace("self.", "") body = re.sub( r"super\([^\)]*\)\.(\w+)\(", lambda m: name_map[m.group(1)] + "(", body ) @@ -1115,13 +1123,13 @@ def insert_args_in_method_calls( mod = import_module(mod_name) used = UsedSymbols.find(mod, methods, omit_classes=(BaseInterface, TraitedSpec)) all_funcs.update(methods) - for func in used.local_functions: + for func in used.functions: all_funcs.add(cleanup_function_body(get_source_code(func))) - for klass in used.local_classes: + for klass in used.classes: klass_src = cleanup_function_body(get_source_code(klass)) if klass_src not in all_classes: all_classes.append(klass_src) - for new_func_name, func in used.intra_pkg_funcs: + for new_func_name, func in used.imported_funcs: if new_func_name is None: continue # Not referenced directly in this module func_src = get_source_code(func) @@ -1140,7 +1148,7 @@ def insert_args_in_method_calls( + match.group(2) ) all_funcs.add(cleanup_function_body(func_src)) - for new_klass_name, klass in used.intra_pkg_classes: + for new_klass_name, klass in used.imported_classes: if new_klass_name is None: continue # Not referenced directly in this module klass_src = get_source_code(klass) @@ -1161,7 +1169,7 @@ def insert_args_in_method_calls( klass_src = cleanup_function_body(klass_src) if klass_src not in all_classes: all_classes.append(klass_src) - all_imports.update(used.imports) + all_imports.update(used.import_stmts) all_constants.update(used.constants) return ( sorted( diff --git a/nipype2pydra/statements/__init__.py b/nipype2pydra/statements/__init__.py index 9300d5df..9a29c723 100644 --- a/nipype2pydra/statements/__init__.py +++ b/nipype2pydra/statements/__init__.py @@ -5,6 +5,7 @@ GENERIC_PYDRA_IMPORTS, ExplicitImport, from_list_to_imports, + make_imports_absolute, ) from .workflow_build import ( # noqa: F401 AddNestedWorkflowStatement, diff --git a/nipype2pydra/statements/imports.py b/nipype2pydra/statements/imports.py index 9639d3cc..a1b993a9 100644 --- a/nipype2pydra/statements/imports.py +++ b/nipype2pydra/statements/imports.py @@ -465,6 +465,39 @@ def collate( ) +def translate( + module_name: str, translations: ty.Sequence[ty.Tuple[str, str]] +) -> ty.Optional[str]: + for from_pkg, to_pkg in translations: + if re.match(from_pkg, module_name): + return re.sub( + from_pkg, + to_pkg, + module_name, + count=1, + flags=re.MULTILINE | re.DOTALL, + ) + return None + + +def make_imports_absolute( + src: str, modulepath: str, translations: ty.Sequence[ty.Tuple[str, str]] = () +) -> str: + parts = modulepath.split(".") + + def replacer(match) -> str: + levels = len(match.group(2)) + assert levels < len(parts) + abs_modulepath = ".".join(parts[:-levels]) + "." + match.group(3) + if translations: + newpath = translate(abs_modulepath, translations) + if newpath: + abs_modulepath = newpath + return f"{match.group(1)}from {abs_modulepath}" + + return re.sub(r"(\s*)from\s+(\.+)(\w+)", replacer, src) + + def parse_imports( stmts: ty.Union[str, ty.Sequence[str]], relative_to: ty.Union[str, ModuleType, None] = None, @@ -497,18 +530,6 @@ def parse_imports( ".__init__" if relative_to.__file__.endswith("__init__.py") else "" ) - def translate(module_name: str) -> ty.Optional[str]: - for from_pkg, to_pkg in translations: - if re.match(from_pkg, module_name): - return re.sub( - from_pkg, - to_pkg, - module_name, - count=1, - flags=re.MULTILINE | re.DOTALL, - ) - return None - parsed = [] for stmt in stmts: if isinstance(stmt, ImportStatement): @@ -543,7 +564,7 @@ def translate(module_name: str) -> ty.Optional[str]: ) if absolute: import_stmt = import_stmt.absolute() - import_stmt.translation = translate(import_stmt.module_name) + import_stmt.translation = translate(import_stmt.module_name, translations) parsed.append(import_stmt) else: @@ -554,7 +575,7 @@ def translate(module_name: str) -> ty.Optional[str]: ImportStatement( indent=match.group(1), imported={imp.local_name: imp}, - translation=translate(imp.name), + translation=translate(imp.name, translations), ) ) return parsed @@ -566,6 +587,8 @@ def translate(module_name: str) -> ty.Optional[str]: "from fileformats.generic import File, Directory", "from pydra.engine.specs import MultiInputObj", "from pathlib import Path", + "import json", + "import yaml", "import logging", "import pydra.mark", "import typing as ty", diff --git a/nipype2pydra/statements/workflow_build.py b/nipype2pydra/statements/workflow_build.py index 05c05496..5cf0307f 100644 --- a/nipype2pydra/statements/workflow_build.py +++ b/nipype2pydra/statements/workflow_build.py @@ -106,7 +106,7 @@ class DynamicField(VarField): callable: ty.Callable = attrs.field() def __repr__(self): - return f"DelayedVarField({self.varname}, callable={self.callable})" + return f"DynamicField({self.varname}, callable={self.callable})" @attrs.define @@ -539,9 +539,7 @@ def parse( if intf_name.endswith("("): # strip trailing parenthesis intf_name = intf_name[:-1] try: - imported_obj = workflow_converter.used_symbols.get_imported_object( - intf_name - ) + imported_obj = workflow_converter.used.get_imported_object(intf_name) except ImportError: imported_obj = None is_factory = "already-initialised" diff --git a/nipype2pydra/symbols.py b/nipype2pydra/symbols.py new file mode 100644 index 00000000..74bd8ac1 --- /dev/null +++ b/nipype2pydra/symbols.py @@ -0,0 +1,931 @@ +import typing as ty +import re +import keyword +import types +import inspect +import builtins +from operator import attrgetter +from collections import defaultdict +from logging import getLogger +from importlib import import_module +import itertools +from functools import cached_property +import attrs +from nipype.interfaces.base import BaseInterface, TraitedSpec, isdefined, Undefined +from nipype.interfaces.base import traits_extension +from .utils.misc import ( + split_source_into_statements, + extract_args, + find_super_method, + get_return_line, +) +from .statements.imports import ImportStatement, parse_imports + +if ty.TYPE_CHECKING: + from .package import PackageConverter + + +logger = getLogger("nipype2pydra") + + +@attrs.define +class UsedSymbols: + """ + A class to hold the used symbols in a module + + Parameters + ---------- + module_name: str + the name of the module containing the functions to be converted + imports : list[str] + the import statements that need to be included in the converted file + local_functions: set[callable] + locally-defined functions used in the function bodies, or nested functions thereof + local_classes : set[type] + like local_functions but classes + constants: set[tuple[str, str]] + constants used in the function bodies, or nested functions thereof, tuples consist + of the constant name and its definition + intra_pkg_funcs: set[tuple[str, callable]] + list of functions that are defined in neighbouring modules that need to be + included in the converted file (as opposed of just imported from independent + packages) along with the name that they were imported as and therefore should + be named as in the converted module if they are included inline + intra_pkg_classes: list[tuple[str, callable]] + like neigh_mod_funcs but classes + intra_pkg_constants: set[tuple[str, str, str]] + set of all the constants defined within the package that are referenced by the + function, (, , ), where + the local alias and the definition of the constant + methods: set[callable] + the names of the methods that are referenced, by default None is a function not + a method + class_constants: set[tuple[str, str]] + the names of the class attributes that are referenced by the method + + class_name: str, optional + the name of the class that the methods originate from + """ + + module_name: str + import_stmts: ty.Set[str] = attrs.field(factory=set) + functions: ty.Set[ty.Callable] = attrs.field(factory=set) + classes: ty.List[type] = attrs.field(factory=list) + constants: ty.Set[ty.Tuple[str, str]] = attrs.field(factory=set) + imported_funcs: ty.Set[ty.Tuple[str, ty.Callable]] = attrs.field(factory=set) + imported_classes: ty.List[ty.Tuple[str, ty.Callable]] = attrs.field(factory=list) + imported_constants: ty.Set[ty.Tuple[str, str, str]] = attrs.field(factory=set) + package: "PackageConverter" = attrs.field(default=None) + + ALWAYS_OMIT_MODULES = [ + "traits.trait_handlers", # Old traits module, pre v6.0 + "nipype.pipeline", + "nipype.logging", + "nipype.config", + "nipype.interfaces.base", + "nipype.interfaces.utility", + ] + + _cache = {} + _stmts_cache = {} + _imports_cache = {} + _funcs_cache = {} + _classes_cache = {} + _constants_cache = {} + + symbols_re = re.compile(r"(? "UsedSymbols": + """Get the imports and local functions/classes/constants referenced in the + provided function bodies, and those nested within them + + Parameters + ---------- + module: ModuleType + the module containing the functions to be converted + function_bodies: list[str | callable | type] + the source of all functions/classes (or the functions/classes themselves) + that need to be checked for used imports + collapse_intra_pkg : bool + whether functions and classes defined within the same package, but not the + same module, are to be included in the output module or not, i.e. whether + the local funcs/classes/constants they referenced need to be included also + pull_out_inline_imports : bool, optional + whether to pull out imports that are inline in the function bodies + or not, by default True + omit_constants : list, optional + a list of objects to filter out from the used symbols, + by default (Undefined, traits_extension.File, traits_extension.Directory) + omit_functions : list[type], optional + a list of functions to filter out from the used symbols, + by default [isdefined] + omit_classes : list[type], optional + a list of classes (including subclasses) to filter out from the used symbols, + by default None + always_include : list[str], optional + a list of module objects (e.g. functions, classes, etc...) to always include + in list of used imports, even if they would be normally filtered out by + one of the `omit` clauses, by default None + translations : list[tuple[str, str]], optional + a list of tuples where the first element is the name of the symbol to be + replaced and the second element is the name of the symbol to replace it with, + regex supported, by default None + absolute_imports : bool, optional + whether to convert relative imports to absolute imports, by default False + + Returns + ------- + UsedSymbols + a class containing the used symbols in the module + """ + if always_include is None: + always_include = [] + if isinstance(module, str): + module = import_module(module) + cache_key = ( + module.__name__, + tuple(f.__name__ if not isinstance(f, str) else f for f in function_bodies), + collapse_intra_pkg, + pull_out_inline_imports, + ) + try: + return cls._cache[cache_key] + except KeyError: + pass + used = cls(module_name=module.__name__, package=package) + cls._cache[cache_key] = used + used._find_referenced( + module, + function_bodies, + pull_out_inline_imports, + absolute_imports, + always_include, + collapse_intra_pkg, + ) + return used + + @classmethod + def _module_statements(cls, module) -> list: + try: + return cls._stmts_cache[module.__name__] + except KeyError: + pass + source_code = inspect.getsource(module) + cls._stmts_cache[module.__name__] = stmts = split_source_into_statements( + source_code + ) + return stmts + + @classmethod + def _global_imports( + cls, module, package, absolute_imports, pull_out_inline_imports + ): + """Get the global imports in the module""" + try: + return cls._imports_cache[module.__name__] + except KeyError: + pass + module_statements = cls._module_statements(module) + imports: ty.List[ImportStatement] = [] + global_scope = True + for stmt in module_statements: + if not pull_out_inline_imports: + if stmt.startswith("def ") or stmt.startswith("class "): + global_scope = False + continue + if not global_scope: + if stmt and not stmt.startswith(" "): + global_scope = True + else: + continue + if ImportStatement.matches(stmt): + imports.extend( + parse_imports( + stmt, + relative_to=module, + translations=package.all_import_translations, + absolute=absolute_imports, + ) + ) + imports = sorted(imports) + cls._imports_cache[module.__name__] = imports + return imports + + def _find_referenced( + self, + module, + function_bodies, + pull_out_inline_imports, + absolute_imports, + always_include, + collapse_intra_pkg, + ): + + imports = self._global_imports( + module, self.package, absolute_imports, pull_out_inline_imports + ) + # Sort local func/classes/consts so they are iterated in a consistent order to + # remove stochastic element of traversal and make debugging easier + + used_symbols, all_src = self._get_used_symbols(function_bodies, module) + + base_pkg = module.__name__.split(".")[0] + + module_omit_re = re.compile( + r"^\b(" + + "|".join( + self.ALWAYS_OMIT_MODULES + [module.__name__] + self.package.omit_modules + ) + + r")\b", + ) + + # functions to copy from a relative or nipype module into the output module + for stmt in imports: + stmt = stmt.only_include(used_symbols) + # Skip if no required symbols are in the import statement + if not stmt: + continue + # Filter out Nipype-specific objects that aren't relevant in Pydra + module_omit = bool(module_omit_re.match(stmt.module_name)) + if ( + module_omit + or self.package.omit_classes + or self.package.omit_functions + or self.package.omit_constants + ): + to_include = [] + for imported in stmt.values(): + if imported.address in always_include: + to_include.append(imported.local_name) + continue + if module_omit: + continue + try: + obj = imported.object + except ImportError: + logger.warning( + ( + "Could not import %s from %s, unable to check whether " + "it is is present in list of classes %s or objects %s " + "to be filtered out" + ), + imported.name, + imported.statement.module_name, + self.package.omit_classes, + self.package.omit_functions, + ) + to_include.append(imported.local_name) + continue + if inspect.isclass(obj): + if self.package.omit_classes and issubclass( + obj, tuple(self.package.omit_classes) + ): + continue + elif inspect.isfunction(obj): + if ( + self.package.omit_functions + and obj in self.package.omit_functions + ): + continue + elif imported.address in self.package.omit_constants: + continue + to_include.append(imported.local_name) + if not to_include: + continue + stmt = stmt.only_include(to_include) + intra_pkg_objs = defaultdict(set) + if stmt.in_package(base_pkg) or ( + stmt.in_package("nipype") and not stmt.in_package("nipype.interfaces") + ): + + for imported in list(stmt.values()): + if not ( + imported.in_package(base_pkg) or imported.in_package("nipype") + ) or inspect.isbuiltin(imported.object): + # Case where an object is a nested import from a different package + # which is imported in a chain from a neighbouring module + self.import_stmts.add( + imported.as_independent_statement(resolve=True) + ) + stmt.drop(imported) + elif inspect.isfunction(imported.object): + self.imported_funcs.add((imported.local_name, imported.object)) + # Recursively include objects imported in the module + intra_pkg_objs[import_module(imported.object.__module__)].add( + imported.object + ) + if collapse_intra_pkg: + stmt.drop(imported) + elif inspect.isclass(imported.object): + class_def = (imported.local_name, imported.object) + # Add the class to the intra_pkg_classes list if it is not + # already there. NB: we can't use a set for intra_pkg_classes + # like we did for functions here because we need to preserve the + # order the classes are defined in the module in case one inherits + # from the other + if class_def not in self.imported_classes: + self.imported_classes.append(class_def) + # Recursively include objects imported in the module + intra_pkg_objs[import_module(imported.object.__module__)].add( + imported.object, + ) + if collapse_intra_pkg: + stmt.drop(imported) + elif inspect.ismodule(imported.object): + # Skip if the module is the same as the module being converted + if module_omit_re.match(imported.object.__name__): + stmt.drop(imported) + continue + # Findall references to the module's attributes in the source code + # and add them to the list of intra package objects + used_attrs = re.findall( + r"\b" + imported.local_name + r"\.(\w+)\b", all_src + ) + for attr_name in used_attrs: + obj = getattr(imported.object, attr_name) + + if inspect.isfunction(obj): + self.imported_funcs.add((obj.__name__, obj)) + intra_pkg_objs[imported.object.__name__].add(obj) + elif inspect.isclass(obj): + class_def = (obj.__name__, obj) + if class_def not in self.imported_classes: + self.imported_classes.append(class_def) + intra_pkg_objs[imported.object.__name__].add(obj) + else: + self.imported_constants.add( + ( + imported.object.__name__, + attr_name, + attr_name, + ) + ) + intra_pkg_objs[imported.object.__name__].add(attr_name) + + if collapse_intra_pkg: + raise NotImplementedError( + f"Cannot inline imported module in statement '{stmt}'" + ) + else: + self.imported_constants.add( + ( + stmt.module_name, + imported.local_name, + imported.name, + ) + ) + intra_pkg_objs[stmt.module].add(imported.local_name) + if collapse_intra_pkg: + stmt.drop(imported) + + # Recursively include neighbouring objects imported in the module + for from_mod, inlined_objs in intra_pkg_objs.items(): + used_in_mod = UsedSymbols.find( + from_mod, + function_bodies=inlined_objs, + collapse_intra_pkg=collapse_intra_pkg, + package=self.package, + always_include=always_include, + ) + self.update(used_in_mod, to_be_inlined=collapse_intra_pkg) + if stmt: + self.import_stmts.add(stmt) + + def _get_used_symbols(self, function_bodies, module): + """Search the given source code for any symbols that are used in the function bodies""" + + all_src = "" + used_symbols = set() + for function_body in function_bodies: + if not isinstance(function_body, str): + function_body = inspect.getsource(function_body) + all_src += "\n\n" + function_body + self._get_symbols(function_body, used_symbols) + + # Keep stepping into nested referenced local function/class sources until all local + # functions and constants that are referenced are added to the used symbols + prev_num_symbols = -1 + while len(used_symbols) > prev_num_symbols: + prev_num_symbols = len(used_symbols) + for local_func in self.local_functions(module): + if ( + local_func.__name__ in used_symbols + and local_func not in self.functions + ): + self.functions.add(local_func) + self._get_symbols(local_func, used_symbols) + all_src += "\n\n" + inspect.getsource(local_func) + for local_class in self.local_classes(module): + if ( + local_class.__name__ in used_symbols + and local_class not in self.classes + ): + if issubclass(local_class, (BaseInterface, TraitedSpec)): + continue + self.classes.append(local_class) + class_body = inspect.getsource(local_class) + bases = extract_args(class_body)[1] + used_symbols.update(bases) + self._get_symbols(class_body, used_symbols) + all_src += "\n\n" + class_body + for const_name, const_def in self.local_constants(module): + if ( + const_name in used_symbols + and (const_name, const_def) not in self.constants + ): + self.constants.add((const_name, const_def)) + self._get_symbols(const_def, used_symbols) + all_src += "\n\n" + const_def + used_symbols -= set(self.SYMBOLS_TO_IGNORE) + return used_symbols, all_src + + @classmethod + def filter_imports( + cls, imports: ty.List[ImportStatement], source_code: str + ) -> ty.List[ImportStatement]: + """Filter out the imports that are not used in the function bodies""" + symbols = set() + cls._get_symbols(source_code, symbols) + symbols -= set(cls.SYMBOLS_TO_IGNORE) + filtered = [] + for stmt in imports: + if stmt.from_: + stmt = stmt.only_include(symbols) + if stmt: + filtered.append(stmt) + elif stmt.sole_imported.local_name in symbols: + filtered.append(stmt) + return filtered + + def copy(self) -> "UsedSymbols": + return attrs.evolve(self) + + @classmethod + def _get_symbols( + cls, func: ty.Union[str, ty.Callable, ty.Type], symbols: ty.Set[str] + ): + """Get the symbols used in a function body""" + try: + fbody = inspect.getsource(func) + except TypeError: + fbody = func + for stmt in split_source_into_statements(fbody): + if stmt and not re.match( + r"\s*(#|\"|'|from |import |r'|r\"|f'|f\")", stmt + ): # skip comments/docs + for sym in cls.symbols_re.findall(stmt): + if "." in sym: + parts = sym.split(".") + symbols.update( + ".".join(parts[: i + 1]) for i in range(len(parts)) + ) + else: + symbols.add(sym) + + # Nipype-specific names and Python keywords + SYMBOLS_TO_IGNORE = ["isdefined"] + keyword.kwlist + list(builtins.__dict__.keys()) + + def get_imported_object(self, name: str) -> ty.Any: + """Get the object with the given name from used import statements + + Parameters + ---------- + name : str + the name of the object to get + imports : list[ImportStatement], optional + the import statements to search in (used in tests), by default the imports + in the used symbols + + Returns + ------- + Any + the object with the given name referenced by the given import statements + """ + # Check to see if it isn't an imported module + # imported = { + # i.sole_imported.local_name: i.sole_imported.object + # for i in self.imports + # if not i.from_ + # } + all_imported = {} + for stmt in self.import_stmts: + all_imported.update(stmt.imported) + try: + return all_imported[name].object + except KeyError: + pass + parts = name.rsplit(".") + imported_obj = None + for i in range(1, len(parts)): + obj_name = ".".join(parts[:-i]) + try: + imported_obj = all_imported[obj_name].object + except KeyError: + continue + else: + break + if imported_obj is None: + raise ImportError( + f"Could not find object named {name} in any of the imported modules:\n" + + "\n".join(str(i) for i in self.import_stmts) + ) + for part in parts[-i:]: + imported_obj = getattr(imported_obj, part) + return imported_obj + + @classmethod + def local_functions(cls, mod) -> ty.List[ty.Callable]: + """Get the functions defined in the module""" + try: + return cls._funcs_cache[mod.__name__] + except KeyError: + pass + functions = [] + for attr_name in dir(mod): + attr = getattr(mod, attr_name) + if inspect.isfunction(attr) and attr.__module__ == mod.__name__: + functions.append(attr) + functions = sorted(functions, key=attrgetter("__name__")) + cls._funcs_cache[mod.__name__] = functions + return functions + + @classmethod + def local_classes(cls, mod) -> ty.List[type]: + """Get the functions defined in the module""" + try: + return cls._classes_cache[mod.__name__] + except KeyError: + pass + classes = [] + for attr_name in dir(mod): + attr = getattr(mod, attr_name) + if inspect.isclass(attr) and attr.__module__ == mod.__name__: + classes.append(attr) + classes = sorted(classes, key=attrgetter("__name__")) + cls._classes_cache[mod.__name__] = classes + return classes + + @classmethod + def local_constants(cls, mod) -> ty.List[ty.Tuple[str, str]]: + """ + Get the constants defined in the module + """ + try: + return cls._constants_cache[mod.__name__] + except KeyError: + pass + source_code = inspect.getsource(mod) + source_code = source_code.replace("\\\n", " ") + constants = [] + for stmt in split_source_into_statements(source_code): + match = re.match(r"^(\w+) *= *(.*)", stmt, flags=re.MULTILINE | re.DOTALL) + if match: + constants.append(tuple(match.groups())) + constants = sorted(constants) + cls._constants_cache[mod.__name__] = constants + return constants + + +@attrs.define(kw_only=True) +class UsedClassSymbols(UsedSymbols): + """Class to detect/hold symbols that are used in class methods""" + + klass: type + methods: ty.Set[ty.Callable] = attrs.field(factory=set) + class_attrs: ty.Set[ty.Tuple[str, str]] = attrs.field(factory=set) + supers: ty.Dict[str, ty.Tuple[ty.Callable, type]] = attrs.field(factory=dict) + method_args: ty.Dict[str, ty.Set[str]] = attrs.field( + factory=lambda: defaultdict(set) + ) + method_returns: ty.Dict[str, ty.Set[str]] = attrs.field( + factory=lambda: defaultdict(set) + ) + method_stacks: ty.Dict[str, ty.Set[ty.Tuple[str, ...]]] = attrs.field( + factory=lambda: defaultdict(set) + ) + super_func_names: ty.Dict[type, ty.Dict[str, str]] = attrs.field( + factory=lambda: defaultdict(dict) + ) + outputs: ty.Set[str] = attrs.field(factory=set) + inputs: ty.Set[str] = attrs.field(factory=set) + + _class_attrs_cache = {} + + @classmethod + def find( + cls, + klass: type, + method_names: ty.List[str], + package: "PackageConverter", + collapse_intra_pkg: bool = False, + pull_out_inline_imports: bool = True, + always_include: ty.Optional[ty.List[str]] = None, + absolute_imports: bool = False, + ) -> "UsedSymbols": + """Get the imports and local functions/classes/constants referenced in the + provided function bodies, and those nested within them + + Parameters + ---------- + klass: type + the klass the methods belong to + method_names: list[str] + the source of all functions/classes (or the functions/classes themselves) + that need to be checked for used imports + collapse_intra_pkg : bool + whether functions and classes defined within the same package, but not the + same module, are to be included in the output module or not, i.e. whether + the local funcs/classes/constants they referenced need to be included also + pull_out_inline_imports : bool, optional + whether to pull out imports that are inline in the function bodies + or not, by default True + always_include : list[str], optional + a list of module objects (e.g. functions, classes, etc...) to always include + in list of used imports, even if they would be normally filtered out by + one of the `omit` clauses, by default None + translations : list[tuple[str, str]], optional + a list of tuples where the first element is the name of the symbol to be + replaced and the second element is the name of the symbol to replace it with, + regex supported, by default None + absolute_imports : bool, optional + whether to convert relative imports to absolute imports, by default False + + Returns + ------- + UsedSymbols + a class containing the used symbols in the module + """ + if always_include is None: + always_include = [] + cache_key = ( + klass.__name__, + klass.__module__, + tuple(method_names), + collapse_intra_pkg, + pull_out_inline_imports, + absolute_imports, + tuple(always_include), + ) + try: + return cls._cache[cache_key] + except KeyError: + pass + used = cls(klass=klass, module_name=klass.__module__, package=package) + cls._cache[cache_key] = used + + for method_name in method_names: + used._find_referenced_by_method( + method_name=method_name, + pull_out_inline_imports=pull_out_inline_imports, + absolute_imports=absolute_imports, + always_include=always_include, + collapse_intra_pkg=collapse_intra_pkg, + ) + + return used + + def _find_referenced_by_method( + self, + method_name, + pull_out_inline_imports, + absolute_imports, + always_include, + collapse_intra_pkg, + already_processed=None, + method_stack=(), + super_base=None, + ): + if super_base: + method = getattr(super_base, method_name) + else: + method, super_base = find_super_method( + self.klass, method_name, include_class=True + ) + module = import_module(super_base.__module__) + if already_processed: + already_processed.add(method) + else: + already_processed = {method} + + if self.package.is_omitted(super_base): + return + self.methods.add(method) + method_stack += (method,) + method_body = inspect.getsource(method) + method_body = re.sub(r"\s*#.*", "", method_body) # Strip out comments + + meth_ref_inputs = set(re.findall(r"(?<=self\.inputs\.)(\w+)", method_body)) + meth_ref_outputs = set(re.findall(r"self\.(\w+) *=", method_body)) + + self._find_referenced( + module, + [method], + pull_out_inline_imports=pull_out_inline_imports, + absolute_imports=absolute_imports, + always_include=always_include, + collapse_intra_pkg=collapse_intra_pkg, + ) + # Find all referenced methods + ref_method_names = re.findall(r"(?<=self\.)(\w+)\(", method_body) + # Filter methods in omitted common base-classes like BaseInterface & CommandLine + ref_method_names = [ + m + for m in ref_method_names + if ( + m != "output_spec" + and not self.package.is_omitted( + find_super_method(super_base, m, include_class=True)[1] + ) + ) + ] + ref_methods = set(getattr(self.klass, m) for m in ref_method_names) + for meth in ref_methods: + if meth in already_processed: + continue + if inspect.isclass(meth): + logger.warning( + "Found %s type, that is instantiated as a method, " + "should be treated as a nested type, skipping for now", + meth, + ) + continue + ref_inputs, ref_outputs = self._find_referenced_by_method( + meth.__name__, + pull_out_inline_imports=pull_out_inline_imports, + absolute_imports=absolute_imports, + always_include=always_include, + collapse_intra_pkg=collapse_intra_pkg, + already_processed=already_processed, + method_stack=method_stack, + ) + self.method_args[meth.__name__].update(ref_inputs) + self.method_returns[meth.__name__].update(ref_outputs) + self.method_stacks[meth.__name__].add(method_stack) + self.inputs.update(ref_inputs) + self.outputs.update(ref_outputs) + meth_ref_inputs.update(ref_inputs) + meth_ref_outputs.update(ref_outputs) + + # Find all referenced supers + for match in re.findall(r"super\([^\)]*\)\.(\w+)\(", method_body): + super_method, base = find_super_method(super_base, match) + if self.package.is_omitted(base): + continue + func_name = self._different_parent_pkg_prefix(base) + match + if func_name not in self.supers: + self.supers[func_name] = (super_method, base) + self.super_func_names[super_base][match] = func_name + self.method_stacks[func_name].add(method_stack) + ref_inputs, ref_outputs = self._find_referenced_by_method( + super_method.__name__, + pull_out_inline_imports=pull_out_inline_imports, + absolute_imports=absolute_imports, + always_include=always_include, + collapse_intra_pkg=collapse_intra_pkg, + already_processed=already_processed, + method_stack=method_stack, + super_base=base, + ) + self.inputs.update(ref_inputs) + self.outputs.update(ref_outputs) + self.method_args[func_name].update(ref_inputs) + self.method_returns[func_name].update(ref_outputs) + meth_ref_inputs.update(ref_inputs) + meth_ref_outputs.update(ref_outputs) + + # Find all referenced constants/class attributes + local_class_attrs = self.local_class_attrs(super_base) + for match in re.findall(r"self\.(\w+)\b(?! =|\(.)", method_body): + try: + value = local_class_attrs[match] + except KeyError: + continue + base = find_super_method(super_base, match, include_class=True)[1] + if self.package.is_omitted(base): + base = self.klass + self.class_attrs.add((match, value)) + + return_value = get_return_line(method_body) + if return_value and return_value.startswith("self."): + self.outputs.update( + re.findall( + return_value + r"\[(?:'|\")(\w+)(?:'|\")\] *=", + method_body, + ) + ) + + return sorted(meth_ref_inputs), sorted(meth_ref_outputs) + + # def _find_referenced( + # self, + # klass: type, + # method_names: ty.List[str], + # package: "PackageConverter", + # collapse_intra_pkg: bool = False, + # pull_out_inline_imports: bool = True, + # always_include: ty.Optional[ty.List[str]] = None, + # absolute_imports: bool = False, + # method_stack: ty.Tuple[ty.Callable] = (), + # already_processed: ty.Optional[ty.Set[ty.Callable]] = None, + # ) -> "UsedClassSymbols": + + # module = import_module(klass.__module__) + + # for method_name in method_names: + # self._find_referenced_by_method( + # method_name=method_name, + # already_processed=already_processed, + # method_stack=method_stack, + # ) + + def _different_parent_pkg_prefix(self, base: type) -> str: + """Return the common part of two package names""" + ref_parts = self.klass.__module__.split(".") + mod_parts = base.__module__.split(".") + different = [] + is_common = True + for r_part, m_part in zip( + itertools.chain(ref_parts, itertools.repeat(None)), mod_parts + ): + if r_part != m_part: + is_common = False + if not is_common: + different.append(m_part) + if not different: + return "" + return "_".join(different) + "__" + base.__name__ + "__" + + @classmethod + def local_class_attrs(cls, klass) -> ty.Dict[str, str]: + """ + Get the constant attrs defined in the klass + """ + cache_key = klass.__module__ + "__" + klass.__name__ + try: + return cls._class_attrs_cache[cache_key] + except KeyError: + pass + source_code = inspect.getsource(klass) + source_code = source_code.replace("\\\n", " ") + class_attrs = [] + for stmt in split_source_into_statements(source_code): + match = re.match( + r"^ (\w+) *= *(.*)", stmt, flags=re.MULTILINE | re.DOTALL + ) + if match: + class_attrs.append(tuple(match.groups())) + class_attrs = dict(class_attrs) + cls._constants_cache[cache_key] = class_attrs + return class_attrs diff --git a/nipype2pydra/utils/__init__.py b/nipype2pydra/utils/__init__.py index 58d7df7a..c6bb9b95 100644 --- a/nipype2pydra/utils/__init__.py +++ b/nipype2pydra/utils/__init__.py @@ -19,11 +19,9 @@ str_to_type, types_converter, unwrap_nested_type, + get_return_line, + find_super_method, + strip_comments, + min_indentation, INBUILT_NIPYPE_TRAIT_NAMES, ) -from .symbols import ( # noqa: F401 - UsedSymbols, - get_local_functions, - get_local_classes, - get_local_constants, -) diff --git a/nipype2pydra/utils/misc.py b/nipype2pydra/utils/misc.py index 56254d3b..a49d6315 100644 --- a/nipype2pydra/utils/misc.py +++ b/nipype2pydra/utils/misc.py @@ -22,6 +22,7 @@ from importlib import import_module from logging import getLogger +from pydra.engine.specs import MultiInputObj logger = getLogger("nipype2pydra") @@ -156,7 +157,9 @@ def add_exc_note(e, note): return e -def extract_args(snippet) -> ty.Tuple[str, ty.List[str], str]: +def extract_args( + snippet, drop_parens: bool = False +) -> ty.Tuple[str, ty.List[str], str]: """Splits the code snippet at the first opening brackets into a 3-tuple consisting of the preceding text + opening bracket, the arguments/items within the parenthesis/bracket pair, and the closing paren/bracket + trailing text. @@ -255,7 +258,11 @@ def extract_args(snippet) -> ty.Tuple[str, ty.List[str], str]: if matching_open == first and depth[matching_open] == 0: if next_item: contents.append(next_item) - return pre, contents, "".join(splits[i:]) + post = "".join(splits[i:]) + if drop_parens: + pre = pre[:-1] + post = post[1:] + return pre, contents, post if ( first and depth[first] == 1 @@ -305,9 +312,7 @@ def cleanup_function_body(function_body: str) -> str: with_signature = True else: with_signature = False - # Detect the indentation of the source code in src and reduce it to 4 spaces - non_empty_lines = [ln for ln in function_body.splitlines() if ln] - indent_size = len(re.match(r"^( *)", non_empty_lines[0]).group(1)) + indent_size = min_indentation(function_body) indent_reduction = indent_size - (0 if with_signature else 4) assert indent_reduction >= 0, ( "Indentation reduction cannot be negative, probably didn't detect signature of " @@ -317,11 +322,18 @@ def cleanup_function_body(function_body: str) -> str: function_body = re.sub( r"^" + " " * indent_reduction, "", function_body, flags=re.MULTILINE ) + # Other misc replacements # function_body = function_body.replace("LOGGER.", "logger.") return replace_undefined(function_body) +def min_indentation(function_body: str) -> int: + # Detect the indentation of the source code in src and reduce it to 4 spaces + non_empty_lines = [ln for ln in function_body.splitlines() if ln] + return len(re.match(r"^( *)", non_empty_lines[0]).group(1)) + + def replace_undefined(function_body: str) -> str: parts = re.split(r"not isdefined\b", function_body, flags=re.MULTILINE) new_function_body = parts[0] @@ -360,7 +372,11 @@ def insert_args_in_signature(snippet: str, new_args: ty.Iterable[str]) -> str: pre, args, post = extract_args(snippet) if "runtime" in args: args.remove("runtime") - return pre + ", ".join(args + new_args) + post + if args and args[-1].startswith("**"): + kwargs = [args.pop()] + else: + kwargs = [] + return pre + ", ".join(args + new_args + kwargs) + post def get_source_code(func_or_klass: ty.Union[ty.Callable, ty.Type]) -> str: @@ -415,7 +431,7 @@ def split_source_into_statements(source_code: str) -> ty.List[str]: else: # Handle dictionary assignments where the first open-closing bracket is # before the assignment, e.g. outputs["out_file"] = [..." - if post and re.match(r"\s*=", post[1:]): + if post and re.match(r"\s*=|.*[\(\[\{\"'].*", post[1:]): try: extract_args(post[1:]) except (UnmatchedParensException, UnmatchedQuoteException): @@ -473,12 +489,20 @@ def from_named_dicts_converter( def str_to_type(type_str: str) -> type: """Resolve a string representation of a type into a valid type""" if "/" in type_str: + if type_str.startswith("multi["): + assert type_str.endswith("]"), f"Invalid multi type: {type_str}" + type_str = type_str[6:-1] + multi = True + else: + multi = False tp = from_mime(type_str) try: # If datatype is a field, use its primitive instead tp = tp.primitive # type: ignore except AttributeError: pass + if multi: + tp = MultiInputObj[tp] else: def resolve_type(type_str: str) -> type: @@ -528,3 +552,33 @@ def unwrap_nested_type(t: type) -> ty.List[type]: unwrapped.extend(unwrap_nested_type(c)) return unwrapped return [t] + + +def get_return_line(func: ty.Union[str, ty.Callable]) -> str: + if not isinstance(func, str): + func = inspect.getsource(func) + return_line = func.strip().split("\n")[-1] + match = re.match(r"\s*return(.*)", return_line) + if not match: + return None + return match.group(1).strip() + + +def find_super_method( + super_base: type, method_name: str, include_class: bool = False +) -> ty.Tuple[ty.Optional[ty.Callable], ty.Optional[type]]: + mro = super_base.__mro__ + if not include_class: + mro = mro[1:] + for base in mro: + if method_name in base.__dict__: # Found the match + return getattr(base, method_name), base + return None, None + # raise RuntimeError( + # f"Could not find super of '{method_name}' method in base classes of " + # f"{super_base}" + # ) + + +def strip_comments(src: str) -> str: + return re.sub(r"^\s+#.*", "", src, flags=re.MULTILINE) diff --git a/nipype2pydra/utils/symbols.py b/nipype2pydra/utils/symbols.py deleted file mode 100644 index 1140163c..00000000 --- a/nipype2pydra/utils/symbols.py +++ /dev/null @@ -1,546 +0,0 @@ -import typing as ty -import re -import keyword -import types -import inspect -import builtins -from operator import attrgetter -from collections import defaultdict -from logging import getLogger -from importlib import import_module -import attrs -from nipype.interfaces.base import BaseInterface, TraitedSpec, isdefined, Undefined -from nipype.interfaces.base import traits_extension -from .misc import split_source_into_statements, extract_args -from ..statements.imports import ImportStatement, parse_imports - - -logger = getLogger("nipype2pydra") - - -@attrs.define -class UsedSymbols: - """ - A class to hold the used symbols in a module - - Parameters - ------- - imports : list[str] - the import statements that need to be included in the converted file - local_functions: set[callable] - locally-defined functions used in the function bodies, or nested functions thereof - local_classes : set[type] - like local_functions but classes - constants: set[tuple[str, str]] - constants used in the function bodies, or nested functions thereof, tuples consist - of the constant name and its definition - intra_pkg_funcs: set[tuple[str, callable]] - list of functions that are defined in neighbouring modules that need to be - included in the converted file (as opposed of just imported from independent - packages) along with the name that they were imported as and therefore should - be named as in the converted module if they are included inline - intra_pkg_classes: list[tuple[str, callable]] - like neigh_mod_funcs but classes - intra_pkg_constants: set[tuple[str, str, str]] - set of all the constants defined within the package that are referenced by the - function, (, , ), where - the local alias and the definition of the constant - """ - - module_name: str - imports: ty.Set[str] = attrs.field(factory=set) - local_functions: ty.Set[ty.Callable] = attrs.field(factory=set) - local_classes: ty.List[type] = attrs.field(factory=list) - constants: ty.Set[ty.Tuple[str, str]] = attrs.field(factory=set) - intra_pkg_funcs: ty.Set[ty.Tuple[str, ty.Callable]] = attrs.field(factory=set) - intra_pkg_classes: ty.List[ty.Tuple[str, ty.Callable]] = attrs.field(factory=list) - intra_pkg_constants: ty.Set[ty.Tuple[str, str, str]] = attrs.field(factory=set) - - ALWAYS_OMIT_MODULES = [ - "traits.trait_handlers", # Old traits module, pre v6.0 - "nipype.pipeline", - "nipype.logging", - "nipype.config", - "nipype.interfaces.base", - "nipype.interfaces.utility", - ] - - _cache = {} - - symbols_re = re.compile(r"(? "UsedSymbols": - """Get the imports and local functions/classes/constants referenced in the - provided function bodies, and those nested within them - - Parameters - ---------- - module: ModuleType - the module containing the functions to be converted - function_bodies: list[str | callable | type] - the source of all functions/classes (or the functions/classes themselves) - that need to be checked for used imports - collapse_intra_pkg : bool - whether functions and classes defined within the same package, but not the - same module, are to be included in the output module or not, i.e. whether - the local funcs/classes/constants they referenced need to be included also - pull_out_inline_imports : bool, optional - whether to pull out imports that are inline in the function bodies - or not, by default True - omit_constants : list, optional - a list of objects to filter out from the used symbols, - by default (Undefined, traits_extension.File, traits_extension.Directory) - omit_functions : list[type], optional - a list of functions to filter out from the used symbols, - by default [isdefined] - omit_classes : list[type], optional - a list of classes (including subclasses) to filter out from the used symbols, - by default None - always_include : list[str], optional - a list of module objects (e.g. functions, classes, etc...) to always include - in list of used imports, even if they would be normally filtered out by - one of the `omit` clauses, by default None - translations : list[tuple[str, str]], optional - a list of tuples where the first element is the name of the symbol to be - replaced and the second element is the name of the symbol to replace it with, - regex supported, by default None - absolute_imports : bool, optional - whether to convert relative imports to absolute imports, by default False - - Returns - ------- - UsedSymbols - a class containing the used symbols in the module - """ - if omit_classes is None: - omit_classes = [] - if omit_modules is None: - omit_modules = [] - if always_include is None: - always_include = [] - if isinstance(module, str): - module = import_module(module) - cache_key = ( - module.__name__, - tuple(f.__name__ if not isinstance(f, str) else f for f in function_bodies), - collapse_intra_pkg, - pull_out_inline_imports, - tuple(omit_constants) if omit_constants else None, - tuple(omit_functions) if omit_functions else None, - tuple(omit_classes) if omit_classes else None, - tuple(omit_modules) if omit_modules else None, - tuple(always_include) if always_include else None, - tuple(translations) if translations else None, - ) - try: - return cls._cache[cache_key] - except KeyError: - pass - used = cls(module_name=module.__name__) - cls._cache[cache_key] = used - source_code = inspect.getsource(module) - # Sort local func/classes/consts so they are iterated in a consistent order to - # remove stochastic element of traversal and make debugging easier - local_functions = sorted( - get_local_functions(module), key=attrgetter("__name__") - ) - local_constants = sorted(get_local_constants(module)) - local_classes = sorted(get_local_classes(module), key=attrgetter("__name__")) - module_statements = split_source_into_statements(source_code) - imports: ty.List[ImportStatement] = [] - global_scope = True - for stmt in module_statements: - if not pull_out_inline_imports: - if stmt.startswith("def ") or stmt.startswith("class "): - global_scope = False - continue - if not global_scope: - if stmt and not stmt.startswith(" "): - global_scope = True - else: - continue - if ImportStatement.matches(stmt): - imports.extend( - parse_imports( - stmt, - relative_to=module, - translations=translations, - absolute=absolute_imports, - ) - ) - imports = sorted(imports) - - all_src = "" # All the source code that is searched for symbols - - used_symbols = set() - for function_body in function_bodies: - if not isinstance(function_body, str): - function_body = inspect.getsource(function_body) - all_src += "\n\n" + function_body - cls._get_symbols(function_body, used_symbols) - - # Keep stepping into nested referenced local function/class sources until all local - # functions and constants that are referenced are added to the used symbols - prev_num_symbols = -1 - while len(used_symbols) > prev_num_symbols: - prev_num_symbols = len(used_symbols) - for local_func in local_functions: - if ( - local_func.__name__ in used_symbols - and local_func not in used.local_functions - ): - used.local_functions.add(local_func) - cls._get_symbols(local_func, used_symbols) - all_src += "\n\n" + inspect.getsource(local_func) - for local_class in local_classes: - if ( - local_class.__name__ in used_symbols - and local_class not in used.local_classes - ): - if issubclass(local_class, (BaseInterface, TraitedSpec)): - continue - used.local_classes.append(local_class) - class_body = inspect.getsource(local_class) - bases = extract_args(class_body)[1] - used_symbols.update(bases) - cls._get_symbols(class_body, used_symbols) - all_src += "\n\n" + class_body - for const_name, const_def in local_constants: - if ( - const_name in used_symbols - and (const_name, const_def) not in used.constants - ): - used.constants.add((const_name, const_def)) - cls._get_symbols(const_def, used_symbols) - all_src += "\n\n" + const_def - used_symbols -= set(cls.SYMBOLS_TO_IGNORE) - - base_pkg = module.__name__.split(".")[0] - - module_omit_re = re.compile( - r"^\b(" - + "|".join(cls.ALWAYS_OMIT_MODULES + [module.__name__] + omit_modules) - + r")\b", - ) - - # functions to copy from a relative or nipype module into the output module - for stmt in imports: - stmt = stmt.only_include(used_symbols) - # Skip if no required symbols are in the import statement - if not stmt: - continue - # Filter out Nipype-specific objects that aren't relevant in Pydra - module_omit = bool(module_omit_re.match(stmt.module_name)) - if module_omit or omit_classes or omit_functions or omit_constants: - to_include = [] - for imported in stmt.values(): - if imported.address in always_include: - to_include.append(imported.local_name) - continue - if module_omit: - continue - try: - obj = imported.object - except ImportError: - logger.warning( - ( - "Could not import %s from %s, unable to check whether " - "it is is present in list of classes %s or objects %s " - "to be filtered out" - ), - imported.name, - imported.statement.module_name, - omit_classes, - omit_functions, - ) - to_include.append(imported.local_name) - continue - if inspect.isclass(obj): - if omit_classes and issubclass(obj, tuple(omit_classes)): - continue - elif inspect.isfunction(obj): - if omit_functions and obj in omit_functions: - continue - elif imported.address in omit_constants: - continue - to_include.append(imported.local_name) - if not to_include: - continue - stmt = stmt.only_include(to_include) - intra_pkg_objs = defaultdict(set) - if stmt.in_package(base_pkg) or ( - stmt.in_package("nipype") and not stmt.in_package("nipype.interfaces") - ): - - for imported in list(stmt.values()): - if not ( - imported.in_package(base_pkg) or imported.in_package("nipype") - ) or inspect.isbuiltin(imported.object): - # Case where an object is a nested import from a different package - # which is imported in a chain from a neighbouring module - used.imports.add( - imported.as_independent_statement(resolve=True) - ) - stmt.drop(imported) - elif inspect.isfunction(imported.object): - used.intra_pkg_funcs.add((imported.local_name, imported.object)) - # Recursively include objects imported in the module - intra_pkg_objs[import_module(imported.object.__module__)].add( - imported.object - ) - if collapse_intra_pkg: - stmt.drop(imported) - elif inspect.isclass(imported.object): - class_def = (imported.local_name, imported.object) - # Add the class to the intra_pkg_classes list if it is not - # already there. NB: we can't use a set for intra_pkg_classes - # like we did for functions here because we need to preserve the - # order the classes are defined in the module in case one inherits - # from the other - if class_def not in used.intra_pkg_classes: - used.intra_pkg_classes.append(class_def) - # Recursively include objects imported in the module - intra_pkg_objs[import_module(imported.object.__module__)].add( - imported.object, - ) - if collapse_intra_pkg: - stmt.drop(imported) - elif inspect.ismodule(imported.object): - # Skip if the module is the same as the module being converted - if module_omit_re.match(imported.object.__name__): - stmt.drop(imported) - continue - # Findall references to the module's attributes in the source code - # and add them to the list of intra package objects - used_attrs = re.findall( - r"\b" + imported.local_name + r"\.(\w+)\b", all_src - ) - for attr_name in used_attrs: - obj = getattr(imported.object, attr_name) - - if inspect.isfunction(obj): - used.intra_pkg_funcs.add((obj.__name__, obj)) - intra_pkg_objs[imported.object.__name__].add(obj) - elif inspect.isclass(obj): - class_def = (obj.__name__, obj) - if class_def not in used.intra_pkg_classes: - used.intra_pkg_classes.append(class_def) - intra_pkg_objs[imported.object.__name__].add(obj) - else: - used.intra_pkg_constants.add( - ( - imported.object.__name__, - attr_name, - attr_name, - ) - ) - intra_pkg_objs[imported.object.__name__].add(attr_name) - - if collapse_intra_pkg: - raise NotImplementedError( - f"Cannot inline imported module in statement '{stmt}'" - ) - else: - used.intra_pkg_constants.add( - ( - stmt.module_name, - imported.local_name, - imported.name, - ) - ) - intra_pkg_objs[stmt.module].add(imported.local_name) - if collapse_intra_pkg: - stmt.drop(imported) - - # Recursively include neighbouring objects imported in the module - for from_mod, inlined_objs in intra_pkg_objs.items(): - used_in_mod = cls.find( - from_mod, - function_bodies=inlined_objs, - collapse_intra_pkg=collapse_intra_pkg, - translations=translations, - omit_modules=omit_modules, - omit_classes=omit_classes, - omit_functions=omit_functions, - omit_constants=omit_constants, - always_include=always_include, - ) - used.update(used_in_mod, to_be_inlined=collapse_intra_pkg) - if stmt: - used.imports.add(stmt) - return used - - @classmethod - def filter_imports( - cls, imports: ty.List[ImportStatement], source_code: str - ) -> ty.List[ImportStatement]: - """Filter out the imports that are not used in the function bodies""" - symbols = set() - cls._get_symbols(source_code, symbols) - symbols -= set(cls.SYMBOLS_TO_IGNORE) - filtered = [] - for stmt in imports: - if stmt.from_: - stmt = stmt.only_include(symbols) - if stmt: - filtered.append(stmt) - elif stmt.sole_imported.local_name in symbols: - filtered.append(stmt) - return filtered - - def copy(self) -> "UsedSymbols": - return attrs.evolve(self) - - @classmethod - def _get_symbols( - cls, func: ty.Union[str, ty.Callable, ty.Type], symbols: ty.Set[str] - ): - """Get the symbols used in a function body""" - try: - fbody = inspect.getsource(func) - except TypeError: - fbody = func - for stmt in split_source_into_statements(fbody): - if stmt and not re.match( - r"\s*(#|\"|'|from |import |r'|r\"|f'|f\")", stmt - ): # skip comments/docs - for sym in cls.symbols_re.findall(stmt): - if "." in sym: - parts = sym.split(".") - symbols.update( - ".".join(parts[: i + 1]) for i in range(len(parts)) - ) - else: - symbols.add(sym) - - # Nipype-specific names and Python keywords - SYMBOLS_TO_IGNORE = ["isdefined"] + keyword.kwlist + list(builtins.__dict__.keys()) - - def get_imported_object(self, name: str) -> ty.Any: - """Get the object with the given name from used import statements - - Parameters - ---------- - name : str - the name of the object to get - imports : list[ImportStatement], optional - the import statements to search in (used in tests), by default the imports - in the used symbols - - Returns - ------- - Any - the object with the given name referenced by the given import statements - """ - # Check to see if it isn't an imported module - # imported = { - # i.sole_imported.local_name: i.sole_imported.object - # for i in self.imports - # if not i.from_ - # } - all_imported = {} - for stmt in self.imports: - all_imported.update(stmt.imported) - try: - return all_imported[name].object - except KeyError: - pass - parts = name.rsplit(".") - imported_obj = None - for i in range(1, len(parts)): - obj_name = ".".join(parts[:-i]) - try: - imported_obj = all_imported[obj_name].object - except KeyError: - continue - else: - break - if imported_obj is None: - raise ImportError( - f"Could not find object named {name} in any of the imported modules:\n" - + "\n".join(str(i) for i in self.imports) - ) - for part in parts[-i:]: - imported_obj = getattr(imported_obj, part) - return imported_obj - - -def get_local_functions(mod) -> ty.List[ty.Callable]: - """Get the functions defined in the module""" - functions = [] - for attr_name in dir(mod): - attr = getattr(mod, attr_name) - if inspect.isfunction(attr) and attr.__module__ == mod.__name__: - functions.append(attr) - return functions - - -def get_local_classes(mod) -> ty.List[type]: - """Get the functions defined in the module""" - classes = [] - for attr_name in dir(mod): - attr = getattr(mod, attr_name) - if inspect.isclass(attr) and attr.__module__ == mod.__name__: - classes.append(attr) - return classes - - -def get_local_constants(mod) -> ty.List[ty.Tuple[str, str]]: - """ - Get the constants defined in the module - """ - source_code = inspect.getsource(mod) - source_code = source_code.replace("\\\n", " ") - local_vars = [] - for stmt in split_source_into_statements(source_code): - match = re.match(r"^(\w+) *= *(.*)", stmt, flags=re.MULTILINE | re.DOTALL) - if match: - local_vars.append(tuple(match.groups())) - return local_vars diff --git a/nipype2pydra/utils/tests/test_utils_imports.py b/nipype2pydra/utils/tests/test_utils_imports.py index 483ebf00..b0ce1cc1 100644 --- a/nipype2pydra/utils/tests/test_utils_imports.py +++ b/nipype2pydra/utils/tests/test_utils_imports.py @@ -1,5 +1,5 @@ import pytest -from nipype2pydra.utils.symbols import UsedSymbols +from nipype2pydra.symbols import UsedSymbols from nipype2pydra.statements.imports import ImportStatement, parse_imports import nipype.interfaces.utility @@ -49,7 +49,9 @@ def test_get_imported_object1(): import_stmts = [ "import nipype.interfaces.utility as niu", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) assert ( used.get_imported_object("niu.IdentityInterface") is nipype.interfaces.utility.IdentityInterface @@ -60,7 +62,9 @@ def test_get_imported_object2(): import_stmts = [ "import nipype.interfaces.utility", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) assert ( used.get_imported_object("nipype.interfaces.utility") is nipype.interfaces.utility @@ -71,7 +75,9 @@ def test_get_imported_object3(): import_stmts = [ "from nipype.interfaces.utility import IdentityInterface", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) assert ( used.get_imported_object("IdentityInterface") is nipype.interfaces.utility.IdentityInterface @@ -82,7 +88,9 @@ def test_get_imported_object4(): import_stmts = [ "from nipype.interfaces.utility import IdentityInterface", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) assert ( used.get_imported_object("IdentityInterface.input_spec") is nipype.interfaces.utility.IdentityInterface.input_spec @@ -93,7 +101,9 @@ def test_get_imported_object5(): import_stmts = [ "import nipype.interfaces.utility", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) assert ( used.get_imported_object( "nipype.interfaces.utility.IdentityInterface.input_spec" @@ -106,7 +116,9 @@ def test_get_imported_object_fail1(): import_stmts = [ "import nipype.interfaces.utility", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) with pytest.raises(ImportError, match="Could not find object named"): used.get_imported_object("nipype.interfaces.utilityboo") @@ -115,6 +127,8 @@ def test_get_imported_object_fail2(): import_stmts = [ "from nipype.interfaces.utility import IdentityInterface", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) with pytest.raises(ImportError, match="Could not find object named"): used.get_imported_object("IdentityBoo") diff --git a/nipype2pydra/workflow.py b/nipype2pydra/workflow.py index 82505371..a2bf9086 100644 --- a/nipype2pydra/workflow.py +++ b/nipype2pydra/workflow.py @@ -14,8 +14,8 @@ import yaml from fileformats.core import from_mime, FileSet, Field from fileformats.core.exceptions import FormatRecognitionError +from .symbols import UsedSymbols from .utils import ( - UsedSymbols, split_source_into_statements, extract_args, full_address, @@ -36,6 +36,7 @@ WorkflowInitStatement, AssignmentStatement, OtherStatement, + DynamicField, ) import nipype2pydra.package @@ -71,6 +72,7 @@ class WorkflowInterfaceField: }, ) node_name: ty.Optional[str] = attrs.field( + default=None, metadata={ "help": "The name of the node that the input/output is connected to", }, @@ -93,8 +95,7 @@ class WorkflowInterfaceField: factory=list, metadata={ "help": ( - "node-name/field-name pairs of other fields that are to be routed to " - "from other node fields to this input/output", + "node-name/field-name pairs of additional fields that this input/output replaces", ) }, ) @@ -130,7 +131,7 @@ def type_repr_(t): elif issubclass(t, Field): return t.primitive.__name__ elif issubclass(t, FileSet): - return t.__name__ + return t.type_name elif t.__module__ == "builtins": return t.__name__ else: @@ -158,6 +159,11 @@ def __hash__(self): @attrs.define class WorkflowInput(WorkflowInterfaceField): + connections: ty.Tuple[ty.Tuple[str, str]] = attrs.field( + converter=lambda lst: tuple(sorted(tuple(t) for t in lst)), + factory=list, + metadata={"help": ("Explicit connections to be made from this input field",)}, + ) out_conns: ty.List[ConnectionStatement] = attrs.field( factory=list, eq=False, @@ -169,9 +175,7 @@ class WorkflowInput(WorkflowInterfaceField): ) }, ) - include: bool = attrs.field( - default=False, eq=False, hash=False, metadata={ @@ -182,6 +186,10 @@ class WorkflowInput(WorkflowInterfaceField): }, ) + @include.default + def _include_default(self) -> bool: + return bool(self.connections) + def __hash__(self): return super().__hash__() @@ -189,6 +197,11 @@ def __hash__(self): @attrs.define class WorkflowOutput(WorkflowInterfaceField): + connection: ty.Tuple[str, str] = attrs.field( + converter=tuple, + factory=list, + metadata={"help": ("Explicit connection to be made to this output field",)}, + ) in_conns: ty.List[ConnectionStatement] = attrs.field( factory=list, eq=False, @@ -412,6 +425,12 @@ def get_input_from_conn(self, conn: ConnectionStatement) -> WorkflowInput: """ Returns the name of the input field in the workflow for the given node and field escaped by the prefix of the node if present""" + if isinstance(conn.source_out, DynamicField): + logger.warning( + f"Not able to connect inputs from {conn.source_name}:{conn.source_out}->" + f"{conn.target_name}:{conn.target_in} properly due to adynamic-field " + "just connecting to source input for now" + ) try: return self.make_input( field_name=conn.source_out, @@ -603,17 +622,12 @@ def add_connection_from_output(self, out_conn: ConnectionStatement): self._add_output_conn(out_conn, "from") @cached_property - def used_symbols(self) -> UsedSymbols: + def used(self) -> UsedSymbols: return UsedSymbols.find( self.nipype_module, [self.func_body], collapse_intra_pkg=False, - omit_classes=self.package.omit_classes, - omit_modules=self.package.omit_modules, - omit_functions=self.package.omit_functions, - omit_constants=self.package.omit_constants, - always_include=self.package.all_explicit, - translations=self.package.all_import_translations, + package=self.package, ) @property @@ -647,10 +661,10 @@ def func_body(self): @cached_property def nested_workflows(self): potential_funcs = { - full_address(f[1]): f[0] for f in self.used_symbols.intra_pkg_funcs if f[0] + full_address(f[1]): f[0] for f in self.used.imported_funcs if f[0] } potential_funcs.update( - (full_address(f), f.__name__) for f in self.used_symbols.local_functions + (full_address(f), f.__name__) for f in self.used.functions ) return { potential_funcs[address]: workflow @@ -705,23 +719,23 @@ def write( if additional_funcs is None: additional_funcs = [] - used = self.used_symbols.copy() - all_used = self.used_symbols.copy() + used = self.used.copy() + all_used = self.used.copy() # Start writing output module with used imports and converted function body of # main workflow code_str = self.converted_code - local_func_names = {f.__name__ for f in used.local_functions} + local_func_names = {f.__name__ for f in used.functions} # Convert any nested workflows for name, conv in self.nested_workflows.items(): if conv.address in already_converted: continue already_converted.add(conv.address) - all_used.update(conv.used_symbols) + all_used.update(conv.used) if name in local_func_names: code_str += "\n\n\n" + conv.converted_code - used.update(conv.used_symbols) + used.update(conv.used) else: conv_all_used = conv.write( package_root, @@ -764,7 +778,9 @@ def write( ), converted_code=self.test_code, used=self.test_used, - additional_imports=self.input_output_imports, + additional_imports=( + self.input_output_imports + parse_imports("import pytest") + ), ) conftest_fspath = test_module_fspath.parent / "conftest.py" @@ -911,6 +927,7 @@ def add_nonstd_types(tp): # Write to file for debugging debug_file = "~/unparsable-nipype2pydra-output.py" with open(Path(debug_file).expanduser(), "w") as f: + f.write(f"# Attemping to convert {self.full_address}\n") f.write(code_str) raise RuntimeError( f"Black could not parse generated code (written to {debug_file}): " @@ -931,22 +948,51 @@ def parsed_statements(self): def test_code(self): args_str = ", ".join(f"{n}={v}" for n, v in self.test_inputs.items()) - return f""" + code_str = f""" -def test_{self.name}(): + +def test_{self.name}_build(): workflow = {self.name}({args_str}) assert isinstance(workflow, Workflow) """ + inputs_dict = {} + for inpt in self.inputs.values(): + if issubclass(inpt.type, FileSet): + inputs_dict[inpt.name] = inpt.type.type_name + ".sample()" + elif inpt.name in self.test_inputs: + inputs_dict[inpt.name] = self.test_inputs[inpt.name] + args_str = ", ".join(f"{n}={v}" for n, v in inputs_dict.items()) + + code_str += f""" + +@pytest.mark.skip(reason="Appropriate inputs for this workflow haven't been specified yet") +def test_{self.name}_run(): + workflow = {self.name}({args_str}) + result = workflow(plugin='serial') + print(result.out) +""" + return code_str + @property def test_used(self): + nonstd_types = [ + i.type for i in self.inputs.values() if issubclass(i.type, FileSet) + ] + nonstd_type_imports = [] + for tp in itertools.chain(*(unwrap_nested_type(t) for t in nonstd_types)): + nonstd_type_imports.append(ImportStatement.from_object(tp)) + return UsedSymbols( module_name=self.nipype_module.__name__, - imports=parse_imports( - [ - f"from {self.output_module} import {self.name}", - "from pydra.engine import Workflow", - ] + import_stmts=( + nonstd_type_imports + + parse_imports( + [ + f"from {self.output_module} import {self.name}", + "from pydra.engine import Workflow", + ] + ) ), ) @@ -1003,7 +1049,7 @@ def prepare_connections(self): # append to parsed statements so set_output can be set self.parsed_statements.append(conn_stmt) while self._unprocessed_connections: - conn = self._unprocessed_connections.pop() + conn = self._unprocessed_connections.pop(0) try: inpt = self.get_input_from_conn(conn) except KeyError: @@ -1023,6 +1069,47 @@ def prepare_connections(self): conn.target_in = outpt.name outpt.in_conns.append(conn) + # Overwrite connections with explict connections + for inpt in list(self.inputs.values()): + for target_name, target_in in inpt.connections: + conn = ConnectionStatement( + indent=" ", + source_name=None, + source_out=inpt.name, + target_name=target_name, + target_in=target_in, + workflow_converter=self, + ) + for tgt_node in self.nodes[conn.target_name]: + try: + existing_conn = next( + c for c in tgt_node.in_conns if c.target_in == target_in + ) + except StopIteration: + pass + else: + tgt_node.in_conns.remove(existing_conn) + self.inputs[existing_conn.source_out].out_conns.remove( + existing_conn + ) + inpt.out_conns.append(conn) + tgt_node.add_input_connection(conn) + + for outpt in list(self.outputs.values()): + if outpt.connection: + source_name, source_out = outpt.connection + conn = ConnectionStatement( + indent=" ", + source_name=source_name, + source_out=source_out, + target_name=None, + target_in=outpt.name, + workflow_converter=self, + ) + for src_node in self.nodes[conn.source_name]: + src_node.add_output_connection(conn) + outpt.in_conns.append(conn) + def _parse_statements(self, func_body: str) -> ty.Tuple[ ty.List[ ty.Union[