Skip to content

Commit 8a4e029

Browse files
authored
[MLIR][Quant] Introduce BlockFloatQuantizedType to represent block-float quantized types like BFP16 (#730)
2 parents a874197 + 17f4aca commit 8a4e029

9 files changed

Lines changed: 447 additions & 15 deletions

File tree

mlir/include/mlir/Dialect/Quant/IR/Quant.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its
7+
// affiliates
68
//
79
//===----------------------------------------------------------------------===//
810

@@ -25,6 +27,7 @@ namespace mlir {
2527
namespace quant {
2628

2729
class QuantizedType;
30+
class BlockFloatQuantizedType;
2831
class UniformQuantizedType;
2932
class UniformQuantizedPerAxisType;
3033

mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its
7+
// affiliates
68
//
79
//===----------------------------------------------------------------------===//
810
//
@@ -81,6 +83,14 @@ def UniformQuantizedPerAxisType: DialectType<(type
8183
}];
8284
}
8385

86+
def BlockFloatQuantizedType: DialectType<(type
87+
WithGetter<"static_cast<uint32_t>($_attrType.getBlockMode())", VarInt>:$blockMode,
88+
VarInt:$axis
89+
)> {
90+
let cBuilder = "get<$_resultType>(context, "
91+
" static_cast<BlockFloatQuantizedType::BlockMode>(blockMode), axis)";
92+
}
93+
8494
/// This enum contains marker codes used to indicate which attribute is
8595
/// currently being decoded, and how it should be decoded. The order of these
8696
/// codes should generally be unchanged, as any changes will inevitably break
@@ -93,7 +103,8 @@ def QuantDialectTypes : DialectTypes<"Quant"> {
93103
AnyQuantizedTypeWithExpressedType,
94104
CalibratedQuantizedType,
95105
UniformQuantizedType,
96-
UniformQuantizedPerAxisType
106+
UniformQuantizedPerAxisType,
107+
BlockFloatQuantizedType
97108
];
98109
}
99110

mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its
7+
// affiliates
68
//
79
//===----------------------------------------------------------------------===//
810

@@ -23,6 +25,7 @@ namespace detail {
2325

2426
struct QuantizedTypeStorage;
2527
struct AnyQuantizedTypeStorage;
28+
struct BlockFloatQuantizedTypeStorage;
2629
struct UniformQuantizedTypeStorage;
2730
struct UniformQuantizedPerAxisTypeStorage;
2831
struct CalibratedQuantizedTypeStorage;
@@ -224,6 +227,94 @@ class AnyQuantizedType
224227
int64_t storageTypeMax);
225228
};
226229

230+
/// Represents block floating point quantization where multiple elements share
231+
/// data along a particular axis (e.g. BFP16). The concrete block format
232+
/// determines the implied storage characteristics and is not exposed in the IR.
233+
/// This class is experimental and may be subject to change.
234+
/// Design decisions:
235+
/// - The base class requires an integral storage type. For
236+
/// block-quantized/packed types, the required storage with depends on the
237+
/// number of elements. For example, a single BFP16 element requires 16 bits to
238+
/// be represented, but a block of 8 BFP16 elements can be packed into 9 bits
239+
/// per element on average (72 bits total). The storage type for
240+
/// an element from BlockFloatQuantizedType is the "packed" type
241+
/// divided by the number of packed elements, so for BFP16 i9.
242+
/// -- As accessing properties like min/max storage values and integral width
243+
/// depend on the block size, these methods are overridden to return errors.
244+
/// - The expressed type is not stored yet, this may change if there is a use
245+
/// for it.
246+
/// - The axis is signed to match MLIR convention, but enforced to be
247+
/// non-negative.
248+
class BlockFloatQuantizedType
249+
: public Type::TypeBase<BlockFloatQuantizedType, QuantizedType,
250+
detail::BlockFloatQuantizedTypeStorage> {
251+
public:
252+
using Base::Base;
253+
using Base::getChecked;
254+
255+
static constexpr StringLiteral name = "quant.block_float";
256+
257+
// MX6 refers to the MicoExponent format, not to the OCP MicroScaling format
258+
// with the same name.
259+
enum class BlockMode : uint32_t { BFP16 = 0, MX6 = 1, MAX_VALUE = MX6 };
260+
261+
static std::optional<BlockMode> parseBlockMode(StringRef name);
262+
static StringRef getBlockModeName(BlockMode blockMode);
263+
264+
static BlockFloatQuantizedType get(MLIRContext *ctx, BlockMode blockMode,
265+
int32_t axis);
266+
static BlockFloatQuantizedType
267+
getChecked(function_ref<InFlightDiagnostic()> emitError, MLIRContext *ctx,
268+
BlockMode blockMode, int32_t axis);
269+
270+
static LogicalResult
271+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
272+
uint32_t blockModeRaw, int32_t axis, unsigned flags,
273+
Type storageType, Type expressedType, int64_t storageTypeMin,
274+
int64_t storageTypeMax);
275+
276+
Type getStorageType() const {
277+
assert(false &&
278+
"BlockFloatQuantizedType does not have a direct storage type");
279+
return QuantizedType::getStorageType();
280+
}
281+
282+
int64_t getStorageTypeMin() const {
283+
assert(false &&
284+
"BlockFloatQuantizedType does not have a direct storage type");
285+
return QuantizedType::getStorageTypeMin();
286+
}
287+
288+
int64_t getStorageTypeMax() const {
289+
assert(false &&
290+
"BlockFloatQuantizedType does not have a direct storage type");
291+
return QuantizedType::getStorageTypeMax();
292+
}
293+
294+
bool hasStorageTypeBounds() const {
295+
assert(false &&
296+
"BlockFloatQuantizedType does not have a direct storage type");
297+
return QuantizedType::hasStorageTypeBounds();
298+
}
299+
300+
unsigned getStorageTypeIntegralWidth() const {
301+
assert(false &&
302+
"BlockFloatQuantizedType does not have a direct storage type");
303+
return QuantizedType::getStorageTypeIntegralWidth();
304+
}
305+
306+
BlockMode getBlockMode() const;
307+
int32_t getAxis() const;
308+
309+
/// Number of elements in a block
310+
unsigned getBlockSize() const;
311+
/// Average number of bits used to represent each element in the block
312+
unsigned getAverageBitsPerElement() const;
313+
/// Returns the size in bits required to represent a single, not
314+
/// blocked/packed element.
315+
unsigned getSingleElementStorageSize() const;
316+
};
317+
227318
/// Represents a family of uniform, quantized types.
228319
///
229320
/// Each instance of this type expresses a mapping between real values (most

