Skip to content

Commit cef17ad

Browse files
committed
Restrict DML EP MatMul to 4D
1 parent d4d419f commit cef17ad

File tree

4 files changed

+99
-6
lines changed

4 files changed

+99
-6
lines changed

onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper
368368
}
369369
}
370370
tensorDesc.SetDimensionsAndStrides(newSizes, newStrides);
371-
tensorDesc.EnsureDimensionCount(1, TensorAxis::RightAligned);
371+
tensorDesc.EnsureMinimumDimensionCount(1, TensorAxis::RightAligned);
372372
}
373373

374374
// Reproject a tensor to the given axis arrangement.

onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMul.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ class DmlOperatorMatMul : public DmlOperator
3232
// Initialize the output description while overriding the shape
3333
m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape);
3434

35+
// DirectML only supports ranks up to 4D for GEMM, and so leading dimensions must be clamped.
36+
m_inputTensorDescs[0].EnsureMaximumDimensionCount(4, TensorAxis::RightAligned);
37+
m_inputTensorDescs[1].EnsureMaximumDimensionCount(4, TensorAxis::RightAligned);
38+
m_outputTensorDescs[0].EnsureMaximumDimensionCount(4, TensorAxis::RightAligned);
39+
3540
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
3641
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
3742

onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,11 +290,20 @@ void TensorDesc::ForceUnsignedDataType()
290290
}
291291

292292
// Add additional padding 1's to ensure the count is at least that large.
293-
void TensorDesc::EnsureDimensionCount(uint32_t newDimensionCount, TensorAxis alignment)
293+
void TensorDesc::EnsureMinimumDimensionCount(uint32_t minimumDimensionCount, TensorAxis alignment)
294294
{
295-
if (m_bufferTensorDesc.DimensionCount < newDimensionCount)
295+
if (m_bufferTensorDesc.DimensionCount < minimumDimensionCount)
296296
{
297-
SetDimensionCount(newDimensionCount, alignment);
297+
SetDimensionCount(minimumDimensionCount, alignment);
298+
}
299+
}
300+
301+
// Ensure the dimension count is less than or equal to the limit.
302+
void TensorDesc::EnsureMaximumDimensionCount(uint32_t maximumDimensionCount, TensorAxis alignment)
303+
{
304+
if (m_bufferTensorDesc.DimensionCount > maximumDimensionCount)
305+
{
306+
SetDimensionCount(maximumDimensionCount, alignment);
298307
}
299308
}
300309

@@ -313,7 +322,52 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm
313322
int32_t fillOffset = oldDimensionCount;
314323
int32_t fillCount = std::max(0, difference);
315324

