Skip to content

Sub-optimal LLVM IR vectorization for reduce over small non-power-of-2 dimensions on CPU (interleaved shufflevector instead of contiguous loads) #40677

@othakkar

Description

@othakkar

When XLA:CPU lowers a reduce over a small non-power-of-2 innermost dimension (e.g., size 7), the MLIR-emitted LLVM IR gets poorly vectorized by LLVM's loop vectorizer. Instead of contiguous vector loads with horizontal reduction, LLVM produces wide interleaved loads (<28 x float>) deinterleaved into 7 × <4 x float> via shufflevector, followed by sequential fadd chains. This happens when the reduce operation goes through the FusionWrapper and not through the YNNPACK. This is significantly worse than what the old direct-to-LLVM emitter produced, which was contiguous <8 x float> loads followed by vectorized llvm.vector.reduce.fadd.

Reproducer

HloModule minimal_reduce

%add_computation (lhs: f32[], rhs: f32[]) -> f32[] {
  %lhs = f32[] parameter(0)
  %rhs = f32[] parameter(1)
  ROOT %add = f32[] add(%lhs, %rhs)
}

ENTRY %main (input: f32[256,12,197,7]) -> f32[256,12,197] {
  %input = f32[256,12,197,7]{3,2,1,0} parameter(0)
  %zero = f32[] constant(0)
  ROOT %reduce = f32[256,12,197]{2,1,0} reduce(%input, %zero), dimensions={3}, to_apply=%add_computation
}

Compile and inspect the generated LLVM IR

XLA_FLAGS="--xla_cpu_experimental_ynn_fusion_type=invalid --xla_dump_to=/tmp/xla_dump bazel run //xla/tools:run_hlo_module -- --input_format=hlo --platform=cpu minimal_reduce.hlo

Observed (bad) - after LLVM optimization

The inner loop is vectorized as an interleave group of 7 with VF=4, producing:

%wide.vec = load <28 x float>, ptr %ptr, align 4
%strided.vec   = shufflevector <28 x float> %wide.vec, <28 x float> poison, <4 x i32> <i32 0, i32 7, i32 14, i32 21>
%strided.vec35 = shufflevector <28 x float> %wide.vec, <28 x float> poison, <4 x i32> <i32 1, i32 8, i32 15, i32 22>
%strided.vec36 = shufflevector <28 x float> %wide.vec, <28 x float> poison, <4 x i32> <i32 2, i32 9, i32 16, i32 23>
; ... 4 more shufflevectors ...
%27 = fadd reassoc <4 x float> %broadcast.splat, %strided.vec
%28 = fadd reassoc <4 x float> %27, %strided.vec35
; ... 5 more sequential fadds ...

Expected (good) - from old direct-to-LLVM emitter

This commit reverted the change to lower directly to LLVM, which generated contiguous <8 x float> loads followed by tree-reduction and llvm.vector.reduce.fadd:

%storemerge1328.us.us = phi i64 [ 0, %reduce-window.97.loop_header.dim.2.preheader.us ], [ %invar.inc6.us.us, %reduce-window.97.loop_body.dim.2.us.us ]
  %.idx.us.us = shl nuw nsw i64 %storemerge1328.us.us, 7
  %5 = getelementptr inbounds nuw i8, ptr %.split19.us, i64 %.idx.us.us
  %6 = getelementptr inbounds nuw i8, ptr %5, i64 96
  %wide.load45 = load <8 x float>, ptr %6, align 32, !invariant.load !4, !noalias !9
  %7 = getelementptr inbounds nuw i8, ptr %5, i64 64
  %wide.load44 = load <8 x float>, ptr %7, align 64, !invariant.load !4, !noalias !9
  %8 = getelementptr inbounds nuw i8, ptr %5, i64 32
  %wide.load43 = load <8 x float>, ptr %8, align 32, !invariant.load !4, !noalias !9
  %wide.load = load <8 x float>, ptr %5, align 64, !invariant.load !4, !noalias !9
  %9 = fadd reassoc <8 x float> %3, %wide.load
  %bin.rdx = fadd reassoc <8 x float> %wide.load43, %9
  %bin.rdx46 = fadd reassoc <8 x float> %wide.load44, %bin.rdx
  %bin.rdx47 = fadd reassoc <8 x float> %wide.load45, %bin.rdx46
  %10 = tail call reassoc float @llvm.vector.reduce.fadd.v8f32(float -0.000000e+00, <8 x float> %bin.rdx47)

Metadata

Metadata

Labels

CPURelated to XLA on CPU

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions