Skip to content

Commit bc7836b

Browse files
committed
cleanup
1 parent cc1a9d1 commit bc7836b

File tree

7 files changed

+21
-209
lines changed

7 files changed

+21
-209
lines changed

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def _get_pytorch_version(is_nightly):
6666
# if "PYTORCH_VERSION" in os.environ:
6767
# return f"torch=={os.environ['PYTORCH_VERSION']}"
6868
if is_nightly:
69-
return "torch>=2.2.0.dev"
70-
return "torch>=2.1.0"
69+
return "torch>=2.3.0.dev"
70+
return "torch>=2.2.1"
7171

7272

7373
def _get_packages():

tensordict/_td.py

-16
Original file line numberDiff line numberDiff line change
@@ -724,22 +724,6 @@ def _apply_nest(
724724
validated=checked,
725725
)
726726

727-
if filter_empty and not any_set:
728-
return
729-
elif filter_empty is None and not any_set and not self.is_empty():
730-
# we raise the deprecation warning only if the tensordict wasn't already empty.
731-
# After we introduce the new behaviour, we will have to consider what happens
732-
# to empty tensordicts by default: will they disappear or stay?
733-
warn(
734-
"Your resulting tensordict has no leaves but you did not specify filter_empty=False. "
735-
"Currently, this returns an empty tree (filter_empty=True), but from v0.5 it will return "
736-
"a None unless filter_empty=False. "
737-
"To silcence this warning, set filter_empty to the desired value in your call to `apply`.",
738-
category=DeprecationWarning,
739-
)
740-
if result is None:
741-
result = make_result()
742-
743727
if not inplace and is_locked:
744728
out.lock_()
745729
return out

tensordict/base.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
_td_fields,
5151
_unravel_key_to_tuple,
5252
as_decorator,
53+
Buffer,
5354
cache,
5455
convert_ellipsis_to_idx,
5556
DeviceType,
@@ -62,6 +63,7 @@
6263
lock_blocked,
6364
NestedKey,
6465
prod,
66+
set_lazy_legacy,
6567
TensorDictFuture,
6668
unravel_key,
6769
unravel_key_list,
@@ -2319,7 +2321,7 @@ def _filter(x):
23192321
return x.filter_non_tensor_data()
23202322
return x
23212323

2322-
return self._apply_nest(_filter, call_on_nested=True, filter_empty=False)
2324+
return self._apply_nest(_filter, call_on_nested=True)
23232325

