Skip to content

Commit 13d7348

Browse files
jtuylsclaude
andcommitted
Emit torch.vtensor.literal for small initializers
Small initializers (<=256 bytes) now emit torch.vtensor.literal instead of flow.tensor.constant + torch_c.from_builtin_tensor. This allows torch-mlir conversion patterns to directly match constant values. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Jorn <jorn.tuyls@gmail.com>
1 parent 4f595c0 commit 13d7348

File tree

2 files changed

+48
-27
lines changed

2 files changed

+48
-27
lines changed

src/mlir_gen.cc

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ ErrorOr<std::string> FormatTensorType(const Ort::ConstTypeInfo& type_info) {
156156
// Formats a tensor type as tensor<dimsxdtype> (standard MLIR format).
157157
// Uses signless integer types as required by MLIR tensor dialect.
158158
// Dynamic dims are always emitted as "?".
159-
ErrorOr<std::string> FormatMlirTensorType(const Ort::ConstTypeInfo& type_info) {
159+
ErrorOr<std::string> FormatMlirTensorType(const Ort::ConstTypeInfo& type_info,
160+
bool signless = true) {
160161
if (type_info.GetONNXType() != ONNX_TYPE_TENSOR) {
161162
return errorWithCode(ErrorCode::kNotImplemented,
162163
"Non-tensor type {} not supported",
@@ -178,7 +179,7 @@ ErrorOr<std::string> FormatMlirTensorType(const Ort::ConstTypeInfo& type_info) {
178179
ss << "x";
179180
}
180181
IREE_EP_ASSIGN_OR_RETURN(std::string elem_type,
181-
GetElementType(dtype, /*signless=*/true));
182+
GetElementType(dtype, signless));
182183
ss << elem_type << ">";
183184
return ss.str();
184185
}
@@ -414,7 +415,7 @@ class MlirGenerator {
414415
// Emit dim constraints (util.assume.int + flow.tensor.tie_shape).
415416
IREE_EP_RETURN_IF_ERROR(EmitDimConstraints());
416417

417-
// Emit initializers as flow.tensor.constant ops.
418+
// Emit initializers (small as vtensor.literal, large as flow.parameter).
418419
for (const auto& init : initializers_) {
419420
IREE_EP_RETURN_IF_ERROR(EmitInitializer(init));
420421
}
@@ -429,15 +430,14 @@ class MlirGenerator {
429430
return EmitReturn();
430431
}
431432

432-
// Emits an initializer as a flow.tensor.constant with a
433-
// torch_c.from_builtin_tensor cast. Small initializers use dense<> with
434-
// inline hex-encoded data. Large initializers use #flow.parameter.named
435-
// (data stored in IRPA archive).
433+
// Emits an initializer. Small initializers (<=256 bytes) use
434+
// torch.vtensor.literal so torch-mlir conversion patterns can directly
435+
// match constant values. Large initializers use flow.tensor.constant
436+
// with #flow.parameter.named (data stored in IRPA archive).
436437
//
437438
// Output format (small):
438-
// %__raw_name = flow.tensor.constant dense<"0x..."> : tensor<...>
439-
// %name = torch_c.from_builtin_tensor %__raw_name : tensor<...>
440-
// -> !torch.vtensor<[...],dtype>
439+
// %name = torch.vtensor.literal(dense<"0x..."> : tensor<...>)
440+
// : !torch.vtensor<[...],dtype>
441441
//
442442
// Output format (large):
443443
// %__raw_name = flow.tensor.constant
@@ -456,18 +456,24 @@ class MlirGenerator {
456456
OnnxElementTypeSize(tensor_info.GetElementType());
457457

458458
if (byte_size <= kMaxInlineInitializerSize) {
459-
// Small: inline with dense<> DenseElementsAttr.
459+
// Small: emit as torch.vtensor.literal so torch-mlir conversion
460+
// patterns can directly match constant values.
460461
Ort::ConstValue tensor_value{nullptr};
461462
IREE_EP_RETURN_IF_ORT_STATUS(init.GetInitializer(tensor_value).release());
462463
const auto* data =
463464
static_cast<const uint8_t*>(tensor_value.GetTensorRawData());
464465
std::string hex = HexEncode(data, tensor_value.GetTensorSizeInBytes());
465466

467+
// vtensor.literal requires signed integer types (si64, si32, etc.)
468+
// in the inner dense attribute, not signless (i64, i32).
469+
IREE_EP_ASSIGN_OR_RETURN(
470+
std::string signed_tensor_type,
471+
FormatMlirTensorType(init.TypeInfo(), /*signless=*/false));
472+
466473
constexpr std::string_view schema =
467-
R"( %__raw_{0} = flow.tensor.constant dense<"{3}"> : {1}
468-
%{0} = torch_c.from_builtin_tensor %__raw_{0} : {1} -> {2}
474+
R"( %{0} = torch.vtensor.literal(dense<"{2}"> : {1}) : {3}
469475
)";
470-
out_ << std::format(schema, name, tensor_type, vtensor_type, hex);
476+
out_ << std::format(schema, name, signed_tensor_type, hex, vtensor_type);
471477
} else {
472478
// Large: parameter reference. Data stored in IRPA archive.
473479
constexpr std::string_view schema =

test/test_initializers.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,21 @@
1414
# Fixed seed for reproducibility.
1515
np.random.seed(42)
1616

17-
# Test data. Four initializers, each handled differently:
18-
# D_small: [1, 64] float32 = 256 bytes -> inline dense<>
17+
# Test data. Five initializers, each handled differently:
18+
# D_small: [1, 64] float32 = 256 bytes -> vtensor.literal (inline)
1919
# D_large: [64, 64] float32 = 16384 bytes -> IRPA parameter
2020
# D_ext: [64, 64] float32 = 16384 bytes -> external file (parameter, not in IRPA)
21-
# D_ext_small: [1, 64] float32 = 256 bytes -> external file (inlined as dense<>)
22-
# Graph: C = (((A + D_small) + D_large) + D_ext) + D_ext_small
21+
# D_ext_small: [1, 64] float32 = 256 bytes -> external file (vtensor.literal)
22+
# axes: [1] int64 = 8 bytes -> vtensor.literal (int, tests si64 type)
23+
# Graph: C = ReduceMean((((A + D_small) + D_large) + D_ext) + D_ext_small, axes=[1])
2324
SHAPE = [64, 64]
2425
A_DATA = np.random.rand(*SHAPE).astype(np.float32)
2526
B_SMALL = np.random.rand(1, 64).astype(np.float32)
2627
B_LARGE = np.random.rand(*SHAPE).astype(np.float32)
2728
B_EXT = np.random.rand(*SHAPE).astype(np.float32)
2829
B_EXT_SMALL = np.random.rand(1, 64).astype(np.float32)
29-
EXPECTED = (((A_DATA + B_SMALL) + B_LARGE) + B_EXT) + B_EXT_SMALL
30+
SUM_ALL = (((A_DATA + B_SMALL) + B_LARGE) + B_EXT) + B_EXT_SMALL
31+
EXPECTED = np.mean(SUM_ALL, axis=1, keepdims=True)
3032

3133

3234
def _create_model():
@@ -71,7 +73,7 @@ def _create_model():
7173
ext_tensor.ClearField("raw_data")
7274
ext_tensor.data_location = TensorProto.EXTERNAL
7375

74-
# Small external initializer (should be inlined as dense<>).
76+
# Small external initializer (should be inlined as vtensor.literal).
7577
ext_small_filename = "ext_small_weights.bin"
7678
ext_small_path = os.path.join(model_dir, ext_small_filename)
7779
ext_small_tensor = from_array(B_EXT_SMALL, name="D_ext_small")
@@ -84,27 +86,33 @@ def _create_model():
8486
ext_small_tensor.ClearField("raw_data")
8587
ext_small_tensor.data_location = TensorProto.EXTERNAL
8688

89+
# Int64 axes initializer for ReduceMean (8 bytes — tests si64 signedness).
90+
axes_tensor = from_array(np.array([1], dtype=np.int64), name="axes")
91+
8792
input_a = helper.make_tensor_value_info("A", TensorProto.FLOAT, SHAPE)
88-
output = helper.make_tensor_value_info("C", TensorProto.FLOAT, SHAPE)
93+
output = helper.make_tensor_value_info("C", TensorProto.FLOAT, [64, 1])
8994

9095
add1 = helper.make_node("Add", inputs=["A", "D_small"], outputs=["T1"])
9196
add2 = helper.make_node("Add", inputs=["T1", "D_large"], outputs=["T2"])
9297
add3 = helper.make_node("Add", inputs=["T2", "D_ext"], outputs=["T3"])
93-
add4 = helper.make_node("Add", inputs=["T3", "D_ext_small"], outputs=["C"])
98+
add4 = helper.make_node("Add", inputs=["T3", "D_ext_small"], outputs=["T4"])
99+
reduce_mean = helper.make_node(
100+
"ReduceMean", inputs=["T4", "axes"], outputs=["C"], keepdims=1
101+
)
94102

95103
graph = helper.make_graph(
96-
[add1, add2, add3, add4, const_small, const_large],
104+
[add1, add2, add3, add4, reduce_mean, const_small, const_large],
97105
"test_graph",
98106
[input_a],
99107
[output],
100-
initializer=[ext_tensor, ext_small_tensor],
108+
initializer=[ext_tensor, ext_small_tensor, axes_tensor],
101109
)
102110
model = helper.make_model(
103111
graph,
104112
producer_name="iree_test",
105-
opset_imports=[helper.make_opsetid("", 17)],
113+
opset_imports=[helper.make_opsetid("", 18)],
106114
)
107-
model.ir_version = 8
115+
model.ir_version = 9
108116

109117
model_path = os.path.join(model_dir, "model.onnx")
110118
onnx.save(model, model_path)
@@ -164,10 +172,17 @@ def test_with_save_intermediates(iree_device):
164172

165173
mlir_content = open(list(new_mlir)[0]).read()
166174

167-
# D_small and D_ext_small should be inlined via dense<>.
175+
# D_small, D_ext_small, and axes should be vtensor.literal.
176+
assert (
177+
"torch.vtensor.literal" in mlir_content
178+
), "MLIR should contain torch.vtensor.literal for small constants"
168179
assert (
169180
'dense<"0x' in mlir_content
170181
), "MLIR should contain inline dense<> attributes"
182+
# Int64 axes initializer should use signed type (si64).
183+
assert (
184+
"si64" in mlir_content
185+
), "int64 initializer should use signed type (si64) in vtensor.literal"
171186
assert (
172187
"dense_resource" not in mlir_content
173188
), "MLIR should not contain dense_resource (replaced by dense<>)"

0 commit comments

Comments
 (0)