diff --git a/.coveragerc b/.coveragerc index 45facbe67..214e36002 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,9 +1,10 @@ [run] source = scico command_line = -m pytest -omit = +omit = scico/test/* scico/plot.py + scico/trace.py [report] # Regexes for lines to exclude from consideration diff --git a/CHANGES.rst b/CHANGES.rst index c297f8bc5..472c777ed 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,7 +6,7 @@ SCICO Release Notes Version 0.0.7 (unreleased) ---------------------------- -• No changes yet. +• New module ``scico.trace`` for tracing function/method calls. diff --git a/examples/examples_requirements.txt b/examples/examples_requirements.txt index b19087e01..0ab0e57d4 100644 --- a/examples/examples_requirements.txt +++ b/examples/examples_requirements.txt @@ -1,4 +1,5 @@ -r ../requirements.txt +colorama colour_demosaicing svmbir>=0.4.0 astra-toolbox diff --git a/examples/jnb.py b/examples/jnb.py index f51c3f9a8..6fe58b984 100644 --- a/examples/jnb.py +++ b/examples/jnb.py @@ -48,15 +48,17 @@ def py_file_to_string(src): if re.match("^import|^from .* import", line): import_seen = True lines.append(line) - # Backtrack through list of lines to find last import statement - n = 1 - for line in lines[-2::-1]: - if re.match("^(import|from)", line): - break - else: - n += 1 - # Insert notebook plotting config directly after last import statement - lines.insert(-n, "plot.config_notebook_plotting()\n") + + if "plot" in "".join(lines): + # Backtrack through list of lines to find last import statement + n = 1 + for line in lines[-2::-1]: + if re.match("^(import|from)", line): + break + else: + n += 1 + # Insert notebook plotting config directly after last import statement + lines.insert(-n, "plot.config_notebook_plotting()\n") # Process remainder of source file for line in srcfile: @@ -73,7 +75,8 @@ def py_file_to_string(src): n += 1 else: break - lines = lines[0:-n] + if n > 0: + lines = lines[0:-n] return "".join(lines) diff --git a/examples/scripts/trace_example.py b/examples/scripts/trace_example.py new file mode 100644 index 000000000..10df12814 --- /dev/null +++ b/examples/scripts/trace_example.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# This file is part of the SCICO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + +r""" +SCICO Call Tracing +================== + +This example demonstrates the call tracing functionality provided by the +[trace](../_autosummary/scico.trace.rst) module. It is based on the +[non-negative BPDN example](sparsecode_nn_admm.rst). +""" + + +import numpy as np + +import jax + +import scico.numpy as snp +from scico import functional, linop, loss, metric +from scico.optimize.admm import ADMM, MatrixSubproblemSolver +from scico.trace import register_variable, trace_scico_calls +from scico.util import device_info + +""" +Initialize tracing. JIT must be disabled for correct tracing. + +The call tracing mechanism prints the name, arguments, and return values +of functions/methods as they are called. Module and class names are +printed in light red, function and method names in dark red, arguments +and return values in light blue, and the names of registered variables +in light yellow. When a method defined in a class is called for an object +of a derived class type, the class of that object is printed in light +magenta, in square brackets. Function names and return values are +distinguished by initial ">>" and "<<" characters respectively. +""" +jax.config.update("jax_disable_jit", True) +trace_scico_calls() + + +""" +Create random dictionary, reference random sparse representation, and +test signal consisting of the synthesis of the reference sparse +representation. +""" +m = 32 # signal size +n = 128 # dictionary size +s = 10 # sparsity level + +np.random.seed(1) +D = np.random.randn(m, n).astype(np.float32) +D = D / np.linalg.norm(D, axis=0, keepdims=True) # normalize dictionary + +xt = np.zeros(n, dtype=np.float32) # true signal +idx = np.random.randint(low=0, high=n, size=s) # support of xt +xt[idx] = np.random.rand(s) +y = D @ xt + 5e-2 * np.random.randn(m) # synthetic signal + +xt = snp.array(xt) # convert to jax array +y = snp.array(y) # convert to jax array + + +""" +Register a variable so that it can be referenced by name in the call trace. +Any hashable object and numpy arrays may be registered, but JAX arrays +cannot. +""" +register_variable(D, "D") + + +""" +Set up the forward operator and ADMM solver object. +""" +lmbda = 1e-1 +A = linop.MatrixOperator(D) +register_variable(A, "A") +f = loss.SquaredL2Loss(y=y, A=A) +g_list = [lmbda * functional.L1Norm(), functional.NonNegativeIndicator()] +C_list = [linop.Identity((n)), linop.Identity((n))] +rho_list = [1.0, 1.0] +maxiter = 1 # number of ADMM iterations (set to small value to simplify trace output) + +register_variable(f, "f") +register_variable(g_list[0], "g_list[0]") +register_variable(g_list[1], "g_list[1]") +register_variable(C_list[0], "C_list[0]") +register_variable(C_list[1], "C_list[1]") + +solver = ADMM( + f=f, + g_list=g_list, + C_list=C_list, + rho_list=rho_list, + x0=A.adj(y), + maxiter=maxiter, + subproblem_solver=MatrixSubproblemSolver(), + itstat_options={"display": True, "period": 5}, +) + +register_variable(solver, "solver") + + +""" +Run the solver. +""" +print(f"Solving on {device_info()}\n") +x = solver.solve() +mse = metric.mse(xt, x) diff --git a/scico/optimize/_pgm.py b/scico/optimize/_pgm.py index 6188f7480..83aae022a 100644 --- a/scico/optimize/_pgm.py +++ b/scico/optimize/_pgm.py @@ -11,6 +11,7 @@ # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations +from functools import partial from typing import Optional, Union import jax @@ -101,15 +102,22 @@ def __init__( self.L: float = L0 # reciprocal of step size (estimate of Lipschitz constant of ∇f) self.fixed_point_residual = snp.inf - def x_step(v: Union[Array, BlockArray], L: float) -> Union[Array, BlockArray]: - return self.g.prox(v - 1.0 / L * self.f.grad(v), 1.0 / L) - - self.x_step = jax.jit(x_step) - self.x: Union[Array, BlockArray] = x0 # current estimate of solution super().__init__(**kwargs) + def x_step(self, v: Union[Array, BlockArray], L: float) -> Union[Array, BlockArray]: + """Compute update for variable `x`.""" + return PGM._x_step(self.f, self.g, v, L) + + @staticmethod + @partial(jax.jit, static_argnums=(0, 1)) + def _x_step( + f: Functional, g: Functional, v: Union[Array, BlockArray], L: float + ) -> Union[Array, BlockArray]: + """Jit-able static method for computing update for variable `x`.""" + return g.prox(v - 1.0 / L * f.grad(v), 1.0 / L) + def _working_vars_finite(self) -> bool: """Determine where ``NaN`` of ``Inf`` encountered in solve. diff --git a/scico/optimize/pgm.py b/scico/optimize/pgm.py index d84681f55..9908038d4 100644 --- a/scico/optimize/pgm.py +++ b/scico/optimize/pgm.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +# Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the diff --git a/scico/trace.py b/scico/trace.py new file mode 100644 index 000000000..b7d61cf07 --- /dev/null +++ b/scico/trace.py @@ -0,0 +1,456 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2024 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SCICO package. Details of the copyright and +# user license can be found in the 'LICENSE' file distributed with the +# package. + +"""Call tracing of scico functions and methods. + +JIT must be disabled for tracing to function correctly (set environment +variable :code:`JAX_DISABLE_JIT=1`, or call +:code:`jax.config.update('jax_disable_jit', True)` before importing `jax` +or `scico`). Call :code:`trace_scico_calls` to initialize tracing, and +call :code:`register_variable` to associate a name with a variable so +that it can be referenced by name in the call trace. + +The call trace is color-code as follows if +`colorama `_ is installed: + +- `module and class names`: light red +- `function and method names`: dark red +- `arguments and return values`: light blue +- `names of registered variables`: light yellow + +When a method defined in a class is called for an object of a derived +class type, the class of that object is displayed in light magenta, in +square brackets. Function names and return values are distinguished by +initial ``>>`` and ``<<`` characters respectively. + +A usage example is provided in the script :code:`trace_example.py`. +""" + + +from __future__ import annotations + +import inspect +import sys +import types +import warnings +from collections import defaultdict +from functools import wraps +from typing import Any, Callable, Optional, Sequence + +import numpy as np + +import jax + +from jaxlib.xla_extension import PjitFunction + +try: + import colorama + + have_colorama = True +except ImportError: + have_colorama = False + + +if have_colorama: + clr_main = colorama.Fore.LIGHTRED_EX # main trace information + clr_rvar = colorama.Fore.LIGHTYELLOW_EX # registered variable names + clr_self = colorama.Fore.LIGHTMAGENTA_EX # type of object for which method is called + clr_func = colorama.Fore.RED # function/method name + clr_args = colorama.Fore.LIGHTBLUE_EX # function/method arguments + clr_retv = colorama.Fore.LIGHTBLUE_EX # function/method return values + clr_devc = colorama.Fore.CYAN # JAX array device and sharding + clr_reset = colorama.Fore.RESET # reset color +else: + clr_main, clr_rvar, clr_self, clr_func = "", "", "", "" + clr_args, clr_retv, clr_devc, clr_reset = "", "", "", "" + + +def _get_hash(val: Any) -> Optional[int]: + """Get a hash representing an object. + + Args: + val: An object for which the hash is required. + + Returns: + A hash value of ``None`` if a hash cannot be computed. + """ + if isinstance(val, np.ndarray): + hash = val.ctypes.data # for an ndarray, hash is the memory address + elif hasattr(val, "__hash__") and callable(val.__hash__): + try: + hash = val.__hash__() + except TypeError: + hash = None + else: + hash = None + return hash + + +def _trace_arg_repr(val: Any) -> str: + """Compute string representation of function arguments. + + Args: + val: Argument value + + Returns: + A string representation of the argument. + """ + if val is None: + return "None" + elif np.isscalar(val): # a scalar value + return str(val) + elif isinstance(val, tuple) and len(val) < 6 and all([np.isscalar(s) for s in val]): + return f"{val}" # a short sequence of scalars + elif isinstance(val, np.dtype): # a numpy dtype + return f"numpy.{val}" + elif isinstance(val, type): # a class name + return f"{val.__module__}.{val.__qualname__}" + elif isinstance(val, np.ndarray) and _get_hash(val) in call_trace.instance_hash: # type: ignore + return f"{clr_rvar}{call_trace.instance_hash[_get_hash(val)]}{clr_args}" # type: ignore + elif isinstance(val, (np.ndarray, jax.Array)): # a jax or numpy array + if val.shape == (): + return str(val) + else: + dev_str, shard_str = "", "" + if isinstance(val, jax.Array) and not isinstance( + val, jax._src.interpreters.partial_eval.JaxprTracer + ): + if call_trace.show_jax_device: # type: ignore + platform = list(val.devices())[0].platform # assume all of same type + devices = ",".join(map(str, sorted([d.id for d in val.devices()]))) + dev_str = f"{clr_devc}{{dev={platform}({devices})}}{clr_args}" + if call_trace.show_jax_sharding and isinstance( # type: ignore + val.sharding, jax._src.sharding_impls.PositionalSharding + ): + shard_str = f"{clr_devc}{{shard={val.sharding.shape}}}{clr_args}" + return f"Array{val.shape}{dev_str}{shard_str}" + else: + if _get_hash(val) in call_trace.instance_hash: # type: ignore + return f"{clr_rvar}{call_trace.instance_hash[val.__hash__()]}{clr_args}" # type: ignore + else: + return f"[{type(val).__name__}]" + + +def register_variable(var: Any, name: str): + """Register a variable name for call tracing. + + Any hashable object (or numpy array, with the memory address + used as a hash) may be registered. JAX arrays may not be registered + since they are not hashable and there is no clear mechanism for + associating them with a unique memory address. + + Args: + var: The variable to be registered. + name: The name to be associated with the variable. + """ + hash = _get_hash(var) + if hash is None: + raise ValueError(f"Can't get hash for variable {name}.") + call_trace.instance_hash[hash] = name # type: ignore + + +def _call_wrapped_function(func: Callable, *args, **kwargs) -> Any: + """Call a wrapped function within the wrapper. + + Handle different call mechanisms required for static and class + methods. + + Args: + func: Wrapped function + *args: Positional arguments + **kwargs: Named arguments + + Returns: + Return value of wrapped function. + """ + if isinstance(func, staticmethod): + ret = func(*args[1:], **kwargs) + elif isinstance(func, classmethod): + ret = func.__func__(*args, **kwargs) + else: + ret = func(*args, **kwargs) + return ret + + +def call_trace(func: Callable) -> Callable: + """Print log of calls to `func`. + + Decorator for printing a log of calls to the wrapped function. A + record of call levels is maintained so that call nesting is indicated + by call log indentation. + """ + try: + method_class = inspect._findclass(func) # type: ignore + except AttributeError: + method_class = None + + @wraps(func) + def wrapper(*args, **kwargs): + name = f"{func.__module__}.{clr_func}{func.__qualname__}" + arg_idx = 0 + if ( + args + and hasattr(args[0], "__hash__") + and callable(args[0].__hash__) + and method_class + and isinstance(args[0], method_class) + ): # first argument is self for a method call + arg_idx = 1 # skip self in handling arguments + if args[0].__hash__() in call_trace.instance_hash: + # self object registered using register_variable + name = ( + f"{clr_rvar}{call_trace.instance_hash[args[0].__hash__()]}." + f"{clr_func}{func.__name__}" + ) + else: + # self object not registered + func_class = method_class.__name__ + self_class = args[0].__class__.__name__ + # If the class in which this method is defined is same as that + # of the self object for which it's called, just display the + # class name. Otherwise, display the name of the name defining + # class followed by the name of the self object class in + # square brackets. + if func_class == self_class: + class_name = func_class + else: + class_name = f"{func_class}{clr_self}[{self_class}]{clr_main}" + name = f"{func.__module__}.{class_name}.{clr_func}{func.__name__}" + args_repr = [_trace_arg_repr(val) for val in args[arg_idx:]] + kwargs_repr = [f"{key}={_trace_arg_repr(val)}" for key, val in kwargs.items()] + args_str = clr_args + ", ".join(args_repr + kwargs_repr) + clr_main + print( + f"{clr_main}>> {' ' * 2 * call_trace.trace_level}{name}" + f"({args_str}{clr_func}){clr_reset}", + file=sys.stderr, + ) + # call wrapped function + call_trace.trace_level += 1 + ret = _call_wrapped_function(func, *args, **kwargs) + call_trace.trace_level -= 1 + # print representation of return value + if ret is not None and call_trace.show_return_value: + print( + f"{clr_main}<< {' ' * 2 * call_trace.trace_level}{clr_retv}" + f"{_trace_arg_repr(ret)}{clr_reset}", + file=sys.stderr, + ) + return ret + + # Set flag indicating that function is already wrapped + wrapper._call_trace_wrap = True # type: ignore + # Avoid multiple wrapper layers + if hasattr(func, "_call_trace_wrap"): + return func + else: + return wrapper + + +# call level counter for call_trace decorator +call_trace.trace_level = 0 # type: ignore +# hash dict allowing association of objects with variable names +call_trace.instance_hash = {} # type: ignore +# flag indicating whether to show function return value +call_trace.show_return_value = True # type: ignore +# flag indicating whether to show JAX array devices +call_trace.show_jax_device = False # type: ignore +# flag indicating whether to show JAX array sharding shape +call_trace.show_jax_sharding = False # type: ignore + + +def _submodule_name(module, obj): + if ( + len(obj.__name__) > len(module.__name__) + and obj.__name__[0 : len(module.__name__)] == module.__name__ + ): + short_name = obj.__name__[len(module.__name__) + 1 :] + else: + short_name = "" + return short_name + + +def _is_scico_object(obj: Any) -> bool: + """Determine whether an object is defined in a scico submodule. + + Args: + obj: Object to check. + + Returns: + A boolean value indicating whether `obj` is defined in a scico + submodule. + """ + return hasattr(obj, "__module__") and obj.__module__[0:5] == "scico" + + +def _is_scico_module(mod: types.ModuleType) -> bool: + """Determine whether a module is a scico submodule. + + Args: + mod: Module to check. + + Returns: + A boolean value indicating whether `mod` is a scico submodule. + """ + return hasattr(mod, "__name__") and mod.__name__[0:5] == "scico" + + +def _in_module(mod: types.ModuleType, obj: Any) -> bool: + """Determine whether an object is defined in a module. + + Args: + mod: Module of interest. + obj: Object to check. + + Returns: + A boolean value indicating whether `obj` is defined in `mod`. + """ + return obj.__module__ == mod.__name__ + + +def _is_submodule(mod: types.ModuleType, submod: types.ModuleType) -> bool: + """Determine whether a module is a submodule of another module. + + Args: + mod: Parent module of interest. + submod: Possible submodule to check. + + Returns: + A boolean value indicating whether `submod` is defined in `mod`. + """ + return submod.__name__[0 : len(mod.__name__)] == mod.__name__ + + +def apply_decorator( + module: types.ModuleType, + decorator: Callable, + recursive: bool = True, + skip: Optional[Sequence] = None, + seen: Optional[defaultdict[str, int]] = None, + verbose: bool = False, + level: int = 0, +) -> defaultdict[str, int]: + """Apply a decorator function to all functions in a scico module. + + Apply a decorator function to all functions in a scico module, + including methods of classes in that module. + + Args: + module: The module containing the functions/methods to be + decorated. + decorator: The decorator function to apply to each module + function/method. + recursive: Flag indicating whether to recurse into submodules + of the specified module. (Hidden modules with a name starting + with an underscore are ignored.) + skip: A list of class/function/method names to be skipped. + seen: A :class:`defaultdict` providing a count of the number of + times each function/method was seen. + verbose: Flag indicating whether to print a log of functions + as they are encountered. + level: Counter for recursive call levels. + + Returns: + A :class:`defaultdict` providing a count of the number of times + each function/method was seen. + """ + indent = " " * 4 * level + if skip is None: + skip = [] + if seen is None: + seen = defaultdict(int) + if verbose: + print(f"{indent}Module: {module.__name__}") + indent += " " * 4 + + # Iterate over functions in module + for name, func in inspect.getmembers( + module, + lambda obj: isinstance(obj, (types.FunctionType, PjitFunction)) and _in_module(module, obj), + ): + if name in skip: + continue + qualname = func.__module__ + "." + func.__qualname__ + if not seen[qualname]: # avoid multiple applications of decorator + setattr(module, name, decorator(func)) + seen[qualname] += 1 + if verbose: + print(f"{indent}Function: {qualname}") + + # Iterate over classes in module + for name, cls in inspect.getmembers( + module, lambda obj: inspect.isclass(obj) and _in_module(module, obj) + ): + qualname = cls.__module__ + "." + cls.__qualname__ # type: ignore + if verbose: + print(f"{indent}Class: {qualname}") + + # Iterate over methods in class + for name, func in inspect.getmembers( + cls, lambda obj: isinstance(obj, (types.FunctionType, PjitFunction)) + ): + if name in skip: + continue + qualname = func.__module__ + "." + func.__qualname__ # type: ignore + if not seen[qualname]: # avoid multiple applications of decorator + # Can't use cls returned by inspect.getmembers because it uses plain + # getattr internally, which interferes with identification of static + # methods. From Python 3.11 onwards one could use + # inspect.getmembers_static instead of inspect.getmembers, but that + # would imply incompatibility with earlier Python versions. + func = inspect.getattr_static(cls, name) + setattr(cls, name, decorator(func)) + seen[qualname] += 1 + if verbose: + print(f"{indent + ' '}Method: {qualname}") + + # Iterate over submodules of module + if recursive: + for name, mod in inspect.getmembers( + module, lambda obj: inspect.ismodule(obj) and _is_submodule(module, obj) + ): + if name[0:1] == "_": + continue + seen = apply_decorator( + mod, + decorator, + recursive=recursive, + skip=skip, + seen=seen, + verbose=verbose, + level=level + 1, + ) + + return seen + + +def trace_scico_calls(verbose: bool = False): + """Enable tracing of calls to all significant scico functions/methods. + + Enable tracing of calls to all significant scico functions and + methods. Note that JIT should be disabled to ensure correct + functioning of the tracing mechanism. + """ + if not jax.config.jax_disable_jit: + warnings.warn( + "Call tracing requested but jit is not disabled. Disable jit" + " by setting the environment variable JAX_DISABLE_JIT=1, or use" + " jax.config.update('jax_disable_jit', True)." + ) + from scico import ( + function, + functional, + linop, + loss, + metric, + operator, + optimize, + solver, + ) + + seen = None + for module in (functional, linop, loss, operator, optimize, function, metric, solver): + seen = apply_decorator(module, call_trace, skip=["__repr__"], seen=seen, verbose=verbose) diff --git a/scico/util.py b/scico/util.py index d57c8efc6..138ed0ffa 100644 --- a/scico/util.py +++ b/scico/util.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2023 by SCICO Developers +# Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the