Skip to content

Conversation

@davidegrohmann
Copy link
Contributor

@davidegrohmann davidegrohmann commented Jan 5, 2026

…00.1)

This patch adds initial support for the TOSA Extended Instruction Set (001000.1) to the SPIR-V dialect in MLIR. The TOSA extended instruction set provides a standardized set of machine learning operations designed to be used within spirv.ARM.Graph operations (corresponding to OpGraphARM in SPV_ARM_graph) and typed with !spirv.arm.tensor<...> (corresponding to OpTypeTensorARM in SPV_ARM_tensor).

The change introduces:

  • Dialect plumbing for import, serialization, and deserialization of the TOSA extended instruction set.
  • The spirv.Tosa.ArgMax operation from TOSA extended instruction, each lowering to the corresponding OpExtInst.
  • Verification enforcing that spirv.Tosa.ArgMax appears only within spirv.ARM.Graph regions, operates on !spirv.arm.tensor<...> types, and is well-formed according to the TOSA 001000.1 specification.

Only the ArgMax operation from TOSA 001000.1 extended instructions is introduced in order to show case the work needed: [arser, printer, verifier, and round-trip tests using MLIR’s SPIR-V serialization/deserialization infrastructure are included.

This work aligns with Khronos SPIR-V TOSA specifications.

Specification:
https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html

…00.1)

This patch adds initial support for the TOSA Extended Instruction Set
(001000.1) to the SPIR-V dialect in MLIR. The TOSA extended
instruction set provides a standardized set of machine learning
operations designed to be used within `spirv.ARM.Graph` operations
(corresponding to OpGraphARM in SPV_ARM_graph) and typed with
`!spirv.arm.tensor<...>` (corresponding to OpTypeTensorARM in
SPV_ARM_tensor).

The change introduces:
* Dialect plumbing for import, serialization, and deserialization of
  the TOSA extended instruction set.
* The `spirv.Tosa.ArgMax` operation from TOSA extended instruction, each
  lowering to the corresponding `OpExtInst`.
* Verification enforcing that `spirv.Tosa.ArgMax` appears only within
  `spirv.ARM.Graph` regions, operates on `!spirv.arm.tensor<...>`
  types, and is well-formed according to the TOSA 001000.1
  specification.

Only the ArgMax operation from TOSA 001000.1 extended instructions is
introduced in order to show case the work needed: [arser, printer,
verifier, and round-trip tests using MLIR’s SPIR-V
serialization/deserialization infrastructure are included.

This work aligns with Khronos SPIR-V TOSA specifications.

Specification:
https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html

Signed-off-by: Davide Grohmann <[email protected]>
Change-Id: Ibf2aad7c9e86a9dc28c6133f5d6cb0cd67163ddf
@llvmbot
Copy link
Member

llvmbot commented Jan 5, 2026

@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Davide Grohmann (davidegrohmann)

Changes

…00.1)

This patch adds initial support for the TOSA Extended Instruction Set (001000.1) to the SPIR-V dialect in MLIR. The TOSA extended instruction set provides a standardized set of machine learning operations designed to be used within spirv.ARM.Graph operations (corresponding to OpGraphARM in SPV_ARM_graph) and typed with !spirv.arm.tensor&lt;...&gt; (corresponding to OpTypeTensorARM in SPV_ARM_tensor).

The change introduces:

  • Dialect plumbing for import, serialization, and deserialization of the TOSA extended instruction set.
  • The spirv.Tosa.ArgMax operation from TOSA extended instruction, each lowering to the corresponding OpExtInst.
  • Verification enforcing that spirv.Tosa.ArgMax appears only within spirv.ARM.Graph regions, operates on !spirv.arm.tensor&lt;...&gt; types, and is well-formed according to the TOSA 001000.1 specification.

Only the ArgMax operation from TOSA 001000.1 extended instructions is introduced in order to show case the work needed: [arser, printer, verifier, and round-trip tests using MLIR’s SPIR-V serialization/deserialization infrastructure are included.

This work aligns with Khronos SPIR-V TOSA specifications.

Specification:
https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html


Full diff: https://github.com/llvm/llvm-project/pull/174402.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+14)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td (+1)
  • (added) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td (+64)
  • (added) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td (+39)
  • (modified) mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp (+72)
  • (added) mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir (+27)
  • (added) mlir/test/Target/SPIRV/tosa-ops.mlir (+43)
  • (modified) mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp (+7-3)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index ecbbf39a534e1..5677f825a8aec 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4233,6 +4233,7 @@ def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_s
 def SPIRV_Void : TypeAlias<NoneType, "void">;
 def SPIRV_Bool : TypeAlias<I1, "bool">;
 def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
+def SPIRV_Int8 : TypeAlias<I8, "Int8">;
 def SPIRV_Int16 : TypeAlias<I16, "Int16">;
 def SPIRV_Int32 : TypeAlias<I32, "Int32">;
 def SPIRV_Float32 : TypeAlias<F32, "Float32">;
@@ -4908,4 +4909,17 @@ def SPIRV_FPFastMathModeAttr :
       SPIRV_FPFMM_AllowReassocINTEL
     ]>;
 
