Skip to content

Commit

Permalink
Made 2 improves
Browse files Browse the repository at this point in the history
1. Improve vmap_fix to allow custom register vmap fixing functions
2. Fix ParamsAndVector to enable vmap of its methods
  • Loading branch information
sses7757 committed Feb 13, 2025
1 parent 261e007 commit 31fa017
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 60 deletions.
258 changes: 223 additions & 35 deletions src/evox/core/_vmap_fix.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
__all__ = [
"add_batch_dim",
"get_level",
"current_level",
"unwrap_batch_tensor",
"wrap_batch_tensor",
"register_fix_function",
"unregister_fix_function",
"use_batch_fixing",
"_set_func_id",
]

import math
import warnings
from contextlib import contextmanager
from contextvars import ContextVar, Token
from typing import Any, Callable, List, Sequence, Tuple
from threading import Lock
from typing import Any, Callable, Dict, List, Sequence, Tuple

# cSpell:words bdim batchedtensor
import torch
Expand Down Expand Up @@ -196,6 +211,10 @@ def batched_random_like(rand_func: Callable, like_tensor: torch.Tensor, **kwargs
_original_randint_like = torch.randint_like
_original_get_item = torch.Tensor.__getitem__
_original_set_item = torch.Tensor.__setitem__
_original_reshape = torch.reshape
_original_view = torch.Tensor.view
_original_flatten = torch.flatten
_original_unflatten = torch.unflatten


def _batch_size(tensor: torch.Tensor, dim: int | None = None):
Expand Down Expand Up @@ -297,16 +316,43 @@ def _batch_randint_like(like_tensor, **kwargs):
return batched_random_like(_original_randint_like, like_tensor, **kwargs)


def _batch_getitem(tensor: torch.Tensor, indices, dim=0):
def _batch_getitem(tensor: torch.Tensor, indices):
level = current_level()
if level is None or level <= 0:
return _original_get_item(tensor, indices)
# else
if isinstance(indices, torch.Tensor) and indices.ndim <= 1:
tensor = torch.index_select(tensor, dim, indices)
# special case for scalar
if indices is None and tensor.ndim == 0:
return tensor.unsqueeze(0)
# special case for single index
if isinstance(indices, torch.Tensor):
tensor = torch.index_select(tensor, 0, indices.flatten())
if indices.ndim == 0:
tensor = tensor.__getitem__(*(([slice(None)] * dim) + [0]))
return tensor
return _original_get_item(tensor, 0)
else:
return tensor.unflatten(0, indices.size())
if not isinstance(indices, Sequence):
return _original_get_item(tensor, indices)
# else
if all(map(lambda ind: isinstance(ind, torch.Tensor) and not ind.dtype.is_floating_point, indices)):
# https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing
indices: List[torch.Tensor] = list(indices)
assert all(map(lambda ind: ind.size() == indices[0].size(), indices[1:])), "Expect all index tensors have same shape."
if get_level(tensor) <= 0:
return _original_get_item(tensor, indices)
original_indices = [unwrap_batch_tensor(ind)[0] for ind in indices]
original_tensor, dims, sizes = unwrap_batch_tensor(tensor)
for d, s in zip(dims, sizes):
identity = torch.arange(0, s, dtype=indices[0].dtype, device=tensor.device)
original_indices.insert(d, (identity,))
for i, identity in enumerate(original_indices):
if not isinstance(identity, tuple):
continue
original_shape = [1] * original_tensor.ndim
original_shape[i] = -1
identity = identity[0].view(*original_shape)
original_indices[i] = identity
original_tensor = torch._unsafe_index(original_tensor, original_indices)
return wrap_batch_tensor(original_tensor, dims)
# default
return _original_get_item(tensor, indices)

Expand All @@ -319,38 +365,182 @@ def _batch_setitem(tensor: torch.Tensor, indices, values, dim=0):
return _original_set_item(tensor, indices, values)


def _get_original_dims(tensor: torch.Tensor, batch_dims: Tuple[int], batch_sizes: Tuple[int]):
ori_shape = list(tensor.size())
for i, (d, s) in enumerate(zip(batch_dims, batch_sizes)):
ori_shape.insert(d, (s, i))
ori_dims = [-1] * len(batch_dims)
for i, s in enumerate(ori_shape):
if isinstance(s, tuple):
ori_shape[i] = s[0]
ori_dims[s[1]] = i
return tuple(ori_dims)


def _special_reshape(ori_tensor: torch.Tensor, ori_dims: Tuple[int], new_shape: Tuple[int]):
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
# check
original_ndim = ori_tensor.ndim
for d in ori_dims:
assert 0 <= d < original_ndim
assert len(set(ori_dims)) == len(ori_dims)
# get remaining dims
preserved_dims = list(ori_dims)
remaining_dims = [d for d in range(original_ndim) if d not in ori_dims]
permute_order = preserved_dims + remaining_dims
# get actual new shape
preserved_shape = [ori_tensor.size(d) for d in preserved_dims]
remaining_size = 1
for d in remaining_dims:
remaining_size *= ori_tensor.size(d)
# check new shape
new_shape_prod = 1
for s in new_shape:
assert new_shape_prod > 0 or s > 0, "Cannot have multiple dimensions with size=-1"
new_shape_prod *= s
assert remaining_size == new_shape_prod or (new_shape_prod < 0 and remaining_size % (-new_shape_prod) == 0), (
f"Cannot reshape size {remaining_size} to {new_shape} ({ori_dims}, {ori_tensor.size()})"
)
# permute and reshape
permuted = ori_tensor.permute(*permute_order)
return _original_reshape(permuted, preserved_shape + list(new_shape))


def _batch_reshape(tensor: torch.Tensor, *shape):
if isinstance(shape[0], Sequence):
shape = shape[0]
level = get_level(tensor)
if level is None or level <= 0:
return _original_reshape(tensor, shape)
# else
original_tensor, dims, sizes = unwrap_batch_tensor(tensor)
ori_dims = _get_original_dims(tensor, dims, sizes)
new_tensor = _special_reshape(original_tensor, ori_dims, shape)
return wrap_batch_tensor(new_tensor, tuple(range(len(ori_dims))))


def _batch_view(tensor: torch.Tensor, *args, **kwargs):
if "dtype" in kwargs or isinstance(args[0], torch.dtype):
return _original_view(tensor, *args, **kwargs)
level = get_level(tensor)
if level is None or level <= 0:
return _original_view(tensor, *args, **kwargs)
# else
shape = kwargs.get("size", None) or args
if isinstance(shape, Sequence) and isinstance(shape[0], Sequence):
shape = shape[0]
original_tensor, dims, sizes = unwrap_batch_tensor(tensor)
ori_dims = _get_original_dims(tensor, dims, sizes)
new_tensor = _special_reshape(original_tensor, ori_dims, shape)
return wrap_batch_tensor(new_tensor, tuple(range(len(ori_dims))))


def _batch_flatten(tensor: torch.Tensor, start_dim=0, end_dim=-1):
level = get_level(tensor)
if level is None or level <= 0:
return _original_flatten(tensor, start_dim, end_dim)
# else
original_tensor, dims, sizes = unwrap_batch_tensor(tensor)
ori_dims = _get_original_dims(tensor, dims, sizes)
shape = list(tensor.size())
if end_dim not in [-1, tensor.ndim - 1]:
shape = shape[:start_dim] + [math.prod(shape[start_dim:end_dim + 1])] + shape[end_dim + 1:]
else:
shape = shape[:start_dim] + [math.prod(shape[start_dim:])]
new_tensor = _special_reshape(original_tensor, ori_dims, shape)
return wrap_batch_tensor(new_tensor, tuple(range(len(ori_dims))))


def _batch_unflatten(tensor: torch.Tensor, dim: int, unflattened_size: Sequence[int]):
level = get_level(tensor)
if level is None or level <= 0:
return _original_unflatten(tensor, dim, unflattened_size)
# else
original_tensor, dims, sizes = unwrap_batch_tensor(tensor)
ori_dims = _get_original_dims(tensor, dims, sizes)
shape = list(tensor.size())
if dim not in [-1, tensor.ndim - 1]:
shape = shape[:dim] + list(unflattened_size) + shape[dim + 1 :]
else:
shape = shape[:dim] + list(unflattened_size)
new_tensor = _special_reshape(original_tensor, ori_dims, shape)
return wrap_batch_tensor(new_tensor, tuple(range(len(ori_dims))))


_fix_functions_lock = Lock()
_fix_functions: Dict[str, Tuple[Any, Callable, Callable]] = {}


def register_fix_function(namespace: str, name: str, fix_function: Callable):
"""Register a function that fixes a PyTorch function for `torch.vmap`.
:param namespace: The namespace of the function to be fixed, e.g. "torch".
:param name: The name of the function to be fixed, e.g. "randn_like".
:param fix_function: The function that fixes the original function.
:raises AssertionError: If the specified function is not callable.
"""
namespace_obj = eval(namespace)
original_function = getattr(namespace_obj, name, None)
assert original_function is not None and callable(original_function), f"{namespace}.{name} is not callable"
_fix_functions_lock.acquire()
try:
_fix_functions[f"{namespace}.{name}"] = (namespace_obj, original_function, fix_function)
finally:
_fix_functions_lock.release()


def unregister_fix_function(namespace: str, name: str):
"""Unregister a function that fixes a PyTorch function for `torch.vmap`.
:param namespace: The namespace of the function to be unregistered, e.g. "torch".
:param name: The name of the function to be unregistered, e.g. "randn_like".
"""
_fix_functions_lock.acquire()
try:
del _fix_functions[f"{namespace}.{name}"]
finally:
_fix_functions_lock.release()


register_fix_function("torch", "rand", _batch_rand)
register_fix_function("torch", "randn", _batch_randn)
register_fix_function("torch", "randint", _batch_randint)
register_fix_function("torch", "randperm", _batch_randperm)
register_fix_function("torch", "rand_like", _batch_rand_like)
register_fix_function("torch", "randn_like", _batch_randn_like)
register_fix_function("torch", "randint_like", _batch_randint_like)
register_fix_function("torch.Tensor", "__getitem__", _batch_getitem)
register_fix_function("torch.Tensor", "__setitem__", _batch_setitem)
register_fix_function("torch.Tensor", "size", _batch_size)

register_fix_function("torch.Tensor", "reshape", _batch_reshape)
register_fix_function("torch.Tensor", "view", _batch_view)
register_fix_function("torch.Tensor", "flatten", _batch_flatten)
register_fix_function("torch.Tensor", "unflatten", _batch_unflatten)
register_fix_function("torch", "reshape", _batch_reshape)
register_fix_function("torch", "flatten", _batch_flatten)
register_fix_function("torch", "unflatten", _batch_unflatten)


_batch_fixing: ContextVar[bool] = ContextVar("batch_fixing", default=False)


@contextmanager
def use_batch_fixing(new_batch_fixing: bool = True):
# Set the new state and obtain a token
token: Token = _batch_fixing.set(new_batch_fixing)
torch.Tensor.size = _batch_size if new_batch_fixing else _original_size
torch.rand = _batch_rand if new_batch_fixing else _original_rand
torch.randn = _batch_randn if new_batch_fixing else _original_randn
torch.randint = _batch_randint if new_batch_fixing else _original_randint
torch.randperm = _batch_randperm if new_batch_fixing else _original_randperm
torch.rand_like = _batch_rand_like if new_batch_fixing else _original_rand_like
torch.randn_like = _batch_randn_like if new_batch_fixing else _original_randn_like
torch.randint_like = _batch_randint_like if new_batch_fixing else _original_randint_like
torch.Tensor.__getitem__ = _batch_getitem if new_batch_fixing else _original_get_item
torch.Tensor.__setitem__ = _batch_setitem if new_batch_fixing else _original_set_item
if new_batch_fixing:
for name, (namespace_obj, _, fix_function) in _fix_functions.items():
setattr(namespace_obj, name.split(".")[-1], fix_function)
try:
yield token
finally:
# Reset the state to its previous value
_batch_fixing.reset(token)
torch.Tensor.size = _original_size
torch.rand = _original_rand
torch.randn = _original_randn
torch.randint = _original_randint
torch.randperm = _original_randperm
torch.rand_like = _original_rand_like
torch.randn_like = _original_randn_like
torch.randint_like = _original_randint_like
torch.Tensor.__getitem__ = _original_get_item
torch.Tensor.__setitem__ = _original_set_item
for name, (namespace_obj, original_function, _) in _fix_functions.items():
setattr(namespace_obj, name.split(".")[-1], original_function)


def align_vmap_tensor(value: Any, current_value: Any | None) -> torch.Tensor:
Expand All @@ -362,14 +552,11 @@ def align_vmap_tensor(value: Any, current_value: Any | None) -> torch.Tensor:
already a batched tensor or `current_value` is not a batched tensor, it
returns `value` unchanged.
:param value: The tensor to be aligned. If not a `torch.Tensor`, it is
returned unchanged.
:param current_value: The reference batched tensor. If `None` or
not a batched tensor, `value` is returned
unchanged.
:param value: The tensor to be aligned. If not a `torch.Tensor`, it is returned unchanged.
:param current_value: The reference batched tensor. If `None` or not a batched tensor,
`value` is returned unchanged.
:return: The input `value` aligned with the batch dimensions of
`current_value`, if applicable.
:return: The input `value` aligned with the batch dimensions of `current_value`, if applicable.
"""

if not isinstance(value, torch.Tensor):
Expand All @@ -393,7 +580,8 @@ def _debug_print(format: str, arg: torch.Tensor) -> torch.Tensor:
def debug_print(format: str, arg: torch.Tensor) -> torch.Tensor:
"""Prints a formatted string with one positional tensor used for debugging inside JIT traced functions on-the-fly.
When vectorized-mapping, it unwraps the batched tensor to print the underlying values. Otherwise, the function behaves like `format.format(*args, **kwargs)`.
When vectorized-mapping, it unwraps the batched tensor to print the underlying values.
Otherwise, the function behaves like `format.format(*args, **kwargs)`.
:param format: A string format.
:param arg: The positional tensor.
Expand Down
4 changes: 2 additions & 2 deletions src/evox/core/jit_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def jit(
no_cache: bool = False,
return_dummy_output: bool = False,
debug_manual_seed: int | None = None,
) -> T | UseStateFunc | MappedUseStateFunc:
) -> T | UseStateFunc | MappedUseStateFunc | Tuple[T | UseStateFunc | MappedUseStateFunc, Any]:
"""Just-In-Time (JIT) compile the given `func` via [`torch.jit.trace`](https://pytorch.org/docs/stable/generated/torch.jit.script.html) (`trace=True`) or [`torch.jit.script`](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) (`trace=False`).
This function wrapper effectively deals with nested JIT and vector map (`vmap`) expressions like `jit(func1)` -> `vmap` -> `jit(func2)`,
Expand All @@ -250,7 +250,7 @@ def jit(
:param return_dummy_output: Whether to return the dummy output of `func` as the second output or not. Defaults to False. Has no effect when `trace=False` or `lazy=True` or `no_cache=True`.
:param debug_manual_seed: The manual seed to be set before each running of the function. Defaults to None. Has no effect when `trace=False`. None means no manual seed will be set. Notice that any value other than None changes the GLOBAL random seed.
:return: The JIT version of `func`
:return: The JIT version of `func`. If `return_dummy_output=True` works, returns the JIT function and its dummy output.
"""
if is_generator:
func = func()
Expand Down
Loading

0 comments on commit 31fa017

Please sign in to comment.