Skip to content

Commit 03fe699

Browse files
minansysCopilot
andauthored
Fix crash on partial-window accumulation into fixed-width vectors (#2712)
* Fix crash on partial-window accumulation into fixed-width vectors * add the regression * Update partial_vec_window.ll * fix the comments Co-authored-by: Copilot <copilot@github.com> * update tests --------- Co-authored-by: Copilot <copilot@github.com>
1 parent 7e54cc3 commit 03fe699

3 files changed

Lines changed: 160 additions & 0 deletions

File tree

enzyme/Enzyme/DiffeGradientUtils.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,67 @@ SmallVector<SelectInst *, 4> DiffeGradientUtils::addToDiffe(
364364
return res;
365365
}
366366

367+
if (auto VecT = dyn_cast<VectorType>(VT)) {
368+
if (!VecT->getElementCount().isScalable()) {
369+
Type *elemTy = VecT->getElementType();
370+
auto elemBytes = (DL.getTypeSizeInBits(elemTy) + 7) / 8;
371+
372+
// Only handle element-aligned windows
373+
if (elemBytes != 0 && start % elemBytes == 0 && size % elemBytes == 0) {
374+
unsigned left_idx = start / elemBytes;
375+
unsigned right_idx = (start + size) / elemBytes; // exclusive
376+
377+
unsigned numElts = VecT->getElementCount().getFixedValue();
378+
if (left_idx > numElts)
379+
left_idx = numElts;
380+
if (right_idx > numElts)
381+
right_idx = numElts;
382+
383+
auto maskVec = [&](Value *dsub) -> Value * {
384+
if (left_idx == 0 && right_idx == numElts)
385+
return dsub;
386+
Value *masked = Constant::getNullValue(VT);
387+
for (unsigned i = left_idx; i < right_idx; i++) {
388+
Value *vidx =
389+
ConstantInt::get(Type::getInt32Ty(val->getContext()), i);
390+
Value *el = BuilderM.CreateExtractElement(dsub, vidx);
391+
masked = BuilderM.CreateInsertElement(masked, el, vidx);
392+
}
393+
return masked;
394+
};
395+
396+
if (getWidth() == 1) {
397+
SmallVector<unsigned, 1> eidxs;
398+
for (auto idx : idxs.slice(ignoreFirstSlicesOfDif))
399+
eidxs.push_back((unsigned)cast<ConstantInt>(idx)->getZExtValue());
400+
401+
Value *subdif = extractMeta(BuilderM, dif, eidxs);
402+
return addToDiffe(val, maskVec(subdif), BuilderM, addingType, idxs,
403+
mask);
404+
} else {
405+
SmallVector<SelectInst *, 4> res;
406+
for (unsigned j = 0; j < getWidth(); j++) {
407+
SmallVector<Value *, 1> lidxs;
408+
SmallVector<unsigned, 1> eidxs = {(unsigned)j};
409+
410+
lidxs.push_back(
411+
ConstantInt::get(Type::getInt32Ty(val->getContext()), j));
412+
for (auto idx : idxs.slice(ignoreFirstSlicesOfDif))
413+
eidxs.push_back((unsigned)cast<ConstantInt>(idx)->getZExtValue());
414+
for (auto idx : idxs)
415+
lidxs.push_back(idx);
416+
417+
Value *subdif = extractMeta(BuilderM, dif, eidxs);
418+
for (auto v : addToDiffe(val, maskVec(subdif), BuilderM, addingType,
419+
lidxs, mask))
420+
res.push_back(v);
421+
}
422+
return res;
423+
}
424+
}
425+
}
426+
}
427+
367428
llvm::errs() << " VT: " << *VT << " idxs:{";
368429
for (auto idx : idxs)
369430
llvm::errs() << *idx << ",";
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,early-cse,sroa,instsimplify,%simplifycfg,adce)" -enzyme-preopt=false -S | FileCheck %s
2+
; Regression test: partial-window accumulation into a fixed vector (<2 x float>).
3+
; Previously asserted: "unhandled accumulate with partial sizes".
4+
5+
source_filename = "partial_vec_window"
6+
target triple = "x86_64-pc-linux-gnu"
7+
8+
%ret2v = type { <2 x float>, <2 x float> }
9+
10+
define %ret2v @make(float %x) {
11+
entry:
12+
%v0 = insertelement <2 x float> zeroinitializer, float %x, i32 0
13+
%r0 = insertvalue %ret2v undef, <2 x float> %v0, 0
14+
%r1 = insertvalue %ret2v %r0, <2 x float> zeroinitializer, 1
15+
ret %ret2v %r1
16+
}
17+
18+
define float @tester(float %x) {
19+
entry:
20+
%call = call %ret2v @make(float %x)
21+
%vec = extractvalue %ret2v %call, 0
22+
23+
; Force "partial" use: only the first 4 bytes of the <2 x float>
24+
%tmp = alloca <2 x float>, align 8
25+
store <2 x float> %vec, <2 x float>* %tmp, align 8
26+
%fp = bitcast <2 x float>* %tmp to float*
27+
%a = load float, float* %fp, align 4
28+
29+
ret float %a
30+
}
31+
32+
define float @test_derivative(float %x) {
33+
entry:
34+
%d = call float (float (float)*, ...) @__enzyme_autodiff(float (float)* @tester, float %x)
35+
ret float %d
36+
}
37+
38+
declare float @__enzyme_autodiff(float (float)*, ...)
39+
; CHECK: @diffetester
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -early-cse -S | FileCheck %s; fi
2+
; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,%simplifycfg,early-cse)" -enzyme-preopt=false -S | FileCheck %s
3+
4+
; Regression test: reverse vector mode handles partial-window accumulation into a fixed vector.
5+
6+
%struct.Gradients = type { [2 x float] }
7+
%ret2v = type { <2 x float>, <2 x float> }
8+
9+
declare %struct.Gradients @__enzyme_autodiff(float (float)*, ...)
10+
11+
define %ret2v @make(float %x) {
12+
entry:
13+
%v0 = insertelement <2 x float> zeroinitializer, float %x, i32 0
14+
%r0 = insertvalue %ret2v undef, <2 x float> %v0, 0
15+
%r1 = insertvalue %ret2v %r0, <2 x float> zeroinitializer, 1
16+
ret %ret2v %r1
17+
}
18+
19+
define float @tester(float %x) {
20+
entry:
21+
%call = call %ret2v @make(float %x)
22+
%vec = extractvalue %ret2v %call, 0
23+
%tmp = alloca <2 x float>, align 8
24+
store <2 x float> %vec, <2 x float>* %tmp, align 8
25+
%fp = bitcast <2 x float>* %tmp to float*
26+
%a = load float, float* %fp, align 4
27+
ret float %a
28+
}
29+
30+
define %struct.Gradients @test_derivative(float %x) {
31+
entry:
32+
%d = call %struct.Gradients (float (float)*, ...) @__enzyme_autodiff(float (float)* @tester, metadata !"enzyme_width", i64 2, float %x)
33+
ret %struct.Gradients %d
34+
}
35+
36+
; CHECK-LABEL: define internal { [2 x float] } @diffe2tester(float %x, [2 x float] %differeturn)
37+
; CHECK: entry:
38+
; CHECK: %"vec'de" = alloca [2 x <2 x float>]
39+
; CHECK: %"call'de" = alloca [2 x %ret2v]
40+
; CHECK: %"x'de" = alloca [2 x float]
41+
; CHECK: %call_augmented = call [2 x %ret2v] @augmented_make(float %x)
42+
; CHECK: %"tmp'ipa" = alloca <2 x float>
43+
; CHECK: %"tmp'ipa1" = alloca <2 x float>
44+
; CHECK: %[[D0:.+]] = extractvalue [2 x float] %differeturn, 0
45+
; CHECK: %[[L0:.+]] = load float, {{.*}}align 4{{.*}}
46+
; CHECK: %[[A0:.+]] = fadd fast float %[[L0]], %[[D0]]
47+
; CHECK: store float %[[A0]], {{.*}}align 4{{.*}}
48+
; CHECK: %[[D1:.+]] = extractvalue [2 x float] %differeturn, 1
49+
; CHECK: %[[L1:.+]] = load float, {{.*}}align 4{{.*}}
50+
; CHECK: %[[A1:.+]] = fadd fast float %[[L1]], %[[D1]]
51+
; CHECK: store float %[[A1]], {{.*}}align 4{{.*}}
52+
; CHECK: %[[V0:.+]] = load <2 x float>, {{.*}}align 8{{.*}}
53+
; CHECK: %[[V1:.+]] = load <2 x float>, {{.*}}align 8{{.*}}
54+
; CHECK: %[[PACK:.+]] = load [2 x <2 x float>], {{.*}}align 8
55+
; CHECK: %[[LANE0V:.+]] = extractvalue [2 x <2 x float>] %[[PACK]], 0
56+
; CHECK: %[[LANE0:.+]] = extractelement <2 x float> %[[LANE0V]], i32 0
57+
; CHECK: %[[LANE1V:.+]] = extractvalue [2 x <2 x float>] %[[PACK]], 1
58+
; CHECK: %[[LANE1:.+]] = extractelement <2 x float> %[[LANE1V]], i32 0
59+
; CHECK: %[[MAKE:.+]] = call { [2 x float] } @diffe2make(float %x)
60+
; CHECK: ret { [2 x float] }

0 commit comments

Comments
 (0)