+//===----------------------------------------------------------------------===//
+// SPIR-V TOSA enum definitions.
+//===----------------------------------------------------------------------===//
+
+// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
+// SPIR-V proper.
+def SPIRV_TosaExtNanPropagationModeAttr : SPIRV_I32EnumAttr<
+  "TosaExtNanPropagationModeType", "Tosa Ext NAN Propoagation Mode Type", "tosa_ext_nan_propagation_mode_type",
+  [
+      I32EnumAttrCase<"PROPAGATE", 1>,
+      I32EnumAttrCase<"IGNORE", 2>,
+  ]>;
+
 #endif // MLIR_DIALECT_SPIRV_IR_BASE
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index 96ef035eda37a..3ef9699154cd1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -45,6 +45,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVPrimitiveOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVCLOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
 #endif // MLIR_DIALECT_SPIRV_IR_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
new file mode 100644
index 0000000000000..f545c4c2fa1ae
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -0,0 +1,64 @@
+//===- SPIRVTosaOps.td - TOSA extended insts spec file -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of TOSA extension ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_TOSA_OPS
+#define MLIR_DIALECT_SPIRV_IR_TOSA_OPS
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// SPIR-V TOSA opcode specification.
+//===----------------------------------------------------------------------===//
+
+// Base class for all TOSA ops.
+class SPIRV_TosaOp<string mnemonic, int opcode, list<Trait> traits = []> :
+  SPIRV_ExtInstOp<mnemonic, "Tosa", "TOSA.001000.1", opcode, !listconcat(traits, [InGraphScope])> {
+
+  let availability = [
+    MinVersion<SPIRV_V_1_5>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_ARM_graph, SPV_ARM_tensors]>,
+    Capability<[SPIRV_C_GraphARM]>
+  ];
+}
+
+
+def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> {
+ let summary = "ArgMax - TOSA extended instruction set 001000.1";
+
+  let description = [{
+    https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_argmax
+  }];
+
+
+  let arguments = (ins
+    SPIRV_Int32: $axis,
+    SPIRV_TosaExtNanPropagationModeAttr: $nan_mode,
+    SPIRV_TosaNumerical_TensorArm: $input
+  );
+
+
+  let results = (outs
+    SPIRV_TosaInteger_TensorArmUpTo5D: $output
+  );
+
+  let hasVerifier = 1;
+
+  let assemblyFormat = [{
+    operands attr-dict `:` `(` type(operands) `)` `->` type(results)
+  }];
+}
+
+#endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
new file mode 100644
index 0000000000000..9081c5fa6c156
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -0,0 +1,39 @@
+//===- SPIRVTosaTypes.td - Tosa Types insts spec file ----*- tablegen -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This specifies Tosa types used by the Graph Extension and Tosa Ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
+#define MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+
+def SPIRV_TosaInteger : AnyIntOfWidths<[8, 16, 32, 64]>;
+def SPIRV_TosaFloat : AnyTypeOf<[SPIRV_BFloat16KHR, SPIRV_Float16or32]>;
+def SPIRV_TosaNumerical : AnyTypeOf<[SPIRV_TosaInteger, SPIRV_TosaFloat]>;
+def SPIRV_TosaAny : AnyTypeOf<[SPIRV_TosaNumerical, SPIRV_Bool]>;
+
+// TensorARM Types
+
+class RankedTensorArmOf<list<Type> allowedTypes, list<Pred> preds = [],
+                     string summary = "ranked tensorArm">
+  : ShapedContainerType<
+      allowedTypes, And<!listconcat([SPIRV_IsTensorArmType], preds)>,
+      summary, "::mlir::spirv::TensorArmType">;
+
+class TensorArmRankOf<list<Type> allowedTypes, list<int> ranks>
+  : RankedTensorArmOf<allowedTypes,
+      [HasAnyRankOfPred<ranks>],
+      !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensorArm">;
+
+def SPIRV_TosaNumerical_TensorArm : TensorArmRankOf<[SPIRV_TosaNumerical], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_TosaInteger_TensorArmUpTo5D : TensorArmRankOf<[SPIRV_TosaInteger], [1, 2, 3, 4, 5]>;
+
+#endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index 60d705d940cfc..f05f596aa9f23 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRSPIRVDialect
   SPIRVOpDefinition.cpp
   SPIRVOps.cpp
   SPIRVParsingUtils.cpp
