Skip to content

Commit cd3a91c

Browse files
committed
[SM6.10][Spec update] Update vector sizes on Multiply* functions to match matrix x colunm-vector multiplication
1 parent 9ae1619 commit cd3a91c

1 file changed

Lines changed: 50 additions & 50 deletions

File tree

proposals/0035-linalg-matrix.md

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -232,42 +232,42 @@ Matrix<CompTy, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup> Multiply(
232232

233233
template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
234234
ComponentEnum MatrixDT>
235-
vector<OutputElTy, K>
235+
vector<OutputElTy, M>
236236
Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
237-
vector<InputElTy, M> Vec);
237+
vector<InputElTy, K> Vec);
238238

239239
template <typename OutputElTy, typename InputElTy, typename BiasElTy,
240240
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
241-
vector<OutputElTy, K>
241+
vector<OutputElTy, M>
242242
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
243-
vector<InputElTy, M>, vector<BiasElTy, K> Vec);
243+
vector<InputElTy, K>, vector<BiasElTy, M> Vec);
244244

245245
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
246-
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM, SIZE_TYPE K,
246+
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
247247
ComponentEnum MatrixDT>
248248
typename hlsl::enable_if<
249-
InterpretedVector<InputElTy, VecM, InputInterp>::Size == M,
250-
vector<OutputElTy, K> >::type
249+
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
250+
vector<OutputElTy, M> >::type
251251
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
252-
InterpretedVector<InputElTy, VecM, InputInterp> InterpVec,
253-
vector<BiasElTy, K> Bias);
252+
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
253+
vector<BiasElTy, M> Bias);
254254

255255
template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy,
256256
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
257257
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value,
258-
vector<OutputElTy, K> >::type
258+
vector<OutputElTy, M> >::type
259259
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
260-
vector<InputElTy, M> Vec, VectorRef<BiasElTy, K> BiasRef);
260+
vector<InputElTy, K> Vec, VectorRef<BiasElTy, M> BiasRef);
261261

262262
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
263-
ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM, SIZE_TYPE K,
263+
ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
264264
ComponentEnum MatrixDT>
265265
typename hlsl::enable_if<
266-
InterpretedVector<InputElTy, VecM, InputInterp>::Size == M,
267-
vector<OutputElTy, K> >::type
266+
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
267+
vector<OutputElTy, M> >::type
268268
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
269-
InterpretedVector<InputElTy, VecM, InputInterp> InterpVec,
270-
VectorRef<BiasElTy, K> BiasRef);
269+
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
270+
VectorRef<BiasElTy, M> BiasRef);
271271

272272
// Outer product functions
273273
template <ComponentEnum OutTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE N>
@@ -1071,16 +1071,16 @@ type and takes arguments with potentially mismatched element types.
10711071
``` c++
10721072
template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
10731073
ComponentEnum MatrixDT>
1074-
vector<OutputElTy, K>
1074+
vector<OutputElTy, M>
10751075
linalg::Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
1076-
vector<InputElTy, M> Vec);
1076+
vector<InputElTy, K> Vec);
10771077
```
10781078

10791079
Requires `Thread` scope matrix input, may be called from divergent control flow.
10801080

1081-
The `linalg::Multiply` function has an overload that takes an `M`-element vector
1082-
and an MxK `A` matrix with `Thread` scope. The function returns a `K`-element
1083-
vector.
1081+
The `linalg::Multiply` function has an takes an MxK `A` matrix with `Thread`
1082+
scope, an `K`-element vector `Vec`. The operation multiplies the matrix by the
1083+
`K`-element vector `Vec` producing a result `M`-element vector.
10841084

10851085
#### linalg::OuterProduct(vector, vector)
10861086

@@ -1101,23 +1101,23 @@ parameter for the output matrix element type.
11011101
``` c++
11021102
template <typename OutputElTy, typename InputElTy, typename BiasElTy,
11031103
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
1104-
vector<OutputElTy, K>
1104+
vector<OutputElTy, M>
11051105
linalg::MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
1106-
vector<InputElTy, M> Vec, vector<BiasElTy, K> Bias);
1106+
vector<InputElTy, K> Vec, vector<BiasElTy, M> Bias);
11071107
```
11081108

11091109
Requires `Thread` scope matrix input, may be called from divergent control flow.
11101110

11111111
The `linalg::MultiplyAdd` function has an overload that takes an MxK `A` matrix
1112-
with `Thread` scope, an `M`-element vector, and a `K`-element vector. The operation
1113-
multiplies the `M`-element vector by the matrix then adds the `K`-element vector
1114-
producing a result `K`-element vector.
1112+
with `Thread` scope, an `K`-element vector `Vec`, and a `M`-element vector
1113+
`Bias`. The operation multiplies the matrix by the `K`-element vector `Vec` and
1114+
then adds the `M`-element vector `Bias` producing a result `M`-element vector.
11151115

