@@ -290,11 +290,20 @@ void TensorDesc::ForceUnsignedDataType()
290290}
291291
292292// Add additional padding 1's to ensure the count is at least that large.
293- void TensorDesc::EnsureDimensionCount (uint32_t newDimensionCount , TensorAxis alignment)
293+ void TensorDesc::EnsureMinimumDimensionCount (uint32_t minimumDimensionCount , TensorAxis alignment)
294294{
295- if (m_bufferTensorDesc.DimensionCount < newDimensionCount )
295+ if (m_bufferTensorDesc.DimensionCount < minimumDimensionCount )
296296 {
297- SetDimensionCount (newDimensionCount, alignment);
297+ SetDimensionCount (minimumDimensionCount, alignment);
298+ }
299+ }
300+
301+ // Ensure the dimension count is less than or equal to the limit.
302+ void TensorDesc::EnsureMaximumDimensionCount (uint32_t maximumDimensionCount, TensorAxis alignment)
303+ {
304+ if (m_bufferTensorDesc.DimensionCount > maximumDimensionCount)
305+ {
306+ SetDimensionCount (maximumDimensionCount, alignment);
298307 }
299308}
300309
@@ -313,7 +322,52 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm
313322 int32_t fillOffset = oldDimensionCount;
314323 int32_t fillCount = std::max (0 , difference);
315324
316- // alignment == TensorAxis::LeftAligned is the easy case.
325+ // If shrinking the rank, fold dimensions into the first/last dimension.
326+ // e.g. Folding 4D dimensions [2,3,4,5] to 3D right-aligned yield [6,4,5]
327+ // e.g. 6D dimensions [2,3,4,5,6,7] to 3D left-aligned yield [1,2,840]
328+ if (difference < 0 && newDimensionCount > 0 )
329+ {
330+ uint32_t dimensionCountRemoved = -difference;
331+ uint32_t dimensionCountFolded = dimensionCountRemoved + 1 ; // If 2 dimensions are removed, then 3 dimensions are folded into one.
332+ uint32_t targetDimensionIndex;
333+ uint32_t firstFoldedDimensionIndex;
334+
335+ // Determine the range to fold and which dimension to fold them into.
336+ if (alignment == TensorAxis::RightAligned)
337+ {
338+ targetDimensionIndex = dimensionCountRemoved; // Fold extra dimensions into the first dimension of the new size.
339+ firstFoldedDimensionIndex = 0 ;
340+ }
341+ else // alignment == TensorAxis::LeftAligned
342+ {
343+ targetDimensionIndex = newDimensionCount - 1 ; // Fold extra dimensions into the last dimension of the new size.
344+ firstFoldedDimensionIndex = targetDimensionIndex;
345+ }
346+ auto sizeFoldBegin = &m_sizes[firstFoldedDimensionIndex];
347+ auto sizeFoldEnd = &m_sizes[firstFoldedDimensionIndex + dimensionCountFolded];
348+
349+ // Ensure no stride broadcasting is lost during the fold, which would silently give incorrect results.
350+ ML_CHECK_VALID_ARGUMENT (
351+ m_bufferTensorDesc.Strides == nullptr ||
352+ !HasBroadcastedDimensions (
353+ { sizeFoldBegin, sizeFoldEnd },
354+ { &m_strides[firstFoldedDimensionIndex], dimensionCountFolded }
355+ )
356+ );
357+
358+ m_sizes[targetDimensionIndex] = std::accumulate (sizeFoldBegin, sizeFoldEnd, 1u , std::multiplies<uint32_t >());
359+
360+ // Update strides too.
361+ if (alignment == TensorAxis::LeftAligned)
362+ {
363+ m_strides[targetDimensionIndex] = m_strides[oldDimensionCount - 1 ]; // Migrate the last stride to its new position.
364+ }
365+ // Ensure the target stride is at least 1, not 0, in case a dimension of size 1 was folded that had a stride
366+ // of 0 (which might happen because a stride of 0 for dimension of size 1 is ignorable).
367+ m_strides[targetDimensionIndex] = std::max (m_strides[targetDimensionIndex], 1u );
368+ }
369+
370+ // alignment == TensorAxis::LeftAligned is the easy case (just truncate the end).
317371 // Right alignment needs more work, shifting values over.
318372 if (alignment == TensorAxis::RightAligned)
319373 {
@@ -322,6 +376,7 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm
322376 memmove (&m_sizes[fillCount], &m_sizes[oldDimensionCount - moveCount], sizeof (m_sizes[0 ]) * moveCount);
323377 memmove (&m_strides[fillCount], &m_strides[oldDimensionCount - moveCount], sizeof (m_strides[0 ]) * moveCount);
324378 }
379+ // For any new dimensions, insert leading/trailing 1's for sizes and 0's for strides.
325380 if (fillCount > 0 )
326381 {
327382 std::fill (&m_sizes[fillOffset], &m_sizes[fillOffset] + fillCount, 1u );
@@ -375,3 +430,30 @@ void TensorDesc::EnsureStridesExist() noexcept
375430 GetDescendingPackedStrides ({m_sizes, m_bufferTensorDesc.DimensionCount }, {m_strides, m_bufferTensorDesc.DimensionCount });
376431 m_bufferTensorDesc.Strides = m_strides;
377432}
433+
434+ bool TensorDesc::HasBroadcastedDimensions (
435+ gsl::span<const uint32_t > dimensions,
436+ gsl::span<const uint32_t > strides
437+ ) noexcept
438+ {
439+ assert (dimensions.size () == strides.size ());
440+ for (uint32_t i = 0 ; i < dimensions.size (); ++i)
441+ {
442+ // Note logical dimensions of size 1 (even when stride is 0) are not considered broadcasted.
443+ if (strides[i] == 0 && dimensions[i] != 1 )
444+ {
445+ return true ;
446+ }
447+ }
448+ return false ;
449+ }
450+
451+ bool TensorDesc::HasBroadcastedDimensions () const noexcept
452+ {
453+ return IsValid ()
454+ && m_bufferTensorDesc.Strides != nullptr
455+ && HasBroadcastedDimensions (
456+ { m_sizes, m_bufferTensorDesc.DimensionCount },
457+ { m_strides, m_bufferTensorDesc.DimensionCount }
458+ );
459+ }
0 commit comments