mlir/lib/Dialect/Quant/IR/QuantOps.cpp

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its
7+
// affiliates
68
//
79
//===----------------------------------------------------------------------===//
810

@@ -37,30 +39,62 @@ namespace {
3739
LogicalResult verifyPerAxisQuantization(Operation *op,
3840
QuantizedType quantizedType,
3941
Type containerType) {
40-
auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
42+
auto quantizedPerAxisType =
43+
dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
4144
if (!quantizedPerAxisType)
4245
return success();
4346

44-
auto tensorType = dyn_cast<TensorType>(containerType);
45-
if (!tensorType)
47+
auto shapedType = dyn_cast<ShapedType>(containerType);
48+
if (!shapedType)
4649
return op->emitError("scalar types may not use per-axis quantization");
4750

48-
if (!tensorType.hasRank())
51+
if (!shapedType.hasRank())
4952
return success();
5053

5154
int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension();
52-
if (quantizedDimension >= tensorType.getRank())
55+
if (quantizedDimension >= shapedType.getRank())
5356
return op->emitError("quantized dimension must be less than tensor rank");
5457

55-
int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
58+
int64_t quantizedDimensionSize = shapedType.getDimSize(quantizedDimension);
5659
if (quantizedDimensionSize != ShapedType::kDynamic &&
57-
quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size())
60+
quantizedDimensionSize !=
61+
(int64_t)quantizedPerAxisType.getScales().size())
5862
return op->emitError(
5963
"quantized dimension size does not match number of scales");
6064

6165
return success();
6266
}
6367

68+
// Verify the integrity of block float quantization information, if present.
69+
//
70+
// - quantizedType
71+
// Any quantized type. Any quantized type with no block float quantization is
72+
// ignored.
73+
//
74+
// - containerType
75+
// Original input or result type of the operation using the provided quantized
76+
// type. Used to ensure that the quantized type appears within a tensor and
77+
// that the tensor is compatible with block float quantization information.
78+
//
79+
LogicalResult verifyBlockFloatQuantization(Operation *op,
80+
QuantizedType quantizedType,
81+
Type containerType) {
82+
auto blockModeType = dyn_cast<BlockFloatQuantizedType>(quantizedType);
83+
if (!blockModeType)
84+
return success();
85+
86+
auto shapedType = dyn_cast<ShapedType>(containerType);
87+
if (!shapedType)
88+
return op->emitError("scalar types may not use block float quantization");
89+
if (!shapedType.hasRank())
90+
return success();
91+
// We could also check that the tensor is a multiple of the block size, but
92+
// that requires that all padding is visible in MLIR
93+
if (blockModeType.getAxis() >= shapedType.getRank())
94+
return op->emitError("block axis must be less than tensor rank");
95+
return success();
96+
}
97+
6498
// Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
6599
//
66100
// - quantizedType
@@ -76,12 +110,18 @@ LogicalResult verifyPerAxisQuantization(Operation *op,
76110
//
77111
LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
78112
FloatType floatType, Type containerType) {
79-
if (quantizedType.getExpressedType() != floatType)
113+
if (!isa<BlockFloatQuantizedType>(quantizedType) &&
114+
quantizedType.getExpressedType() != floatType)
80115
return op->emitError(
81116
"expressed type in quantized type expected to match float type");
82117

83-
// Veriy integrity of per-axis quantization information, if present.
84-
return verifyPerAxisQuantization(op, quantizedType, containerType);
118+
if (failed(verifyPerAxisQuantization(op, quantizedType, containerType)))
119+
return failure();
120+
121+
if (failed(verifyBlockFloatQuantization(op, quantizedType, containerType)))
122+
return failure();
123+
124+
return success();
85125
}
86126

87127
} // namespace
@@ -92,8 +132,8 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
92132
//===----------------------------------------------------------------------===//
93133

94134
void QuantDialect::initialize() {
95-
addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
96-
UniformQuantizedPerAxisType>();
135+
addTypes<AnyQuantizedType, BlockFloatQuantizedType, CalibratedQuantizedType,
136+
UniformQuantizedType, UniformQuantizedPerAxisType>();
97137
addOperations<
98138
#define GET_OP_LIST
99139
#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
@@ -167,6 +207,9 @@ QuantizedType QuantizeCastOp::getQuantizedType() {
167207

168208
LogicalResult StorageCastOp::verify() {
169209
auto quantizedType = getQuantizedType();
210+
if (isa<BlockFloatQuantizedType>(quantizedType))
211+
return getOperation()->emitError(
212+
"storage cast not supported for block float quantized types");
170213
auto integerType = getIntegerType();
171214
if (quantizedType.getStorageType() != integerType)
172215
return emitError(

0 commit comments

Comments
 (0)