@@ -274,7 +274,6 @@ Target calculate_host_target() {
274274 bool use_64_bits = (sizeof (size_t ) == 8 );
275275 int bits = use_64_bits ? 64 : 32 ;
276276 int vector_bits = 0 ;
277- int streaming_vector_bits = 0 ;
278277 Target::Processor processor = Target::Processor::ProcessorGeneric;
279278 std::vector<Target::Feature> initial_features;
280279
@@ -386,7 +385,11 @@ Target calculate_host_target() {
386385 vector_bits = get_sve_vector_length ();
387386 }
388387 if (has_streaming_scalable_vector) {
389- streaming_vector_bits = get_sme_streaming_vector_length ();
388+ const int streaming_vector_bits = get_sme_streaming_vector_length ();
389+ Target::Feature sme_svl = Target::sme_svl_feature_from_bits (streaming_vector_bits);
390+ user_assert (sme_svl != Target::FeatureEnd)
391+ << " Detected unsupported SME streaming vector length " << streaming_vector_bits << " bits.\n " ;
392+ initial_features.push_back (sme_svl);
390393 }
391394#endif
392395
@@ -461,7 +464,7 @@ Target calculate_host_target() {
461464 processor = get_amd_processor (family, model, have_sse3);
462465
463466 if (processor == Target::Processor::ZnVer4) {
464- Target t{os, arch, bits, processor, initial_features, vector_bits, streaming_vector_bits };
467+ Target t{os, arch, bits, processor, initial_features, vector_bits};
465468 t.set_feature (Target::SSE41 );
466469 if (os_avx) {
467470 t.set_features ({Target::AVX , Target::F16C , Target::FMA , Target::AVX2 });
@@ -472,7 +475,7 @@ Target calculate_host_target() {
472475 }
473476 return t;
474477 } else if (processor == Target::Processor::ZnVer5) {
475- Target t{os, arch, bits, processor, initial_features, vector_bits, streaming_vector_bits };
478+ Target t{os, arch, bits, processor, initial_features, vector_bits};
476479 t.set_feature (Target::SSE41 );
477480 if (os_avx) {
478481 t.set_features ({Target::AVX , Target::F16C , Target::FMA ,
@@ -589,7 +592,7 @@ Target calculate_host_target() {
589592#endif
590593#endif
591594
592- return {os, arch, bits, processor, initial_features, vector_bits, streaming_vector_bits };
595+ return {os, arch, bits, processor, initial_features, vector_bits};
593596}
594597
595598bool is_using_hexagon (const Target &t) {
@@ -842,6 +845,11 @@ const std::map<std::string, Target::Feature> feature_name_map = {
842845 {" sve" , Target::SVE },
843846 {" sve2" , Target::SVE2 },
844847 {" sme2" , Target::SME2 },
848+ {" sme_svl128" , Target::SME_SVL128 },
849+ {" sme_svl256" , Target::SME_SVL256 },
850+ {" sme_svl512" , Target::SME_SVL512 },
851+ {" sme_svl1024" , Target::SME_SVL1024 },
852+ {" sme_svl2048" , Target::SME_SVL2048 },
845853 {" arm_dot_prod" , Target::ARMDotProd},
846854 {" arm_fp16" , Target::ARMFp16},
847855 {" llvm_large_code_model" , Target::LLVMLargeCodeModel},
@@ -1000,8 +1008,6 @@ bool merge_string(Target &t, const std::string &target) {
10001008 features_specified = true ;
10011009 } else if (auto vb = parse_vector_bits (tok, " vector_bits_" ); vb >= 0 ) {
10021010 t.vector_bits = vb;
1003- } else if (auto svb = parse_vector_bits (tok, " streaming_vector_bits_" ); svb >= 0 ) {
1004- t.streaming_vector_bits = svb;
10051011 } else {
10061012 return false ;
10071013 }
@@ -1123,6 +1129,11 @@ void Target::validate_features() const {
11231129 POWER_ARCH_2_07 ,
11241130 RVV ,
11251131 SME2 ,
1132+ SME_SVL128 ,
1133+ SME_SVL256 ,
1134+ SME_SVL512 ,
1135+ SME_SVL1024 ,
1136+ SME_SVL2048 ,
11261137 SVE ,
11271138 SVE2 ,
11281139 VSX ,
@@ -1184,12 +1195,31 @@ void Target::validate_features() const {
11841195 POWER_ARCH_2_07 ,
11851196 RVV ,
11861197 SSE41 ,
1198+ SME_SVL128 ,
1199+ SME_SVL256 ,
1200+ SME_SVL512 ,
1201+ SME_SVL1024 ,
1202+ SME_SVL2048 ,
11871203 SME2 ,
11881204 SVE ,
11891205 SVE2 ,
11901206 VSX ,
11911207 });
11921208 }
1209+
1210+ const int num_sme_svl_features =
1211+ (int )has_feature (SME_SVL128 ) +
1212+ (int )has_feature (SME_SVL256 ) +
1213+ (int )has_feature (SME_SVL512 ) +
1214+ (int )has_feature (SME_SVL1024 ) +
1215+ (int )has_feature (SME_SVL2048 );
1216+
1217+ user_assert (num_sme_svl_features <= 1 )
1218+ << " Target may have at most one SME_SVL feature.\n " ;
1219+ user_assert (!has_feature (SME2 ) || num_sme_svl_features == 1 )
1220+ << " Target feature sme2 requires exactly one SME_SVL feature.\n " ;
1221+ user_assert (has_feature (SME2 ) || num_sme_svl_features == 0 )
1222+ << " Target features SME_SVL128, SME_SVL256, SME_SVL512, SME_SVL1024, and SME_SVL2048 require target feature sme2.\n " ;
11931223}
11941224
11951225Target::Target (const std::string &target) {
@@ -1233,6 +1263,23 @@ Target::Feature Target::feature_from_name(const std::string &name) {
12331263 return Target::FeatureEnd;
12341264}
12351265
1266+ Target::Feature Target::sme_svl_feature_from_bits (int bits) {
1267+ switch (bits) {
1268+ case 128 :
1269+ return Target::SME_SVL128 ;
1270+ case 256 :
1271+ return Target::SME_SVL256 ;
1272+ case 512 :
1273+ return Target::SME_SVL512 ;
1274+ case 1024 :
1275+ return Target::SME_SVL1024 ;
1276+ case 2048 :
1277+ return Target::SME_SVL2048 ;
1278+ default :
1279+ return Target::FeatureEnd;
1280+ }
1281+ }
1282+
12361283std::string Target::to_string () const {
12371284 string result;
12381285 for (const auto &arch_entry : arch_name_map) {
@@ -1269,9 +1316,6 @@ std::string Target::to_string() const {
12691316 if (vector_bits != 0 ) {
12701317 result += " -vector_bits_" + std::to_string (vector_bits);
12711318 }
1272- if (streaming_vector_bits != 0 ) {
1273- result += " -streaming_vector_bits_" + std::to_string (streaming_vector_bits);
1274- }
12751319
12761320 return result;
12771321}
@@ -1600,6 +1644,31 @@ int Target::natural_vector_size(const Halide::Type &t) const {
16001644 return natural_vector_size (t, false );
16011645}
16021646
1647+ int Target::sme_streaming_vector_bits () const {
1648+ int result = 0 ;
1649+ auto set_result = [&result](int bits) {
1650+ user_assert (result == 0 )
1651+ << " Target may have at most one SME_SVL feature.\n " ;
1652+ result = bits;
1653+ };
1654+ if (has_feature (Target::SME_SVL128 )) {
1655+ set_result (128 );
1656+ }
1657+ if (has_feature (Target::SME_SVL256 )) {
1658+ set_result (256 );
1659+ }
1660+ if (has_feature (Target::SME_SVL512 )) {
1661+ set_result (512 );
1662+ }
1663+ if (has_feature (Target::SME_SVL1024 )) {
1664+ set_result (1024 );
1665+ }
1666+ if (has_feature (Target::SME_SVL2048 )) {
1667+ set_result (2048 );
1668+ }
1669+ return result;
1670+ }
1671+
16031672int Target::natural_vector_size (const Halide::Type &t, bool is_sme_streaming) const {
16041673 user_assert (!has_unknowns ())
16051674 << " natural_vector_size cannot be used on a Target with Unknown values.\n " ;
@@ -1609,9 +1678,9 @@ int Target::natural_vector_size(const Halide::Type &t, bool is_sme_streaming) co
16091678
16101679 if (arch == Target::ARM ) {
16111680 if (is_sme_streaming &&
1612- streaming_vector_bits != 0 &&
1681+ sme_streaming_vector_bits () != 0 &&
16131682 has_feature (Halide::Target::SME2 )) {
1614- return streaming_vector_bits / (data_size * 8 );
1683+ return sme_streaming_vector_bits () / (data_size * 8 );
16151684 } else if (vector_bits != 0 &&
16161685 (has_feature (Halide::Target::SVE2 ) ||
16171686 (t.is_float () && has_feature (Halide::Target::SVE )))) {
@@ -1965,14 +2034,13 @@ void target_test() {
19652034 internal_assert (Target (with_vector_bits.to_string ()).vector_bits == 512 ) << " Vector bits not round tripped properly.\n " ;
19662035 internal_assert (with_vector_bits.natural_vector_size (Int (32 )) == 16 ) << " Wrong natural_vector_size.\n " ;
19672036
1968- // Tests for streaming_vector_bits
1969- internal_assert (Target ().streaming_vector_bits == 0 ) << " Default Target streaming_vector_bits not 0.\n " ;
1970- internal_assert (Target (" arm-64-linux-sme2-vector_bits_512-streaming_vector_bits_1024" ).streaming_vector_bits == 1024 ) << " Streaming vector bits not parsed correctly.\n " ;
1971- Target with_streaming_vector_bits (Target::Linux, Target::ARM , 64 , Target::ProcessorGeneric, {Target::SVE2 , Target::SME2 }, 512 , 1024 );
1972- internal_assert (with_streaming_vector_bits.streaming_vector_bits == 1024 ) << " Streaming vector bits not populated in constructor.\n " ;
1973- internal_assert (Target (with_streaming_vector_bits.to_string ()).streaming_vector_bits == 1024 ) << " Streaming vector bits not round tripped properly.\n " ;
1974- internal_assert (with_streaming_vector_bits.natural_vector_size (Int (32 ), true ) == 32 ) << " Wrong natural_vector_size with SME streaming.\n " ;
1975- internal_assert (with_streaming_vector_bits.natural_vector_size (Int (32 ), false ) == 16 ) << " Wrong natural_vector_size without SME streaming.\n " ;
2037+ // Tests for SME streaming vector length
2038+ internal_assert (Target ().sme_streaming_vector_bits () == 0 ) << " Default Target SME SVL not 0.\n " ;
2039+ internal_assert (Target::sme_svl_feature_from_bits (1024 ) == Target::SME_SVL1024 ) << " SME SVL feature lookup failed.\n " ;
2040+ Target with_sme_svl (Target::Linux, Target::ARM , 64 , Target::ProcessorGeneric, {Target::SVE2 , Target::SME2 , Target::SME_SVL1024 }, 512 );
2041+ internal_assert (with_sme_svl.sme_streaming_vector_bits () == 1024 ) << " SME SVL not populated in constructor.\n " ;
2042+ internal_assert (with_sme_svl.natural_vector_size (Int (32 ), true ) == 32 ) << " Wrong natural_vector_size with SME streaming.\n " ;
2043+ internal_assert (with_sme_svl.natural_vector_size (Int (32 ), false ) == 16 ) << " Wrong natural_vector_size without SME streaming.\n " ;
19762044
19772045 std::cout << " Target test passed\n " ;
19782046}
0 commit comments