11161116
Either vector may be a native vector or an `InterpretedVector` which combines a
1117-
packed element vector with an interpretation type. The `K`-element vector may
1118-
also be a `VectorRef` which refers to a vector in memory. Using the `VectorRef`
1119-
overload makes it easier for the backend compiler to optimize the bias vector
1120-
loads with the ALU operations.
1117+
packed element vector with an interpretation type. The `M`-element vector `Bias`
1118+
may also be a `VectorRef` which refers to a vector in memory. Using the
1119+
`VectorRef` overload makes it easier for the backend compiler to optimize the
1120+
bias vector loads with the ALU operations.
11211121

11221122
### DXIL Types
11231123

@@ -1471,8 +1471,8 @@ declare <[NUMo] x [TYo]> @dx.op.linAlgMatVecMul.v[NUMo][TYo].[MatTy].v[NUMi][TYi
14711471
)
14721472
```
14731473

1474-
This operation implements a row-vector multiplication against an `A` matrix of
1475-
`Thread` scope.
1474+
This operation implements a column-vector multiplication against an `A` matrix
1475+
of `Thread` scope.
14761476

14771477
Validation will enforce that:
14781478
* The input vector length matches the `K` matrix dimension
@@ -1495,8 +1495,8 @@ declare <[NUMo] x [TYo]> @dx.op.linAlgMatVecMulAdd.v[NUMo][TYo].[MatTy].v[NUMi][
14951495
)
14961496
```
14971497

1498-
This operation implements a row-vector multiplication against an `A` matrix of
1499-
`Thread` scope with a bias vector added to the result.
1498+
This operation implements a column-vector multiplication against an `A` matrix
1499+
of `Thread` scope with a bias vector added to the result.
15001500

15011501
Validation will enforce that:
15021502
* The input vector length matches the `K` matrix dimension
@@ -2071,42 +2071,42 @@ Matrix<CompTy, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup> Multiply(
20712071

20722072
template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
20732073
ComponentEnum MatrixDT>
2074-
vector<OutputElTy, K>
2074+
vector<OutputElTy, M>
20752075
Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
2076-
vector<InputElTy, M> Vec);
2076+
vector<InputElTy, K> Vec);
20772077

20782078
template <typename OutputElTy, typename InputElTy, typename BiasElTy,
20792079
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
2080-
vector<OutputElTy, K>
2080+
vector<OutputElTy, M>
20812081
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
2082-
vector<InputElTy, M>, vector<BiasElTy, K> Vec);
2082+
vector<InputElTy, K> Vec, vector<BiasElTy, M> Vec);
20832083

20842084
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
2085-
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM, SIZE_TYPE K,
2085+
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
20862086
ComponentEnum MatrixDT>
20872087
typename hlsl::enable_if<
2088-
InterpretedVector<InputElTy, VecM, InputInterp>::Size == M,
2089-
vector<OutputElTy, K> >::type
2088+
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
2089+
vector<OutputElTy, M> >::type
20902090
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
2091-
InterpretedVector<InputElTy, VecM, InputInterp> InterpVec,
2092-
vector<BiasElTy, K> Bias);
2091+
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
2092+
vector<BiasElTy, M> Bias);
20932093

20942094
template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy,
20952095
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
20962096
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value,
2097-
vector<OutputElTy, K> >::type
2097+
vector<OutputElTy, M> >::type
20982098
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
2099-
vector<InputElTy, M> Vec, VectorRef<BiasElTy, K> BiasRef);
2099+
vector<InputElTy, K> Vec, VectorRef<BiasElTy, M> BiasRef);
21002100

21012101
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
2102-
ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM, SIZE_TYPE K,
2102+
ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
21032103
ComponentEnum MatrixDT>
21042104
typename hlsl::enable_if<
2105-
InterpretedVector<InputElTy, VecM, InputInterp>::Size == M,
2106-
vector<OutputElTy, K> >::type
2105+
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
2106+
vector<OutputElTy, M> >::type
21072107
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
2108-
InterpretedVector<InputElTy, VecM, InputInterp> InterpVec,
2109-
VectorRef<BiasElTy, K> BiasRef);
2108+
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
2109+
VectorRef<BiasElTy, M> BiasRef);
21102110

21112111
// Outer product functions
21122112
template <ComponentEnum OutTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE N>

0 commit comments

Comments
 (0)