Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 26, 2024
1 parent cc1a9d1 commit bc7836b
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 209 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def _get_pytorch_version(is_nightly):
# if "PYTORCH_VERSION" in os.environ:
# return f"torch=={os.environ['PYTORCH_VERSION']}"
if is_nightly:
return "torch>=2.2.0.dev"
return "torch>=2.1.0"
return "torch>=2.3.0.dev"
return "torch>=2.2.1"


def _get_packages():
Expand Down
16 changes: 0 additions & 16 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,22 +724,6 @@ def _apply_nest(
validated=checked,
)

if filter_empty and not any_set:
return
elif filter_empty is None and not any_set and not self.is_empty():
# we raise the deprecation warning only if the tensordict wasn't already empty.
# After we introduce the new behaviour, we will have to consider what happens
# to empty tensordicts by default: will they disappear or stay?
warn(
"Your resulting tensordict has no leaves but you did not specify filter_empty=False. "
"Currently, this returns an empty tree (filter_empty=True), but from v0.5 it will return "
"a None unless filter_empty=False. "
"To silcence this warning, set filter_empty to the desired value in your call to `apply`.",
category=DeprecationWarning,
)
if result is None:
result = make_result()

if not inplace and is_locked:
out.lock_()
return out
Expand Down
6 changes: 4 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
_td_fields,
_unravel_key_to_tuple,
as_decorator,
Buffer,
cache,
convert_ellipsis_to_idx,
DeviceType,
Expand All @@ -62,6 +63,7 @@
lock_blocked,
NestedKey,
prod,
set_lazy_legacy,
TensorDictFuture,
unravel_key,
unravel_key_list,
Expand Down Expand Up @@ -2319,7 +2321,7 @@ def _filter(x):
return x.filter_non_tensor_data()
return x

return self._apply_nest(_filter, call_on_nested=True, filter_empty=False)
return self._apply_nest(_filter, call_on_nested=True)

