Skip to content

Commit 1f3fcca

Browse files
committed
Update PSV0 and RDAT definitions
Added macro definitions and comments to explain the exact meaning of RDAT record elements.
1 parent 008dcbf commit 1f3fcca

1 file changed

Lines changed: 157 additions & 32 deletions

File tree

proposals/0035-linalg-matrix.md

Lines changed: 157 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1788,6 +1788,7 @@ struct PSVLinAlgRuntimeInfo0 {
17881788
// Tables are serialized in this order, with each starting with the record
17891789
// stride in bytes, followed by the records.
17901790
uint32_t MatrixOperationShapeCount;
1791+
uint32_t MatrixConstructionCount;
17911792
uint32_t ThreadVectorMatrixMultiplyCount;
17921793
uint32_t WaveMatrixMultiplyCount;
17931794
uint32_t ThreadGroupMatrixMultiplyCount;
@@ -1807,10 +1808,15 @@ struct PSVLinAlgMatrixShapeArrayReference {
18071808
uint32_t Count;
18081809
};
18091810
1811+
struct PSVLinAlgMatrixConstruction0 {
1812+
PSVLinAlgMatrixShapeArrayReference OperationShapes;
1813+
uint8_t MatrixType;
1814+
};
1815+
18101816
enum class PSVLinAlgThreadVectorMatrixMultiplyFlag : uint8_t {
1811-
None = 0x00000000,
1812-
// Whether the matrix operand is loaded as transposed:
1813-
MatrixTransposed = 0x00000001,
1817+
None = 0,
1818+
// MatrixTransposed: The matrix is loaded from MulOptimalTranspose layout.
1819+
MatrixTransposed = 1 << 0,
18141820
};
18151821
18161822
struct PSVLinAlgThreadVectorMatrixMultiply0 {
@@ -1853,9 +1859,14 @@ struct PSVLinAlgOuterProduct0 {
18531859
};
18541860
18551861
enum class PSVLinAlgAccumulateStoreFlag : uint8_t {
1856-
None = 0x00000000,
1857-
// Whether matrix is stored as transposed in an accumulate-store operation:
1858-
MatrixTransposed = 0x00000001,
1862+
None = 0,
1863+
// MatrixTransposed: Accumulate to OuterProductOptimalTranspose layout,
1864+
// thread-scope only.
1865+
MatrixTransposed = 1 << 0,
1866+
// RawBuffer: Accumulate is to a raw buffer, all scopes.
1867+
RawBuffer = 1 << 1,
1868+
// GroupShared: Accumulate to GroupShared memory, wave/group scope only.
1869+
GroupShared = 1 << 2,
18591870
};
18601871
18611872
struct PSVLinAlgAccumulateStore0 {
@@ -1881,6 +1892,9 @@ PSVLinAlgRuntimeInfo0 record.
18811892
* If `MatrixOperationShapeCount > 0`:
18821893
* `uint32_t LinAlgMatrixOperationShapeSize`
18831894
* `{ (PSVLinAlgMatrixOperationShapeN) char[LinAlgMatrixOperationShapeSize] } * LinAlgMatrixOperationShapeCount`
1895+
* If `ThreadMatrixConstructionCount > 0`:
1896+
* `uint32_t LinAlgThreadMatrixConstructionSize`
1897+
* `{ (PSVLinAlgThreadMatrixConstructionN) char[LinAlgThreadMatrixConstructionSize] } * ThreadMatrixConstructionCount`
18841898
* If `ThreadVectorMatrixMultiplyCount > 0`:
18851899
* `uint32_t LinAlgThreadVectorMatrixMultiplySize`
18861900
* `{ (PSVLinAlgThreadVectorMatrixMultiplyN) char[LinAlgThreadVectorMatrixMultiplySize] } * ThreadVectorMatrixMultiplyCount`
@@ -1911,11 +1925,12 @@ enum class RuntimeDataPartType : uint32_t {
19111925
...
19121926
ExtendedFunctionPropertiesTable = 12,
19131927
LinAlgMatrixOperationShapeTable = 13,
1914-
LinAlgThreadVectorMatrixMultiplyTable = 14,
1915-
LinAlgWaveMatrixMultiplyTable = 15,
1916-
LinAlgThreadGroupMatrixMultiplyTable = 16,
1917-
LinAlgOuterProductTable = 17,
1918-
LinAlgAccumulateStoreTable = 18,
1928+
LinAlgMatrixConstructionTable = 14,
1929+
LinAlgThreadVectorMatrixMultiplyTable = 15,
1930+
LinAlgWaveMatrixMultiplyTable = 16,
1931+
LinAlgThreadGroupMatrixMultiplyTable = 17,
1932+
LinAlgOuterProductTable = 18,
1933+
LinAlgAccumulateStoreTable = 19,
19191934
...
19201935
};
19211936
```
@@ -1929,30 +1944,132 @@ The following are record definitions associated with each new table type. Each
19291944
record definition corresponds to the record format used in the associated PSV
19301945
structure, with some adjustments to fit RDAT patterns.
19311946
1947+
RDAT macro form from DXC is used because it encodes all information required to
1948+
implement serialization, deserialization, reader helpers, dumping, validation,
1949+
and so on, based on the meaning of the RDAT macros and the supplied parameters.
1950+
Here, basic definitions are provided for serialization format with comments for
1951+
semantic meanings.
1952+
19321953
```cpp
1954+
// Basic RDAT enum definitions stored as sTy, accessed as eTy.
1955+
// eTy defined under hlsl::RDAT namespace (in DXC)
1956+
1957+
#define RDAT_ENUM_START(eTy, sTy) enum class eTy : sTy {
1958+
#define RDAT_ENUM_VALUE(name, value) name = value,
1959+
#define RDAT_ENUM_VALUE_ALIAS(name, value) name = value,
1960+
#define RDAT_ENUM_VALUE_NODEF(name) name,
1961+
#define RDAT_ENUM_END() };
1962+
1963+
// Basic RDAT struct definitions define the stored record in little-endian.
1964+
// Comments for semantic meanings used in RDAT system.
1965+
1966+
// type defined under hlsl::RDAT namespace (in DXC)
1967+
#define RDAT_STRUCT(type) struct type {
1968+
// Derivation extends a record for versioning.
1969+
#define RDAT_STRUCT_DERIVED(type, base) struct type : public base {
1970+
#define RDAT_STRUCT_END() };
1971+
1972+
// RDAT_STRUCT_TABLE[_DERIVED]: Struct record stored in record table.
1973+
// Derivation is used for versioning, table is the same for a chain of record
1974+
// versions. Record stride is used to determine the version stored in the record
1975+
// table.
1976+
#define RDAT_STRUCT_TABLE(type, table) RDAT_STRUCT(type)
1977+
#define RDAT_STRUCT_TABLE_DERIVED(type, base, table) \
1978+
RDAT_STRUCT_DERIVED(type, base)
1979+
1980+
// INDEX_ARRAY_REF is an offset into the uint32_t index buffer which starts with
1981+
// the array length and is followed by uint32_t array elements
1982+
#define RDAT_INDEX_ARRAY_REF(name) uint32_t name;
1983+
1984+
// RDAT_RECORD_REF is an index into the record table defined for the record type
1985+
// by RDAT_STRUCT_TABLE, which may be extended by RDAT_STRUCT_TABLE_DERIVED.
1986+
// Available version is determined by the record stride in the table header.
1987+
#define RDAT_RECORD_REF(type, name) uint32_t name;
1988+
// RDAT_RECORD_ARRAY_REF is an array of indexes like RDAT_INDEX_ARRAY_REF, but
1989+
// each index is treated as an index into the record table associated with type,
1990+
// like RDAT_RECORD_REF.
1991+
#define RDAT_RECORD_ARRAY_REF(type, name) uint32_t name;
1992+
1993+
// By-value record stored in-place. Record version is fixed by 'type' used.
1994+
#define RDAT_RECORD_VALUE(type, name) type name;
1995+
1996+
// byte offset into the shared utf-8 null-terminated string buffer
1997+
#define RDAT_STRING(name) uint32_t name;
1998+
1999+
// RDAT_STRING_ARRAY_REF is the index into the index buffer like
2000+
// RDAT_INDEX_ARRAY_REF, but each index is treated as a byte offset into the
2001+
// shared string buffer containing null-terminated utf-8 strings.
2002+
#define RDAT_STRING_ARRAY_REF(name) uint32_t name;
2003+
2004+
// Arbirary type stored inline by-value (always little-endian)
2005+
#define RDAT_VALUE(type, name) type name;
2006+
2007+
// ENUM/FLAGS stored as sTy, with associated enum type eTy for accessors.
2008+
#define RDAT_ENUM(sTy, eTy, name) sTy name;
2009+
#define RDAT_FLAGS(sTy, eTy, name) sTy name;
2010+
2011+
// Raw bytes as an offset into the shared byte buffer, and a size in bytes.
2012+
#define RDAT_BYTES(name) \
2013+
uint32_t name; \
2014+
uint32_t name##_Size;
2015+
#define RDAT_ARRAY_VALUE(type, count, type_name, name) type_name name;
2016+
2017+
// RDAT_UNION should wrap a set of RDAT_UNION_IF/RDAT_UNION_ELIF expressions,
2018+
// each containing a single element, only valid when the expression is true.
2019+
#define RDAT_UNION() union {
2020+
#define RDAT_UNION_END() };
2021+
2022+
// UNION IF/ELIF creates bool has##name() accessors in reader based on expr.
2023+
// Wraps a union member (typically of the same name) that is only accessed when
2024+
// expr is true.
2025+
#define RDAT_UNION_IF(name, expr) \
2026+
bool GLUE(RECORD_TYPE, _Reader)::has##name() const { \
2027+
if (auto *pRecord = asRecord()) \
2028+
return !!(expr); \
2029+
return false; \
2030+
}
2031+
#define RDAT_UNION_ELIF(name, expr) RDAT_UNION_IF(name, expr)
2032+
2033+
// expr example: getPropertyType() == FunctionPropertyType::Flags
2034+
// getPropertyType() calls the generated reader accessor for PropertyType that
2035+
// returns the FunctionPropertyType, and compares this with
2036+
// FunctionPropertyType::Flags defined in the enum below.
2037+
19332038
RDAT_ENUM_START(LinAlgThreadVectorMatrixMultiplyFlag, uint8_t)
19342039
RDAT_ENUM_VALUE(None, 0)
1935-
// The matrix operand is loaded as transposed:
2040+
// MatrixTransposed: The matrix is loaded from MulOptimalTranspose layout.
19362041
RDAT_ENUM_VALUE(MatrixTransposed, 1 << 0)
19372042
RDAT_ENUM_END()
19382043
19392044
RDAT_ENUM_START(LinAlgAccumulateStoreFlag, uint8_t)
19402045
RDAT_ENUM_VALUE(None, 0)
1941-
// The matrix is stored in transposed layout:
2046+
// MatrixTransposed: Accumulate to OuterProductOptimalTranspose layout,
2047+
// thread-scope only.
19422048
RDAT_ENUM_VALUE(MatrixTransposed, 1 << 0)
2049+
// RawBuffer: Accumulate is to a raw buffer, all scopes.
2050+
RDAT_ENUM_VALUE(RawBuffer, 1 << 1)
2051+
// GroupShared: Accumulate to GroupShared memory, wave/group scope only.
2052+
RDAT_ENUM_VALUE(GroupShared, 1 << 2)
19432053
RDAT_ENUM_END()
19442054
19452055
RDAT_STRUCT_TABLE(LinAlgMatrixOperationShape,
19462056
LinAlgMatrixOperationShapeTable)
2057+
// Unused dimensions stored as 0.
19472058
RDAT_VALUE(uint32_t, M) // Rows in matrix A
19482059
RDAT_VALUE(uint32_t, N) // Columns in matrix B
19492060
RDAT_VALUE(uint32_t, K) // Columns in matrix A / Rows in matrix B
19502061
RDAT_STRUCT_END()
19512062
2063+
// In the following, an unused ComponentType would be stored as
2064+
// hlsl::DXIL::ComponentType::Invalid, aka: 0.
2065+
2066+
RDAT_STRUCT_TABLE(LinAlgMatrixConstruction, LinAlgMatrixConstructionTable)
2067+
RDAT_RECORD_ARRAY_REF(LinAlgMatrixOperationShape, OperationShapes)
2068+
RDAT_ENUM(uint8_t, hlsl::DXIL::ComponentType, MatrixType)
2069+
RDAT_STRUCT_END()
2070+
19522071
RDAT_STRUCT_TABLE(LinAlgThreadVectorMatrixMultiply,
19532072
LinAlgThreadVectorMatrixMultiplyTable)
1954-
// Do we need shapes? If so, K would be unused (0)
1955-
RDAT_RECORD_ARRAY_REF(LinAlgMatrixOperationShape, OperationShapes)
19562073
RDAT_ENUM(uint8_t, hlsl::DXIL::ComponentType, ResultType)
19572074
RDAT_ENUM(uint8_t, hlsl::DXIL::ComponentType, MatrixType)
19582075
RDAT_ENUM(uint8_t, hlsl::DXIL::ComponentType, VectorInputType)
@@ -1990,17 +2107,22 @@ RDAT_STRUCT_END()
19902107
19912108
// ------------ RuntimeDataFunctionInfo3 dependencies ------------
19922109
2110+
// RuntimeDataFunctionInfo3 adds an extended property list. Each extended
2111+
// property has a FunctionPropertyType and a 32-bit field for the property.
2112+
// For LinAlg properties, the field is always a REF of some kind (index or
2113+
// offset into some other table). Currently, these are RDAT_RECORD_ARRAY_REF
2114+
// which are an index into the index table, which starts with an array size (in
2115+
// elements), followed by values interpreted as indexes into record tables
2116+
// associated with the specified type defined with RDAT_STRUCT_TABLE.
2117+
19932118
RDAT_ENUM_START(FunctionPropertyType, uint32_t)
19942119
RDAT_ENUM_VALUE(Flags, 0)
1995-
RDAT_ENUM_VALUE(LinAlgThreadVectorMatrixMultiply, 1)
1996-
#ifdef UNIFY_MATRIX_MULTIPLY_STRUCTURES
1997-
RDAT_ENUM_VALUE(LinAlgMatrixMultiply, 2)
1998-
#else
1999-
RDAT_ENUM_VALUE(LinAlgWaveMatrixMultiply, 2)
2000-
RDAT_ENUM_VALUE(LinAlgThreadGroupMatrixMultiply, 3)
2001-
#endif // UNIFY_MATRIX_MULTIPLY_STRUCTURES
2002-
RDAT_ENUM_VALUE(LinAlgOuterProduct, 4)
2003-
RDAT_ENUM_VALUE(LinAlgAccumulateStore, 5)
2120+
RDAT_ENUM_VALUE(LinAlgMatrixConstruction, 1)
2121+
RDAT_ENUM_VALUE(LinAlgThreadVectorMatrixMultiply, 2)
2122+
RDAT_ENUM_VALUE(LinAlgWaveMatrixMultiply, 3)
2123+
RDAT_ENUM_VALUE(LinAlgThreadGroupMatrixMultiply, 4)
2124+
RDAT_ENUM_VALUE(LinAlgOuterProduct, 5)
2125+
RDAT_ENUM_VALUE(LinAlgAccumulateStore, 6)
20042126
RDAT_ENUM_END()
20052127
20062128
RDAT_STRUCT_TABLE(ExtendedFunctionProperties, ExtendedFunctionPropertiesTable)
@@ -2009,30 +2131,33 @@ RDAT_STRUCT_TABLE(ExtendedFunctionProperties, ExtendedFunctionPropertiesTable)
20092131
RDAT_UNION()
20102132
RDAT_UNION_IF(Flags, getPropertyType() == FunctionPropertyType::Flags)
20112133
RDAT_VALUE(uint32_t, Flags)
2134+
RDAT_UNION_ELIF(LinAlgMatrixConstruction, getPropertyType() ==
2135+
FunctionPropertyType::LinAlgMatrixConstruction)
2136+
RDAT_RECORD_ARRAY_REF(LinAlgMatrixConstruction, MatrixConstructionArray)
20122137
RDAT_UNION_ELIF(
20132138
LinAlgThreadVectorMatrixMultiply,
20142139
getPropertyType() ==
20152140
FunctionPropertyType::LinAlgThreadVectorMatrixMultiply)
2016-
RDAT_RECORD_ARRAY_REF(LinAlgThreadVectorMatrixMultiply,
2017-
LinAlgThreadVectorMatrixMultiplyArray)
2141+
RDAT_RECORD_ARRAY_REF(LinAlgThreadVectorMatrixMultiply,
2142+
LinAlgThreadVectorMatrixMultiplyArray)
20182143
RDAT_UNION_ELIF(LinAlgWaveMatrixMultiply,
20192144
getPropertyType() ==
20202145
FunctionPropertyType::LinAlgWaveMatrixMultiply)
2021-
RDAT_RECORD_ARRAY_REF(LinAlgWaveMatrixMultiply,
2022-
LinAlgWaveMatrixMultiplyArray)
2146+
RDAT_RECORD_ARRAY_REF(LinAlgWaveMatrixMultiply,
2147+
LinAlgWaveMatrixMultiplyArray)
20232148
RDAT_UNION_ELIF(LinAlgThreadGroupMatrixMultiply,
20242149
getPropertyType() ==
20252150
FunctionPropertyType::LinAlgThreadGroupMatrixMultiply)
2026-
RDAT_RECORD_ARRAY_REF(LinAlgThreadGroupMatrixMultiply,
2027-
LinAlgThreadGroupMatrixMultiplyArray)
2151+
RDAT_RECORD_ARRAY_REF(LinAlgThreadGroupMatrixMultiply,
2152+
LinAlgThreadGroupMatrixMultiplyArray)
20282153
RDAT_UNION_ELIF(LinAlgOuterProduct,
20292154
getPropertyType() ==
20302155
FunctionPropertyType::LinAlgOuterProduct)
2031-
RDAT_RECORD_ARRAY_REF(LinAlgOuterProduct, LinAlgOuterProductArray)
2156+
RDAT_RECORD_ARRAY_REF(LinAlgOuterProduct, LinAlgOuterProductArray)
20322157
RDAT_UNION_ELIF(LinAlgAccumulateStore,
20332158
getPropertyType() ==
20342159
FunctionPropertyType::LinAlgAccumulateStore)
2035-
RDAT_RECORD_ARRAY_REF(LinAlgAccumulateStore, LinAlgAccumulateStoreArray)
2160+
RDAT_RECORD_ARRAY_REF(LinAlgAccumulateStore, LinAlgAccumulateStoreArray)
20362161
RDAT_UNION_ENDIF()
20372162
RDAT_UNION_END()
20382163

0 commit comments

Comments
 (0)