Skip to content

Commit 5ce81c4

Browse files
committed
[0035] Fix dimensions of Cast with Transpose=true
We need to swap the M and N dimensions on the return type of the Cast when Transpose is true.
1 parent 9ae1619 commit 5ce81c4

1 file changed

Lines changed: 32 additions & 5 deletions

File tree

proposals/0035-linalg-matrix.md

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,9 @@ class Matrix {
9696

9797
template <ComponentEnum NewCompTy, MatrixUseEnum NewUse = Use,
9898
bool Transpose = false>
99-
Matrix<NewCompTy, M, N, NewUse, Scope> Cast();
99+
Matrix<NewCompTy, __detail::DimMN<M, N, Transpose>::M,
100+
__detail::DimMN<M, N, Transpose>::N, NewUse, Scope>
101+
Cast();
100102

101103
template <typename T>
102104
static typename hlsl::enable_if<hlsl::is_arithmetic<T>::value, Matrix>::type
@@ -719,6 +721,16 @@ template <ComponentEnum DstTy, ComponentEnum SrcTy, int SrcN> struct DstN {
719721
ComponentTypeTraits<DstTy>::ElementsPerScalar;
720722
};
721723
724+
template <SIZE_TYPE MVal, SIZE_TYPE NVal, bool Transposed> struct DimMN {
725+
static const SIZE_TYPE M = MVal;
726+
static const SIZE_TYPE N = NVal;
727+
};
728+
729+
template <SIZE_TYPE MVal, SIZE_TYPE NVal> struct DimMN<MVal, NVal, true> {
730+
static const SIZE_TYPE M = NVal;
731+
static const SIZE_TYPE N = MVal;
732+
};
733+
722734
} // namespace __detail
723735
```
724736

@@ -747,8 +759,11 @@ HLSL type casting rules, and they apply to native and non-native types.
747759
#### Matrix::Cast
748760
749761
```c++
750-
template <ComponentType NewCompTy, MatrixUse NewUse = Use>
751-
Matrix<NewCompTy, M, N, NewUse, Scope> Matrix::Cast();
762+
template <ComponentEnum NewCompTy, MatrixUseEnum NewUse = Use,
763+
bool Transpose = false>
764+
Matrix<NewCompTy, __detail::DimMN<M, N, Transpose>::M,
765+
__detail::DimMN<M, N, Transpose>::N, NewUse, Scope>
766+
Matrix::Cast();
752767
```
753768

754769
Requires `Wave` or `ThreadGroup` scope input and output matrices.
@@ -1719,7 +1734,7 @@ in the [`DXIL::ComponentType` enumeration](#dxil-enumerations).
17191734
17201735
## Appendix 1: HLSL Header
17211736

1722-
[Compiler Explorer](https://godbolt.org/z/ajGbYbMP8)
1737+
[Compiler Explorer](https://godbolt.org/z/5qzYaosf1)
17231738
> Note: this mostly works with Clang, but has some issues to work out still.
17241739
17251740
```cpp
@@ -1891,6 +1906,16 @@ template <ComponentEnum DstTy, ComponentEnum SrcTy, int SrcN> struct DstN {
18911906
ComponentTypeTraits<DstTy>::ElementsPerScalar;
18921907
};
18931908

1909+
template <SIZE_TYPE MVal, SIZE_TYPE NVal, bool Transposed> struct DimMN {
1910+
static const SIZE_TYPE M = MVal;
1911+
static const SIZE_TYPE N = NVal;
1912+
};
1913+
1914+
template <SIZE_TYPE MVal, SIZE_TYPE NVal> struct DimMN<MVal, NVal, true> {
1915+
static const SIZE_TYPE M = NVal;
1916+
static const SIZE_TYPE N = MVal;
1917+
};
1918+
18941919
} // namespace __detail
18951920

18961921
template <ComponentEnum ElementType, uint DimA> struct VectorRef {
@@ -1935,7 +1960,9 @@ class Matrix {
19351960

19361961
template <ComponentEnum NewCompTy, MatrixUseEnum NewUse = Use,
19371962
bool Transpose = false>
1938-
Matrix<NewCompTy, M, N, NewUse, Scope> Cast();
1963+
Matrix<NewCompTy, __detail::DimMN<M, N, Transpose>::M,
1964+
__detail::DimMN<M, N, Transpose>::N, NewUse, Scope>
1965+
Cast();
19391966

19401967
template <typename T>
19411968
static typename hlsl::enable_if<hlsl::is_arithmetic<T>::value, Matrix>::type

0 commit comments

Comments
 (0)