def _convert_inplace(self, inplace, key):
if inplace is not False:
Expand Down Expand Up @@ -3718,7 +3720,7 @@ def _reduce(
return

# Apply and map functionality
def apply_(self, fn: Callable, *others) -> T:
def apply_(self, fn: Callable, *others, **kwargs) -> T:
"""Applies a callable to all values stored in the tensordict and re-writes them in-place.
Args:
Expand Down
29 changes: 13 additions & 16 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,26 @@
"""Persistent tensordicts (H5 and others)."""
from __future__ import annotations

import importlib

import json
import os

import tempfile
import warnings
from pathlib import Path
from typing import Any, Callable, Type

from tensordict._td import _unravel_key_to_tuple
from torch import multiprocessing as mp

H5_ERR = None
try:
import h5py

_has_h5 = True
except ModuleNotFoundError as err:
H5_ERR = err
_has_h5 = False

import json
import os

import numpy as np
import torch
from tensordict._td import _TensorDictKeysView, CompatibleType, NO_DEFAULT, TensorDict

from tensordict._td import (
_TensorDictKeysView,
_unravel_key_to_tuple,
CompatibleType,
NO_DEFAULT,
TensorDict,
)
from tensordict.base import _default_is_leaf, is_tensor_collection, T, TensorDictBase
from tensordict.memmap import MemoryMappedTensor
from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor
Expand Down
169 changes: 0 additions & 169 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1919,85 +1919,6 @@ def format_size(size):
logging.info(indent + os.path.basename(path))


def isin(
input: TensorDictBase,
reference: TensorDictBase,
key: NestedKey,
dim: int = 0,
) -> Tensor:
"""Tests if each element of ``key`` in input ``dim`` is also present in the reference.
This function returns a boolean tensor of length ``input.batch_size[dim]`` that is ``True`` for elements in
the entry ``key`` that are also present in the ``reference``. This function assumes that both ``input`` and
``reference`` have the same batch size and contain the specified entry, otherwise an error will be raised.
Args:
input (TensorDictBase): Input TensorDict.
reference (TensorDictBase): Target TensorDict against which to test.
key (Nestedkey): The key to test.
dim (int, optional): The dimension along which to test. Defaults to ``0``.
Returns:
out (Tensor): A boolean tensor of length ``input.batch_size[dim]`` that is ``True`` for elements in
the ``input`` ``key`` tensor that are also present in the ``reference``.
Examples:
>>> td = TensorDict(
... {
... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [7, 8, 9]]),
... "tensor2": torch.tensor([[10, 20], [30, 40], [40, 50], [50, 60]]),
... },
... batch_size=[4],
... )
>>> td_ref = TensorDict(
... {
... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [10, 11, 12]]),
... "tensor2": torch.tensor([[10, 20], [30, 40], [50, 60]]),
... },
... batch_size=[3],
... )
>>> in_reference = isin(td, td_ref, key="tensor1")
>>> expected_in_reference = torch.tensor([True, True, True, False])
>>> torch.testing.assert_close(in_reference, expected_in_reference)
"""
# Get the data
reference_tensor = reference.get(key, default=None)
target_tensor = input.get(key, default=None)

# Check key is present in both tensordict and reference_tensordict
if not isinstance(target_tensor, torch.Tensor):
raise KeyError(f"Key '{key}' not found in input or not a tensor.")
if not isinstance(reference_tensor, torch.Tensor):
raise KeyError(f"Key '{key}' not found in reference or not a tensor.")

# Check that both TensorDicts have the same number of dimensions
if len(input.batch_size) != len(reference.batch_size):
raise ValueError(
"The number of dimensions in the batch size of the input and reference must be the same."
)

# Check dim is valid
batch_dims = input.ndim
if dim >= batch_dims or dim < -batch_dims or batch_dims == 0:
raise ValueError(
f"The specified dimension '{dim}' is invalid for an input TensorDict with batch size '{input.batch_size}'."
)

# Convert negative dimension to its positive equivalent
if dim < 0:
dim = batch_dims + dim

# Find the common indices
N = reference_tensor.shape[dim]
cat_data = torch.cat([reference_tensor, target_tensor], dim=dim)
_, unique_indices = torch.unique(
cat_data, dim=dim, sorted=True, return_inverse=True
)
out = torch.isin(unique_indices[N:], unique_indices[:N], assume_unique=True)

return out


def _index_preserve_data_ptr(index):
if isinstance(index, tuple):
return all(_index_preserve_data_ptr(idx) for idx in index)
Expand All @@ -2011,96 +1932,6 @@ def _index_preserve_data_ptr(index):
return False


def remove_duplicates(
input: TensorDictBase,
key: NestedKey,
dim: int = 0,
*,
return_indices: bool = False,
) -> TensorDictBase:
"""Removes indices duplicated in `key` along the specified dimension.
This method detects duplicate elements in the tensor associated with the specified `key` along the specified
`dim` and removes elements in the same indices in all other tensors within the TensorDict. It is expected for
`dim` to be one of the dimensions within the batch size of the input TensorDict to ensure consistency in all
tensors. Otherwise, an error will be raised.
Args:
input (TensorDictBase): The TensorDict containing potentially duplicate elements.
key (NestedKey): The key of the tensor along which duplicate elements should be identified and removed. It
must be one of the leaf keys within the TensorDict, pointing to a tensor and not to another TensorDict.
dim (int, optional): The dimension along which duplicate elements should be identified and removed. It must be one of
the dimensions within the batch size of the input TensorDict. Defaults to ``0``.
return_indices (bool, optional): If ``True``, the indices of the unique elements in the input tensor will be
returned as well. Defaults to ``False``.
Returns:
output (TensorDictBase): input tensordict with the indices corrsponding to duplicated elements
in tensor `key` along dimension `dim` removed.
unique_indices (torch.Tensor, optional): The indices of the first occurrences of the unique elements in the
input tensordict for the specified `key` along the specified `dim`. Only provided if return_index is True.
Example:
>>> td = TensorDict(
... {
... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [7, 8, 9]]),
... "tensor2": torch.tensor([[10, 20], [30, 40], [40, 50], [50, 60]]),
... }
... batch_size=[4],
... )
>>> output_tensordict = remove_duplicate_elements(td, key="tensor1", dim=0)
>>> expected_output = TensorDict(
... {
... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
... "tensor2": torch.tensor([[10, 20], [30, 40], [50, 60]]),
... },
... batch_size=[3],
... )
>>> assert (td == expected_output).all()
"""
tensor = input.get(key, default=None)

# Check if the key is a TensorDict
if tensor is None:
raise KeyError(f"The key '{key}' does not exist in the TensorDict.")

# Check that the key points to a tensor
if not isinstance(tensor, torch.Tensor):
raise KeyError(f"The key '{key}' does not point to a tensor in the TensorDict.")

# Check dim is valid
batch_dims = input.ndim
if dim >= batch_dims or dim < -batch_dims or batch_dims == 0:
raise ValueError(
f"The specified dimension '{dim}' is invalid for a TensorDict with batch size '{input.batch_size}'."
)

# Convert negative dimension to its positive equivalent
if dim < 0:
dim = batch_dims + dim

# Get indices of unique elements (e.g. [0, 1, 0, 2])
_, unique_indices, counts = torch.unique(
tensor, dim=dim, sorted=True, return_inverse=True, return_counts=True
)

# Find first occurrence of each index (e.g. [0, 1, 3])
_, unique_indices_sorted = torch.sort(unique_indices, stable=True)
cum_sum = counts.cumsum(0, dtype=torch.long)
cum_sum = torch.cat(
(torch.zeros(1, device=input.device, dtype=torch.long), cum_sum[:-1])
)
first_indices = unique_indices_sorted[cum_sum]

# Remove duplicate elements in the TensorDict
output = input[(slice(None),) * dim + (first_indices,)]

if return_indices:
return output, unique_indices

return output


class _CloudpickleWrapper(object):
def __init__(self, fn):
self.fn = fn
Expand Down
2 changes: 1 addition & 1 deletion test/test_functorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def zero_grad(p):
for p in params.flatten_keys().values()
)
assert params.requires_grad
params.apply_(zero_grad, filter_empty=True)
params.apply_(zero_grad)
assert params.requires_grad

def test_repopulate(self):
Expand Down
4 changes: 1 addition & 3 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3220,7 +3220,7 @@ def named_plus(name, x):
with pytest.raises(ValueError, match="Failed to update"):
td.named_apply(named_plus, inplace=inplace)
return
td_1 = td.named_apply(named_plus, inplace=inplace, filter_empty=True)
td_1 = td.named_apply(named_plus, inplace=inplace)
if inplace:
assert td_1 is td
for key in td_1.keys(True, True):
Expand Down Expand Up @@ -3253,12 +3253,10 @@ def count(name, value, keys):
td.named_apply(
functools.partial(count, keys=keys_complete),
nested_keys=True,
filter_empty=True,
)
td.named_apply(
functools.partial(count, keys=keys_not_complete),
nested_keys=False,
filter_empty=True,
)
assert len(keys_complete) == len(list(td.keys(True, True)))
assert len(keys_complete) > len(keys_not_complete)
Expand Down

0 comments on commit bc7836b

Please sign in to comment.