@@ -1919,85 +1919,6 @@ def format_size(size):
1919
1919
logging .info (indent + os .path .basename (path ))
1920
1920
1921
1921
1922
- def isin (
1923
- input : TensorDictBase ,
1924
- reference : TensorDictBase ,
1925
- key : NestedKey ,
1926
- dim : int = 0 ,
1927
- ) -> Tensor :
1928
- """Tests if each element of ``key`` in input ``dim`` is also present in the reference.
1929
-
1930
- This function returns a boolean tensor of length ``input.batch_size[dim]`` that is ``True`` for elements in
1931
- the entry ``key`` that are also present in the ``reference``. This function assumes that both ``input`` and
1932
- ``reference`` have the same batch size and contain the specified entry, otherwise an error will be raised.
1933
-
1934
- Args:
1935
- input (TensorDictBase): Input TensorDict.
1936
- reference (TensorDictBase): Target TensorDict against which to test.
1937
- key (Nestedkey): The key to test.
1938
- dim (int, optional): The dimension along which to test. Defaults to ``0``.
1939
-
1940
- Returns:
1941
- out (Tensor): A boolean tensor of length ``input.batch_size[dim]`` that is ``True`` for elements in
1942
- the ``input`` ``key`` tensor that are also present in the ``reference``.
1943
-
1944
- Examples:
1945
- >>> td = TensorDict(
1946
- ... {
1947
- ... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [7, 8, 9]]),
1948
- ... "tensor2": torch.tensor([[10, 20], [30, 40], [40, 50], [50, 60]]),
1949
- ... },
1950
- ... batch_size=[4],
1951
- ... )
1952
- >>> td_ref = TensorDict(
1953
- ... {
1954
- ... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [10, 11, 12]]),
1955
- ... "tensor2": torch.tensor([[10, 20], [30, 40], [50, 60]]),
1956
- ... },
1957
- ... batch_size=[3],
1958
- ... )
1959
- >>> in_reference = isin(td, td_ref, key="tensor1")
1960
- >>> expected_in_reference = torch.tensor([True, True, True, False])
1961
- >>> torch.testing.assert_close(in_reference, expected_in_reference)
1962
- """
1963
- # Get the data
1964
- reference_tensor = reference .get (key , default = None )
1965
- target_tensor = input .get (key , default = None )
1966
-
1967
- # Check key is present in both tensordict and reference_tensordict
1968
- if not isinstance (target_tensor , torch .Tensor ):
1969
- raise KeyError (f"Key '{ key } ' not found in input or not a tensor." )
1970
- if not isinstance (reference_tensor , torch .Tensor ):
1971
- raise KeyError (f"Key '{ key } ' not found in reference or not a tensor." )
1972
-
1973
- # Check that both TensorDicts have the same number of dimensions
1974
- if len (input .batch_size ) != len (reference .batch_size ):
1975
- raise ValueError (
1976
- "The number of dimensions in the batch size of the input and reference must be the same."
1977
- )
1978
-
1979
- # Check dim is valid
1980
- batch_dims = input .ndim
1981
- if dim >= batch_dims or dim < - batch_dims or batch_dims == 0 :
1982
- raise ValueError (
1983
- f"The specified dimension '{ dim } ' is invalid for an input TensorDict with batch size '{ input .batch_size } '."
1984
- )
1985
-
1986
- # Convert negative dimension to its positive equivalent
1987
- if dim < 0 :
1988
- dim = batch_dims + dim
1989
-
1990
- # Find the common indices
1991
- N = reference_tensor .shape [dim ]
1992
- cat_data = torch .cat ([reference_tensor , target_tensor ], dim = dim )
1993
- _ , unique_indices = torch .unique (
1994
- cat_data , dim = dim , sorted = True , return_inverse = True
1995
- )
1996
- out = torch .isin (unique_indices [N :], unique_indices [:N ], assume_unique = True )
1997
-
1998
- return out
1999
-
2000
-
2001
1922
def _index_preserve_data_ptr (index ):
2002
1923
if isinstance (index , tuple ):
2003
1924
return all (_index_preserve_data_ptr (idx ) for idx in index )
@@ -2011,96 +1932,6 @@ def _index_preserve_data_ptr(index):
2011
1932
return False
2012
1933
2013
1934
2014
- def remove_duplicates (
2015
- input : TensorDictBase ,
2016
- key : NestedKey ,
2017
- dim : int = 0 ,
2018
- * ,
2019
- return_indices : bool = False ,
2020
- ) -> TensorDictBase :
2021
- """Removes indices duplicated in `key` along the specified dimension.
2022
-
2023
- This method detects duplicate elements in the tensor associated with the specified `key` along the specified
2024
- `dim` and removes elements in the same indices in all other tensors within the TensorDict. It is expected for
2025
- `dim` to be one of the dimensions within the batch size of the input TensorDict to ensure consistency in all
2026
- tensors. Otherwise, an error will be raised.
2027
-
2028
- Args:
2029
- input (TensorDictBase): The TensorDict containing potentially duplicate elements.
2030
- key (NestedKey): The key of the tensor along which duplicate elements should be identified and removed. It
2031
- must be one of the leaf keys within the TensorDict, pointing to a tensor and not to another TensorDict.
2032
- dim (int, optional): The dimension along which duplicate elements should be identified and removed. It must be one of
2033
- the dimensions within the batch size of the input TensorDict. Defaults to ``0``.
2034
- return_indices (bool, optional): If ``True``, the indices of the unique elements in the input tensor will be
2035
- returned as well. Defaults to ``False``.
2036
-
2037
- Returns:
2038
- output (TensorDictBase): input tensordict with the indices corrsponding to duplicated elements
2039
- in tensor `key` along dimension `dim` removed.
2040
- unique_indices (torch.Tensor, optional): The indices of the first occurrences of the unique elements in the
2041
- input tensordict for the specified `key` along the specified `dim`. Only provided if return_index is True.
2042
-
2043
- Example:
2044
- >>> td = TensorDict(
2045
- ... {
2046
- ... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [7, 8, 9]]),
2047
- ... "tensor2": torch.tensor([[10, 20], [30, 40], [40, 50], [50, 60]]),
2048
- ... }
2049
- ... batch_size=[4],
2050
- ... )
2051
- >>> output_tensordict = remove_duplicate_elements(td, key="tensor1", dim=0)
2052
- >>> expected_output = TensorDict(
2053
- ... {
2054
- ... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
2055
- ... "tensor2": torch.tensor([[10, 20], [30, 40], [50, 60]]),
2056
- ... },
2057
- ... batch_size=[3],
2058
- ... )
2059
- >>> assert (td == expected_output).all()
2060
- """
2061
- tensor = input .get (key , default = None )
2062
-
2063
- # Check if the key is a TensorDict
2064
- if tensor is None :
2065
- raise KeyError (f"The key '{ key } ' does not exist in the TensorDict." )
2066
-
2067
- # Check that the key points to a tensor
2068
- if not isinstance (tensor , torch .Tensor ):
2069
- raise KeyError (f"The key '{ key } ' does not point to a tensor in the TensorDict." )
2070
-
2071
- # Check dim is valid
2072
- batch_dims = input .ndim
2073
- if dim >= batch_dims or dim < - batch_dims or batch_dims == 0 :
2074
- raise ValueError (
2075
- f"The specified dimension '{ dim } ' is invalid for a TensorDict with batch size '{ input .batch_size } '."
2076
- )
2077
-
2078
- # Convert negative dimension to its positive equivalent
2079
- if dim < 0 :
2080
- dim = batch_dims + dim
2081
-
2082
- # Get indices of unique elements (e.g. [0, 1, 0, 2])
2083
- _ , unique_indices , counts = torch .unique (
2084
- tensor , dim = dim , sorted = True , return_inverse = True , return_counts = True
2085
- )
2086
-
2087
- # Find first occurrence of each index (e.g. [0, 1, 3])
2088
- _ , unique_indices_sorted = torch .sort (unique_indices , stable = True )
2089
- cum_sum = counts .cumsum (0 , dtype = torch .long )
2090
- cum_sum = torch .cat (
2091
- (torch .zeros (1 , device = input .device , dtype = torch .long ), cum_sum [:- 1 ])
2092
- )
2093
- first_indices = unique_indices_sorted [cum_sum ]
2094
-
2095
- # Remove duplicate elements in the TensorDict
2096
- output = input [(slice (None ),) * dim + (first_indices ,)]
2097
-
2098
- if return_indices :
2099
- return output , unique_indices
2100
-
2101
- return output
2102
-
2103
-
2104
1935
class _CloudpickleWrapper (object ):
2105
1936
def __init__ (self , fn ):
2106
1937
self .fn = fn
0 commit comments