+  SPIRVTosaOps.cpp
   SPIRVTypes.cpp
   TargetAndABI.cpp
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
new file mode 100644
index 0000000000000..31f7734dd8495
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
@@ -0,0 +1,72 @@
+//===- SPIRVTosaOps.cpp - MLIR SPIR-V operations --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Tosa operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// TOSA Operator Verifiers.
+//===----------------------------------------------------------------------===//
+
+// Get value attr from spirv::ConstantOp or
+// spirv::EXTConstantCompositeReplicateOp
+template <typename TAttr>
+static LogicalResult getConstAttr(Value value, TAttr &valAttr) {
+  if (auto constOp = value.template getDefiningOp<spirv::ConstantOp>()) {
+    valAttr = dyn_cast<TAttr>(constOp.getValue());
+  } else if (auto constCompositeReplicateOp =
+                 value.template getDefiningOp<
+                     spirv::EXTConstantCompositeReplicateOp>()) {
+    auto splatAttr = constCompositeReplicateOp.getValue();
+    auto denseValAttr = SplatElementsAttr::get(
+        cast<ShapedType>(constCompositeReplicateOp.getType()), splatAttr);
+    valAttr = dyn_cast<TAttr>(denseValAttr);
+  }
+
+  return valAttr ? success() : failure();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TosaArgmaxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::TosaArgMaxOp::verify() {
+  auto inputTy = cast<ShapedType>(getInput().getType());
+  auto resultTy = cast<ShapedType>(getType());
+
+  if (inputTy.hasRank() && resultTy.hasRank() &&
+      resultTy.getRank() !=
+          (inputTy.getRank() > 1 ? inputTy.getRank() - 1 : 1)) {
+    return emitOpError("result rank must be max of 1 and (input rank - 1)");
+  }
+
+  auto resultETy = resultTy.getElementType();
+  if (!resultETy.isIntOrIndex()) {
+    return emitOpError("result is not of integer type");
+  }
+
+  IntegerAttr axisAttr;
+  if (getConstAttr(getAxis(), axisAttr).failed()) {
+    return emitOpError("axis type must be a constant integer");
+  }
+
+  const int axis = axisAttr.getInt();
+  if (inputTy.hasRank() && ((axis < 0) || axis >= inputTy.getRank())) {
+    return emitOpError("specified axis is outside the rank of input");
+  }
+
+  return success();
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
new file mode 100644
index 0000000000000..48a3735c2c596
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @argmax_int(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi32>) {
+  %0 = spirv.Constant 3 : i32
+  // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<3x28x17x17xi8>) -> !spirv.arm.tensor<3x28x17xi32>
+  %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<3x28x17x17xi8>) -> !spirv.arm.tensor<3x28x17xi32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<3x28x17xi32>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @argmax_fp(%arg0: !spirv.arm.tensor<2x2x7x14xf32>) -> (!spirv.arm.tensor<2x2x14xi32>) {
+  %0 = spirv.Constant 2 : i32
+  // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<2x2x7x14xf32>) -> !spirv.arm.tensor<2x2x14xi32>
+  %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<2x2x7x14xf32>) -> !spirv.arm.tensor<2x2x14xi32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x2x14xi32>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x2x14xi32>
+}
diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir
new file mode 100644
index 0000000000000..53e92050f5d32
--- /dev/null
+++ b/mlir/test/Target/SPIRV/tosa-ops.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-translate --no-implicit-module --split-input-file --verify-diagnostics --test-spirv-roundtrip %s | FileCheck %s
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module  --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @argmax_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<3x28x17x17xi8>, UniformConstant>
+  spirv.GlobalVariable @argmax_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<3x28x17xi32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @argmax_int, @argmax_int_arg_0, @argmax_int_res_0
+  spirv.ARM.Graph @argmax_int(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi32>) {
+    %0 = spirv.Constant 3 : i32
+    // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<3x28x17x17xi8>) -> !spirv.arm.tensor<3x28x17xi32>
+    %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<3x28x17x17xi8>) -> !spirv.arm.tensor<3x28x17xi32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<3x28x17xi32>
+    spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @argmax_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<2x2x7x14xf32>, UniformConstant>
+  spirv.GlobalVariable @argmax_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<2x2x14xi32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @argmax_fp, @argmax_fp_arg_0, @argmax_fp_res_0
+  spirv.ARM.Graph @argmax_fp(%arg0: !spirv.arm.tensor<2x2x7x14xf32>) -> (!spirv.arm.tensor<2x2x14xi32>) {
+    %0 = spirv.Constant 2 : i32
+    // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<2x2x7x14xf32>) -> !spirv.arm.tensor<2x2x14xi32>
+    %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<2x2x7x14xf32>) -> !spirv.arm.tensor<2x2x14xi32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x2x14xi32>
+    spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x2x14xi32>
+  }
+}
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index ca291b57f4344..83d3322ebfe13 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -496,9 +496,13 @@ static mlir::GenRegistration
 // directly use the constant value as attribute in SPIR-V dialect. So need
 // to handle them separately from normal enum attributes.
 constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
