33// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44// See https://llvm.org/LICENSE.txt for license information.
55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+ // Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its
7+ // affiliates
68//
79// ===----------------------------------------------------------------------===//
810
@@ -37,30 +39,62 @@ namespace {
3739LogicalResult verifyPerAxisQuantization (Operation *op,
3840 QuantizedType quantizedType,
3941 Type containerType) {
40- auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
42+ auto quantizedPerAxisType =
43+ dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
4144 if (!quantizedPerAxisType)
4245 return success ();
4346
44- auto tensorType = dyn_cast<TensorType >(containerType);
45- if (!tensorType )
47+ auto shapedType = dyn_cast<ShapedType >(containerType);
48+ if (!shapedType )
4649 return op->emitError (" scalar types may not use per-axis quantization" );
4750
48- if (!tensorType .hasRank ())
51+ if (!shapedType .hasRank ())
4952 return success ();
5053
5154 int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension ();
52- if (quantizedDimension >= tensorType .getRank ())
55+ if (quantizedDimension >= shapedType .getRank ())
5356 return op->emitError (" quantized dimension must be less than tensor rank" );
5457
55- int64_t quantizedDimensionSize = tensorType .getDimSize (quantizedDimension);
58+ int64_t quantizedDimensionSize = shapedType .getDimSize (quantizedDimension);
5659 if (quantizedDimensionSize != ShapedType::kDynamic &&
57- quantizedDimensionSize != (int64_t )quantizedPerAxisType.getScales ().size ())
60+ quantizedDimensionSize !=
61+ (int64_t )quantizedPerAxisType.getScales ().size ())
5862 return op->emitError (
5963 " quantized dimension size does not match number of scales" );
6064
6165 return success ();
6266}
6367
68+ // Verify the integrity of block float quantization information, if present.
69+ //
70+ // - quantizedType
71+ // Any quantized type. Any quantized type with no block float quantization is
72+ // ignored.
73+ //
74+ // - containerType
75+ // Original input or result type of the operation using the provided quantized
76+ // type. Used to ensure that the quantized type appears within a tensor and
77+ // that the tensor is compatible with block float quantization information.
78+ //
79+ LogicalResult verifyBlockFloatQuantization (Operation *op,
80+ QuantizedType quantizedType,
81+ Type containerType) {
82+ auto blockModeType = dyn_cast<BlockFloatQuantizedType>(quantizedType);
83+ if (!blockModeType)
84+ return success ();
85+
86+ auto shapedType = dyn_cast<ShapedType>(containerType);
87+ if (!shapedType)
88+ return op->emitError (" scalar types may not use block float quantization" );
89+ if (!shapedType.hasRank ())
90+ return success ();
91+ // We could also check that the tensor is a multiple of the block size, but
92+ // that requires that all padding is visible in MLIR
93+ if (blockModeType.getAxis () >= shapedType.getRank ())
94+ return op->emitError (" block axis must be less than tensor rank" );
95+ return success ();
96+ }
97+
6498// Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
6599//
66100// - quantizedType
@@ -76,12 +110,18 @@ LogicalResult verifyPerAxisQuantization(Operation *op,
76110//
77111LogicalResult verifyQuantizationOp (Operation *op, QuantizedType quantizedType,
78112 FloatType floatType, Type containerType) {
79- if (quantizedType.getExpressedType () != floatType)
113+ if (!isa<BlockFloatQuantizedType>(quantizedType) &&
114+ quantizedType.getExpressedType () != floatType)
80115 return op->emitError (
81116 " expressed type in quantized type expected to match float type" );
82117
83- // Veriy integrity of per-axis quantization information, if present.
84- return verifyPerAxisQuantization (op, quantizedType, containerType);
118+ if (failed (verifyPerAxisQuantization (op, quantizedType, containerType)))
119+ return failure ();
120+
121+ if (failed (verifyBlockFloatQuantization (op, quantizedType, containerType)))
122+ return failure ();
123+
124+ return success ();
85125}
86126
87127} // namespace
@@ -92,8 +132,8 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
92132// ===----------------------------------------------------------------------===//
93133
94134void QuantDialect::initialize () {
95- addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType ,
96- UniformQuantizedPerAxisType>();
135+ addTypes<AnyQuantizedType, BlockFloatQuantizedType, CalibratedQuantizedType ,
136+ UniformQuantizedType, UniformQuantizedPerAxisType>();
97137 addOperations<
98138#define GET_OP_LIST
99139#include " mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
@@ -167,6 +207,9 @@ QuantizedType QuantizeCastOp::getQuantizedType() {
167207
168208LogicalResult StorageCastOp::verify () {
169209 auto quantizedType = getQuantizedType ();
210+ if (isa<BlockFloatQuantizedType>(quantizedType))
211+ return getOperation ()->emitError (
212+ " storage cast not supported for block float quantized types" );
170213 auto integerType = getIntegerType ();
171214 if (quantizedType.getStorageType () != integerType)
172215 return emitError (
0 commit comments