From 908fd2d6836c148db564a649aa52e58e3ed1fa78 Mon Sep 17 00:00:00 2001 From: Kevin Sun Date: Fri, 14 Feb 2025 09:49:44 +0800 Subject: [PATCH] Move debug_print to jit_util --- src/evox/core/__init__.py | 3 +-- src/evox/core/_vmap_fix.py | 30 +++--------------------------- src/evox/core/jit_util.py | 26 ++++++++++++++++++++++++++ 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/src/evox/core/__init__.py b/src/evox/core/__init__.py index d43c9d2e2..64506f506 100644 --- a/src/evox/core/__init__.py +++ b/src/evox/core/__init__.py @@ -19,9 +19,8 @@ # deal with vmap nesting and JIT from . import _vmap_fix -from ._vmap_fix import debug_print from .components import Algorithm, Monitor, Problem, Workflow -from .jit_util import jit, vmap +from .jit_util import debug_print, jit, vmap # export symbols from .module import ModuleBase, Mutable, Parameter, assign_load_state_dict, jit_class, trace_impl, use_state, vmap_impl diff --git a/src/evox/core/_vmap_fix.py b/src/evox/core/_vmap_fix.py index b4d38e1b0..ae87e8902 100644 --- a/src/evox/core/_vmap_fix.py +++ b/src/evox/core/_vmap_fix.py @@ -7,6 +7,8 @@ "register_fix_function", "unregister_fix_function", "use_batch_fixing", + "tree_flatten", + "tree_unflatten", "_set_func_id", ] @@ -50,7 +52,7 @@ def current_level() -> int | None: current_level = _functorch.maybe_current_level from torch import nn -from torch.utils._pytree import tree_flatten, tree_unflatten # noqa: F401 +from torch.utils._pytree import tree_flatten, tree_unflatten if "Buffer" not in nn.__dict__: nn.Buffer = nn.parameter.Buffer @@ -570,29 +572,3 @@ def align_vmap_tensor(value: Any, current_value: Any | None) -> torch.Tensor: value = value.unsqueeze(dim).expand(*value.shape[:dim], size, *value.shape[dim:]) value = wrap_batch_tensor(value, batch_dims) return value - - -def _debug_print(format: str, arg: torch.Tensor) -> torch.Tensor: - print(format.format(arg)) - return arg - - -def debug_print(format: str, arg: torch.Tensor) -> torch.Tensor: - """Prints a formatted string with one positional tensor used for debugging inside JIT traced functions on-the-fly. - - When vectorized-mapping, it unwraps the batched tensor to print the underlying values. - Otherwise, the function behaves like `format.format(*args, **kwargs)`. - - :param format: A string format. - :param arg: The positional tensor. - :return: The unchanged tensor. - """ - level = current_level() - if level is None or level <= 0: - inner_arg = arg - else: - inner_arg = unwrap_batch_tensor(arg)[0] - return torch.jit.script_if_tracing(_debug_print)(format, inner_arg) - - -debug_print.__prepare_scriptable__ = lambda: _debug_print diff --git a/src/evox/core/jit_util.py b/src/evox/core/jit_util.py index 3b666038e..63f69a17a 100644 --- a/src/evox/core/jit_util.py +++ b/src/evox/core/jit_util.py @@ -365,3 +365,29 @@ def jit_wrapper(*args, **kwargs): _vmap_fix._set_func_id(jit_wrapper, func) return jit_wrapper + + +def _debug_print(format: str, arg: torch.Tensor) -> torch.Tensor: + print(format.format(arg)) + return arg + + +def debug_print(format: str, arg: torch.Tensor) -> torch.Tensor: + """Prints a formatted string with one positional tensor used for debugging inside JIT traced functions on-the-fly. + + When vectorized-mapping, it unwraps the batched tensor to print the underlying values. + Otherwise, the function behaves like `format.format(*args, **kwargs)`. + + :param format: A string format. + :param arg: The positional tensor. + :return: The unchanged tensor. + """ + level = _vmap_fix.current_level() + if level is None or level <= 0: + inner_arg = arg + else: + inner_arg = _vmap_fix.unwrap_batch_tensor(arg)[0] + return torch.jit.script_if_tracing(_debug_print)(format, inner_arg) + + +debug_print.__prepare_scriptable__ = lambda: _debug_print \ No newline at end of file