316-
// alignment == TensorAxis::LeftAligned is the easy case.
325+
// If shrinking the rank, fold dimensions into the first/last dimension.
326+
// e.g. Folding 4D dimensions [2,3,4,5] to 3D right-aligned yield [6,4,5]
327+
// e.g. 6D dimensions [2,3,4,5,6,7] to 3D left-aligned yield [1,2,840]
328+
if (difference < 0 && newDimensionCount > 0)
329+
{
330+
uint32_t dimensionCountRemoved = -difference;
331+
uint32_t dimensionCountFolded = dimensionCountRemoved + 1; // If 2 dimensions are removed, then 3 dimensions are folded into one.
332+
uint32_t targetDimensionIndex;
333+
uint32_t firstFoldedDimensionIndex;
334+
335+
// Determine the range to fold and which dimension to fold them into.
336+
if (alignment == TensorAxis::RightAligned)
337+
{
338+
targetDimensionIndex = dimensionCountRemoved; // Fold extra dimensions into the first dimension of the new size.
339+
firstFoldedDimensionIndex = 0;
340+
}
341+
else // alignment == TensorAxis::LeftAligned
342+
{
343+
targetDimensionIndex = newDimensionCount - 1; // Fold extra dimensions into the last dimension of the new size.
344+
firstFoldedDimensionIndex = targetDimensionIndex;
345+
}
346+
auto sizeFoldBegin = &m_sizes[firstFoldedDimensionIndex];
347+
auto sizeFoldEnd = &m_sizes[firstFoldedDimensionIndex + dimensionCountFolded];
348+
349+
// Ensure no stride broadcasting is lost during the fold, which would silently give incorrect results.
350+
ML_CHECK_VALID_ARGUMENT(
351+
m_bufferTensorDesc.Strides == nullptr ||
352+
!HasBroadcastedDimensions(
353+
{ sizeFoldBegin, sizeFoldEnd },
354+
{ &m_strides[firstFoldedDimensionIndex], dimensionCountFolded }
355+
)
356+
);
357+
358+
m_sizes[targetDimensionIndex] = std::accumulate(sizeFoldBegin, sizeFoldEnd, 1u, std::multiplies<uint32_t>());
359+
360+
// Update strides too.
361+
if (alignment == TensorAxis::LeftAligned)
362+
{
363+
m_strides[targetDimensionIndex] = m_strides[oldDimensionCount - 1]; // Migrate the last stride to its new position.
364+
}
365+
// Ensure the target stride is at least 1, not 0, in case a dimension of size 1 was folded that had a stride
366+
// of 0 (which might happen because a stride of 0 for dimension of size 1 is ignorable).
367+
m_strides[targetDimensionIndex] = std::max(m_strides[targetDimensionIndex], 1u);
368+
}
369+
370+
// alignment == TensorAxis::LeftAligned is the easy case (just truncate the end).
317371
// Right alignment needs more work, shifting values over.
318372
if (alignment == TensorAxis::RightAligned)
319373
{
@@ -322,6 +376,7 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm
322376
memmove(&m_sizes[fillCount], &m_sizes[oldDimensionCount - moveCount], sizeof(m_sizes[0]) * moveCount);
323377
memmove(&m_strides[fillCount], &m_strides[oldDimensionCount - moveCount], sizeof(m_strides[0]) * moveCount);
324378
}
379+
// For any new dimensions, insert leading/trailing 1's for sizes and 0's for strides.
325380
if (fillCount > 0)
326381
{
327382
std::fill(&m_sizes[fillOffset], &m_sizes[fillOffset] + fillCount, 1u);
@@ -375,3 +430,30 @@ void TensorDesc::EnsureStridesExist() noexcept
375430
GetDescendingPackedStrides({m_sizes, m_bufferTensorDesc.DimensionCount}, {m_strides, m_bufferTensorDesc.DimensionCount});
376431
m_bufferTensorDesc.Strides = m_strides;
377432
}
433+
434+
bool TensorDesc::HasBroadcastedDimensions(
435+
gsl::span<const uint32_t> dimensions,
436+
gsl::span<const uint32_t> strides
437+
) noexcept
438+
{
439+
assert(dimensions.size() == strides.size());
440+
for (uint32_t i = 0; i < dimensions.size(); ++i)
441+
{
442+
// Note logical dimensions of size 1 (even when stride is 0) are not considered broadcasted.
443+
if (strides[i] == 0 && dimensions[i] != 1)
444+
{
445+
return true;
446+
}
447+
}
448+
return false;
449+
}
450+
451+
bool TensorDesc::HasBroadcastedDimensions() const noexcept
452+
{
453+
return IsValid()
454+
&& m_bufferTensorDesc.Strides != nullptr
455+
&& HasBroadcastedDimensions(
456+
{ m_sizes, m_bufferTensorDesc.DimensionCount },
457+
{ m_strides, m_bufferTensorDesc.DimensionCount }
458+
);
459+
}

onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,18 @@ namespace Dml
4141
inline bool IsValid() const noexcept { return m_tensorType != DML_TENSOR_TYPE_INVALID; }
4242
inline uint32_t GetDimensionCount() const { return m_bufferTensorDesc.DimensionCount; }
4343
void SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignment);
44-
void EnsureDimensionCount(uint32_t newDimensionCount, TensorAxis alignment);
44+
void EnsureMinimumDimensionCount(uint32_t newDimensionCount, TensorAxis alignment);
45+
void EnsureMaximumDimensionCount(uint32_t maximumDimensionCount, TensorAxis alignment);
4546

4647
gsl::span<const uint32_t> GetSizes() const noexcept { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; }
4748
gsl::span<const uint32_t> GetStrides() const noexcept;
4849
void SetStrides(gsl::span<const uint32_t> strides);
4950
void EnsureStridesExist() noexcept;
51+
bool HasBroadcastedDimensions() const noexcept;
52+
static bool HasBroadcastedDimensions(
53+
gsl::span<const uint32_t> dimensions,
54+
gsl::span<const uint32_t> strides
55+
) noexcept;
5056

5157
void SetDimensionsAndStrides(gsl::span<const uint32_t> sizes, gsl::span<const uint32_t> strides);
5258

0 commit comments

Comments
 (0)