-
Notifications
You must be signed in to change notification settings - Fork 15.7k
[mlir][spirv] Initial support for TOSA Extended Instruction Set (0010… #174402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][spirv] Initial support for TOSA Extended Instruction Set (0010… #174402
Conversation
…00.1) This patch adds initial support for the TOSA Extended Instruction Set (001000.1) to the SPIR-V dialect in MLIR. The TOSA extended instruction set provides a standardized set of machine learning operations designed to be used within `spirv.ARM.Graph` operations (corresponding to OpGraphARM in SPV_ARM_graph) and typed with `!spirv.arm.tensor<...>` (corresponding to OpTypeTensorARM in SPV_ARM_tensor). The change introduces: * Dialect plumbing for import, serialization, and deserialization of the TOSA extended instruction set. * The `spirv.Tosa.ArgMax` operation from TOSA extended instruction, each lowering to the corresponding `OpExtInst`. * Verification enforcing that `spirv.Tosa.ArgMax` appears only within `spirv.ARM.Graph` regions, operates on `!spirv.arm.tensor<...>` types, and is well-formed according to the TOSA 001000.1 specification. Only the ArgMax operation from TOSA 001000.1 extended instructions is introduced in order to show case the work needed: [arser, printer, verifier, and round-trip tests using MLIR’s SPIR-V serialization/deserialization infrastructure are included. This work aligns with Khronos SPIR-V TOSA specifications. Specification: https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html Signed-off-by: Davide Grohmann <[email protected]> Change-Id: Ibf2aad7c9e86a9dc28c6133f5d6cb0cd67163ddf
|
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Davide Grohmann (davidegrohmann) Changes…00.1) This patch adds initial support for the TOSA Extended Instruction Set (001000.1) to the SPIR-V dialect in MLIR. The TOSA extended instruction set provides a standardized set of machine learning operations designed to be used within The change introduces:
Only the ArgMax operation from TOSA 001000.1 extended instructions is introduced in order to show case the work needed: [arser, printer, verifier, and round-trip tests using MLIR’s SPIR-V serialization/deserialization infrastructure are included. This work aligns with Khronos SPIR-V TOSA specifications. Specification: Full diff: https://github.com/llvm/llvm-project/pull/174402.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index ecbbf39a534e1..5677f825a8aec 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4233,6 +4233,7 @@ def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_s
def SPIRV_Void : TypeAlias<NoneType, "void">;
def SPIRV_Bool : TypeAlias<I1, "bool">;
def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
+def SPIRV_Int8 : TypeAlias<I8, "Int8">;
def SPIRV_Int16 : TypeAlias<I16, "Int16">;
def SPIRV_Int32 : TypeAlias<I32, "Int32">;
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
@@ -4908,4 +4909,17 @@ def SPIRV_FPFastMathModeAttr :
SPIRV_FPFMM_AllowReassocINTEL
]>;
+//===----------------------------------------------------------------------===//
+// SPIR-V TOSA enum definitions.
+//===----------------------------------------------------------------------===//
+
+// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
+// SPIR-V proper.
+def SPIRV_TosaExtNanPropagationModeAttr : SPIRV_I32EnumAttr<
+ "TosaExtNanPropagationModeType", "Tosa Ext NAN Propoagation Mode Type", "tosa_ext_nan_propagation_mode_type",
+ [
+ I32EnumAttrCase<"PROPAGATE", 1>,
+ I32EnumAttrCase<"IGNORE", 2>,
+ ]>;
+
#endif // MLIR_DIALECT_SPIRV_IR_BASE
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index 96ef035eda37a..3ef9699154cd1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -45,6 +45,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVPrimitiveOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVCLOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
#endif // MLIR_DIALECT_SPIRV_IR_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
new file mode 100644
index 0000000000000..f545c4c2fa1ae
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -0,0 +1,64 @@
+//===- SPIRVTosaOps.td - TOSA extended insts spec file -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of TOSA extension ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_TOSA_OPS
+#define MLIR_DIALECT_SPIRV_IR_TOSA_OPS
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// SPIR-V TOSA opcode specification.
+//===----------------------------------------------------------------------===//
+
+// Base class for all TOSA ops.
+class SPIRV_TosaOp<string mnemonic, int opcode, list<Trait> traits = []> :
+ SPIRV_ExtInstOp<mnemonic, "Tosa", "TOSA.001000.1", opcode, !listconcat(traits, [InGraphScope])> {
+
+ let availability = [
+ MinVersion<SPIRV_V_1_5>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_ARM_graph, SPV_ARM_tensors]>,
+ Capability<[SPIRV_C_GraphARM]>
+ ];
+}
+
+
+def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> {
+ let summary = "ArgMax - TOSA extended instruction set 001000.1";
+
+ let description = [{
+ https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_argmax
+ }];
+
+
+ let arguments = (ins
+ SPIRV_Int32: $axis,
+ SPIRV_TosaExtNanPropagationModeAttr: $nan_mode,
+ SPIRV_TosaNumerical_TensorArm: $input
+ );
+
+
+ let results = (outs
+ SPIRV_TosaInteger_TensorArmUpTo5D: $output
+ );
+
+ let hasVerifier = 1;
+
+ let assemblyFormat = [{
+ operands attr-dict `:` `(` type(operands) `)` `->` type(results)
+ }];
+}
+
+#endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
new file mode 100644
index 0000000000000..9081c5fa6c156
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -0,0 +1,39 @@
+//===- SPIRVTosaTypes.td - Tosa Types insts spec file ----*- tablegen -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This specifies Tosa types used by the Graph Extension and Tosa Ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
+#define MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+
+def SPIRV_TosaInteger : AnyIntOfWidths<[8, 16, 32, 64]>;
+def SPIRV_TosaFloat : AnyTypeOf<[SPIRV_BFloat16KHR, SPIRV_Float16or32]>;
+def SPIRV_TosaNumerical : AnyTypeOf<[SPIRV_TosaInteger, SPIRV_TosaFloat]>;
+def SPIRV_TosaAny : AnyTypeOf<[SPIRV_TosaNumerical, SPIRV_Bool]>;
+
+// TensorARM Types
+
+class RankedTensorArmOf<list<Type> allowedTypes, list<Pred> preds = [],
+ string summary = "ranked tensorArm">
+ : ShapedContainerType<
+ allowedTypes, And<!listconcat([SPIRV_IsTensorArmType], preds)>,
+ summary, "::mlir::spirv::TensorArmType">;
+
+class TensorArmRankOf<list<Type> allowedTypes, list<int> ranks>
+ : RankedTensorArmOf<allowedTypes,
+ [HasAnyRankOfPred<ranks>],
+ !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensorArm">;
+
+def SPIRV_TosaNumerical_TensorArm : TensorArmRankOf<[SPIRV_TosaNumerical], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_TosaInteger_TensorArmUpTo5D : TensorArmRankOf<[SPIRV_TosaInteger], [1, 2, 3, 4, 5]>;
+
+#endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index 60d705d940cfc..f05f596aa9f23 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRSPIRVDialect
SPIRVOpDefinition.cpp
SPIRVOps.cpp
SPIRVParsingUtils.cpp
+ SPIRVTosaOps.cpp
SPIRVTypes.cpp
TargetAndABI.cpp
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
new file mode 100644
index 0000000000000..31f7734dd8495
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
@@ -0,0 +1,72 @@
+//===- SPIRVTosaOps.cpp - MLIR SPIR-V operations --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Tosa operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// TOSA Operator Verifiers.
+//===----------------------------------------------------------------------===//
+
+// Get value attr from spirv::ConstantOp or
+// spirv::EXTConstantCompositeReplicateOp
+template <typename TAttr>
+static LogicalResult getConstAttr(Value value, TAttr &valAttr) {
+ if (auto constOp = value.template getDefiningOp<spirv::ConstantOp>()) {
+ valAttr = dyn_cast<TAttr>(constOp.getValue());
+ } else if (auto constCompositeReplicateOp =
+ value.template getDefiningOp<
+ spirv::EXTConstantCompositeReplicateOp>()) {
+ auto splatAttr = constCompositeReplicateOp.getValue();
+ auto denseValAttr = SplatElementsAttr::get(
+ cast<ShapedType>(constCompositeReplicateOp.getType()), splatAttr);
+ valAttr = dyn_cast<TAttr>(denseValAttr);
+ }
+
+ return valAttr ? success() : failure();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TosaArgmaxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::TosaArgMaxOp::verify() {
+ auto inputTy = cast<ShapedType>(getInput().getType());
+ auto resultTy = cast<ShapedType>(getType());
+
+ if (inputTy.hasRank() && resultTy.hasRank() &&
+ resultTy.getRank() !=
+ (inputTy.getRank() > 1 ? inputTy.getRank() - 1 : 1)) {
+ return emitOpError("result rank must be max of 1 and (input rank - 1)");
+ }
+
+ auto resultETy = resultTy.getElementType();
+ if (!resultETy.isIntOrIndex()) {
+ return emitOpError("result is not of integer type");
+ }
+
+ IntegerAttr axisAttr;
+ if (getConstAttr(getAxis(), axisAttr).failed()) {
+ return emitOpError("axis type must be a constant integer");
+ }
+
+ const int axis = axisAttr.getInt();
+ if (inputTy.hasRank() && ((axis < 0) || axis >= inputTy.getRank())) {
+ return emitOpError("specified axis is outside the rank of input");
+ }
+
+ return success();
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
new file mode 100644
index 0000000000000..48a3735c2c596
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @argmax_int(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi32>) {
+ %0 = spirv.Constant 3 : i32
+ // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<3x28x17x17xi8>) -> !spirv.arm.tensor<3x28x17xi32>
+ %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<3x28x17x17xi8>) -> !spirv.arm.tensor<3x28x17xi32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<3x28x17xi32>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @argmax_fp(%arg0: !spirv.arm.tensor<2x2x7x14xf32>) -> (!spirv.arm.tensor<2x2x14xi32>) {
+ %0 = spirv.Constant 2 : i32
+ // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<2x2x7x14xf32>) -> !spirv.arm.tensor<2x2x14xi32>
+ %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<2x2x7x14xf32>) -> !spirv.arm.tensor<2x2x14xi32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x2x14xi32>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x2x14xi32>
+}
diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir
new file mode 100644
index 0000000000000..53e92050f5d32
--- /dev/null
+++ b/mlir/test/Target/SPIRV/tosa-ops.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-translate --no-implicit-module --split-input-file --verify-diagnostics --test-spirv-roundtrip %s | FileCheck %s
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @argmax_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<3x28x17x17xi8>, UniformConstant>
+ spirv.GlobalVariable @argmax_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<3x28x17xi32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @argmax_int, @argmax_int_arg_0, @argmax_int_res_0
+ spirv.ARM.Graph @argmax_int(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi32>) {
+ %0 = spirv.Constant 3 : i32
+ // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<3x28x17x17xi8>) -> !spirv.arm.tensor<3x28x17xi32>
+ %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<3x28x17x17xi8>) -> !spirv.arm.tensor<3x28x17xi32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<3x28x17xi32>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @argmax_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<2x2x7x14xf32>, UniformConstant>
+ spirv.GlobalVariable @argmax_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<2x2x14xi32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @argmax_fp, @argmax_fp_arg_0, @argmax_fp_res_0
+ spirv.ARM.Graph @argmax_fp(%arg0: !spirv.arm.tensor<2x2x7x14xf32>) -> (!spirv.arm.tensor<2x2x14xi32>) {
+ %0 = spirv.Constant 2 : i32
+ // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<2x2x7x14xf32>) -> !spirv.arm.tensor<2x2x14xi32>
+ %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<2x2x7x14xf32>) -> !spirv.arm.tensor<2x2x14xi32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x2x14xi32>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x2x14xi32>
+ }
+}
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index ca291b57f4344..83d3322ebfe13 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -496,9 +496,13 @@ static mlir::GenRegistration
// directly use the constant value as attribute in SPIR-V dialect. So need
// to handle them separately from normal enum attributes.
constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
- "SPIRV_ScopeAttr", "SPIRV_KHR_CooperativeMatrixUseAttr",
- "SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr",
- "SPIRV_MatrixLayoutAttr"};
+ "SPIRV_ScopeAttr",
+ "SPIRV_KHR_CooperativeMatrixUseAttr",
+ "SPIRV_KHR_CooperativeMatrixLayoutAttr",
+ "SPIRV_MemorySemanticsAttr",
+ "SPIRV_MatrixLayoutAttr",
+ "SPIRV_TosaExtNanPropagationModeAttr",
+};
/// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
/// generates code extracts the attribute with name `attrName` from
|
| let summary = "ArgMax - TOSA extended instruction set 001000.1"; | ||
|
|
||
| let description = [{ | ||
| https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_argmax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should copy a description from the spec here as we do for all other ops and extended instructions. Or at least have a description saying what the operation does, what are arguments and results, etc.
Also maybe let's have an example.
| let hasVerifier = 1; | ||
|
|
||
| let assemblyFormat = [{ | ||
| operands attr-dict `:` `(` type(operands) `)` `->` type(results) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we use ( ... ) for operands type in other parts of the dialect.
| return emitOpError("result rank must be max of 1 and (input rank - 1)"); | ||
| } | ||
|
|
||
| auto resultETy = resultTy.getElementType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Probably should spell out the type.
| // NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in | ||
| // SPIR-V proper. | ||
| def SPIRV_TosaExtNanPropagationModeAttr : SPIRV_I32EnumAttr< | ||
| "TosaExtNanPropagationModeType", "Tosa Ext NAN Propoagation Mode Type", "tosa_ext_nan_propagation_mode_type", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: s/NAN/NaN/ ?
|
|
||
|
|
||
| def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> { | ||
| let summary = "ArgMax - TOSA extended instruction set 001000.1"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would punt the info about the extended instruction set to the description and, instead, focus on what the op does
| ); | ||
|
|
||
|
|
||
| let results = (outs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| ); | |
| let results = (outs | |
| ); | |
| let results = (outs |
| @@ -0,0 +1,39 @@ | |||
| //===- SPIRVTosaTypes.td - Tosa Types insts spec file ----*- tablegen -*-=// | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the header is misaligned
| include "mlir/Dialect/SPIRV/IR/SPIRVBase.td" | ||
|
|
||
| def SPIRV_TosaInteger : AnyIntOfWidths<[8, 16, 32, 64]>; | ||
| def SPIRV_TosaFloat : AnyTypeOf<[SPIRV_BFloat16KHR, SPIRV_Float16or32]>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we flatten this by specifying f16 and f32 separately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no SPIRV_Float16 definition in SPIRVBase.td. I can add it if spelling out is preferred.
| //===----------------------------------------------------------------------===// | ||
|
|
||
| LogicalResult spirv::TosaArgMaxOp::verify() { | ||
| auto inputTy = cast<ShapedType>(getInput().getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't ODS generate getInputType()? If not, can we add it as a helper method?
| } | ||
|
|
||
| IntegerAttr axisAttr; | ||
| if (getConstAttr(getAxis(), axisAttr).failed()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
failed(...)
| @@ -0,0 +1,27 @@ | |||
| // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s | |
| // RUN: mlir-opt --split-input-file %s | FileCheck %s |
I don't see any diagnostics used by this file
| @@ -0,0 +1,43 @@ | |||
| // RUN: mlir-translate --no-implicit-module --split-input-file --verify-diagnostics --test-spirv-roundtrip %s | FileCheck %s | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| // RUN: mlir-translate --no-implicit-module --split-input-file --verify-diagnostics --test-spirv-roundtrip %s | FileCheck %s | |
| // RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s |
| spirv.ARM.Graph @argmax_int(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi32>) { | ||
| %0 = spirv.Constant 3 : i32 | ||
| // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<3x28x17x17xi8>) -> !spirv.arm.tensor<3x28x17xi32> | ||
| %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<3x28x17x17xi8>) -> !spirv.arm.tensor<3x28x17xi32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think other enum attributes within the spirv dialect are snake_case or CamelCase, no?
…00.1)
This patch adds initial support for the TOSA Extended Instruction Set (001000.1) to the SPIR-V dialect in MLIR. The TOSA extended instruction set provides a standardized set of machine learning operations designed to be used within
spirv.ARM.Graphoperations (corresponding to OpGraphARM in SPV_ARM_graph) and typed with!spirv.arm.tensor<...>(corresponding to OpTypeTensorARM in SPV_ARM_tensor).The change introduces:
spirv.Tosa.ArgMaxoperation from TOSA extended instruction, each lowering to the correspondingOpExtInst.spirv.Tosa.ArgMaxappears only withinspirv.ARM.Graphregions, operates on!spirv.arm.tensor<...>types, and is well-formed according to the TOSA 001000.1 specification.Only the ArgMax operation from TOSA 001000.1 extended instructions is introduced in order to show case the work needed: [arser, printer, verifier, and round-trip tests using MLIR’s SPIR-V serialization/deserialization infrastructure are included.
This work aligns with Khronos SPIR-V TOSA specifications.
Specification:
https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html