You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Why: torch-mlir doesn't support block_size attr, so INT4 QDQ models fail. What: Decompose to extui→uitofp→subf→mulf in the EP. Standalone. Alternatives:
torch-mlir (proper fix): Add block_size support to onnx.DequantizeLinear → torch lowering. This is where it belongs — it's a missing ONNX opset feature.
Why: Opset 18 moved axes from attribute to tensor input. torch-mlir expects attribute form. What: Convert constant tensor input back to attribute in the EP. Standalone. Alternatives:
torch-mlir (proper fix): Handle axes-as-tensor in ONNX→torch lowering. Same story — opset version gap in torch-mlir.
ORT (proper fix for part of it):GetOutputTypeInfo should respect add_free_dimension_override_by_name. The EP shouldn't need to re-derive static shapes that ORT already knows.
torch-mlir: The op rewrites (Reshape→aten.reshape, Unsqueeze→aten.unsqueeze, etc.) are essentially doing ONNX→torch lowering that torch-mlir already does. The EP emits torch.operator "onnx.Reshape" and lets torch-mlir lower it — but the shape info is lost by then. If ORT provided correct types, the generic path would work fine.
IREE: Could be more aggressive about shape inference on onnx.* ops before lowering, but that's pushing ONNX knowledge into the wrong layer.
The real root cause is split: ORT gives wrong types + the EP uses a passthrough onnx.* emission that loses shape context torch-mlir could recover.
A4. Expand output type fix
Why: ORT can't infer KV cache broadcast shape through Where(Equal(Concat, -1), ones, Concat) pattern. onnx.Expand output stays [1,1,?,1] instead of [1,16,?,128]. ScatterElements then writes only 1 element per position → nearly all-zero KV cache. What: Extend ResolveShapeVector to handle Reshape/Where chains; rewrite Expand refinement in PropagateRefinedTypes to resolve shape tensor and compute broadcast output. Depends on A3 infrastructure. Alternatives:
ORT (proper fix): Shape inference should trace through Where(Equal(...), ...) patterns to resolve Expand output shapes. The EP is compensating for ORT's shape inference limitations.
torch-mlir: If the Expand op arrived with correct output type annotation, torch-mlir would lower it correctly. This is an ORT type propagation issue, not a lowering issue.
A5. ScatterElements → tensor.insert_slice for KV cache
Why: ONNX ScatterElements with sequential indices decomposes to hundreds of memcpy dispatches in IREE. The KV cache update pattern (ScatterElements(past_kv, Expand(Reshape(Cast(Add(seq_pos, ...)))), updates)) is a contiguous slice write. What: Pattern-match the KV cache ScatterElements and emit tensor.insert_slice instead. Depends on A3 (uses refined types for shape resolution). Alternatives:
IREE: Recognize contiguous scatter patterns and lower to insert_slice internally. The EP is doing a pattern match that the compiler should do.
torch-mlir: Better onnx.ScatterElements → torch lowering that preserves the contiguous-slice structure.
Why: ORT's node-level TypeInfo returns original model shapes (with dynamic ? dims) even after add_free_dimension_override_by_name("seq_len", 1). So graph inputs have static types in the function signature but ? in every op annotation downstream. Cast(f16→f32)→Cast(f32→f16) roundtrips after Softmax add 24 unnecessary dispatches. What: Seed refined_types_ from graph input signatures so static shapes propagate everywhere. Add Gather shape refinement and Reshape -1 resolution. Detect Cast roundtrips and alias through them for DCE. Depends on A3. Alternatives:
ORT (proper fix for shape propagation): Node-level GetOutputTypeInfo should incorporate free dimension overrides. This is the root cause — the EP is re-deriving information ORT already has in the graph-level type but doesn't expose per-node.
torch-mlir / IREE (Cast elimination): Cast(A→B)→Cast(B→A) roundtrip canonicalization is a standard compiler optimization. Both torch-mlir and IREE should already handle this. The EP shouldn't need to do DCE prep.
The Gather refinement and Reshape -1 resolution are shape inference that torch-mlir could do if it received correct input types — again falls out of the ORT type info bug.
A7. MatMul f16→f32 accumulation promotion
Why: IREE selects WMMAR3_F16_16x16x16_F16 (f16 accumulator) for f16 matmuls on gfx1100, which produces NaN for transposed-B operands (the QxK^T pattern in attention). This appears to be a WMMA bug. What: Wrap every f16 MatMul with Cast(f16→f32) + MatMul(f32) + Cast(f32→f16), forcing IREE to use the f32 path. Depends on A3 (uses ParseVtensorType, BuildVtensorType, ResolveAlias). ~57 lines. Alternatives:
IREE (proper fix): Fix WMMA instruction selection to use WMMAR3_F32_16x16x16_F16 (f32 accumulator) for transposed-B cases, or fix the f16 accumulator path itself. This is where the bug lives: [GPU] Fix RDNA3 WMMA f16/bf16 accumulator layout iree#23806
IREE (alternative): Recognize extf→matmul→truncf as a request for f32 accumulation and select the right WMMA variant. This would make the EP workaround at least intentional rather than a blunt hammer.
The EP workaround promotes all f16 matmuls, not just the buggy transposed-B ones — unnecessary overhead for matmuls that would be correct with f16 accumulation. A targeted fix in IREE would be more precise.
Model: Qwen1.5-MoE-A2.7B-Chat — INT4 — 24 layers, 60 experts/layer, top-4 routing, 16 KV heads, head_dim=128
Local fixes to get this to work:
A1. DequantizeLinear block_size decomposition: llvm/torch-mlir#4505
Why: torch-mlir doesn't support
block_sizeattr, so INT4 QDQ models fail.What: Decompose to
extui→uitofp→subf→mulfin the EP. Standalone.Alternatives:
block_sizesupport toonnx.DequantizeLinear→ torch lowering. This is where it belongs — it's a missing ONNX opset feature.A2. ReduceMean/ReduceSum opset 18+: llvm/torch-mlir#4506
Why: Opset 18 moved axes from attribute to tensor input. torch-mlir expects attribute form.
What: Convert constant tensor input back to attribute in the EP. Standalone.
Alternatives:
A3. Shape refinement + op rewrites
Why: ORT's per-node TypeInfo loses shape precision, producing all-
?dims → slow dynamic codegen.What: Track refined types, rewrite ops to torch dialect with traceable shapes. (~900 lines)
Alternatives:
GetOutputTypeInfoshould respectadd_free_dimension_override_by_name. The EP shouldn't need to re-derive static shapes that ORT already knows.Reshape→aten.reshape,Unsqueeze→aten.unsqueeze, etc.) are essentially doing ONNX→torch lowering that torch-mlir already does. The EP emitstorch.operator "onnx.Reshape"and lets torch-mlir lower it — but the shape info is lost by then. If ORT provided correct types, the generic path would work fine.onnx.*ops before lowering, but that's pushing ONNX knowledge into the wrong layer.onnx.*emission that loses shape context torch-mlir could recover.A4. Expand output type fix
Why: ORT can't infer KV cache broadcast shape through
Where(Equal(Concat, -1), ones, Concat)pattern.onnx.Expandoutput stays[1,1,?,1]instead of[1,16,?,128]. ScatterElements then writes only 1 element per position → nearly all-zero KV cache.What: Extend
ResolveShapeVectorto handle Reshape/Where chains; rewrite Expand refinement inPropagateRefinedTypesto resolve shape tensor and compute broadcast output. Depends on A3 infrastructure.Alternatives:
Where(Equal(...), ...)patterns to resolve Expand output shapes. The EP is compensating for ORT's shape inference limitations.A5. ScatterElements → tensor.insert_slice for KV cache
Why: ONNX ScatterElements with sequential indices decomposes to hundreds of memcpy dispatches in IREE. The KV cache update pattern (
ScatterElements(past_kv, Expand(Reshape(Cast(Add(seq_pos, ...)))), updates)) is a contiguous slice write.What: Pattern-match the KV cache ScatterElements and emit
tensor.insert_sliceinstead. Depends on A3 (uses refined types for shape resolution).Alternatives:
insert_sliceinternally. The EP is doing a pattern match that the compiler should do.onnx.ScatterElements→ torch lowering that preserves the contiguous-slice structure.A6. Static shape propagation + Cast roundtrip elimination
Why: ORT's node-level TypeInfo returns original model shapes (with dynamic
?dims) even afteradd_free_dimension_override_by_name("seq_len", 1). So graph inputs have static types in the function signature but?in every op annotation downstream. Cast(f16→f32)→Cast(f32→f16) roundtrips after Softmax add 24 unnecessary dispatches.What: Seed
refined_types_from graph input signatures so static shapes propagate everywhere. Add Gather shape refinement and Reshape-1resolution. Detect Cast roundtrips and alias through them for DCE. Depends on A3.Alternatives:
GetOutputTypeInfoshould incorporate free dimension overrides. This is the root cause — the EP is re-deriving information ORT already has in the graph-level type but doesn't expose per-node.A→B)→Cast(B→A) roundtrip canonicalization is a standard compiler optimization. Both torch-mlir and IREE should already handle this. The EP shouldn't need to do DCE prep.-1resolution are shape inference that torch-mlir could do if it received correct input types — again falls out of the ORT type info bug.A7. MatMul f16→f32 accumulation promotion
Why: IREE selects
WMMAR3_F16_16x16x16_F16(f16 accumulator) for f16 matmuls on gfx1100, which produces NaN for transposed-B operands (the QxK^T pattern in attention). This appears to be a WMMA bug.What: Wrap every f16 MatMul with
Cast(f16→f32)+MatMul(f32)+Cast(f32→f16), forcing IREE to use the f32 path. Depends on A3 (usesParseVtensorType,BuildVtensorType,ResolveAlias). ~57 lines.Alternatives:
WMMAR3_F32_16x16x16_F16(f32 accumulator) for transposed-B cases, or fix the f16 accumulator path itself. This is where the bug lives: [GPU] Fix RDNA3 WMMA f16/bf16 accumulator layout iree#23806extf→matmul→truncfas a request for f32 accumulation and select the right WMMA variant. This would make the EP workaround at least intentional rather than a blunt hammer.