Skip to content
This repository was archived by the owner on Nov 27, 2025. It is now read-only.

Commit d6153a3

Browse files
committed
Added new type interafce to let UniformQuantizeType accept other than built in types. Updated parser and printer in Quant dialect
1 parent 544ee1a commit d6153a3

File tree

10 files changed

+248
-61
lines changed

10 files changed

+248
-61
lines changed

mlir/cmake/modules/AddMLIR.cmake

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,14 @@ function(add_mlir_interface interface)
196196
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
197197
endfunction()
198198

199+
# Declare a dialect in the include directory
200+
function(add_mlir_type_interface interface)
201+
set(LLVM_TARGET_DEFINITIONS ${interface}.td)
202+
mlir_tablegen(${interface}.h.inc -gen-type-interface-decls)
203+
mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs)
204+
add_public_tablegen_target(MLIR${interface}IncGen)
205+
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
206+
endfunction()
199207

200208
# Generate Documentation
201209
function(add_mlir_doc doc_filename output_file output_directory command)

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ class FloatType : public Type {
9191
// Tablegen Type Declarations
9292
//===----------------------------------------------------------------------===//
9393

94+
// Include QuantizationInterface before BuiltinTypes to resolve dependencies
95+
#include "mlir/IR/QuantizationInterface.h"
96+
9497
#define GET_TYPEDEF_CLASSES
9598
#include "mlir/IR/BuiltinTypes.h.inc"
9699

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
include "mlir/IR/AttrTypeBase.td"
1818
include "mlir/IR/BuiltinDialect.td"
1919
include "mlir/IR/BuiltinTypeInterfaces.td"
20+
include "mlir/IR/QuantizationInterface.td"
2021

2122
// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
2223
// This is to differentiate the types here with the ones in OpBase.td. We should
@@ -78,8 +79,8 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
7879
//===----------------------------------------------------------------------===//
7980

8081
// Base class for Builtin dialect float types.
81-
class Builtin_FloatType<string name, string mnemonic>
82-
: Builtin_Type<name, mnemonic, /*traits=*/[], "::mlir::FloatType"> {
82+
class Builtin_FloatType<string name, string mnemonic, list<Trait> traits = []>
83+
: Builtin_Type<name, mnemonic, traits, "::mlir::FloatType"> {
8384
let extraClassDeclaration = [{
8485
static }] # name # [{Type get(MLIRContext *context);
8586
}];
@@ -88,7 +89,8 @@ class Builtin_FloatType<string name, string mnemonic>
8889
//===----------------------------------------------------------------------===//
8990
// Float8E5M2Type
9091

91-
def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> {
92+
def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
93+
[QuantizationInterface]> {
9294
let summary = "8-bit floating point with 2 bit mantissa";
9395
let description = [{
9496
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
@@ -104,6 +106,23 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> {
104106

105107
Described in: https://arxiv.org/abs/2209.05433
106108
}];
109+
110+
let extraClassDeclaration = [{
111+
static Float8E5M2Type get(MLIRContext *context);
112+
113+
/// QuantizationInterface method implementations
114+
bool isStorageSigned() const { return true; }
115+
unsigned getStorageWidth() const { return 8; }
116+
int64_t getDefaultMaximum([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned integralWidth) const {
117+
return 448;
118+
}
119+
int64_t getDefaultMinimum(bool isSigned, unsigned integralWidth) const {
120+
return -getDefaultMaximum(isSigned, integralWidth);
121+
}
122+
std::string printStorageType([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned storageWidth) const {
123+
return "f8E5M2";
124+
}
125+
}];
107126
}
108127

109128
//===----------------------------------------------------------------------===//
@@ -128,7 +147,8 @@ def Builtin_Float8E4M3 : Builtin_FloatType<"Float8E4M3", "f8E4M3"> {
128147
//===----------------------------------------------------------------------===//
129148
// Float8E4M3FNType
130149

131-
def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> {
150+
def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
151+
[QuantizationInterface]> {
132152
let summary = "8-bit floating point with 3 bit mantissa";
133153
let description = [{
134154
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
@@ -145,6 +165,23 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> {
145165

146166
Described in: https://arxiv.org/abs/2209.05433
147167
}];
168+
169+
let extraClassDeclaration = [{
170+
static Float8E4M3FNType get(MLIRContext *context);
171+
172+
/// QuantizationInterface method implementations
173+
bool isStorageSigned() const { return true; }
174+
unsigned getStorageWidth() const { return 8; }
175+
int64_t getDefaultMaximum([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned integralWidth) const {
176+
return 57344;
177+
}
178+
int64_t getDefaultMinimum(bool isSigned, unsigned integralWidth) const{
179+
return -getDefaultMaximum(isSigned, integralWidth);
180+
}
181+
std::string printStorageType([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned storageWidth) const {
182+
return "f8E4M3FN";
183+
}
184+
}];
148185
}
149186

150187
//===----------------------------------------------------------------------===//
@@ -358,7 +395,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
358395
// IntegerType
359396
//===----------------------------------------------------------------------===//
360397

361-
def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
398+
def Builtin_Integer : Builtin_Type<"Integer", "integer",
399+
[QuantizationInterface]> {
362400
let summary = "Integer type with arbitrary precision up to a fixed limit";
363401
let description = [{
364402
Syntax:
@@ -415,6 +453,25 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
415453
/// Integer representation maximal bitwidth.
416454
/// Note: This is aligned with the maximum width of llvm::IntegerType.
417455
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
456+
457+
/// QuantizationInterface method implementations
458+
bool isStorageSigned() const { return !isUnsigned(); }
459+
unsigned getStorageWidth() const { return getWidth(); }
460+
int64_t getDefaultMinimum(bool isSigned, unsigned integralWidth) const {
461+
if (isSigned) {
462+
return llvm::minIntN(integralWidth);
463+
}
464+
return 0;
465+
}
466+
int64_t getDefaultMaximum(bool isSigned, unsigned integralWidth) const {
467+
if (isSigned) {
468+
return llvm::maxIntN(integralWidth);
469+
}
470+
return llvm::maxUIntN(integralWidth);
471+
}
472+
std::string printStorageType(bool isSigned, unsigned storageWidth) const {
473+
return (isSigned ? "i" : "u") + std::to_string(storageWidth);
474+
}
418475
}];
419476
}
420477

mlir/include/mlir/IR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ add_mlir_interface(OpAsmInterface)
22
add_mlir_interface(SymbolInterfaces)
33
add_mlir_interface(RegionKindInterface)
44

5+
add_mlir_type_interface(QuantizationInterface)
6+
57
set(LLVM_TARGET_DEFINITIONS BuiltinAttributes.td)
68
mlir_tablegen(BuiltinAttributes.h.inc -gen-attrdef-decls)
79
mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===- QuantizationInterface.h - Quantile Float Interfaces --------*- C++
2+
//-*-===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef MLIR_IR_QuantizationInterface_H
11+
#define MLIR_IR_QuantizationInterface_H
12+
13+
#include "mlir/IR/Types.h"
14+
15+
// Forward declarations for the types we need in the implementation
16+
namespace mlir {
17+
class IntegerType;
18+
class FloatType;
19+
} // namespace mlir
20+
21+
#include "mlir/IR/QuantizationInterface.h.inc"
22+
23+
#endif // MLIR_IR_QuantizationInterface_H
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#ifndef MLIR_IR_QUANTIZATIONINTERFACE
2+
#define MLIR_IR_QUANTIZATIONINTERFACE
3+
4+
include "mlir/IR/OpBase.td"
5+
6+
def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
7+
let description = [{
8+
Interface for types that can be used as quantile storage types.
9+
This interface provides methods to determine storage characteristics
10+
like width and signedness for quantization purposes.
11+
}];
12+
let cppNamespace = "::mlir";
13+
14+
let methods = [
15+
InterfaceMethod<[{
16+
Get the storage type width in bits.
17+
Returns the number of bits used to store values of this type.
18+
}],
19+
"unsigned", "getStorageWidth", (ins)>,
20+
21+
InterfaceMethod<[{
22+
Check if the storage type is signed.
23+
Returns true if the type represents signed values, false for unsigned.
24+
}],
25+
"bool", "isStorageSigned", (ins)>,
26+
27+
InterfaceMethod<[{
28+
Get the default minimum value for the storage type.
29+
}],
30+
"int64_t", "getDefaultMinimum", (ins "bool":$isSigned, "unsigned":$integralWidth)>,
31+
32+
InterfaceMethod<[{
33+
Get the default maximum value for the storage type.
34+
}],
35+
"int64_t", "getDefaultMaximum", (ins "bool":$isSigned, "unsigned":$integralWidth)>,
36+
37+
InterfaceMethod<[{
38+
Get the name of the storage type.
39+
}],
40+
"std::string", "printStorageType", (ins "bool":$isSigned, "unsigned":$storageWidth)>
41+
];
42+
43+
}
44+
45+
#endif // MLIR_IR_QUANTIZATIONINTERFACE

mlir/lib/Dialect/Quant/IR/QuantTypes.cpp

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212

1313
#include "mlir/IR/BuiltinTypes.h"
1414
#include "mlir/IR/MLIRContext.h"
15+
#include "mlir/IR/QuantizationInterface.h"
1516
#include "llvm/ADT/StringRef.h"
1617
#include "llvm/ADT/Twine.h"
1718
#include "llvm/Support/MathExtras.h"
19+
#include <iostream>
1820

1921
using namespace mlir;
2022
using namespace mlir::quant;
@@ -32,6 +34,7 @@ LogicalResult
3234
QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
3335
unsigned flags, Type storageType, Type expressedType,
3436
int64_t storageTypeMin, int64_t storageTypeMax) {
37+
std::cout << "verify QuantizedType" << std::endl;
3538

3639
bool isSigned =
3740
(flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
@@ -45,17 +48,16 @@ QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
4548
return emitError() << "illegal storage type size: " << integralWidth;
4649
}
4750

51+
std::cout << "Before quantile cast" << std::endl;
4852
int64_t defaultMin, defaultMax;
49-
if (storageType.isa<IntegerType>()) {
50-
const auto width = llvm::dyn_cast<IntegerType>(storageType).getWidth();
51-
defaultMin = QuantizedType::getDefaultMinimumForInteger(isSigned, width);
52-
defaultMax = QuantizedType::getDefaultMaximumForInteger(isSigned, width);
53-
} else if (storageType.isa<Float8E5M2Type>()) {
54-
defaultMin = QuantizedType::getDefaultMinimumForF8E5M2();
55-
defaultMax = QuantizedType::getDefaultMaximumForF8E5M2();
56-
} else if (storageType.isa<Float8E4M3FNType>()) {
57-
defaultMin = QuantizedType::getDefaultMinimumForF8E4M3FN();
58-
defaultMax = QuantizedType::getDefaultMaximumForF8E4M3FN();
53+
if (auto quantizationInterface =
54+
llvm::dyn_cast<QuantizationInterface>(storageType)) {
55+
// const auto width = llvm::dyn_cast<IntegerType>(storageType).getWidth();
56+
const auto width = quantizationInterface.getStorageWidth();
57+
defaultMin = quantizationInterface.getDefaultMinimum(isSigned, width);
58+
defaultMax = quantizationInterface.getDefaultMaximum(isSigned, width);
59+
std::cout << "defaultMin: " << defaultMin << ", defaultMax: " << defaultMax
60+
<< std::endl;
5961
} else {
6062
return emitError() << "illegal storage type, supported types are: integral "
6163
"types, Float8E4M3FNType and Float8E5M2Type ";
@@ -67,6 +69,7 @@ QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
6769
return emitError() << "illegal storage min and storage max: ("
6870
<< storageTypeMin << ":" << storageTypeMax << ")";
6971
}
72+
std::cout << "verify QuantizedType END" << std::endl;
7073
return success();
7174
}
7275

@@ -75,17 +78,42 @@ Type QuantizedType::getStorageType() const {
7578
}
7679

7780
int64_t QuantizedType::getStorageTypeMin() const {
81+
Type storageType = static_cast<ImplType *>(impl)->storageType;
82+
83+
if (auto quantizationInterface =
84+
llvm::dyn_cast<QuantizationInterface>(storageType)) {
85+
unsigned storageWidth = quantizationInterface.getStorageWidth();
86+
bool isSigned = quantizationInterface.isStorageSigned();
87+
return quantizationInterface.getDefaultMinimum(isSigned, storageWidth);
88+
}
89+
7890
return static_cast<ImplType *>(impl)->storageTypeMin;
7991
}
8092

8193
int64_t QuantizedType::getStorageTypeMax() const {
94+
Type storageType = static_cast<ImplType *>(impl)->storageType;
95+
96+
if (auto quantizationInterface =
97+
llvm::dyn_cast<QuantizationInterface>(storageType)) {
98+
unsigned storageWidth = quantizationInterface.getStorageWidth();
99+
bool isSigned = quantizationInterface.isStorageSigned();
100+
return quantizationInterface.getDefaultMaximum(isSigned, storageWidth);
101+
}
102+
82103
return static_cast<ImplType *>(impl)->storageTypeMax;
83104
}
84105

85106
unsigned QuantizedType::getStorageTypeIntegralWidth() const {
86107
// NOTE: If ever supporting non-integral storage types, some other scheme
87108
// for determining the width will be needed.
88-
return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
109+
Type storageType = static_cast<ImplType *>(impl)->storageType;
110+
111+
if (auto quantizationInterface =
112+
llvm::dyn_cast<QuantizationInterface>(storageType)) {
113+
return quantizationInterface.getStorageWidth();
114+
}
115+
116+
return storageType.getIntOrFloatBitWidth();
89117
}
90118

91119
Type QuantizedType::getExpressedType() const {
@@ -265,6 +293,7 @@ UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
265293
int64_t zeroPoint,
266294
int64_t storageTypeMin,
267295
int64_t storageTypeMax) {
296+
std::cout << "Creating UniformQuantizedType" << std::endl;
268297
return Base::get(storageType.getContext(), flags, storageType, expressedType,
269298
scale, zeroPoint, storageTypeMin, storageTypeMax);
270299
}
@@ -273,6 +302,7 @@ UniformQuantizedType UniformQuantizedType::getChecked(
273302
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
274303
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
275304
int64_t storageTypeMin, int64_t storageTypeMax) {
305+
std::cout << "getChecked UniformQuantizedType" << std::endl;
276306
return Base::getChecked(emitError, storageType.getContext(), flags,
277307
storageType, expressedType, scale, zeroPoint,
278308
storageTypeMin, storageTypeMax);
@@ -282,6 +312,8 @@ LogicalResult UniformQuantizedType::verify(
282312
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
283313
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
284314
int64_t storageTypeMin, int64_t storageTypeMax) {
315+
std::cout << "verifying UniformQuantizedType" << std::endl;
316+
285317
if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
286318
storageTypeMin, storageTypeMax))) {
287319
return failure();
@@ -301,6 +333,7 @@ LogicalResult UniformQuantizedType::verify(
301333
// Verify scale.
302334
if (std::isinf(scale) || std::isnan(scale))
303335
return emitError() << "illegal scale: " << scale;
336+
std::cout << "verifying UniformQuantizedType END" << std::endl;
304337

305338
return success();
306339
}

0 commit comments

Comments
 (0)