1212
1313#include " mlir/IR/BuiltinTypes.h"
1414#include " mlir/IR/MLIRContext.h"
15+ #include " mlir/IR/QuantizationInterface.h"
1516#include " llvm/ADT/StringRef.h"
1617#include " llvm/ADT/Twine.h"
1718#include " llvm/Support/MathExtras.h"
19+ #include < iostream>
1820
1921using namespace mlir ;
2022using namespace mlir ::quant;
@@ -32,6 +34,7 @@ LogicalResult
3234QuantizedType::verify (function_ref<InFlightDiagnostic()> emitError,
3335 unsigned flags, Type storageType, Type expressedType,
3436 int64_t storageTypeMin, int64_t storageTypeMax) {
37+ std::cout << " verify QuantizedType" << std::endl;
3538
3639 bool isSigned =
3740 (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
@@ -45,17 +48,16 @@ QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
4548 return emitError () << " illegal storage type size: " << integralWidth;
4649 }
4750
51+ std::cout << " Before quantile cast" << std::endl;
4852 int64_t defaultMin, defaultMax;
49- if (storageType.isa <IntegerType>()) {
50- const auto width = llvm::dyn_cast<IntegerType>(storageType).getWidth ();
51- defaultMin = QuantizedType::getDefaultMinimumForInteger (isSigned, width);
52- defaultMax = QuantizedType::getDefaultMaximumForInteger (isSigned, width);
53- } else if (storageType.isa <Float8E5M2Type>()) {
54- defaultMin = QuantizedType::getDefaultMinimumForF8E5M2 ();
55- defaultMax = QuantizedType::getDefaultMaximumForF8E5M2 ();
56- } else if (storageType.isa <Float8E4M3FNType>()) {
57- defaultMin = QuantizedType::getDefaultMinimumForF8E4M3FN ();
58- defaultMax = QuantizedType::getDefaultMaximumForF8E4M3FN ();
53+ if (auto quantizationInterface =
54+ llvm::dyn_cast<QuantizationInterface>(storageType)) {
55+ // const auto width = llvm::dyn_cast<IntegerType>(storageType).getWidth();
56+ const auto width = quantizationInterface.getStorageWidth ();
57+ defaultMin = quantizationInterface.getDefaultMinimum (isSigned, width);
58+ defaultMax = quantizationInterface.getDefaultMaximum (isSigned, width);
59+ std::cout << " defaultMin: " << defaultMin << " , defaultMax: " << defaultMax
60+ << std::endl;
5961 } else {
6062 return emitError () << " illegal storage type, supported types are: integral "
6163 " types, Float8E4M3FNType and Float8E5M2Type " ;
@@ -67,6 +69,7 @@ QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
6769 return emitError () << " illegal storage min and storage max: ("
6870 << storageTypeMin << " :" << storageTypeMax << " )" ;
6971 }
72+ std::cout << " verify QuantizedType END" << std::endl;
7073 return success ();
7174}
7275
@@ -75,17 +78,42 @@ Type QuantizedType::getStorageType() const {
7578}
7679
7780int64_t QuantizedType::getStorageTypeMin () const {
81+ Type storageType = static_cast <ImplType *>(impl)->storageType ;
82+
83+ if (auto quantizationInterface =
84+ llvm::dyn_cast<QuantizationInterface>(storageType)) {
85+ unsigned storageWidth = quantizationInterface.getStorageWidth ();
86+ bool isSigned = quantizationInterface.isStorageSigned ();
87+ return quantizationInterface.getDefaultMinimum (isSigned, storageWidth);
88+ }
89+
7890 return static_cast <ImplType *>(impl)->storageTypeMin ;
7991}
8092
8193int64_t QuantizedType::getStorageTypeMax () const {
94+ Type storageType = static_cast <ImplType *>(impl)->storageType ;
95+
96+ if (auto quantizationInterface =
97+ llvm::dyn_cast<QuantizationInterface>(storageType)) {
98+ unsigned storageWidth = quantizationInterface.getStorageWidth ();
99+ bool isSigned = quantizationInterface.isStorageSigned ();
100+ return quantizationInterface.getDefaultMaximum (isSigned, storageWidth);
101+ }
102+
82103 return static_cast <ImplType *>(impl)->storageTypeMax ;
83104}
84105
85106unsigned QuantizedType::getStorageTypeIntegralWidth () const {
86107 // NOTE: If ever supporting non-integral storage types, some other scheme
87108 // for determining the width will be needed.
88- return static_cast <ImplType *>(impl)->storageType .getIntOrFloatBitWidth ();
109+ Type storageType = static_cast <ImplType *>(impl)->storageType ;
110+
111+ if (auto quantizationInterface =
112+ llvm::dyn_cast<QuantizationInterface>(storageType)) {
113+ return quantizationInterface.getStorageWidth ();
114+ }
115+
116+ return storageType.getIntOrFloatBitWidth ();
89117}
90118
91119Type QuantizedType::getExpressedType () const {
@@ -265,6 +293,7 @@ UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
265293 int64_t zeroPoint,
266294 int64_t storageTypeMin,
267295 int64_t storageTypeMax) {
296+ std::cout << " Creating UniformQuantizedType" << std::endl;
268297 return Base::get (storageType.getContext (), flags, storageType, expressedType,
269298 scale, zeroPoint, storageTypeMin, storageTypeMax);
270299}
@@ -273,6 +302,7 @@ UniformQuantizedType UniformQuantizedType::getChecked(
273302 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
274303 Type storageType, Type expressedType, double scale, int64_t zeroPoint,
275304 int64_t storageTypeMin, int64_t storageTypeMax) {
305+ std::cout << " getChecked UniformQuantizedType" << std::endl;
276306 return Base::getChecked (emitError, storageType.getContext (), flags,
277307 storageType, expressedType, scale, zeroPoint,
278308 storageTypeMin, storageTypeMax);
@@ -282,6 +312,8 @@ LogicalResult UniformQuantizedType::verify(
282312 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
283313 Type storageType, Type expressedType, double scale, int64_t zeroPoint,
284314 int64_t storageTypeMin, int64_t storageTypeMax) {
315+ std::cout << " verifying UniformQuantizedType" << std::endl;
316+
285317 if (failed (QuantizedType::verify (emitError, flags, storageType, expressedType,
286318 storageTypeMin, storageTypeMax))) {
287319 return failure ();
@@ -301,6 +333,7 @@ LogicalResult UniformQuantizedType::verify(
301333 // Verify scale.
302334 if (std::isinf (scale) || std::isnan (scale))
303335 return emitError () << " illegal scale: " << scale;
336+ std::cout << " verifying UniformQuantizedType END" << std::endl;
304337
305338 return success ();
306339}
0 commit comments