@@ -252,32 +252,69 @@ def collapse(
252252 fun : Callable [[np .ndarray ], Union [float , np .ndarray ]] = np .sum ,
253253 ) -> Union [float , np .ndarray , tensor ]:
254254 """
255- Collapse tensor along specified dimensions.
255+ Collapse tensor along specified dimensions using a function .
256256
257257 Parameters
258258 ----------
259- dims:
260- Dimensions to collapse.
261- fun:
262- Method used to collapse dimensions.
259+ dims: optional
260+ Dimensions to collapse (default: all) .
261+ fun: optional
262+ Method used to collapse dimensions (default: :meth:`numpy.sum`) .
263263
264264 Returns
265265 -------
266- Collapsed value .
266+ Scalar (if all dimensions collapsed) or tensor .
267267
268268 Examples
269269 --------
270- >>> T = ttb.tensor(np.ones((2, 2)))
271- >>> T.collapse()
272- 4.0
273- >>> T.collapse(np.array([0]))
274- tensor of shape (2,) with order F
275- data[:] =
276- [2. 2.]
277- >>> T.collapse(np.arange(T.ndims), sum)
278- 4.0
279- >>> T.collapse(np.arange(T.ndims), np.prod)
280- 1.0
270+ Sum all elements of tensor::
271+
272+ >>> T = ttb.tensor(np.ones((4,3,2),order='F'))
273+ >>> T.collapse()
274+ 24.0
275+
276+ Compute the sum for each mode-0 fiber (output is a tensor)::
277+
278+ >>> T.collapse(0)
279+ tensor of shape (3, 2) with order F
280+ data[:, :] =
281+ [[4. 4.]
282+ [4. 4.]
283+ [4. 4.]]
284+
285+ Compute the sum of the entries in each mode-0 slice (output is a tensor)::
286+
287+ >>> T.collapse([1, 2])
288+ tensor of shape (4,) with order F
289+ data[:] =
290+ [6. 6. 6. 6.]
291+
292+ Compute the max entry in each mode-2 slice (output is a tensor)::
293+
294+ >>> T.collapse([0, 1], np.max)
295+ tensor of shape (2,) with order F
296+ data[:] =
297+ [1. 1.]
298+
299+ Find the maximum and minimum values in a tensor::
300+
301+ >>> randn = lambda s : np.random.randn(np.prod(s))
302+ >>> np.random.seed(0) # reproducibility
303+ >>> T = ttb.tensor.from_function(randn, (2, 2, 2))
304+ >>> print(T)
305+ tensor of shape (2, 2, 2) with order F
306+ data[:, :, 0] =
307+ [[1.76405235 0.97873798]
308+ [0.40015721 2.2408932 ]]
309+ data[:, :, 1] =
310+ [[ 1.86755799 0.95008842]
311+ [-0.97727788 -0.15135721]]
312+ >>> max_val = T.collapse(fun=np.max)
313+ >>> min_val = T.collapse(fun=np.min)
314+ >>> print(f"Max value: {max_val}")
315+ Max value: 2.240893199201458
316+ >>> print(f"Min value: {min_val}")
317+ Min value: -0.977277879876411
281318 """
282319 if self .data .size == 0 :
283320 return np .array ([], order = self .order )
0 commit comments