@@ -1788,6 +1788,7 @@ struct PSVLinAlgRuntimeInfo0 {
17881788 // Tables are serialized in this order, with each starting with the record
17891789 // stride in bytes, followed by the records.
17901790 uint32_t MatrixOperationShapeCount;
1791+ uint32_t MatrixConstructionCount;
17911792 uint32_t ThreadVectorMatrixMultiplyCount;
17921793 uint32_t WaveMatrixMultiplyCount;
17931794 uint32_t ThreadGroupMatrixMultiplyCount;
@@ -1807,10 +1808,15 @@ struct PSVLinAlgMatrixShapeArrayReference {
18071808 uint32_t Count;
18081809};
18091810
1811+ struct PSVLinAlgMatrixConstruction0 {
1812+ PSVLinAlgMatrixShapeArrayReference OperationShapes;
1813+ uint8_t MatrixType;
1814+ };
1815+
18101816enum class PSVLinAlgThreadVectorMatrixMultiplyFlag : uint8_t {
1811- None = 0x00000000 ,
1812- // Whether the matrix operand is loaded as transposed:
1813- MatrixTransposed = 0x00000001 ,
1817+ None = 0 ,
1818+ // MatrixTransposed: The matrix is loaded from MulOptimalTranspose layout.
1819+ MatrixTransposed = 1 << 0 ,
18141820};
18151821
18161822struct PSVLinAlgThreadVectorMatrixMultiply0 {
@@ -1853,9 +1859,14 @@ struct PSVLinAlgOuterProduct0 {
18531859};
18541860
18551861enum class PSVLinAlgAccumulateStoreFlag : uint8_t {
1856- None = 0x00000000,
1857- // Whether matrix is stored as transposed in an accumulate-store operation:
1858- MatrixTransposed = 0x00000001,
1862+ None = 0,
1863+ // MatrixTransposed: Accumulate to OuterProductOptimalTranspose layout,
1864+ // thread-scope only.
1865+ MatrixTransposed = 1 << 0,
1866+ // RawBuffer: Accumulate is to a raw buffer, all scopes.
1867+ RawBuffer = 1 << 1,
1868+ // GroupShared: Accumulate to GroupShared memory, wave/group scope only.
1869+ GroupShared = 1 << 2,
18591870};
18601871
18611872struct PSVLinAlgAccumulateStore0 {
@@ -1881,6 +1892,9 @@ PSVLinAlgRuntimeInfo0 record.
18811892 * If ` MatrixOperationShapeCount > 0 ` :
18821893 * ` uint32_t LinAlgMatrixOperationShapeSize `
18831894 * ` { (PSVLinAlgMatrixOperationShapeN) char[LinAlgMatrixOperationShapeSize] } * LinAlgMatrixOperationShapeCount `
1895+ * If ` ThreadMatrixConstructionCount > 0 ` :
1896+ * ` uint32_t LinAlgThreadMatrixConstructionSize `
1897+ * ` { (PSVLinAlgThreadMatrixConstructionN) char[LinAlgThreadMatrixConstructionSize] } * ThreadMatrixConstructionCount `
18841898 * If ` ThreadVectorMatrixMultiplyCount > 0 ` :
18851899 * ` uint32_t LinAlgThreadVectorMatrixMultiplySize `
18861900 * ` { (PSVLinAlgThreadVectorMatrixMultiplyN) char[LinAlgThreadVectorMatrixMultiplySize] } * ThreadVectorMatrixMultiplyCount `
@@ -1911,11 +1925,12 @@ enum class RuntimeDataPartType : uint32_t {
19111925 ...
19121926 ExtendedFunctionPropertiesTable = 12,
19131927 LinAlgMatrixOperationShapeTable = 13,
1914- LinAlgThreadVectorMatrixMultiplyTable = 14,
1915- LinAlgWaveMatrixMultiplyTable = 15,
1916- LinAlgThreadGroupMatrixMultiplyTable = 16,
1917- LinAlgOuterProductTable = 17,
1918- LinAlgAccumulateStoreTable = 18,
1928+ LinAlgMatrixConstructionTable = 14,
1929+ LinAlgThreadVectorMatrixMultiplyTable = 15,
1930+ LinAlgWaveMatrixMultiplyTable = 16,
1931+ LinAlgThreadGroupMatrixMultiplyTable = 17,
1932+ LinAlgOuterProductTable = 18,
1933+ LinAlgAccumulateStoreTable = 19,
19191934 ...
19201935};
19211936```
@@ -1929,30 +1944,132 @@ The following are record definitions associated with each new table type. Each
19291944record definition corresponds to the record format used in the associated PSV
19301945structure, with some adjustments to fit RDAT patterns.
19311946
1947+ RDAT macro form from DXC is used because it encodes all information required to
1948+ implement serialization, deserialization, reader helpers, dumping, validation,
1949+ and so on, based on the meaning of the RDAT macros and the supplied parameters.
1950+ Here, basic definitions are provided for serialization format with comments for
1951+ semantic meanings.
1952+
19321953```cpp
1954+ // Basic RDAT enum definitions stored as sTy, accessed as eTy.
1955+ // eTy defined under hlsl::RDAT namespace (in DXC)
1956+
1957+ #define RDAT_ENUM_START(eTy, sTy) enum class eTy : sTy {
1958+ #define RDAT_ENUM_VALUE(name, value) name = value,
1959+ #define RDAT_ENUM_VALUE_ALIAS(name, value) name = value,
1960+ #define RDAT_ENUM_VALUE_NODEF(name) name,
1961+ #define RDAT_ENUM_END() };
1962+
1963+ // Basic RDAT struct definitions define the stored record in little-endian.
1964+ // Comments for semantic meanings used in RDAT system.
1965+
1966+ // type defined under hlsl::RDAT namespace (in DXC)
1967+ #define RDAT_STRUCT(type) struct type {
1968+ // Derivation extends a record for versioning.
1969+ #define RDAT_STRUCT_DERIVED(type, base) struct type : public base {
1970+ #define RDAT_STRUCT_END() };
1971+
1972+ // RDAT_STRUCT_TABLE[_DERIVED]: Struct record stored in record table.
1973+ // Derivation is used for versioning, table is the same for a chain of record
1974+ // versions. Record stride is used to determine the version stored in the record
1975+ // table.
1976+ #define RDAT_STRUCT_TABLE(type, table) RDAT_STRUCT(type)
1977+ #define RDAT_STRUCT_TABLE_DERIVED(type, base, table) \
1978+ RDAT_STRUCT_DERIVED(type, base)
1979+
1980+ // INDEX_ARRAY_REF is an offset into the uint32_t index buffer which starts with
1981+ // the array length and is followed by uint32_t array elements
1982+ #define RDAT_INDEX_ARRAY_REF(name) uint32_t name;
1983+
1984+ // RDAT_RECORD_REF is an index into the record table defined for the record type
1985+ // by RDAT_STRUCT_TABLE, which may be extended by RDAT_STRUCT_TABLE_DERIVED.
1986+ // Available version is determined by the record stride in the table header.
1987+ #define RDAT_RECORD_REF(type, name) uint32_t name;
1988+ // RDAT_RECORD_ARRAY_REF is an array of indexes like RDAT_INDEX_ARRAY_REF, but
1989+ // each index is treated as an index into the record table associated with type,
1990+ // like RDAT_RECORD_REF.
1991+ #define RDAT_RECORD_ARRAY_REF(type, name) uint32_t name;
1992+
1993+ // By-value record stored in-place. Record version is fixed by 'type' used.
1994+ #define RDAT_RECORD_VALUE(type, name) type name;
1995+
1996+ // byte offset into the shared utf-8 null-terminated string buffer
1997+ #define RDAT_STRING(name) uint32_t name;
1998+
1999+ // RDAT_STRING_ARRAY_REF is the index into the index buffer like
2000+ // RDAT_INDEX_ARRAY_REF, but each index is treated as a byte offset into the
2001+ // shared string buffer containing null-terminated utf-8 strings.
2002+ #define RDAT_STRING_ARRAY_REF(name) uint32_t name;
2003+
2004+ // Arbirary type stored inline by-value (always little-endian)
2005+ #define RDAT_VALUE(type, name) type name;
2006+
2007+ // ENUM/FLAGS stored as sTy, with associated enum type eTy for accessors.
2008+ #define RDAT_ENUM(sTy, eTy, name) sTy name;
2009+ #define RDAT_FLAGS(sTy, eTy, name) sTy name;
2010+
2011+ // Raw bytes as an offset into the shared byte buffer, and a size in bytes.
2012+ #define RDAT_BYTES(name) \
2013+ uint32_t name; \
2014+ uint32_t name##_Size;
2015+ #define RDAT_ARRAY_VALUE(type, count, type_name, name) type_name name;
2016+
2017+ // RDAT_UNION should wrap a set of RDAT_UNION_IF/RDAT_UNION_ELIF expressions,
2018+ // each containing a single element, only valid when the expression is true.
2019+ #define RDAT_UNION() union {
2020+ #define RDAT_UNION_END() };
2021+
2022+ // UNION IF/ELIF creates bool has##name() accessors in reader based on expr.
2023+ // Wraps a union member (typically of the same name) that is only accessed when
2024+ // expr is true.
2025+ #define RDAT_UNION_IF(name, expr) \
2026+ bool GLUE(RECORD_TYPE, _Reader)::has##name() const { \
2027+ if (auto *pRecord = asRecord()) \
2028+ return !!(expr); \
2029+ return false; \
2030+ }
2031+ #define RDAT_UNION_ELIF(name, expr) RDAT_UNION_IF(name, expr)
2032+
2033+ // expr example: getPropertyType() == FunctionPropertyType::Flags
2034+ // getPropertyType() calls the generated reader accessor for PropertyType that
2035+ // returns the FunctionPropertyType, and compares this with
2036+ // FunctionPropertyType::Flags defined in the enum below.
2037+
19332038RDAT_ENUM_START(LinAlgThreadVectorMatrixMultiplyFlag, uint8_t)
19342039 RDAT_ENUM_VALUE(None, 0)
1935- // The matrix operand is loaded as transposed:
2040+ // MatrixTransposed: The matrix is loaded from MulOptimalTranspose layout.
19362041 RDAT_ENUM_VALUE(MatrixTransposed, 1 << 0)
19372042RDAT_ENUM_END()
19382043
19392044RDAT_ENUM_START(LinAlgAccumulateStoreFlag, uint8_t)
19402045 RDAT_ENUM_VALUE(None, 0)
1941- // The matrix is stored in transposed layout:
2046+ // MatrixTransposed: Accumulate to OuterProductOptimalTranspose layout,
2047+ // thread-scope only.
19422048 RDAT_ENUM_VALUE(MatrixTransposed, 1 << 0)
2049+ // RawBuffer: Accumulate is to a raw buffer, all scopes.
2050+ RDAT_ENUM_VALUE(RawBuffer, 1 << 1)
2051+ // GroupShared: Accumulate to GroupShared memory, wave/group scope only.
2052+ RDAT_ENUM_VALUE(GroupShared, 1 << 2)
19432053RDAT_ENUM_END()
19442054
19452055RDAT_STRUCT_TABLE(LinAlgMatrixOperationShape,
19462056 LinAlgMatrixOperationShapeTable)
2057+ // Unused dimensions stored as 0.
19472058 RDAT_VALUE(uint32_t, M) // Rows in matrix A
19482059 RDAT_VALUE(uint32_t, N) // Columns in matrix B
19492060 RDAT_VALUE(uint32_t, K) // Columns in matrix A / Rows in matrix B
19502061RDAT_STRUCT_END()
19512062
2063+ // In the following, an unused ComponentType would be stored as
2064+ // hlsl::DXIL::ComponentType::Invalid, aka: 0.
2065+
2066+ RDAT_STRUCT_TABLE(LinAlgMatrixConstruction, LinAlgMatrixConstructionTable)
2067+ RDAT_RECORD_ARRAY_REF(LinAlgMatrixOperationShape, OperationShapes)
2068+ RDAT_ENUM(uint8_t, hlsl::DXIL::ComponentType, MatrixType)
2069+ RDAT_STRUCT_END()
2070+
19522071RDAT_STRUCT_TABLE(LinAlgThreadVectorMatrixMultiply,
19532072 LinAlgThreadVectorMatrixMultiplyTable)
1954- // Do we need shapes? If so, K would be unused (0)
1955- RDAT_RECORD_ARRAY_REF(LinAlgMatrixOperationShape, OperationShapes)
19562073 RDAT_ENUM(uint8_t, hlsl::DXIL::ComponentType, ResultType)
19572074 RDAT_ENUM(uint8_t, hlsl::DXIL::ComponentType, MatrixType)
19582075 RDAT_ENUM(uint8_t, hlsl::DXIL::ComponentType, VectorInputType)
@@ -1990,17 +2107,22 @@ RDAT_STRUCT_END()
19902107
19912108// ------------ RuntimeDataFunctionInfo3 dependencies ------------
19922109
2110+ // RuntimeDataFunctionInfo3 adds an extended property list. Each extended
2111+ // property has a FunctionPropertyType and a 32-bit field for the property.
2112+ // For LinAlg properties, the field is always a REF of some kind (index or
2113+ // offset into some other table). Currently, these are RDAT_RECORD_ARRAY_REF
2114+ // which are an index into the index table, which starts with an array size (in
2115+ // elements), followed by values interpreted as indexes into record tables
2116+ // associated with the specified type defined with RDAT_STRUCT_TABLE.
2117+
19932118RDAT_ENUM_START(FunctionPropertyType, uint32_t)
19942119 RDAT_ENUM_VALUE(Flags, 0)
1995- RDAT_ENUM_VALUE(LinAlgThreadVectorMatrixMultiply, 1)
1996- #ifdef UNIFY_MATRIX_MULTIPLY_STRUCTURES
1997- RDAT_ENUM_VALUE(LinAlgMatrixMultiply, 2)
1998- #else
1999- RDAT_ENUM_VALUE(LinAlgWaveMatrixMultiply, 2)
2000- RDAT_ENUM_VALUE(LinAlgThreadGroupMatrixMultiply, 3)
2001- #endif // UNIFY_MATRIX_MULTIPLY_STRUCTURES
2002- RDAT_ENUM_VALUE(LinAlgOuterProduct, 4)
2003- RDAT_ENUM_VALUE(LinAlgAccumulateStore, 5)
2120+ RDAT_ENUM_VALUE(LinAlgMatrixConstruction, 1)
2121+ RDAT_ENUM_VALUE(LinAlgThreadVectorMatrixMultiply, 2)
2122+ RDAT_ENUM_VALUE(LinAlgWaveMatrixMultiply, 3)
2123+ RDAT_ENUM_VALUE(LinAlgThreadGroupMatrixMultiply, 4)
2124+ RDAT_ENUM_VALUE(LinAlgOuterProduct, 5)
2125+ RDAT_ENUM_VALUE(LinAlgAccumulateStore, 6)
20042126RDAT_ENUM_END()
20052127
20062128RDAT_STRUCT_TABLE(ExtendedFunctionProperties, ExtendedFunctionPropertiesTable)
@@ -2009,30 +2131,33 @@ RDAT_STRUCT_TABLE(ExtendedFunctionProperties, ExtendedFunctionPropertiesTable)
20092131 RDAT_UNION()
20102132 RDAT_UNION_IF(Flags, getPropertyType() == FunctionPropertyType::Flags)
20112133 RDAT_VALUE(uint32_t, Flags)
2134+ RDAT_UNION_ELIF(LinAlgMatrixConstruction, getPropertyType() ==
2135+ FunctionPropertyType::LinAlgMatrixConstruction)
2136+ RDAT_RECORD_ARRAY_REF(LinAlgMatrixConstruction, MatrixConstructionArray)
20122137 RDAT_UNION_ELIF(
20132138 LinAlgThreadVectorMatrixMultiply,
20142139 getPropertyType() ==
20152140 FunctionPropertyType::LinAlgThreadVectorMatrixMultiply)
2016- RDAT_RECORD_ARRAY_REF(LinAlgThreadVectorMatrixMultiply,
2017- LinAlgThreadVectorMatrixMultiplyArray)
2141+ RDAT_RECORD_ARRAY_REF(LinAlgThreadVectorMatrixMultiply,
2142+ LinAlgThreadVectorMatrixMultiplyArray)
20182143 RDAT_UNION_ELIF(LinAlgWaveMatrixMultiply,
20192144 getPropertyType() ==
20202145 FunctionPropertyType::LinAlgWaveMatrixMultiply)
2021- RDAT_RECORD_ARRAY_REF(LinAlgWaveMatrixMultiply,
2022- LinAlgWaveMatrixMultiplyArray)
2146+ RDAT_RECORD_ARRAY_REF(LinAlgWaveMatrixMultiply,
2147+ LinAlgWaveMatrixMultiplyArray)
20232148 RDAT_UNION_ELIF(LinAlgThreadGroupMatrixMultiply,
20242149 getPropertyType() ==
20252150 FunctionPropertyType::LinAlgThreadGroupMatrixMultiply)
2026- RDAT_RECORD_ARRAY_REF(LinAlgThreadGroupMatrixMultiply,
2027- LinAlgThreadGroupMatrixMultiplyArray)
2151+ RDAT_RECORD_ARRAY_REF(LinAlgThreadGroupMatrixMultiply,
2152+ LinAlgThreadGroupMatrixMultiplyArray)
20282153 RDAT_UNION_ELIF(LinAlgOuterProduct,
20292154 getPropertyType() ==
20302155 FunctionPropertyType::LinAlgOuterProduct)
2031- RDAT_RECORD_ARRAY_REF(LinAlgOuterProduct, LinAlgOuterProductArray)
2156+ RDAT_RECORD_ARRAY_REF(LinAlgOuterProduct, LinAlgOuterProductArray)
20322157 RDAT_UNION_ELIF(LinAlgAccumulateStore,
20332158 getPropertyType() ==
20342159 FunctionPropertyType::LinAlgAccumulateStore)
2035- RDAT_RECORD_ARRAY_REF(LinAlgAccumulateStore, LinAlgAccumulateStoreArray)
2160+ RDAT_RECORD_ARRAY_REF(LinAlgAccumulateStore, LinAlgAccumulateStoreArray)
20362161 RDAT_UNION_ENDIF()
20372162 RDAT_UNION_END()
20382163
0 commit comments