Skip to content

Commit 0d6b554

Browse files
authored
LinAlg packed type clarifications/protections (#845)
Packed types are only intended to be used as inputs to makeinterpreted vector after which they will be used only through that wrapper. Using them directly in the APIs introduces complicated dimension checking and isn't intended to be supported. Mostly sprinkling heavy usage of is_arithmetic checks for fall native vector inputs. Some of these were added already for MultiplyAdd, but left out elsewhere. Removes the restriction on using packed types for groupshared load/store/accumulate operations. Makes a few incidental typo corrections here and there. Removes mention of matrices of packed types as matrices cannot have packed types, though they may have types that can only be represented as packed when converted and wrapped in interpreted vectors Fixes #823
1 parent e1201a4 commit 0d6b554

1 file changed

Lines changed: 32 additions & 24 deletions

File tree

proposals/0035-linalg-matrix.md

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,7 @@ class Matrix {
111111
MatrixLayoutEnum Layout, uint Align = 128);
112112

113113
template <typename T, SIZE_TYPE Size>
114-
static typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
115-
(M * N / ElementsPerScalar <= Size),
114+
static typename hlsl::enable_if<M * N / ElementsPerScalar <= Size,
116115
Matrix>::type
117116
Load(/*groupshared*/ T Arr[Size], uint StartIdx, uint Stride,
118117
MatrixLayoutEnum Layout);
@@ -141,8 +140,7 @@ class Matrix {
141140
MatrixLayoutEnum Layout, uint Align = 128);
142141

143142
template <typename T, SIZE_TYPE Size>
144-
typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
145-
(M * N / ElementsPerScalar <= Size),
143+
typename hlsl::enable_if<M * N / ElementsPerScalar <= Size,
146144
void>::type
147145
Store(/*groupshared*/ T Arr[Size], uint StartIdx, uint Stride,
148146
MatrixLayoutEnum Layout);
@@ -158,8 +156,7 @@ class Matrix {
158156
template <typename T, MatrixUseEnum UseLocal = Use,
159157
MatrixScopeEnum ScopeLocal = Scope, SIZE_TYPE Size>
160158
typename hlsl::enable_if<
161-
hlsl::is_arithmetic<T>::value && Use == MatrixUse::Accumulator &&
162-
UseLocal == Use && (M * N / ElementsPerScalar <= Size) &&
159+
UseLocal == Use && (M * N / ElementsPerScalar <= Size) &&
163160
Scope == MatrixScope::Wave && ScopeLocal == Scope,
164161
void>::type
165162
InterlockedAccumulate(/*groupshared*/ T Arr[Size], uint StartIdx, uint Stride,
@@ -232,21 +229,25 @@ Matrix<CompTy, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup> Multiply(
232229

233230
template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
234231
ComponentEnum MatrixDT>
235-
vector<OutputElTy, M>
232+
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value,
233+
vector<OutputElTy, M> >::type
236234
Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
237235
vector<InputElTy, K> Vec);
238236

239237
template <typename OutputElTy, typename InputElTy, typename BiasElTy,
240238
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
241-
vector<OutputElTy, M>
239+
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value &&
240+
hlsl::is_arithmetic<BiasElTy>::value,
241+
vector<OutputElTy, M> >::type
242242
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
243243
vector<InputElTy, K>, vector<BiasElTy, M> Vec);
244244

245245
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
246246
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
247247
ComponentEnum MatrixDT>
248248
typename hlsl::enable_if<
249-
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
249+
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K &&
250+
hlsl::is_arithmetic<BiasElTy>::value,
250251
vector<OutputElTy, M> >::type
251252
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
252253
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
@@ -270,7 +271,8 @@ MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
270271
VectorRef<BiasElTy, M> BiasRef);
271272

272273
template <ComponentEnum OutTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE N>
273-
Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::Thread>
274+
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value,
275+
Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::Thread> >::type
274276
OuterProduct(vector<InputElTy, M> VecA, vector<InputElTy, N> VecB);
275277

276278
template <typename InputElTy, SIZE_TYPE M>
@@ -529,9 +531,6 @@ DXIL validation.
529531
| Wave | [4,128] |
530532
| ThreadGroup | [1,1024] |
531533

532-
Sizes for matrices of packed data types are 4 times the valid size for a scalar
533-
element.
534-
535534
Not all hardware is required to support all possible dimensions for thread and
536535
wave scope matrices, or all possible element types. The shader compiler will
537536
encode the dimensions and input and output data types used by each shader in the
@@ -1016,7 +1015,7 @@ When accumulating to `RWByteAddressBuffer` objects, the accumulation is
10161015
performed on the component type of the matrix object. When accumulating to
10171016
`groupshared` memory, the matrix component data is converted to the target
10181017
arithmetic or packed data type before atomic arithmetic is performed. No
1019-
conversion is performed if the target aritmetic type matches the matrix
1018+
conversion is performed if the target arithmetic type matches the matrix
10201019
component type.
10211020
10221021
#### Matrix::MultiplyAccumulate(Matrix, Matrix)
@@ -1117,7 +1116,8 @@ type and takes arguments with potentially mismatched element types.
11171116
``` c++
11181117
template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
11191118
ComponentEnum MatrixDT>
1120-
vector<OutputElTy, M>
1119+
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value,
1120+
vector<OutputElTy, M> >::type
11211121
linalg::Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
11221122
vector<InputElTy, K> Vec);
11231123
```
@@ -1133,7 +1133,8 @@ matrix by the `K`-element vector `Vec` producing a result `M`-element vector.
11331133
```c++
11341134
template <ComponentType OutTy, typename InputElTy,
11351135
uint M, uint N>
1136-
Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::Thread>
1136+
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value,
1137+
Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::Thread> >::type
11371138
linalg::OuterProduct(vector<InputElTy, M> VecA, vector<InputElTy, N> VecB);
11381139
```
11391140
@@ -1147,7 +1148,9 @@ parameter for the output matrix element type.
11471148
``` c++
11481149
template <typename OutputElTy, typename InputElTy, typename BiasElTy,
11491150
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
1150-
vector<OutputElTy, M>
1151+
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value &&
1152+
hlsl::is_arithmetic<BiasElTy>::value,
1153+
vector<OutputElTy, M> >::type
11511154
linalg::MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
11521155
vector<InputElTy, K> Vec, vector<BiasElTy, M> Bias);
11531156
```
@@ -1579,9 +1582,9 @@ declare void @dx.op.linAlgMatrixAccumulateToDescriptor.[MatTy](
15791582
Accumulates a matrix to a RWByteAddressBuffer at a specified offset. This
15801583
operation is only available for matrices with `MatrixUse::Accumulator`. The
15811584
matrix data is added to the existing data in the buffer. The matrix component
1582-
data is converted to the target arithmetic or packed data type if the data types
1583-
do not match, then added to the existing data in memory. This operation must
1584-
observe [bounds checking behavior](#bounds-checking-behavior) described below.
1585+
data is added to the existing data in memory using the component type of the
1586+
matrix. This operation must observe
1587+
[bounds checking behavior](#bounds-checking-behavior) described below.
15851588

15861589
Validation rules will enforce that:
15871590
* `Layout` is `OuterProductOptimal` for matrix with `MatrixScope` of `Thread`
@@ -2150,21 +2153,25 @@ Matrix<CompTy, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup> Multiply(
21502153

21512154
template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
21522155
ComponentEnum MatrixDT>
2153-
vector<OutputElTy, M>
2156+
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value,
2157+
vector<OutputElTy, M> >::type
21542158
Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
21552159
vector<InputElTy, K> Vec);
21562160

21572161
template <typename OutputElTy, typename InputElTy, typename BiasElTy,
21582162
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
2159-
vector<OutputElTy, M>
2163+
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value &&
2164+
hlsl::is_arithmetic<BiasElTy>::value,
2165+
vector<OutputElTy, M> >::type
21602166
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
21612167
vector<InputElTy, K> Vec, vector<BiasElTy, M> Vec);
21622168

21632169
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
21642170
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
21652171
ComponentEnum MatrixDT>
21662172
typename hlsl::enable_if<
2167-
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
2173+
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K &&
2174+
hlsl::is_arithmetic<BiasElTy>::value,
21682175
vector<OutputElTy, M> >::type
21692176
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
21702177
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
@@ -2188,7 +2195,8 @@ MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
21882195
VectorRef<BiasElTy, M> BiasRef);
21892196

21902197
template <ComponentEnum OutTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE N>
2191-
Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::Thread>
2198+
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value,
2199+
Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::Thread> >::type
21922200
OuterProduct(vector<InputElTy, M> VecA, vector<InputElTy, N> VecB);
21932201

21942202
template <typename InputElTy, SIZE_TYPE M>

0 commit comments

Comments
 (0)