23242326
def _convert_inplace(self, inplace, key):
23252327
if inplace is not False:
@@ -3718,7 +3720,7 @@ def _reduce(
37183720
return
37193721

37203722
# Apply and map functionality
3721-
def apply_(self, fn: Callable, *others) -> T:
3723+
def apply_(self, fn: Callable, *others, **kwargs) -> T:
37223724
"""Applies a callable to all values stored in the tensordict and re-writes them in-place.
37233725
37243726
Args:

tensordict/persistent.py

+13-16
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,26 @@
66
"""Persistent tensordicts (H5 and others)."""
77
from __future__ import annotations
88

9+
import importlib
10+
11+
import json
12+
import os
13+
914
import tempfile
1015
import warnings
1116
from pathlib import Path
1217
from typing import Any, Callable, Type
1318

14-
from tensordict._td import _unravel_key_to_tuple
15-
from torch import multiprocessing as mp
16-
17-
H5_ERR = None
18-
try:
19-
import h5py
20-
21-
_has_h5 = True
22-
except ModuleNotFoundError as err:
23-
H5_ERR = err
24-
_has_h5 = False
25-
26-
import json
27-
import os
28-
2919
import numpy as np
3020
import torch
31-
from tensordict._td import _TensorDictKeysView, CompatibleType, NO_DEFAULT, TensorDict
21+
22+
from tensordict._td import (
23+
_TensorDictKeysView,
24+
_unravel_key_to_tuple,
25+
CompatibleType,
26+
NO_DEFAULT,
27+
TensorDict,
28+
)
3229
from tensordict.base import _default_is_leaf, is_tensor_collection, T, TensorDictBase
3330
from tensordict.memmap import MemoryMappedTensor
3431
from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor

tensordict/utils.py

-169
Original file line numberDiff line numberDiff line change
@@ -1919,85 +1919,6 @@ def format_size(size):
19191919
logging.info(indent + os.path.basename(path))
19201920

19211921

1922-
def isin(
1923-
input: TensorDictBase,
1924-
reference: TensorDictBase,
1925-
key: NestedKey,
1926-
dim: int = 0,
1927-
) -> Tensor:
1928-
"""Tests if each element of ``key`` in input ``dim`` is also present in the reference.
1929-
1930-
This function returns a boolean tensor of length ``input.batch_size[dim]`` that is ``True`` for elements in
1931-
the entry ``key`` that are also present in the ``reference``. This function assumes that both ``input`` and
1932-
``reference`` have the same batch size and contain the specified entry, otherwise an error will be raised.
1933-
1934-
Args:
1935-
input (TensorDictBase): Input TensorDict.
1936-
reference (TensorDictBase): Target TensorDict against which to test.
1937-
key (Nestedkey): The key to test.
1938-
dim (int, optional): The dimension along which to test. Defaults to ``0``.
1939-
1940-
Returns:
1941-
out (Tensor): A boolean tensor of length ``input.batch_size[dim]`` that is ``True`` for elements in
1942-
the ``input`` ``key`` tensor that are also present in the ``reference``.
1943-
1944-
Examples:
1945-
>>> td = TensorDict(
1946-
... {
1947-
... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [7, 8, 9]]),
1948-
... "tensor2": torch.tensor([[10, 20], [30, 40], [40, 50], [50, 60]]),
1949-
... },
1950-
... batch_size=[4],
1951-
... )
1952-
>>> td_ref = TensorDict(
1953-
... {
1954-
... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [10, 11, 12]]),
1955-
... "tensor2": torch.tensor([[10, 20], [30, 40], [50, 60]]),
1956-
... },
1957-
... batch_size=[3],
1958-
... )
1959-
>>> in_reference = isin(td, td_ref, key="tensor1")
1960-
>>> expected_in_reference = torch.tensor([True, True, True, False])
1961-
>>> torch.testing.assert_close(in_reference, expected_in_reference)
1962-
"""
1963-
# Get the data
1964-
reference_tensor = reference.get(key, default=None)
1965-
target_tensor = input.get(key, default=None)
1966-
1967-
# Check key is present in both tensordict and reference_tensordict
1968-
if not isinstance(target_tensor, torch.Tensor):
1969-
raise KeyError(f"Key '{key}' not found in input or not a tensor.")
1970-
if not isinstance(reference_tensor, torch.Tensor):
1971-
raise KeyError(f"Key '{key}' not found in reference or not a tensor.")
1972-
1973-
# Check that both TensorDicts have the same number of dimensions
1974-
if len(input.batch_size) != len(reference.batch_size):
1975-
raise ValueError(
1976-
"The number of dimensions in the batch size of the input and reference must be the same."
1977-
)
1978-
1979-
# Check dim is valid
1980-
batch_dims = input.ndim
1981-
if dim >= batch_dims or dim < -batch_dims or batch_dims == 0:
1982-
raise ValueError(
1983-
f"The specified dimension '{dim}' is invalid for an input TensorDict with batch size '{input.batch_size}'."
1984-
)
1985-
1986-
# Convert negative dimension to its positive equivalent
1987-
if dim < 0:
1988-
dim = batch_dims + dim
1989-
1990-
# Find the common indices
1991-
N = reference_tensor.shape[dim]
1992-
cat_data = torch.cat([reference_tensor, target_tensor], dim=dim)
1993-
_, unique_indices = torch.unique(
1994-
cat_data, dim=dim, sorted=True, return_inverse=True
1995-
)
1996-
out = torch.isin(unique_indices[N:], unique_indices[:N], assume_unique=True)
1997-
1998-
return out
1999-
2000-
20011922
def _index_preserve_data_ptr(index):
20021923
if isinstance(index, tuple):
20031924
return all(_index_preserve_data_ptr(idx) for idx in index)
@@ -2011,96 +1932,6 @@ def _index_preserve_data_ptr(index):
20111932
return False
20121933

20131934

2014-
def remove_duplicates(
2015-
input: TensorDictBase,
2016-
key: NestedKey,
2017-
dim: int = 0,
2018-
*,
2019-
return_indices: bool = False,
2020-
) -> TensorDictBase:
2021-
"""Removes indices duplicated in `key` along the specified dimension.
2022-
2023-
This method detects duplicate elements in the tensor associated with the specified `key` along the specified
2024-
`dim` and removes elements in the same indices in all other tensors within the TensorDict. It is expected for
2025-
`dim` to be one of the dimensions within the batch size of the input TensorDict to ensure consistency in all
2026-
tensors. Otherwise, an error will be raised.
2027-
2028-
Args:
2029-
input (TensorDictBase): The TensorDict containing potentially duplicate elements.
2030-
key (NestedKey): The key of the tensor along which duplicate elements should be identified and removed. It
2031-
must be one of the leaf keys within the TensorDict, pointing to a tensor and not to another TensorDict.
2032-
dim (int, optional): The dimension along which duplicate elements should be identified and removed. It must be one of
2033-
the dimensions within the batch size of the input TensorDict. Defaults to ``0``.
2034-
return_indices (bool, optional): If ``True``, the indices of the unique elements in the input tensor will be
2035-
returned as well. Defaults to ``False``.
2036-
2037-
Returns:
2038-
output (TensorDictBase): input tensordict with the indices corrsponding to duplicated elements
2039-
in tensor `key` along dimension `dim` removed.
2040-
unique_indices (torch.Tensor, optional): The indices of the first occurrences of the unique elements in the
2041-
input tensordict for the specified `key` along the specified `dim`. Only provided if return_index is True.
2042-
2043-
Example:
2044-
>>> td = TensorDict(
2045-
... {
2046-
... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [7, 8, 9]]),
2047-
... "tensor2": torch.tensor([[10, 20], [30, 40], [40, 50], [50, 60]]),
2048-
... }
2049-
... batch_size=[4],
2050-
... )
2051-
>>> output_tensordict = remove_duplicate_elements(td, key="tensor1", dim=0)
2052-
>>> expected_output = TensorDict(
2053-
... {
2054-
... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
2055-
... "tensor2": torch.tensor([[10, 20], [30, 40], [50, 60]]),
2056-
... },
2057-
... batch_size=[3],
2058-
... )
2059-
>>> assert (td == expected_output).all()
2060-
"""
2061-
tensor = input.get(key, default=None)
2062-
2063-
# Check if the key is a TensorDict
2064-
if tensor is None:
2065-
raise KeyError(f"The key '{key}' does not exist in the TensorDict.")
2066-
2067-
# Check that the key points to a tensor
2068-
if not isinstance(tensor, torch.Tensor):
2069-
raise KeyError(f"The key '{key}' does not point to a tensor in the TensorDict.")
2070-
2071-
# Check dim is valid
2072-
batch_dims = input.ndim
2073-
if dim >= batch_dims or dim < -batch_dims or batch_dims == 0:
2074-
raise ValueError(
2075-
f"The specified dimension '{dim}' is invalid for a TensorDict with batch size '{input.batch_size}'."
2076-
)
2077-
2078-
# Convert negative dimension to its positive equivalent
2079-
if dim < 0:
2080-
dim = batch_dims + dim
2081-
2082-
# Get indices of unique elements (e.g. [0, 1, 0, 2])
2083-
_, unique_indices, counts = torch.unique(
2084-
tensor, dim=dim, sorted=True, return_inverse=True, return_counts=True
2085-
)
2086-
2087-
# Find first occurrence of each index (e.g. [0, 1, 3])
2088-
_, unique_indices_sorted = torch.sort(unique_indices, stable=True)
2089-
cum_sum = counts.cumsum(0, dtype=torch.long)
2090-
cum_sum = torch.cat(
2091-
(torch.zeros(1, device=input.device, dtype=torch.long), cum_sum[:-1])
2092-
)
2093-
first_indices = unique_indices_sorted[cum_sum]
2094-
2095-
# Remove duplicate elements in the TensorDict
2096-
output = input[(slice(None),) * dim + (first_indices,)]
2097-
2098-
if return_indices:
2099-
return output, unique_indices
2100-
2101-
return output
2102-
2103-
21041935
class _CloudpickleWrapper(object):
21051936
def __init__(self, fn):
21061937
self.fn = fn

test/test_functorch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def zero_grad(p):
362362
for p in params.flatten_keys().values()
363363
)
364364
assert params.requires_grad
365-
params.apply_(zero_grad, filter_empty=True)
365+
params.apply_(zero_grad)
366366
assert params.requires_grad
367367

368368
def test_repopulate(self):

test/test_tensordict.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -3220,7 +3220,7 @@ def named_plus(name, x):
32203220
with pytest.raises(ValueError, match="Failed to update"):
32213221
td.named_apply(named_plus, inplace=inplace)
32223222
return
3223-
td_1 = td.named_apply(named_plus, inplace=inplace, filter_empty=True)
3223+
td_1 = td.named_apply(named_plus, inplace=inplace)
32243224
if inplace:
32253225
assert td_1 is td
32263226
for key in td_1.keys(True, True):
@@ -3253,12 +3253,10 @@ def count(name, value, keys):
32533253
td.named_apply(
32543254
functools.partial(count, keys=keys_complete),
32553255
nested_keys=True,
3256-
filter_empty=True,
32573256
)
32583257
td.named_apply(
32593258
functools.partial(count, keys=keys_not_complete),
32603259
nested_keys=False,
3261-
filter_empty=True,
32623260
)
32633261
assert len(keys_complete) == len(list(td.keys(True, True)))
32643262
assert len(keys_complete) > len(keys_not_complete)

0 commit comments

Comments
 (0)