-    "SPIRV_ScopeAttr", "SPIRV_KHR_CooperativeMatrixUseAttr",
-    "SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr",
-    "SPIRV_MatrixLayoutAttr"};
+    "SPIRV_ScopeAttr",
+    "SPIRV_KHR_CooperativeMatrixUseAttr",
+    "SPIRV_KHR_CooperativeMatrixLayoutAttr",
+    "SPIRV_MemorySemanticsAttr",
+    "SPIRV_MatrixLayoutAttr",
+    "SPIRV_TosaExtNanPropagationModeAttr",
+};
 
 /// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
 /// generates code extracts the attribute with name `attrName` from

let summary = "ArgMax - TOSA extended instruction set 001000.1";

let description = [{
https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_argmax
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should copy a description from the spec here as we do for all other ops and extended instructions. Or at least have a description saying what the operation does, what are arguments and results, etc.

Also maybe let's have an example.

let hasVerifier = 1;

let assemblyFormat = [{
operands attr-dict `:` `(` type(operands) `)` `->` type(results)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we use ( ... ) for operands type in other parts of the dialect.

return emitOpError("result rank must be max of 1 and (input rank - 1)");
}

auto resultETy = resultTy.getElementType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Probably should spell out the type.

// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
// SPIR-V proper.
def SPIRV_TosaExtNanPropagationModeAttr : SPIRV_I32EnumAttr<
"TosaExtNanPropagationModeType", "Tosa Ext NAN Propoagation Mode Type", "tosa_ext_nan_propagation_mode_type",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: s/NAN/NaN/ ?



def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> {
let summary = "ArgMax - TOSA extended instruction set 001000.1";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would punt the info about the extended instruction set to the description and, instead, focus on what the op does

Comment on lines +50 to +53
);


let results = (outs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
);
let results = (outs
);
let results = (outs

@@ -0,0 +1,39 @@
//===- SPIRVTosaTypes.td - Tosa Types insts spec file ----*- tablegen -*-=//
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the header is misaligned

include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"

def SPIRV_TosaInteger : AnyIntOfWidths<[8, 16, 32, 64]>;
def SPIRV_TosaFloat : AnyTypeOf<[SPIRV_BFloat16KHR, SPIRV_Float16or32]>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we flatten this by specifying f16 and f32 separately?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no SPIRV_Float16 definition in SPIRVBase.td. I can add it if spelling out is preferred.

//===----------------------------------------------------------------------===//

LogicalResult spirv::TosaArgMaxOp::verify() {
auto inputTy = cast<ShapedType>(getInput().getType());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't ODS generate getInputType()? If not, can we add it as a helper method?

}

IntegerAttr axisAttr;
if (getConstAttr(getAxis(), axisAttr).failed()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

failed(...)

@@ -0,0 +1,27 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
// RUN: mlir-opt --split-input-file %s | FileCheck %s

I don't see any diagnostics used by this file

@@ -0,0 +1,43 @@
// RUN: mlir-translate --no-implicit-module --split-input-file --verify-diagnostics --test-spirv-roundtrip %s | FileCheck %s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// RUN: mlir-translate --no-implicit-module --split-input-file --verify-diagnostics --test-spirv-roundtrip %s | FileCheck %s
// RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s

spirv.ARM.Graph @argmax_int(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi32>) {
%0 = spirv.Constant 3 : i32
// CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<3x28x17x17xi8>) -> !spirv.arm.tensor<3x28x17xi32>
%2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<3x28x17x17xi8>) -> !spirv.arm.tensor<3x28x17xi32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think other enum attributes within the spirv dialect are snake_case or CamelCase, no?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir:spirv mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants