Skip to content

Qwen MoE BringUp #34

@jtuyls

Description

@jtuyls

Model: Qwen1.5-MoE-A2.7B-Chat — INT4 — 24 layers, 60 experts/layer, top-4 routing, 16 KV heads, head_dim=128

Local fixes to get this to work:

A1. DequantizeLinear block_size decomposition: llvm/torch-mlir#4505

Why: torch-mlir doesn't support block_size attr, so INT4 QDQ models fail.
What: Decompose to extui→uitofp→subf→mulf in the EP. Standalone.
Alternatives:

  • torch-mlir (proper fix): Add block_size support to onnx.DequantizeLinear → torch lowering. This is where it belongs — it's a missing ONNX opset feature.

A2. ReduceMean/ReduceSum opset 18+: llvm/torch-mlir#4506

Why: Opset 18 moved axes from attribute to tensor input. torch-mlir expects attribute form.
What: Convert constant tensor input back to attribute in the EP. Standalone.
Alternatives:

  • torch-mlir (proper fix): Handle axes-as-tensor in ONNX→torch lowering. Same story — opset version gap in torch-mlir.

A3. Shape refinement + op rewrites

Why: ORT's per-node TypeInfo loses shape precision, producing all-? dims → slow dynamic codegen.
What: Track refined types, rewrite ops to torch dialect with traceable shapes. (~900 lines)
Alternatives:

  • ORT (proper fix for part of it): GetOutputTypeInfo should respect add_free_dimension_override_by_name. The EP shouldn't need to re-derive static shapes that ORT already knows.
  • torch-mlir: The op rewrites (Reshape→aten.reshape, Unsqueeze→aten.unsqueeze, etc.) are essentially doing ONNX→torch lowering that torch-mlir already does. The EP emits torch.operator "onnx.Reshape" and lets torch-mlir lower it — but the shape info is lost by then. If ORT provided correct types, the generic path would work fine.
  • IREE: Could be more aggressive about shape inference on onnx.* ops before lowering, but that's pushing ONNX knowledge into the wrong layer.
  • The real root cause is split: ORT gives wrong types + the EP uses a passthrough onnx.* emission that loses shape context torch-mlir could recover.

A4. Expand output type fix

Why: ORT can't infer KV cache broadcast shape through Where(Equal(Concat, -1), ones, Concat) pattern. onnx.Expand output stays [1,1,?,1] instead of [1,16,?,128]. ScatterElements then writes only 1 element per position → nearly all-zero KV cache.
What: Extend ResolveShapeVector to handle Reshape/Where chains; rewrite Expand refinement in PropagateRefinedTypes to resolve shape tensor and compute broadcast output. Depends on A3 infrastructure.
Alternatives:

  • ORT (proper fix): Shape inference should trace through Where(Equal(...), ...) patterns to resolve Expand output shapes. The EP is compensating for ORT's shape inference limitations.
  • torch-mlir: If the Expand op arrived with correct output type annotation, torch-mlir would lower it correctly. This is an ORT type propagation issue, not a lowering issue.

A5. ScatterElements → tensor.insert_slice for KV cache

Why: ONNX ScatterElements with sequential indices decomposes to hundreds of memcpy dispatches in IREE. The KV cache update pattern (ScatterElements(past_kv, Expand(Reshape(Cast(Add(seq_pos, ...)))), updates)) is a contiguous slice write.
What: Pattern-match the KV cache ScatterElements and emit tensor.insert_slice instead. Depends on A3 (uses refined types for shape resolution).
Alternatives:

  • IREE: Recognize contiguous scatter patterns and lower to insert_slice internally. The EP is doing a pattern match that the compiler should do.
  • torch-mlir: Better onnx.ScatterElements → torch lowering that preserves the contiguous-slice structure.

A6. Static shape propagation + Cast roundtrip elimination

Why: ORT's node-level TypeInfo returns original model shapes (with dynamic ? dims) even after add_free_dimension_override_by_name("seq_len", 1). So graph inputs have static types in the function signature but ? in every op annotation downstream. Cast(f16→f32)→Cast(f32→f16) roundtrips after Softmax add 24 unnecessary dispatches.
What: Seed refined_types_ from graph input signatures so static shapes propagate everywhere. Add Gather shape refinement and Reshape -1 resolution. Detect Cast roundtrips and alias through them for DCE. Depends on A3.
Alternatives:

  • ORT (proper fix for shape propagation): Node-level GetOutputTypeInfo should incorporate free dimension overrides. This is the root cause — the EP is re-deriving information ORT already has in the graph-level type but doesn't expose per-node.
  • torch-mlir / IREE (Cast elimination): Cast(A→B)→Cast(B→A) roundtrip canonicalization is a standard compiler optimization. Both torch-mlir and IREE should already handle this. The EP shouldn't need to do DCE prep.
  • The Gather refinement and Reshape -1 resolution are shape inference that torch-mlir could do if it received correct input types — again falls out of the ORT type info bug.

A7. MatMul f16→f32 accumulation promotion

Why: IREE selects WMMAR3_F16_16x16x16_F16 (f16 accumulator) for f16 matmuls on gfx1100, which produces NaN for transposed-B operands (the QxK^T pattern in attention). This appears to be a WMMA bug.
What: Wrap every f16 MatMul with Cast(f16→f32) + MatMul(f32) + Cast(f32→f16), forcing IREE to use the f32 path. Depends on A3 (uses ParseVtensorType, BuildVtensorType, ResolveAlias). ~57 lines.
Alternatives:

  • IREE (proper fix): Fix WMMA instruction selection to use WMMAR3_F32_16x16x16_F16 (f32 accumulator) for transposed-B cases, or fix the f16 accumulator path itself. This is where the bug lives: [GPU] Fix RDNA3 WMMA f16/bf16 accumulator layout iree#23806
  • IREE (alternative): Recognize extf→matmul→truncf as a request for f32 accumulation and select the right WMMA variant. This would make the EP workaround at least intentional rather than a blunt hammer.
  • The EP workaround promotes all f16 matmuls, not just the buggy transposed-B ones — unnecessary overhead for matmuls that would be correct with f16 accumulation. A targeted fix in IREE would be more precise.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions