Skip to content

Commit 1877f41

Browse files
Support strided Load/Store in SVE2 (#8888)
* Support strided Load/Store in SVE2 Structured load relies on shuffle_vectors for scalable vector, which is lowered to llvm.vector.deinterleave * Modify test cases of load/store in simd_op_check_sve2 Correct vector bits in load/store test cases for target with 256 bits vector * Skip load/store test for SVE2 with old LLVM With old LLVM (v20, v21), some of the tests of load/store in simd_op_check_sve2 fails due to LLVM crash with the messages: - Invalid size request on a scalable vector - Cannot select: t11: nxv16i8,nxv16i8,nxv16i8 = vector_deinterleave t43, t45, t47 * Fix load/store tests in simd_op_check_sve2 - Keep the test scope before this PR in case of old LLVM * Skip load/store scalar in simd_op_check_sve2 Fixes #8584
1 parent b8c7457 commit 1877f41

File tree

2 files changed

+62
-144
lines changed

2 files changed

+62
-144
lines changed

src/CodeGen_ARM.cpp

Lines changed: 10 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,17 +1475,13 @@ void CodeGen_ARM::visit(const Store *op) {
14751475
is_float16_and_has_feature(elt) ||
14761476
elt == Int(8) || elt == Int(16) || elt == Int(32) || elt == Int(64) ||
14771477
elt == UInt(8) || elt == UInt(16) || elt == UInt(32) || elt == UInt(64)) {
1478-
// TODO(zvookin): Handle vector_bits_*.
1478+
const int target_vector_bits = native_vector_bits();
14791479
if (vec_bits % 128 == 0) {
14801480
type_ok_for_vst = true;
1481-
int target_vector_bits = native_vector_bits();
1482-
if (target_vector_bits == 0) {
1483-
target_vector_bits = 128;
1484-
}
14851481
intrin_type = intrin_type.with_lanes(target_vector_bits / t.bits());
14861482
} else if (vec_bits % 64 == 0) {
14871483
type_ok_for_vst = true;
1488-
auto intrin_bits = (vec_bits % 128 == 0 || target.has_feature(Target::SVE2)) ? 128 : 64;
1484+
auto intrin_bits = (vec_bits % 128 == 0 || target.has_feature(Target::SVE2)) ? target_vector_bits : 64;
14891485
intrin_type = intrin_type.with_lanes(intrin_bits / t.bits());
14901486
}
14911487
}
@@ -1494,7 +1490,9 @@ void CodeGen_ARM::visit(const Store *op) {
14941490
if (ramp && is_const_one(ramp->stride) &&
14951491
shuffle && shuffle->is_interleave() &&
14961492
type_ok_for_vst &&
1497-
2 <= shuffle->vectors.size() && shuffle->vectors.size() <= 4) {
1493+
2 <= shuffle->vectors.size() && shuffle->vectors.size() <= 4 &&
1494+
// TODO: we could handle predicated_store once shuffle_vector gets robust for scalable vectors
1495+
!is_predicated_store) {
14981496

14991497
const int num_vecs = shuffle->vectors.size();
15001498
vector<Value *> args(num_vecs);
@@ -1513,7 +1511,6 @@ void CodeGen_ARM::visit(const Store *op) {
15131511
for (int i = 0; i < num_vecs; ++i) {
15141512
args[i] = codegen(shuffle->vectors[i]);
15151513
}
1516-
Value *store_pred_val = codegen(op->predicate);
15171514

15181515
bool is_sve = target.has_feature(Target::SVE2);
15191516

@@ -1559,8 +1556,8 @@ void CodeGen_ARM::visit(const Store *op) {
15591556
llvm::FunctionCallee fn = module->getOrInsertFunction(instr.str(), fn_type);
15601557
internal_assert(fn);
15611558

1562-
// SVE2 supports predication for smaller than whole vector size.
1563-
internal_assert(target.has_feature(Target::SVE2) || (t.lanes() >= intrin_type.lanes()));
1559+
// Scalable vector supports predication for smaller than whole vector size.
1560+
internal_assert(target_vscale() > 0 || (t.lanes() >= intrin_type.lanes()));
15641561

15651562
for (int i = 0; i < t.lanes(); i += intrin_type.lanes()) {
15661563
Expr slice_base = simplify(ramp->base + i * num_vecs);
@@ -1581,15 +1578,10 @@ void CodeGen_ARM::visit(const Store *op) {
15811578
slice_args.push_back(ConstantInt::get(i32_t, alignment));
15821579
} else {
15831580
if (is_sve) {
1584-
// Set the predicate argument
1581+
// Set the predicate argument to mask active lanes
15851582
auto active_lanes = std::min(t.lanes() - i, intrin_type.lanes());
1586-
Value *vpred_val;
1587-
if (is_predicated_store) {
1588-
vpred_val = slice_vector(store_pred_val, i, intrin_type.lanes());
1589-
} else {
1590-
Expr vpred = make_vector_predicate_1s_0s(active_lanes, intrin_type.lanes() - active_lanes);
1591-
vpred_val = codegen(vpred);
1592-
}
1583+
Expr vpred = make_vector_predicate_1s_0s(active_lanes, intrin_type.lanes() - active_lanes);
1584+
Value *vpred_val = codegen(vpred);
15931585
slice_args.push_back(vpred_val);
15941586
}
15951587
// Set the pointer argument
@@ -1810,74 +1802,6 @@ void CodeGen_ARM::visit(const Load *op) {
18101802
CodeGen_Posix::visit(op);
18111803
return;
18121804
}
1813-
} else if (stride && (2 <= stride->value && stride->value <= 4)) {
1814-
// Structured load ST2/ST3/ST4 of SVE
1815-
1816-
Expr base = ramp->base;
1817-
ModulusRemainder align = op->alignment;
1818-
1819-
int aligned_stride = gcd(stride->value, align.modulus);
1820-
int offset = 0;
1821-
if (aligned_stride == stride->value) {
1822-
offset = mod_imp((int)align.remainder, aligned_stride);
1823-
} else {
1824-
const Add *add = base.as<Add>();
1825-
if (const IntImm *add_c = add ? add->b.as<IntImm>() : base.as<IntImm>()) {
1826-
offset = mod_imp(add_c->value, stride->value);
1827-
}
1828-
}
1829-
1830-
if (offset) {
1831-
base = simplify(base - offset);
1832-
}
1833-
1834-
Value *load_pred_val = codegen(op->predicate);
1835-
1836-
// We need to slice the result in to native vector lanes to use sve intrin.
1837-
// LLVM will optimize redundant ld instructions afterwards
1838-
const int slice_lanes = target.natural_vector_size(op->type);
1839-
vector<Value *> results;
1840-
for (int i = 0; i < op->type.lanes(); i += slice_lanes) {
1841-
int load_base_i = i * stride->value;
1842-
Expr slice_base = simplify(base + load_base_i);
1843-
Expr slice_index = Ramp::make(slice_base, stride, slice_lanes);
1844-
std::ostringstream instr;
1845-
instr << "llvm.aarch64.sve.ld"
1846-
<< stride->value
1847-
<< ".sret.nxv"
1848-
<< slice_lanes
1849-
<< (op->type.is_float() ? 'f' : 'i')
1850-
<< op->type.bits();
1851-
llvm::Type *elt = llvm_type_of(op->type.element_of());
1852-
llvm::Type *slice_type = get_vector_type(elt, slice_lanes);
1853-
StructType *sret_type = StructType::get(module->getContext(), std::vector(stride->value, slice_type));
1854-
std::vector<llvm::Type *> arg_types{get_vector_type(i1_t, slice_lanes), ptr_t};
1855-
llvm::FunctionType *fn_type = FunctionType::get(sret_type, arg_types, false);
1856-
FunctionCallee fn = module->getOrInsertFunction(instr.str(), fn_type);
1857-
1858-
// Set the predicate argument
1859-
int active_lanes = std::min(op->type.lanes() - i, slice_lanes);
1860-
1861-
Expr vpred = make_vector_predicate_1s_0s(active_lanes, slice_lanes - active_lanes);
1862-
Value *vpred_val = codegen(vpred);
1863-
vpred_val = convert_fixed_or_scalable_vector_type(vpred_val, get_vector_type(vpred_val->getType()->getScalarType(), slice_lanes));
1864-
if (is_predicated_load) {
1865-
Value *sliced_load_vpred_val = slice_vector(load_pred_val, i, slice_lanes);
1866-
vpred_val = builder->CreateAnd(vpred_val, sliced_load_vpred_val);
1867-
}
1868-
1869-
Value *elt_ptr = codegen_buffer_pointer(op->name, op->type.element_of(), slice_base);
1870-
CallInst *load_i = builder->CreateCall(fn, {vpred_val, elt_ptr});
1871-
add_tbaa_metadata(load_i, op->name, slice_index);
1872-
// extract one element out of returned struct
1873-
Value *extracted = builder->CreateExtractValue(load_i, offset);
1874-
results.push_back(extracted);
1875-
}
1876-
1877-
// Retrieve original lanes
1878-
value = concat_vectors(results);
1879-
value = slice_vector(value, 0, op->type.lanes());
1880-
return;
18811805
} else if (op->index.type().is_vector()) {
18821806
// General Gather Load
18831807

test/correctness/simd_op_check_sve2.cpp

Lines changed: 52 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -677,79 +677,81 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
677677
vector<tuple<Type, CastFuncTy>> test_params = {
678678
{Int(8), in_i8}, {Int(16), in_i16}, {Int(32), in_i32}, {Int(64), in_i64}, {UInt(8), in_u8}, {UInt(16), in_u16}, {UInt(32), in_u32}, {UInt(64), in_u64}, {Float(16), in_f16}, {Float(32), in_f32}, {Float(64), in_f64}};
679679

680+
const int base_vec_bits = has_sve() ? target.vector_bits : 128;
681+
const int vscale = base_vec_bits / 128;
682+
680683
for (const auto &[elt, in_im] : test_params) {
681684
const int bits = elt.bits();
682685
if ((elt == Float(16) && !is_float16_supported()) ||
683686
(is_arm32() && bits == 64)) {
684687
continue;
685688
}
686689

687-
// LD/ST - Load/Store
688-
for (int width = 64; width <= 64 * 4; width *= 2) {
689-
const int total_lanes = width / bits;
690-
const int instr_lanes = min(total_lanes, 128 / bits);
691-
if (instr_lanes < 2) continue; // bail out scalar op
690+
// LD/ST - Load/Store scalar
691+
// We skip scalar load/store test due to the following challenges.
692+
// The rule by which LLVM selects instruction does not seem simple.
693+
// For example, ld1, ldr, or ldp is used for instruction and z or q register is used for operand,
694+
// depending on data type, vscale, what is performed before/after load, and LLVM version.
695+
// The other thing is, load/store instruction appears in other place than we want to check,
696+
// which makes it prone to false-positive detection as we only search strings line-by-line.
692697

693-
// In case of arm32, instruction selection looks inconsistent due to optimization by LLVM
694-
AddTestFunctor add(*this, bits, total_lanes, target.bits == 64);
695-
// NOTE: if the expr is too simple, LLVM might generate "bl memcpy"
696-
Expr load_store_1 = in_im(x) * 3;
698+
// LDn - Structured Load strided elements
699+
if (Halide::Internal::get_llvm_version() >= 220) {
700+
for (int stride = 2; stride <= 4; ++stride) {
697701

698-
if (has_sve()) {
699-
// This pattern has changed with LLVM 21, see https://github.com/halide/Halide/issues/8584 for more
700-
// details.
701-
if (Halide::Internal::get_llvm_version() < 210) {
702-
// in native width, ld1b/st1b is used regardless of data type
703-
const bool allow_byte_ls = (width == target.vector_bits);
704-
add({get_sve_ls_instr("ld1", bits, bits, "", allow_byte_ls ? "b" : "")}, total_lanes, load_store_1);
705-
add({get_sve_ls_instr("st1", bits, bits, "", allow_byte_ls ? "b" : "")}, total_lanes, load_store_1);
702+
for (int factor : {1, 2, 4}) {
703+
const int vector_lanes = base_vec_bits * factor / bits;
704+
705+
// In StageStridedLoads.cpp (stride < r->lanes) is the condition for staging to happen
706+
// See https://github.com/halide/Halide/issues/8819
707+
if (vector_lanes <= stride) continue;
708+
709+
AddTestFunctor add_ldn(*this, bits, vector_lanes);
710+
711+
Expr load_n = in_im(x * stride) + in_im(x * stride + stride - 1);
712+
713+
const string ldn_str = "ld" + to_string(stride);
714+
if (has_sve()) {
715+
add_ldn({get_sve_ls_instr(ldn_str, bits)}, vector_lanes, load_n);
716+
} else {
717+
add_ldn(sel_op("v" + ldn_str + ".", ldn_str), load_n);
718+
}
706719
}
707-
} else {
708-
// vector register is not used for simple load/store
709-
string reg_prefix = (width <= 64) ? "d" : "q";
710-
add({{"st[rp]", reg_prefix + R"(\d\d?)"}}, total_lanes, load_store_1);
711-
add({{"ld[rp]", reg_prefix + R"(\d\d?)"}}, total_lanes, load_store_1);
712720
}
713721
}
714722

715-
// LD2/ST2 - Load/Store two-element structures
716-
int base_vec_bits = has_sve() ? target.vector_bits : 128;
717-
for (int width = base_vec_bits; width <= base_vec_bits * 4; width *= 2) {
723+
// ST2 - Store two-element structures
724+
for (int factor : {1, 2}) {
725+
const int width = base_vec_bits * 2 * factor;
718726
const int total_lanes = width / bits;
719727
const int vector_lanes = total_lanes / 2;
720728
const int instr_lanes = min(vector_lanes, base_vec_bits / bits);
721-
if (instr_lanes < 2) continue; // bail out scalar op
729+
if (instr_lanes < 2 || (vector_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>
722730

723-
AddTestFunctor add_ldn(*this, bits, vector_lanes);
724731
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);
725732

726733
Func tmp1, tmp2;
727734
tmp1(x) = cast(elt, x);
728735
tmp1.compute_root();
729736
tmp2(x, y) = select(x % 2 == 0, tmp1(x / 2), tmp1(x / 2 + 16));
730737
tmp2.compute_root().vectorize(x, total_lanes);
731-
Expr load_2 = in_im(x * 2) + in_im(x * 2 + 1);
732738
Expr store_2 = tmp2(0, 0) + tmp2(0, 127);
733739

734740
if (has_sve()) {
735-
// TODO(inssue needed): Added strided load support.
736-
#if 0
737-
add_ldn({get_sve_ls_instr("ld2", bits)}, vector_lanes, load_2);
738-
#endif
739741
add_stn({get_sve_ls_instr("st2", bits)}, total_lanes, store_2);
740742
} else {
741-
add_ldn(sel_op("vld2.", "ld2"), load_2);
742743
add_stn(sel_op("vst2.", "st2"), store_2);
743744
}
744745
}
745746

746747
// Also check when the two expressions interleaved have a common
747748
// subexpression, which results in a vector var being lifted out.
748-
for (int width = base_vec_bits; width <= base_vec_bits * 4; width *= 2) {
749+
for (int factor : {1, 2}) {
750+
const int width = base_vec_bits * 2 * factor;
749751
const int total_lanes = width / bits;
750752
const int vector_lanes = total_lanes / 2;
751753
const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target);
752-
if (instr_lanes < 2) continue; // bail out scalar op
754+
if (instr_lanes < 2 || (vector_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>
753755

754756
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);
755757

@@ -768,14 +770,14 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
768770
}
769771
}
770772

771-
// LD3/ST3 - Store three-element structures
772-
for (int width = 192; width <= 192 * 4; width *= 2) {
773+
// ST3 - Store three-element structures
774+
for (int factor : {1, 2}) {
775+
const int width = base_vec_bits * 3 * factor;
773776
const int total_lanes = width / bits;
774777
const int vector_lanes = total_lanes / 3;
775778
const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target);
776-
if (instr_lanes < 2) continue; // bail out scalar op
779+
if (instr_lanes < 2 || (vector_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>
777780

778-
AddTestFunctor add_ldn(*this, bits, vector_lanes);
779781
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);
780782

781783
Func tmp1, tmp2;
@@ -785,29 +787,25 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
785787
x % 3 == 1, tmp1(x / 3 + 16),
786788
tmp1(x / 3 + 32));
787789
tmp2.compute_root().vectorize(x, total_lanes);
788-
Expr load_3 = in_im(x * 3) + in_im(x * 3 + 1) + in_im(x * 3 + 2);
789790
Expr store_3 = tmp2(0, 0) + tmp2(0, 127);
790791

791792
if (has_sve()) {
792-
// TODO(issue needed): Added strided load support.
793-
#if 0
794-
add_ldn({get_sve_ls_instr("ld3", bits)}, vector_lanes, load_3);
795-
add_stn({get_sve_ls_instr("st3", bits)}, total_lanes, store_3);
796-
#endif
793+
if (Halide::Internal::get_llvm_version() >= 220) {
794+
add_stn({get_sve_ls_instr("st3", bits)}, total_lanes, store_3);
795+
}
797796
} else {
798-
add_ldn(sel_op("vld3.", "ld3"), load_3);
799797
add_stn(sel_op("vst3.", "st3"), store_3);
800798
}
801799
}
802800

803-
// LD4/ST4 - Store four-element structures
804-
for (int width = 256; width <= 256 * 4; width *= 2) {
801+
// ST4 - Store four-element structures
802+
for (int factor : {1, 2}) {
803+
const int width = base_vec_bits * 4 * factor;
805804
const int total_lanes = width / bits;
806805
const int vector_lanes = total_lanes / 4;
807806
const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target);
808-
if (instr_lanes < 2) continue; // bail out scalar op
807+
if (instr_lanes < 2 || (vector_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>
809808

810-
AddTestFunctor add_ldn(*this, bits, vector_lanes);
811809
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);
812810

813811
Func tmp1, tmp2;
@@ -818,17 +816,13 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
818816
x % 4 == 2, tmp1(x / 4 + 32),
819817
tmp1(x / 4 + 48));
820818
tmp2.compute_root().vectorize(x, total_lanes);
821-
Expr load_4 = in_im(x * 4) + in_im(x * 4 + 1) + in_im(x * 4 + 2) + in_im(x * 4 + 3);
822819
Expr store_4 = tmp2(0, 0) + tmp2(0, 127);
823820

824821
if (has_sve()) {
825-
// TODO(issue needed): Added strided load support.
826-
#if 0
827-
add_ldn({get_sve_ls_instr("ld4", bits)}, vector_lanes, load_4);
828-
add_stn({get_sve_ls_instr("st4", bits)}, total_lanes, store_4);
829-
#endif
822+
if (Halide::Internal::get_llvm_version() >= 220) {
823+
add_stn({get_sve_ls_instr("st4", bits)}, total_lanes, store_4);
824+
}
830825
} else {
831-
add_ldn(sel_op("vld4.", "ld4"), load_4);
832826
add_stn(sel_op("vst4.", "st4"), store_4);
833827
}
834828
}
@@ -838,7 +832,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
838832
for (int width = 64; width <= 64 * 4; width *= 2) {
839833
const int total_lanes = width / bits;
840834
const int instr_lanes = min(total_lanes, 128 / bits);
841-
if (instr_lanes < 2) continue; // bail out scalar op
835+
if (instr_lanes < 2 || (total_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>
842836

843837
AddTestFunctor add(*this, bits, total_lanes);
844838
Expr index = clamp(cast<int>(in_im(x)), 0, W - 1);

0 commit comments

Comments
 (0)