Skip to content

Commit c860160

Browse files
Replace streaming_vector_bits with Feature::SME_SVL
Reason: While vector_bits is used across multiple target architectures, streaming_vector_bits is aarch64 specific. So we choose to use Target::Feature rather than a new member for arbitrary bits. - Removed Target::streaming_vector_bits member variable - Added Feature::SME_SVL{128,256,512,1024,2048}
1 parent 75596df commit c860160

11 files changed

Lines changed: 162 additions & 41 deletions

File tree

python_bindings/src/halide/halide_/PyEnums.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,11 @@ void define_enums(py::module &m) {
188188
.value("SVE", Target::Feature::SVE)
189189
.value("SVE2", Target::Feature::SVE2)
190190
.value("SME2", Target::Feature::SME2)
191+
.value("SME_SVL128", Target::Feature::SME_SVL128)
192+
.value("SME_SVL256", Target::Feature::SME_SVL256)
193+
.value("SME_SVL512", Target::Feature::SME_SVL512)
194+
.value("SME_SVL1024", Target::Feature::SME_SVL1024)
195+
.value("SME_SVL2048", Target::Feature::SME_SVL2048)
191196
.value("ARMDotProd", Target::Feature::ARMDotProd)
192197
.value("ARMFp16", Target::Feature::ARMFp16)
193198
.value("LLVMLargeCodeModel", Target::Feature::LLVMLargeCodeModel)

python_bindings/src/halide/halide_/PyTarget.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,11 @@ void define_target(py::module &m) {
5555
.def("supports_device_api", &Target::supports_device_api, py::arg("device"))
5656
.def("natural_vector_size", natural_vector_size1_method, py::arg("type"))
5757
.def("natural_vector_size", natural_vector_size2_method, py::arg("type"), py::arg("is_sme_streaming"))
58+
.def("sme_streaming_vector_bits", &Target::sme_streaming_vector_bits)
5859
.def("has_large_buffers", &Target::has_large_buffers)
5960
.def("maximum_buffer_size", &Target::maximum_buffer_size)
6061
.def("supported", &Target::supported)
62+
.def_static("sme_svl_feature_from_bits", &Target::sme_svl_feature_from_bits, py::arg("bits"))
6163
.def_static("validate_target_string", &Target::validate_target_string, py::arg("name"));
6264
;
6365

src/CodeGen_ARM.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,8 +1070,8 @@ void CodeGen_ARM::init_module() {
10701070
user_warning << "Halide does not support SVE for now. Use SVE2 if your target device supports it.\n";
10711071
}
10721072
if (target.has_feature(Target::SME2)) {
1073-
user_assert(target.streaming_vector_bits != 0) << "For SME2 support, Target::streaming_vector_bits must be set. For generator target strings, add \"streaming_vector_bits_<bits>\".\n";
1074-
user_assert((target.streaming_vector_bits % 128) == 0) << "For SME2 support, Target::streaming_vector_bits must be a multiple of 128.\n";
1073+
user_assert(target.sme_streaming_vector_bits() != 0)
1074+
<< "For SME2 support, exactly one Target::SME_SVL* feature must be set. For generator target strings, add \"sme_svl<bits>\".\n";
10751075
}
10761076

10771077
const bool has_neon = !target.has_feature(Target::NoNEON);
@@ -1141,7 +1141,7 @@ void CodeGen_ARM::init_module() {
11411141
intrinsics_map = &intrinsics_sve2;
11421142
break;
11431143
case SIMDFlavors::Streaming:
1144-
vscale = target.streaming_vector_bits / 128;
1144+
vscale = target.sme_streaming_vector_bits() / 128;
11451145
intrinsics_map = &intrinsics_streaming;
11461146
break;
11471147
default:
@@ -1295,7 +1295,7 @@ void CodeGen_ARM::compile_func(const LoweredFunc &f,
12951295
});
12961296

12971297
if (is_streaming_task) {
1298-
feasible_vscale = check_feasible_vscale(target.streaming_vector_bits, // SVL
1298+
feasible_vscale = check_feasible_vscale(target.sme_streaming_vector_bits(), // SVL
12991299
lanes_used, "streaming_", simple_name);
13001300
}
13011301
in_streaming = (feasible_vscale > 0) && is_streaming_task;
@@ -1344,7 +1344,7 @@ void CodeGen_ARM::compile_func(const LoweredFunc &f,
13441344
// We check regardless of streaming mode enabled or not
13451345
// because streaming task is basically internal linkage.
13461346
Expr runtime_vscale = Call::make(Int(32), Call::get_runtime_streaming_vscale, {}, Call::PureIntrinsic);
1347-
Expr compiletime_vscale = Expr(target.streaming_vector_bits / 128);
1347+
Expr compiletime_vscale = Expr(target.sme_streaming_vector_bits() / 128);
13481348
std::vector<Expr> args{simple_name, std::string("streaming"), runtime_vscale, compiletime_vscale};
13491349
Expr error = Call::make(Int(32), "halide_error_vscale_invalid", args, Call::Extern);
13501350
func.body = Block::make(AssertStmt::make(runtime_vscale == compiletime_vscale, error), func.body);

src/Target.cpp

Lines changed: 88 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,6 @@ Target calculate_host_target() {
274274
bool use_64_bits = (sizeof(size_t) == 8);
275275
int bits = use_64_bits ? 64 : 32;
276276
int vector_bits = 0;
277-
int streaming_vector_bits = 0;
278277
Target::Processor processor = Target::Processor::ProcessorGeneric;
279278
std::vector<Target::Feature> initial_features;
280279

@@ -386,7 +385,11 @@ Target calculate_host_target() {
386385
vector_bits = get_sve_vector_length();
387386
}
388387
if (has_streaming_scalable_vector) {
389-
streaming_vector_bits = get_sme_streaming_vector_length();
388+
const int streaming_vector_bits = get_sme_streaming_vector_length();
389+
Target::Feature sme_svl = Target::sme_svl_feature_from_bits(streaming_vector_bits);
390+
user_assert(sme_svl != Target::FeatureEnd)
391+
<< "Detected unsupported SME streaming vector length " << streaming_vector_bits << " bits.\n";
392+
initial_features.push_back(sme_svl);
390393
}
391394
#endif
392395

@@ -461,7 +464,7 @@ Target calculate_host_target() {
461464
processor = get_amd_processor(family, model, have_sse3);
462465

463466
if (processor == Target::Processor::ZnVer4) {
464-
Target t{os, arch, bits, processor, initial_features, vector_bits, streaming_vector_bits};
467+
Target t{os, arch, bits, processor, initial_features, vector_bits};
465468
t.set_feature(Target::SSE41);
466469
if (os_avx) {
467470
t.set_features({Target::AVX, Target::F16C, Target::FMA, Target::AVX2});
@@ -472,7 +475,7 @@ Target calculate_host_target() {
472475
}
473476
return t;
474477
} else if (processor == Target::Processor::ZnVer5) {
475-
Target t{os, arch, bits, processor, initial_features, vector_bits, streaming_vector_bits};
478+
Target t{os, arch, bits, processor, initial_features, vector_bits};
476479
t.set_feature(Target::SSE41);
477480
if (os_avx) {
478481
t.set_features({Target::AVX, Target::F16C, Target::FMA,
@@ -589,7 +592,7 @@ Target calculate_host_target() {
589592
#endif
590593
#endif
591594

592-
return {os, arch, bits, processor, initial_features, vector_bits, streaming_vector_bits};
595+
return {os, arch, bits, processor, initial_features, vector_bits};
593596
}
594597

595598
bool is_using_hexagon(const Target &t) {
@@ -842,6 +845,11 @@ const std::map<std::string, Target::Feature> feature_name_map = {
842845
{"sve", Target::SVE},
843846
{"sve2", Target::SVE2},
844847
{"sme2", Target::SME2},
848+
{"sme_svl128", Target::SME_SVL128},
849+
{"sme_svl256", Target::SME_SVL256},
850+
{"sme_svl512", Target::SME_SVL512},
851+
{"sme_svl1024", Target::SME_SVL1024},
852+
{"sme_svl2048", Target::SME_SVL2048},
845853
{"arm_dot_prod", Target::ARMDotProd},
846854
{"arm_fp16", Target::ARMFp16},
847855
{"llvm_large_code_model", Target::LLVMLargeCodeModel},
@@ -1000,8 +1008,6 @@ bool merge_string(Target &t, const std::string &target) {
10001008
features_specified = true;
10011009
} else if (auto vb = parse_vector_bits(tok, "vector_bits_"); vb >= 0) {
10021010
t.vector_bits = vb;
1003-
} else if (auto svb = parse_vector_bits(tok, "streaming_vector_bits_"); svb >= 0) {
1004-
t.streaming_vector_bits = svb;
10051011
} else {
10061012
return false;
10071013
}
@@ -1123,6 +1129,11 @@ void Target::validate_features() const {
11231129
POWER_ARCH_2_07,
11241130
RVV,
11251131
SME2,
1132+
SME_SVL128,
1133+
SME_SVL256,
1134+
SME_SVL512,
1135+
SME_SVL1024,
1136+
SME_SVL2048,
11261137
SVE,
11271138
SVE2,
11281139
VSX,
@@ -1184,12 +1195,31 @@ void Target::validate_features() const {
11841195
POWER_ARCH_2_07,
11851196
RVV,
11861197
SSE41,
1198+
SME_SVL128,
1199+
SME_SVL256,
1200+
SME_SVL512,
1201+
SME_SVL1024,
1202+
SME_SVL2048,
11871203
SME2,
11881204
SVE,
11891205
SVE2,
11901206
VSX,
11911207
});
11921208
}
1209+
1210+
const int num_sme_svl_features =
1211+
(int)has_feature(SME_SVL128) +
1212+
(int)has_feature(SME_SVL256) +
1213+
(int)has_feature(SME_SVL512) +
1214+
(int)has_feature(SME_SVL1024) +
1215+
(int)has_feature(SME_SVL2048);
1216+
1217+
user_assert(num_sme_svl_features <= 1)
1218+
<< "Target may have at most one SME_SVL feature.\n";
1219+
user_assert(!has_feature(SME2) || num_sme_svl_features == 1)
1220+
<< "Target feature sme2 requires exactly one SME_SVL feature.\n";
1221+
user_assert(has_feature(SME2) || num_sme_svl_features == 0)
1222+
<< "Target features SME_SVL128, SME_SVL256, SME_SVL512, SME_SVL1024, and SME_SVL2048 require target feature sme2.\n";
11931223
}
11941224

11951225
Target::Target(const std::string &target) {
@@ -1233,6 +1263,23 @@ Target::Feature Target::feature_from_name(const std::string &name) {
12331263
return Target::FeatureEnd;
12341264
}
12351265

1266+
Target::Feature Target::sme_svl_feature_from_bits(int bits) {
1267+
switch (bits) {
1268+
case 128:
1269+
return Target::SME_SVL128;
1270+
case 256:
1271+
return Target::SME_SVL256;
1272+
case 512:
1273+
return Target::SME_SVL512;
1274+
case 1024:
1275+
return Target::SME_SVL1024;
1276+
case 2048:
1277+
return Target::SME_SVL2048;
1278+
default:
1279+
return Target::FeatureEnd;
1280+
}
1281+
}
1282+
12361283
std::string Target::to_string() const {
12371284
string result;
12381285
for (const auto &arch_entry : arch_name_map) {
@@ -1269,9 +1316,6 @@ std::string Target::to_string() const {
12691316
if (vector_bits != 0) {
12701317
result += "-vector_bits_" + std::to_string(vector_bits);
12711318
}
1272-
if (streaming_vector_bits != 0) {
1273-
result += "-streaming_vector_bits_" + std::to_string(streaming_vector_bits);
1274-
}
12751319

12761320
return result;
12771321
}
@@ -1600,6 +1644,31 @@ int Target::natural_vector_size(const Halide::Type &t) const {
16001644
return natural_vector_size(t, false);
16011645
}
16021646

1647+
int Target::sme_streaming_vector_bits() const {
1648+
int result = 0;
1649+
auto set_result = [&result](int bits) {
1650+
user_assert(result == 0)
1651+
<< "Target may have at most one SME_SVL feature.\n";
1652+
result = bits;
1653+
};
1654+
if (has_feature(Target::SME_SVL128)) {
1655+
set_result(128);
1656+
}
1657+
if (has_feature(Target::SME_SVL256)) {
1658+
set_result(256);
1659+
}
1660+
if (has_feature(Target::SME_SVL512)) {
1661+
set_result(512);
1662+
}
1663+
if (has_feature(Target::SME_SVL1024)) {
1664+
set_result(1024);
1665+
}
1666+
if (has_feature(Target::SME_SVL2048)) {
1667+
set_result(2048);
1668+
}
1669+
return result;
1670+
}
1671+
16031672
int Target::natural_vector_size(const Halide::Type &t, bool is_sme_streaming) const {
16041673
user_assert(!has_unknowns())
16051674
<< "natural_vector_size cannot be used on a Target with Unknown values.\n";
@@ -1609,9 +1678,9 @@ int Target::natural_vector_size(const Halide::Type &t, bool is_sme_streaming) co
16091678

16101679
if (arch == Target::ARM) {
16111680
if (is_sme_streaming &&
1612-
streaming_vector_bits != 0 &&
1681+
sme_streaming_vector_bits() != 0 &&
16131682
has_feature(Halide::Target::SME2)) {
1614-
return streaming_vector_bits / (data_size * 8);
1683+
return sme_streaming_vector_bits() / (data_size * 8);
16151684
} else if (vector_bits != 0 &&
16161685
(has_feature(Halide::Target::SVE2) ||
16171686
(t.is_float() && has_feature(Halide::Target::SVE)))) {
@@ -1965,14 +2034,13 @@ void target_test() {
19652034
internal_assert(Target(with_vector_bits.to_string()).vector_bits == 512) << "Vector bits not round tripped properly.\n";
19662035
internal_assert(with_vector_bits.natural_vector_size(Int(32)) == 16) << "Wrong natural_vector_size.\n";
19672036

1968-
// Tests for streaming_vector_bits
1969-
internal_assert(Target().streaming_vector_bits == 0) << "Default Target streaming_vector_bits not 0.\n";
1970-
internal_assert(Target("arm-64-linux-sme2-vector_bits_512-streaming_vector_bits_1024").streaming_vector_bits == 1024) << "Streaming vector bits not parsed correctly.\n";
1971-
Target with_streaming_vector_bits(Target::Linux, Target::ARM, 64, Target::ProcessorGeneric, {Target::SVE2, Target::SME2}, 512, 1024);
1972-
internal_assert(with_streaming_vector_bits.streaming_vector_bits == 1024) << "Streaming vector bits not populated in constructor.\n";
1973-
internal_assert(Target(with_streaming_vector_bits.to_string()).streaming_vector_bits == 1024) << "Streaming vector bits not round tripped properly.\n";
1974-
internal_assert(with_streaming_vector_bits.natural_vector_size(Int(32), true) == 32) << "Wrong natural_vector_size with SME streaming.\n";
1975-
internal_assert(with_streaming_vector_bits.natural_vector_size(Int(32), false) == 16) << "Wrong natural_vector_size without SME streaming.\n";
2037+
// Tests for SME streaming vector length
2038+
internal_assert(Target().sme_streaming_vector_bits() == 0) << "Default Target SME SVL not 0.\n";
2039+
internal_assert(Target::sme_svl_feature_from_bits(1024) == Target::SME_SVL1024) << "SME SVL feature lookup failed.\n";
2040+
Target with_sme_svl(Target::Linux, Target::ARM, 64, Target::ProcessorGeneric, {Target::SVE2, Target::SME2, Target::SME_SVL1024}, 512);
2041+
internal_assert(with_sme_svl.sme_streaming_vector_bits() == 1024) << "SME SVL not populated in constructor.\n";
2042+
internal_assert(with_sme_svl.natural_vector_size(Int(32), true) == 32) << "Wrong natural_vector_size with SME streaming.\n";
2043+
internal_assert(with_sme_svl.natural_vector_size(Int(32), false) == 16) << "Wrong natural_vector_size without SME streaming.\n";
19762044

19772045
std::cout << "Target test passed\n";
19782046
}

src/Target.h

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,8 @@ struct Target {
5151

5252
/** The bit-width of a vector register for targets where this is configurable and
5353
* targeting a fixed size is desired. The default of 0 indicates no assumption of
54-
* fixed size is allowed.
55-
* streaming_vector_bits is for streaming mode in aarch64 SME. */
54+
* fixed size is allowed. */
5655
int vector_bits = 0;
57-
int streaming_vector_bits = 0;
5856

5957
/** The specific processor to be targeted, tuned for.
6058
* Corresponds to processor_name_map in Target.cpp.
@@ -157,6 +155,11 @@ struct Target {
157155
SVE = halide_target_feature_sve,
158156
SVE2 = halide_target_feature_sve2,
159157
SME2 = halide_target_feature_sme2,
158+
SME_SVL128 = halide_target_feature_sme_svl128,
159+
SME_SVL256 = halide_target_feature_sme_svl256,
160+
SME_SVL512 = halide_target_feature_sme_svl512,
161+
SME_SVL1024 = halide_target_feature_sme_svl1024,
162+
SME_SVL2048 = halide_target_feature_sme_svl2048,
160163
ARMDotProd = halide_target_feature_arm_dot_prod,
161164
ARMFp16 = halide_target_feature_arm_fp16,
162165
LLVMLargeCodeModel = halide_llvm_large_code_model,
@@ -191,8 +194,8 @@ struct Target {
191194
};
192195
Target() = default;
193196
Target(OS o, Arch a, int b, Processor pt, const std::vector<Feature> &initial_features = std::vector<Feature>(),
194-
int vb = 0, int svb = 0)
195-
: os(o), arch(a), bits(b), vector_bits(vb), streaming_vector_bits(svb), processor_tune(pt) {
197+
int vb = 0)
198+
: os(o), arch(a), bits(b), vector_bits(vb), processor_tune(pt) {
196199
for (const auto &f : initial_features) {
197200
set_feature(f);
198201
}
@@ -285,8 +288,7 @@ struct Target {
285288
bits == other.bits &&
286289
processor_tune == other.processor_tune &&
287290
features == other.features &&
288-
vector_bits == other.vector_bits &&
289-
streaming_vector_bits == other.streaming_vector_bits;
291+
vector_bits == other.vector_bits;
290292
}
291293

292294
bool operator!=(const Target &other) const {
@@ -325,6 +327,10 @@ struct Target {
325327
* for that data type in streaming mode in aarch64 SME. */
326328
int natural_vector_size(const Halide::Type &t, bool is_sme_streaming) const;
327329

330+
/** Return the fixed SME streaming vector length in bits selected by this target,
331+
* or 0 if no SME_SVL feature is set. */
332+
int sme_streaming_vector_bits() const;
333+
328334
/** Given a data type, return an estimate of the "natural" vector size
329335
* for that data type when compiling for this Target. */
330336
template<typename data_t>
@@ -381,6 +387,10 @@ struct Target {
381387
* If the string is not a known feature name, return FeatureEnd. */
382388
static Target::Feature feature_from_name(const std::string &name);
383389

390+
/** Return the SME_SVL feature corresponding to an SME streaming vector
391+
* length in bits, or FeatureEnd if no exact SME_SVL feature exists. */
392+
static Target::Feature sme_svl_feature_from_bits(int bits);
393+
384394
private:
385395
/** A bitmask that stores the active features. */
386396
std::bitset<FeatureEnd> features;

src/runtime/HalideRuntime.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1448,6 +1448,11 @@ typedef enum halide_target_feature_t {
14481448
halide_target_feature_sve, ///< Enable ARM Scalable Vector Extensions
14491449
halide_target_feature_sve2, ///< Enable ARM Scalable Vector Extensions v2
14501450
halide_target_feature_sme2, ///< Enable ARM Scalable Matrix Extensions v2
1451+
halide_target_feature_sme_svl128, ///< Assume ARM SME streaming vector length is 128 bits.
1452+
halide_target_feature_sme_svl256, ///< Assume ARM SME streaming vector length is 256 bits.
1453+
halide_target_feature_sme_svl512, ///< Assume ARM SME streaming vector length is 512 bits.
1454+
halide_target_feature_sme_svl1024, ///< Assume ARM SME streaming vector length is 1024 bits.
1455+
halide_target_feature_sme_svl2048, ///< Assume ARM SME streaming vector length is 2048 bits.
14511456
halide_target_feature_egl, ///< Force use of EGL support.
14521457
halide_target_feature_arm_dot_prod, ///< Enable ARMv8.2-a dotprod extension (i.e. udot and sdot instructions)
14531458
halide_target_feature_arm_fp16, ///< Enable ARMv8.2-a half-precision floating point data processing

src/runtime/aarch64_cpu_features.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ extern "C" WEAK int halide_get_cpu_features(CpuFeatures *features) {
141141
halide_set_known_cpu_feature(features, halide_target_feature_sve);
142142
halide_set_known_cpu_feature(features, halide_target_feature_sve2);
143143
halide_set_known_cpu_feature(features, halide_target_feature_sme2);
144+
halide_set_known_cpu_feature(features, halide_target_feature_sme_svl128);
145+
halide_set_known_cpu_feature(features, halide_target_feature_sme_svl256);
146+
halide_set_known_cpu_feature(features, halide_target_feature_sme_svl512);
147+
halide_set_known_cpu_feature(features, halide_target_feature_sme_svl1024);
148+
halide_set_known_cpu_feature(features, halide_target_feature_sme_svl2048);
144149

145150
// All ARM architectures support "No Neon".
146151
halide_set_available_cpu_feature(features, halide_target_feature_no_neon);

test/correctness/fallback_vscale_sve.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,12 @@ bool test_streaming_vscale(int vectorization_factor, int vector_bits, int stream
9696
}
9797
if (streaming_vector_bits != 0) {
9898
t = t.with_feature(Target::SME2);
99-
t.streaming_vector_bits = streaming_vector_bits;
99+
Target::Feature sme_svl = Target::sme_svl_feature_from_bits(streaming_vector_bits);
100+
if (sme_svl == Target::FeatureEnd) {
101+
printf("[%s] Unsupported streaming_vector_bits %d\n", name.c_str(), streaming_vector_bits);
102+
return false;
103+
}
104+
t.set_feature(sme_svl);
100105
}
101106

102107
// sve or neon

0 commit comments

Comments
 (0)