@@ -677,79 +677,81 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
677677 vector<tuple<Type, CastFuncTy>> test_params = {
678678 {Int (8 ), in_i8}, {Int (16 ), in_i16}, {Int (32 ), in_i32}, {Int (64 ), in_i64}, {UInt (8 ), in_u8}, {UInt (16 ), in_u16}, {UInt (32 ), in_u32}, {UInt (64 ), in_u64}, {Float (16 ), in_f16}, {Float (32 ), in_f32}, {Float (64 ), in_f64}};
679679
680+ const int base_vec_bits = has_sve () ? target.vector_bits : 128 ;
681+ const int vscale = base_vec_bits / 128 ;
682+
680683 for (const auto &[elt, in_im] : test_params) {
681684 const int bits = elt.bits ();
682685 if ((elt == Float (16 ) && !is_float16_supported ()) ||
683686 (is_arm32 () && bits == 64 )) {
684687 continue ;
685688 }
686689
687- // LD/ST - Load/Store
688- for (int width = 64 ; width <= 64 * 4 ; width *= 2 ) {
689- const int total_lanes = width / bits;
690- const int instr_lanes = min (total_lanes, 128 / bits);
691- if (instr_lanes < 2 ) continue ; // bail out scalar op
690+ // LD/ST - Load/Store scalar
691+ // We skip scalar load/store test due to the following challenges.
692+ // The rule by which LLVM selects instruction does not seem simple.
693+ // For example, ld1, ldr, or ldp is used for instruction and z or q register is used for operand,
694+ // depending on data type, vscale, what is performed before/after load, and LLVM version.
695+ // The other thing is, load/store instruction appears in other place than we want to check,
696+ // which makes it prone to false-positive detection as we only search strings line-by-line.
692697
693- // In case of arm32, instruction selection looks inconsistent due to optimization by LLVM
694- AddTestFunctor add (*this , bits, total_lanes, target.bits == 64 );
695- // NOTE: if the expr is too simple, LLVM might generate "bl memcpy"
696- Expr load_store_1 = in_im (x) * 3 ;
698+ // LDn - Structured Load strided elements
699+ if (Halide::Internal::get_llvm_version () >= 220 ) {
700+ for (int stride = 2 ; stride <= 4 ; ++stride) {
697701
698- if (has_sve ()) {
699- // This pattern has changed with LLVM 21, see https://github.com/halide/Halide/issues/8584 for more
700- // details.
701- if (Halide::Internal::get_llvm_version () < 210 ) {
702- // in native width, ld1b/st1b is used regardless of data type
703- const bool allow_byte_ls = (width == target.vector_bits );
704- add ({get_sve_ls_instr (" ld1" , bits, bits, " " , allow_byte_ls ? " b" : " " )}, total_lanes, load_store_1);
705- add ({get_sve_ls_instr (" st1" , bits, bits, " " , allow_byte_ls ? " b" : " " )}, total_lanes, load_store_1);
702+ for (int factor : {1 , 2 , 4 }) {
703+ const int vector_lanes = base_vec_bits * factor / bits;
704+
705+ // In StageStridedLoads.cpp (stride < r->lanes) is the condition for staging to happen
706+ // See https://github.com/halide/Halide/issues/8819
707+ if (vector_lanes <= stride) continue ;
708+
709+ AddTestFunctor add_ldn (*this , bits, vector_lanes);
710+
711+ Expr load_n = in_im (x * stride) + in_im (x * stride + stride - 1 );
712+
713+ const string ldn_str = " ld" + to_string (stride);
714+ if (has_sve ()) {
715+ add_ldn ({get_sve_ls_instr (ldn_str, bits)}, vector_lanes, load_n);
716+ } else {
717+ add_ldn (sel_op (" v" + ldn_str + " ." , ldn_str), load_n);
718+ }
706719 }
707- } else {
708- // vector register is not used for simple load/store
709- string reg_prefix = (width <= 64 ) ? " d" : " q" ;
710- add ({{" st[rp]" , reg_prefix + R"( \d\d?)" }}, total_lanes, load_store_1);
711- add ({{" ld[rp]" , reg_prefix + R"( \d\d?)" }}, total_lanes, load_store_1);
712720 }
713721 }
714722
715- // LD2/ ST2 - Load/ Store two-element structures
716- int base_vec_bits = has_sve () ? target. vector_bits : 128 ;
717- for ( int width = base_vec_bits; width < = base_vec_bits * 4 ; width *= 2 ) {
723+ // ST2 - Store two-element structures
724+ for ( int factor : { 1 , 2 }) {
725+ const int width = base_vec_bits * 2 * factor;
718726 const int total_lanes = width / bits;
719727 const int vector_lanes = total_lanes / 2 ;
720728 const int instr_lanes = min (vector_lanes, base_vec_bits / bits);
721- if (instr_lanes < 2 ) continue ; // bail out scalar op
729+ if (instr_lanes < 2 || (vector_lanes / vscale < 2 )) continue ; // bail out scalar and <vscale x 1 x ty>
722730
723- AddTestFunctor add_ldn (*this , bits, vector_lanes);
724731 AddTestFunctor add_stn (*this , bits, instr_lanes, total_lanes);
725732
726733 Func tmp1, tmp2;
727734 tmp1 (x) = cast (elt, x);
728735 tmp1.compute_root ();
729736 tmp2 (x, y) = select (x % 2 == 0 , tmp1 (x / 2 ), tmp1 (x / 2 + 16 ));
730737 tmp2.compute_root ().vectorize (x, total_lanes);
731- Expr load_2 = in_im (x * 2 ) + in_im (x * 2 + 1 );
732738 Expr store_2 = tmp2 (0 , 0 ) + tmp2 (0 , 127 );
733739
734740 if (has_sve ()) {
735- // TODO(inssue needed): Added strided load support.
736- #if 0
737- add_ldn({get_sve_ls_instr("ld2", bits)}, vector_lanes, load_2);
738- #endif
739741 add_stn ({get_sve_ls_instr (" st2" , bits)}, total_lanes, store_2);
740742 } else {
741- add_ldn (sel_op (" vld2." , " ld2" ), load_2);
742743 add_stn (sel_op (" vst2." , " st2" ), store_2);
743744 }
744745 }
745746
746747 // Also check when the two expressions interleaved have a common
747748 // subexpression, which results in a vector var being lifted out.
748- for (int width = base_vec_bits; width <= base_vec_bits * 4 ; width *= 2 ) {
749+ for (int factor : {1 , 2 }) {
750+ const int width = base_vec_bits * 2 * factor;
749751 const int total_lanes = width / bits;
750752 const int vector_lanes = total_lanes / 2 ;
751753 const int instr_lanes = Instruction::get_instr_lanes (bits, vector_lanes, target);
752- if (instr_lanes < 2 ) continue ; // bail out scalar op
754+ if (instr_lanes < 2 || (vector_lanes / vscale < 2 )) continue ; // bail out scalar and <vscale x 1 x ty>
753755
754756 AddTestFunctor add_stn (*this , bits, instr_lanes, total_lanes);
755757
@@ -768,14 +770,14 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
768770 }
769771 }
770772
771- // LD3/ST3 - Store three-element structures
772- for (int width = 192 ; width <= 192 * 4 ; width *= 2 ) {
773+ // ST3 - Store three-element structures
774+ for (int factor : {1 , 2 }) {
775+ const int width = base_vec_bits * 3 * factor;
773776 const int total_lanes = width / bits;
774777 const int vector_lanes = total_lanes / 3 ;
775778 const int instr_lanes = Instruction::get_instr_lanes (bits, vector_lanes, target);
776- if (instr_lanes < 2 ) continue ; // bail out scalar op
779+ if (instr_lanes < 2 || (vector_lanes / vscale < 2 )) continue ; // bail out scalar and <vscale x 1 x ty>
777780
778- AddTestFunctor add_ldn (*this , bits, vector_lanes);
779781 AddTestFunctor add_stn (*this , bits, instr_lanes, total_lanes);
780782
781783 Func tmp1, tmp2;
@@ -785,29 +787,25 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
785787 x % 3 == 1 , tmp1 (x / 3 + 16 ),
786788 tmp1 (x / 3 + 32 ));
787789 tmp2.compute_root ().vectorize (x, total_lanes);
788- Expr load_3 = in_im (x * 3 ) + in_im (x * 3 + 1 ) + in_im (x * 3 + 2 );
789790 Expr store_3 = tmp2 (0 , 0 ) + tmp2 (0 , 127 );
790791
791792 if (has_sve ()) {
792- // TODO(issue needed): Added strided load support.
793- #if 0
794- add_ldn({get_sve_ls_instr("ld3", bits)}, vector_lanes, load_3);
795- add_stn({get_sve_ls_instr("st3", bits)}, total_lanes, store_3);
796- #endif
793+ if (Halide::Internal::get_llvm_version () >= 220 ) {
794+ add_stn ({get_sve_ls_instr (" st3" , bits)}, total_lanes, store_3);
795+ }
797796 } else {
798- add_ldn (sel_op (" vld3." , " ld3" ), load_3);
799797 add_stn (sel_op (" vst3." , " st3" ), store_3);
800798 }
801799 }
802800
803- // LD4/ST4 - Store four-element structures
804- for (int width = 256 ; width <= 256 * 4 ; width *= 2 ) {
801+ // ST4 - Store four-element structures
802+ for (int factor : {1 , 2 }) {
803+ const int width = base_vec_bits * 4 * factor;
805804 const int total_lanes = width / bits;
806805 const int vector_lanes = total_lanes / 4 ;
807806 const int instr_lanes = Instruction::get_instr_lanes (bits, vector_lanes, target);
808- if (instr_lanes < 2 ) continue ; // bail out scalar op
807+ if (instr_lanes < 2 || (vector_lanes / vscale < 2 )) continue ; // bail out scalar and <vscale x 1 x ty>
809808
810- AddTestFunctor add_ldn (*this , bits, vector_lanes);
811809 AddTestFunctor add_stn (*this , bits, instr_lanes, total_lanes);
812810
813811 Func tmp1, tmp2;
@@ -818,17 +816,13 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
818816 x % 4 == 2 , tmp1 (x / 4 + 32 ),
819817 tmp1 (x / 4 + 48 ));
820818 tmp2.compute_root ().vectorize (x, total_lanes);
821- Expr load_4 = in_im (x * 4 ) + in_im (x * 4 + 1 ) + in_im (x * 4 + 2 ) + in_im (x * 4 + 3 );
822819 Expr store_4 = tmp2 (0 , 0 ) + tmp2 (0 , 127 );
823820
824821 if (has_sve ()) {
825- // TODO(issue needed): Added strided load support.
826- #if 0
827- add_ldn({get_sve_ls_instr("ld4", bits)}, vector_lanes, load_4);
828- add_stn({get_sve_ls_instr("st4", bits)}, total_lanes, store_4);
829- #endif
822+ if (Halide::Internal::get_llvm_version () >= 220 ) {
823+ add_stn ({get_sve_ls_instr (" st4" , bits)}, total_lanes, store_4);
824+ }
830825 } else {
831- add_ldn (sel_op (" vld4." , " ld4" ), load_4);
832826 add_stn (sel_op (" vst4." , " st4" ), store_4);
833827 }
834828 }
@@ -838,7 +832,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
838832 for (int width = 64 ; width <= 64 * 4 ; width *= 2 ) {
839833 const int total_lanes = width / bits;
840834 const int instr_lanes = min (total_lanes, 128 / bits);
841- if (instr_lanes < 2 ) continue ; // bail out scalar op
835+ if (instr_lanes < 2 || (total_lanes / vscale < 2 )) continue ; // bail out scalar and <vscale x 1 x ty>
842836
843837 AddTestFunctor add (*this , bits, total_lanes);
844838 Expr index = clamp (cast<int >(in_im (x)), 0 , W - 1 );
0 commit comments