@@ -357,17 +357,18 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm
357357
358358 m_sizes[targetDimensionIndex] = std::accumulate (sizeFoldBegin, sizeFoldEnd, 1u , std::multiplies<uint32_t >());
359359
360- // Update strides too.
360+ // Update strides too (right alignment has no extra work) .
361361 if (alignment == TensorAxis::LeftAligned)
362362 {
363363 m_strides[targetDimensionIndex] = m_strides[oldDimensionCount - 1 ]; // Migrate the last stride to its new position.
364364 }
365365 // Ensure the target stride is at least 1, not 0, in case a dimension of size 1 was folded that had a stride
366- // of 0 (which might happen because a stride of 0 for dimension of size 1 is ignorable).
366+ // of 0 (which might happen because a stride of 0 for dimension of size 1 is ignorable), and other dimensions
367+ // were folded into the target too.
367368 m_strides[targetDimensionIndex] = std::max (m_strides[targetDimensionIndex], 1u );
368369 }
369370
370- // alignment == TensorAxis::LeftAligned is the easy case (just truncate the end).
371+ // Left alignment is the easy case (just truncate the end).
371372 // Right alignment needs more work, shifting values over.
372373 if (alignment == TensorAxis::RightAligned)
373374 {
@@ -376,6 +377,7 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm
376377 memmove (&m_sizes[fillCount], &m_sizes[oldDimensionCount - moveCount], sizeof (m_sizes[0 ]) * moveCount);
377378 memmove (&m_strides[fillCount], &m_strides[oldDimensionCount - moveCount], sizeof (m_strides[0 ]) * moveCount);
378379 }
380+
379381 // For any new dimensions, insert leading/trailing 1's for sizes and 0's for strides.
380382 if (fillCount > 0 )
381383 {
0 commit comments