@@ -232,42 +232,42 @@ Matrix<CompTy, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup> Multiply(
232232
233233template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
234234 ComponentEnum MatrixDT>
235- vector<OutputElTy, K >
235+ vector<OutputElTy, M >
236236Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
237- vector<InputElTy, M > Vec);
237+ vector<InputElTy, K > Vec);
238238
239239template <typename OutputElTy, typename InputElTy, typename BiasElTy,
240240 SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
241- vector<OutputElTy, K >
241+ vector<OutputElTy, M >
242242MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
243- vector<InputElTy, M >, vector<BiasElTy, K > Vec);
243+ vector<InputElTy, K >, vector<BiasElTy, M > Vec);
244244
245245template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
246- typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM , SIZE_TYPE K,
246+ typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK , SIZE_TYPE K,
247247 ComponentEnum MatrixDT>
248248typename hlsl::enable_if<
249- InterpretedVector<InputElTy, VecM , InputInterp>::Size == M ,
250- vector<OutputElTy, K > >::type
249+ InterpretedVector<InputElTy, VecK , InputInterp>::Size == K ,
250+ vector<OutputElTy, M > >::type
251251MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
252- InterpretedVector<InputElTy, VecM , InputInterp> InterpVec,
253- vector<BiasElTy, K > Bias);
252+ InterpretedVector<InputElTy, VecK , InputInterp> InterpVec,
253+ vector<BiasElTy, M > Bias);
254254
255255template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy,
256256 SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
257257typename hlsl::enable_if< hlsl::is_arithmetic<InputElTy > ::value,
258- vector<OutputElTy, K > >::type
258+ vector<OutputElTy, M > >::type
259259MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
260- vector<InputElTy, M > Vec, VectorRef<BiasElTy, K > BiasRef);
260+ vector<InputElTy, K > Vec, VectorRef<BiasElTy, M > BiasRef);
261261
262262template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
263- ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM , SIZE_TYPE K,
263+ ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK , SIZE_TYPE K,
264264 ComponentEnum MatrixDT>
265265typename hlsl::enable_if<
266- InterpretedVector<InputElTy, VecM , InputInterp>::Size == M ,
267- vector<OutputElTy, K > >::type
266+ InterpretedVector<InputElTy, VecK , InputInterp>::Size == K ,
267+ vector<OutputElTy, M > >::type
268268MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
269- InterpretedVector<InputElTy, VecM , InputInterp> InterpVec,
270- VectorRef<BiasElTy, K > BiasRef);
269+ InterpretedVector<InputElTy, VecK , InputInterp> InterpVec,
270+ VectorRef<BiasElTy, M > BiasRef);
271271
272272// Outer product functions
273273template <ComponentEnum OutTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE N>
@@ -1071,16 +1071,16 @@ type and takes arguments with potentially mismatched element types.
10711071``` c++
10721072template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
10731073 ComponentEnum MatrixDT>
1074- vector<OutputElTy, K >
1074+ vector<OutputElTy, M >
10751075linalg::Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
1076- vector<InputElTy, M > Vec);
1076+ vector<InputElTy, K > Vec);
10771077```
10781078
10791079Requires ` Thread ` scope matrix input, may be called from divergent control flow.
10801080
1081- The ` linalg::Multiply ` function has an overload that takes an ` M ` -element vector
1082- and an MxK ` A ` matrix with ` Thread ` scope . The function returns a ` K ` -element
1083- vector.
1081+ The ` linalg::Multiply ` function has an takes an MxK ` A ` matrix with ` Thread `
1082+ scope, an ` K ` -element vector ` Vec ` . The operation multiplies the matrix by the
1083+ ` K ` -element vector ` Vec ` producing a result ` M ` -element vector.
10841084
10851085#### linalg::OuterProduct(vector, vector)
10861086
@@ -1101,23 +1101,23 @@ parameter for the output matrix element type.
11011101``` c++
11021102template <typename OutputElTy, typename InputElTy, typename BiasElTy,
11031103 SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
1104- vector<OutputElTy, K >
1104+ vector<OutputElTy, M >
11051105linalg::MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
1106- vector<InputElTy, M > Vec, vector<BiasElTy, K > Bias);
1106+ vector<InputElTy, K > Vec, vector<BiasElTy, M > Bias);
11071107```
11081108
11091109Requires ` Thread ` scope matrix input, may be called from divergent control flow.
11101110
11111111The ` linalg::MultiplyAdd ` function has an overload that takes an MxK ` A ` matrix
1112- with ` Thread ` scope, an ` M ` -element vector, and a ` K ` -element vector. The operation
1113- multiplies the ` M ` -element vector by the matrix then adds the ` K ` -element vector
1114- producing a result ` K ` -element vector.
1112+ with ` Thread ` scope, an ` K ` -element vector ` Vec ` , and a ` M ` -element vector
1113+ ` Bias ` . The operation multiplies the matrix by the ` K ` -element vector ` Vec ` and
1114+ then adds the ` M ` -element vector ` Bias ` producing a result ` M ` -element vector.
11151115
11161116Either vector may be a native vector or an ` InterpretedVector ` which combines a
1117- packed element vector with an interpretation type. The ` K ` -element vector may
1118- also be a ` VectorRef ` which refers to a vector in memory. Using the ` VectorRef `
1119- overload makes it easier for the backend compiler to optimize the bias vector
1120- loads with the ALU operations.
1117+ packed element vector with an interpretation type. The ` M ` -element vector ` Bias `
1118+ may also be a ` VectorRef ` which refers to a vector in memory. Using the
1119+ ` VectorRef ` overload makes it easier for the backend compiler to optimize the
1120+ bias vector loads with the ALU operations.
11211121
11221122### DXIL Types
11231123
@@ -1471,8 +1471,8 @@ declare <[NUMo] x [TYo]> @dx.op.linAlgMatVecMul.v[NUMo][TYo].[MatTy].v[NUMi][TYi
14711471)
14721472```
14731473
1474- This operation implements a row -vector multiplication against an ` A ` matrix of
1475- ` Thread ` scope.
1474+ This operation implements a column -vector multiplication against an ` A ` matrix
1475+ of ` Thread ` scope.
14761476
14771477Validation will enforce that:
14781478* The input vector length matches the ` K ` matrix dimension
@@ -1495,8 +1495,8 @@ declare <[NUMo] x [TYo]> @dx.op.linAlgMatVecMulAdd.v[NUMo][TYo].[MatTy].v[NUMi][
14951495)
14961496```
14971497
1498- This operation implements a row -vector multiplication against an ` A ` matrix of
1499- ` Thread ` scope with a bias vector added to the result.
1498+ This operation implements a column -vector multiplication against an ` A ` matrix
1499+ of ` Thread ` scope with a bias vector added to the result.
15001500
15011501Validation will enforce that:
15021502* The input vector length matches the ` K ` matrix dimension
@@ -2071,42 +2071,42 @@ Matrix<CompTy, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup> Multiply(
20712071
20722072template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
20732073 ComponentEnum MatrixDT>
2074- vector<OutputElTy, K >
2074+ vector<OutputElTy, M >
20752075Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
2076- vector<InputElTy, M > Vec);
2076+ vector<InputElTy, K > Vec);
20772077
20782078template <typename OutputElTy, typename InputElTy, typename BiasElTy,
20792079 SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
2080- vector<OutputElTy, K >
2080+ vector<OutputElTy, M >
20812081MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
2082- vector<InputElTy, M> , vector<BiasElTy, K > Vec);
2082+ vector<InputElTy, K> Vec , vector<BiasElTy, M > Vec);
20832083
20842084template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
2085- typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM , SIZE_TYPE K,
2085+ typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK , SIZE_TYPE K,
20862086 ComponentEnum MatrixDT>
20872087typename hlsl::enable_if<
2088- InterpretedVector<InputElTy, VecM , InputInterp>::Size == M ,
2089- vector<OutputElTy, K > >::type
2088+ InterpretedVector<InputElTy, VecK , InputInterp>::Size == K ,
2089+ vector<OutputElTy, M > >::type
20902090MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
2091- InterpretedVector<InputElTy, VecM , InputInterp> InterpVec,
2092- vector<BiasElTy, K > Bias);
2091+ InterpretedVector<InputElTy, VecK , InputInterp> InterpVec,
2092+ vector<BiasElTy, M > Bias);
20932093
20942094template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy,
20952095 SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
20962096typename hlsl::enable_if< hlsl::is_arithmetic<InputElTy > ::value,
2097- vector<OutputElTy, K > >::type
2097+ vector<OutputElTy, M > >::type
20982098MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
2099- vector<InputElTy, M > Vec, VectorRef<BiasElTy, K > BiasRef);
2099+ vector<InputElTy, K > Vec, VectorRef<BiasElTy, M > BiasRef);
21002100
21012101template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
2102- ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM , SIZE_TYPE K,
2102+ ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK , SIZE_TYPE K,
21032103 ComponentEnum MatrixDT>
21042104typename hlsl::enable_if<
2105- InterpretedVector<InputElTy, VecM , InputInterp>::Size == M ,
2106- vector<OutputElTy, K > >::type
2105+ InterpretedVector<InputElTy, VecK , InputInterp>::Size == K ,
2106+ vector<OutputElTy, M > >::type
21072107MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
2108- InterpretedVector<InputElTy, VecM , InputInterp> InterpVec,
2109- VectorRef<BiasElTy, K > BiasRef);
2108+ InterpretedVector<InputElTy, VecK , InputInterp> InterpVec,
2109+ VectorRef<BiasElTy, M > BiasRef);
21102110
21112111// Outer product functions
21122112template <ComponentEnum OutTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE N>
0 commit comments