Skip to content

Commit 357e646

Browse files
Do some basic validation of Target Features (#7986) (#7987)
* Do some basic validation of Target Features (#7986) * Update Target.cpp * Update Target.cpp * Fixes * Update Target.cpp * Improve error messaging. * format * Update Target.cpp
1 parent 9c099c2 commit 357e646

File tree

4 files changed

+93
-8
lines changed

4 files changed

+93
-8
lines changed

python_bindings/test/correctness/target.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,14 @@ def test_target():
5050
32,
5151
[
5252
hl.TargetFeature.JIT,
53-
hl.TargetFeature.SSE41,
54-
hl.TargetFeature.AVX,
55-
hl.TargetFeature.AVX2,
5653
hl.TargetFeature.CUDA,
5754
hl.TargetFeature.OpenCL,
5855
hl.TargetFeature.OpenGLCompute,
5956
hl.TargetFeature.Debug,
6057
],
6158
)
6259
ts = t1.to_string()
63-
assert ts == "arm-32-android-avx-avx2-cuda-debug-jit-opencl-openglcompute-sse41"
60+
assert ts == "arm-32-android-cuda-debug-jit-opencl-openglcompute"
6461
assert hl.Target.validate_target_string(ts)
6562

6663
# Expected failures:

src/Target.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,8 +785,90 @@ void bad_target_string(const std::string &target) {
785785
<< "On this platform, the host target is: " << get_host_target().to_string() << "\n";
786786
}
787787

788+
void do_check_bad(const Target &t, const std::initializer_list<Target::Feature> &v) {
789+
for (Target::Feature f : v) {
790+
user_assert(!t.has_feature(f))
791+
<< "Target feature " << Target::feature_to_name(f)
792+
<< " is incompatible with the Target's architecture. (" << t << ")\n";
793+
}
794+
}
795+
788796
} // namespace
789797

798+
void Target::validate_features() const {
799+
// Note that the features don't have to be exhaustive, but enough to avoid obvious mistakes is good.
800+
if (arch == X86) {
801+
do_check_bad(*this, {
802+
ARMDotProd,
803+
ARMFp16,
804+
ARMv7s,
805+
ARMv81a,
806+
NoNEON,
807+
POWER_ARCH_2_07,
808+
RVV,
809+
SVE,
810+
SVE2,
811+
VSX,
812+
WasmBulkMemory,
813+
WasmMvpOnly,
814+
WasmSimd128,
815+
WasmThreads,
816+
});
817+
} else if (arch == ARM) {
818+
do_check_bad(*this, {
819+
AVX,
820+
AVX2,
821+
AVX512,
822+
AVX512_Cannonlake,
823+
AVX512_KNL,
824+
AVX512_SapphireRapids,
825+
AVX512_Skylake,
826+
AVX512_Zen4,
827+
F16C,
828+
FMA,
829+
FMA4,
830+
POWER_ARCH_2_07,
831+
RVV,
832+
SSE41,
833+
VSX,
834+
WasmBulkMemory,
835+
WasmMvpOnly,
836+
WasmSimd128,
837+
WasmThreads,
838+
});
839+
} else if (arch == WebAssembly) {
840+
do_check_bad(*this, {
841+
ARMDotProd,
842+
ARMFp16,
843+
ARMv7s,
844+
ARMv81a,
845+
AVX,
846+
AVX2,
847+
AVX512,
848+
AVX512_Cannonlake,
849+
AVX512_KNL,
850+
AVX512_SapphireRapids,
851+
AVX512_Skylake,
852+
AVX512_Zen4,
853+
F16C,
854+
FMA,
855+
FMA4,
856+
HVX_128,
857+
HVX_128,
858+
HVX_v62,
859+
HVX_v65,
860+
HVX_v66,
861+
NoNEON,
862+
POWER_ARCH_2_07,
863+
RVV,
864+
SSE41,
865+
SVE,
866+
SVE2,
867+
VSX,
868+
});
869+
}
870+
}
871+
790872
Target::Target(const std::string &target) {
791873
Target host = get_host_target();
792874

@@ -798,6 +880,7 @@ Target::Target(const std::string &target) {
798880
bad_target_string(target);
799881
}
800882
}
883+
validate_features();
801884
}
802885

803886
Target::Target(const char *s)

src/Target.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ struct Target {
177177
for (const auto &f : initial_features) {
178178
set_feature(f);
179179
}
180+
validate_features();
180181
}
181182

182183
Target(OS o, Arch a, int b, const std::vector<Feature> &initial_features = std::vector<Feature>())
@@ -357,6 +358,11 @@ struct Target {
357358
private:
358359
/** A bitmask that stores the active features. */
359360
std::bitset<FeatureEnd> features;
361+
362+
/** Attempt to validate that all features set are sensible for the base Target.
363+
* This is *not* guaranteed to get all invalid combinations, but is intended
364+
* to catch at least the most common (e.g., setting arm-specific features on x86). */
365+
void validate_features() const;
360366
};
361367

362368
/** Return the target corresponding to the host machine. */

test/correctness/target.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,10 @@ int main(int argc, char **argv) {
5151

5252
// Full specification round-trip, crazy features
5353
t1 = Target(Target::Android, Target::ARM, 32,
54-
{Target::JIT, Target::SSE41, Target::AVX, Target::AVX2,
55-
Target::CUDA, Target::OpenCL, Target::OpenGLCompute,
56-
Target::Debug});
54+
{Target::JIT, Target::CUDA, Target::OpenCL,
55+
Target::OpenGLCompute, Target::Debug});
5756
ts = t1.to_string();
58-
if (ts != "arm-32-android-avx-avx2-cuda-debug-jit-opencl-openglcompute-sse41") {
57+
if (ts != "arm-32-android-cuda-debug-jit-opencl-openglcompute") {
5958
printf("to_string failure: %s\n", ts.c_str());
6059
return 1;
6160
}

0 commit comments

Comments
 (0)