Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 54 additions & 17 deletions pyttb/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,32 +317,69 @@ def collapse(
fun: Callable[[np.ndarray], Union[float, np.ndarray]] = np.sum,
) -> Union[float, np.ndarray, tensor]:
"""
Collapse tensor along specified dimensions.
Collapse tensor along specified dimensions using a function.

Parameters
----------
dims:
Dimensions to collapse.
fun:
Method used to collapse dimensions.
dims: optional
Dimensions to collapse (default: all).
fun: optional
Method used to collapse dimensions (default: :meth:`numpy.sum`).

Returns
-------
Collapsed value.
Scalar (if all dimensions collapsed) or tensor.

Examples
--------
>>> T = ttb.tensor(np.ones((2, 2)))
>>> T.collapse()
4.0
>>> T.collapse(np.array([0]))
tensor of shape (2,) with order F
data[:] =
[2. 2.]
>>> T.collapse(np.arange(T.ndims), sum)
4.0
>>> T.collapse(np.arange(T.ndims), np.prod)
1.0
Sum all elements of tensor::

>>> T = ttb.tensor(np.ones((4,3,2),order='F'))
>>> T.collapse()
24.0

Compute the sum for each mode-0 fiber (output is a tensor)::

>>> T.collapse(0)
tensor of shape (3, 2) with order F
data[:, :] =
[[4. 4.]
[4. 4.]
[4. 4.]]

Compute the sum of the entries in each mode-0 slice (output is a tensor)::

>>> T.collapse([1, 2])
tensor of shape (4,) with order F
data[:] =
[6. 6. 6. 6.]

Compute the max entry in each mode-2 slice (output is a tensor)::

>>> T.collapse([0, 1], np.max)
tensor of shape (2,) with order F
data[:] =
[1. 1.]

Find the maximum and minimum values in a tensor::

>>> randn = lambda s : np.random.randn(np.prod(s))
>>> np.random.seed(0) # reproducibility
>>> T = ttb.tensor.from_function(randn, (2, 2, 2))
>>> print(T)
tensor of shape (2, 2, 2) with order F
data[:, :, 0] =
[[1.76405235 0.97873798]
[0.40015721 2.2408932 ]]
data[:, :, 1] =
[[ 1.86755799 0.95008842]
[-0.97727788 -0.15135721]]
>>> max_val = T.collapse(fun=np.max)
>>> min_val = T.collapse(fun=np.min)
>>> print(f"Max value: {max_val}")
Max value: 2.240893199201458
>>> print(f"Min value: {min_val}")
Min value: -0.977277879876411
"""
if self.data.size == 0:
return np.array([], order=self.order)
Expand Down
Loading