Skip to content

Commit

Permalink
Move debug_print to jit_util
Browse files Browse the repository at this point in the history
  • Loading branch information
sses7757 committed Feb 14, 2025
1 parent 31fa017 commit 908fd2d
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 29 deletions.
3 changes: 1 addition & 2 deletions src/evox/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@

# deal with vmap nesting and JIT
from . import _vmap_fix
from ._vmap_fix import debug_print
from .components import Algorithm, Monitor, Problem, Workflow
from .jit_util import jit, vmap
from .jit_util import debug_print, jit, vmap

# export symbols
from .module import ModuleBase, Mutable, Parameter, assign_load_state_dict, jit_class, trace_impl, use_state, vmap_impl
30 changes: 3 additions & 27 deletions src/evox/core/_vmap_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
"register_fix_function",
"unregister_fix_function",
"use_batch_fixing",
"tree_flatten",
"tree_unflatten",
"_set_func_id",
]

Expand Down Expand Up @@ -50,7 +52,7 @@ def current_level() -> int | None:
current_level = _functorch.maybe_current_level

from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten # noqa: F401
from torch.utils._pytree import tree_flatten, tree_unflatten

if "Buffer" not in nn.__dict__:
nn.Buffer = nn.parameter.Buffer
Expand Down Expand Up @@ -570,29 +572,3 @@ def align_vmap_tensor(value: Any, current_value: Any | None) -> torch.Tensor:
value = value.unsqueeze(dim).expand(*value.shape[:dim], size, *value.shape[dim:])
value = wrap_batch_tensor(value, batch_dims)
return value


def _debug_print(format: str, arg: torch.Tensor) -> torch.Tensor:
print(format.format(arg))
return arg


def debug_print(format: str, arg: torch.Tensor) -> torch.Tensor:
"""Prints a formatted string with one positional tensor used for debugging inside JIT traced functions on-the-fly.
When vectorized-mapping, it unwraps the batched tensor to print the underlying values.
Otherwise, the function behaves like `format.format(*args, **kwargs)`.
:param format: A string format.
:param arg: The positional tensor.
:return: The unchanged tensor.
"""
level = current_level()
if level is None or level <= 0:
inner_arg = arg
else:
inner_arg = unwrap_batch_tensor(arg)[0]
return torch.jit.script_if_tracing(_debug_print)(format, inner_arg)


debug_print.__prepare_scriptable__ = lambda: _debug_print
26 changes: 26 additions & 0 deletions src/evox/core/jit_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,29 @@ def jit_wrapper(*args, **kwargs):

_vmap_fix._set_func_id(jit_wrapper, func)
return jit_wrapper


def _debug_print(format: str, arg: torch.Tensor) -> torch.Tensor:
print(format.format(arg))
return arg


def debug_print(format: str, arg: torch.Tensor) -> torch.Tensor:
"""Prints a formatted string with one positional tensor used for debugging inside JIT traced functions on-the-fly.
When vectorized-mapping, it unwraps the batched tensor to print the underlying values.
Otherwise, the function behaves like `format.format(*args, **kwargs)`.
:param format: A string format.
:param arg: The positional tensor.
:return: The unchanged tensor.
"""
level = _vmap_fix.current_level()
if level is None or level <= 0:
inner_arg = arg
else:
inner_arg = _vmap_fix.unwrap_batch_tensor(arg)[0]
return torch.jit.script_if_tracing(_debug_print)(format, inner_arg)


debug_print.__prepare_scriptable__ = lambda: _debug_print

Check failure on line 393 in src/evox/core/jit_util.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W292)

src/evox/core/jit_util.py:393:58: W292 No newline at end of file

Check failure on line 393 in src/evox/core/jit_util.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W292)

src/evox/core/jit_util.py:393:58: W292 No newline at end of file

0 comments on commit 908fd2d

Please sign in to comment.