Skip to content

Commit 732d2eb

Browse files
Experimental (#114)
* Fixed auto tagging for dockerimages * Fixed typo. * Added check + support for wget if curl does not exist on the target system in setup_tsl.sh * Fixed regex error * fixed left shift warnings * Added better functor overload indication for HTML output * added archid to possible cli params * Added functionality to check whether a primitive is supported by the backend * Added functionality to check whether a primitive is supported by the backend * Fixed wrong bits set (integral all true) if number of lanes < 8 --------- Co-authored-by: Alexander Krause <alexander.krause@tu-dresden.de>
1 parent 5c3e595 commit 732d2eb

12 files changed

Lines changed: 181 additions & 68 deletions

File tree

.github/workflows/push.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,17 @@ jobs:
3232
- name: Parse Target Specs
3333
id: parse-target-specs
3434
run: |
35-
echo "x86_flags_matrix=$(jq '{include: .x86}' .github/workflows/target_specs.json -c)" >> $GITHUB_OUTPUT
35+
echo "x86_flags_matrix=$(jq '{include: .x86}' target_specs.json -c)" >> $GITHUB_OUTPUT
3636
echo "x86_flags_compiler_matrix=$(jq '.x86 |= map([
3737
. + {compiler: "g++"},
3838
. + {compiler: "clang++"}
39-
]) | .x86 |= flatten | {include: .x86}' .github/workflows/target_specs.json -c)" >> $GITHUB_OUTPUT
40-
echo "aarch64_flags_matrix=$(jq '{include: .aarch64}' .github/workflows/target_specs.json -c)" >> $GITHUB_OUTPUT
39+
]) | .x86 |= flatten | {include: .x86}' target_specs.json -c)" >> $GITHUB_OUTPUT
40+
echo "aarch64_flags_matrix=$(jq '{include: .aarch64}' target_specs.json -c)" >> $GITHUB_OUTPUT
4141
echo "aarch64_flags_compiler_matrix=$(jq '.aarch64 |= map([
4242
. + {compiler: "aarch64-linux-gnu-g++"},
4343
. + {compiler: "clang++"}
44-
]) | .aarch64 |= flatten | {include: .aarch64}' .github/workflows/target_specs.json -c)" >> $GITHUB_OUTPUT
45-
echo "target_matrix=$(jq -c '{include: ([keys[] as $k | .[$k][] | (. + { arch: $k })])}' .github/workflows/target_specs.json -c)" >> $GITHUB_OUTPUT
44+
]) | .aarch64 |= flatten | {include: .aarch64}' target_specs.json -c)" >> $GITHUB_OUTPUT
45+
echo "target_matrix=$(jq -c '{include: ([keys[] as $k | .[$k][] | (. + { arch: $k })])}' target_specs.json -c)" >> $GITHUB_OUTPUT
4646
4747
build-generation-image:
4848
name: Generation Image (build and push on demand)

.github/workflows/release.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
- name: Parse Target Specs
3030
id: parse-target-specs
3131
run: |
32-
echo "target_matrix=$(jq -c '{include: ([keys[] as $k | .[$k][] | (. + { arch: $k })])}' .github/workflows/target_specs.json -c)" >> $GITHUB_OUTPUT
32+
echo "target_matrix=$(jq -c '{include: ([keys[] as $k | .[$k][] | (. + { arch: $k })])}' target_specs.json -c)" >> $GITHUB_OUTPUT
3333
3434
build-generation-image:
3535
if: ${{ github.event_name != 'workflow_run' }}
@@ -142,7 +142,7 @@ jobs:
142142
python3 ${{ github.workspace }}/repository/.github/workflows/release/prepare_select_flavor.py \
143143
--install-sh ${{ github.workspace }}/release_tmp/select_flavor.sh \
144144
--folder-prefix "tsl" \
145-
--targets-spec-file repository/.github/workflows/target_specs.json \
145+
--targets-spec-file repository/target_specs.json \
146146
--tsl-folder-ph "$<< TslFolderArrayValues >>" \
147147
--default-flags-array-ph "$<< DefaultFlagsArrayValues >>" \
148148
--alt-flags-array-ph "$<< AlternativeFlagsArrayValues >>" \

CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ else()
2929
set(DTD "DRAW_TEST_DEPENDENCIES")
3030
endif()
3131

32+
if (DEFINED ArchId)
33+
message(STATUS "ArchId: ${ArchId}")
34+
else()
35+
message(STATUS "ArchId not defined. Defaulting to 0.")
36+
set(ArchId "")
37+
endif()
38+
3239
create_tsl(
3340
${WAW}
3441
${UC}
@@ -37,6 +44,7 @@ else()
3744
TSLGENERATOR_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
3845
DESTINATION ${DESTINATION}
3946
TARGETS_FLAGS ${TARGETS_FLAGS}
47+
TARGET_ARCHID ${ArchId}
4048
APPEND_TARGETS_FLAGS ${APPEND_TARGETS_FLAGS}
4149
PRIMITIVES_FILTER ${PRIMITIVES_FILTER}
4250
DATATYPES_FILTER ${DATATYPES_FILTER}

generator/core/tsl_config.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import copy
3+
import json
34
import logging.config
45
import pathlib
56
import os
@@ -422,6 +423,7 @@ def parse_args(**kwargs) -> dict:
422423
parser.add_argument('--targets', default=None, nargs='*',
423424
help='List of target flags which match the lscpu_flags from the extension/primitive files.',
424425
dest='targets')
426+
parser.add_argument('--archid', type=str, dest='target_archid', metavar="ArchId", help="Identifier of a target platform (e.g., 'skylake').")
425427
types_help = 'List of types which should be considered for generation.'
426428
if "known_types" in kwargs:
427429
types_help += f" Choose from the following list: [{', '.join(kwargs['known_types'])}]"
@@ -464,9 +466,29 @@ def parse_args(**kwargs) -> dict:
464466
current_dict = current_dict[match]
465467
current_dict[matches[-1]] = value
466468

469+
targets_set = set()
470+
if "target_archid" in args_dict:
471+
with open("target_specs.json", 'r') as specs_file:
472+
target_specs = json.load(specs_file)
473+
found = False
474+
for targets in target_specs.values():
475+
for target in targets:
476+
if target['name'] == args_dict["target_archid"]:
477+
found = True
478+
targets_set.update([flag.strip() for flag in target['flags'].split(" ")])
479+
break
480+
if not found:
481+
print(f"Unknown target archid: {args_dict['target_archid']}")
482+
exit(1)
483+
467484
if "targets" not in args_dict:
468-
args_dict["targets"] = None
485+
if len(targets_set) == 0:
486+
args_dict["targets"] = None
487+
else:
488+
args_dict["targets"] = list(targets_set)
489+
else:
490+
targets_set.update(args_dict["targets"])
491+
args_dict["targets"] = list(targets_set)
469492
# if "relevant_types" not in args_dict:
470493
# args_dict["relevant_types"] = None
471494
return args_dict
472-

generator/core/tsl_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def generate(self, relevant_hardware_flags: List[str] = None, relevant_primitive
141141
for primitive in pClass:
142142
implregex_list.append( primitive.declaration.functor_name )
143143

144-
implregex_string = f'(?<!_)({"|".join(implregex_list)})(?!([a-zA-Z]|_|\[|\.))'
144+
implregex_string = rf'(?<!_)({"|".join(implregex_list)})(?!([a-zA-Z]|_|\[|\.))'
145145

146146
implregex = re.compile(implregex_string, re.IGNORECASE)
147147

generator/static_files/core/utils/preprocessor.yaml

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,5 +136,30 @@ implementations:
136136
};
137137
template <class... Functors>
138138
inline constexpr bool tsl_primitives_redeclared_v = tsl_primitives_redeclared<Functors...>::is_ambiguous::value;
139-
139+
- |
140+
template <class Func, typename ArgsTuple, typename = std::void_t<>>
141+
struct has_function : std::false_type {};
142+
// Specialization: When Func<Ts...>(Args...) is valid
143+
template <class Func, typename... Args>
144+
struct has_function<Func, std::tuple<Args...>,
145+
std::void_t<decltype(Func(std::declval<Args>()...))>> : std::true_type {};
146+
template <typename... Args>
147+
struct callable_by_types {
148+
bool available;
149+
template <typename Fun>
150+
constexpr callable_by_types(Fun&&): available(std::is_invocable_v<Fun, Args...>) {}
151+
};
152+
struct callable_by_values {
153+
bool available;
154+
template <typename Fun, typename... Args>
155+
constexpr callable_by_values(Fun&&, Args&&... args): available(has_function<Fun, std::tuple<decltype(args)...>>::value) {}
156+
};
157+
#define TSL_MAKE_CALLABLE(func) \
158+
[](auto&&... args) -> decltype(func(std::forward<decltype(args)>(args)...)) { \
159+
return func(std::forward<decltype(args)>(args)...); \
160+
}
161+
#define TSL_BACKEND_SUPPORTS_BY_TYPE(func, ...) \
162+
tsl::callable_by_types<__VA_ARGS__>(TSL_MAKE_CALLABLE(func)).available
163+
#define TSL_BACKEND_SUPPORTS_BY_VALUE(func, ...) \
164+
tsl::callable_by_values(TSL_MAKE_CALLABLE(func), __VA_ARGS__).available
140165
...

generator/static_files/core/utils/runtime.yaml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ implementations:
1313
class executor {
1414
public:
1515
template<typename T, int Par>
16-
using simd_ext_t = typename ExecTarget::template simd_ext_t<T, Par>;
16+
using simd_ext_by_par_t = typename ExecTarget::template simd_ext_by_par_t<T, Par>;
1717
private:
1818
ExecTarget target;
1919
public:
@@ -83,6 +83,17 @@ implementations:
8383
return ExecTarget::template available_parallelism<T>();
8484
}
8585
86+
template<TSLArithmetic T, int Par>
87+
static constexpr bool parallelism_available() {
88+
constexpr auto avail = available_parallelism<T>();
89+
for (size_t i = 0; i < avail.size(); ++i) {
90+
if (avail[i] == Par) {
91+
return true;
92+
}
93+
}
94+
return false;
95+
}
96+
8697
void wait() {
8798
target.wait();
8899
}

parseForPrimitiveTable.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
from generator.core.ctrl.tsl_lib import TSLLib
1414

1515
class PrintablePrimitive:
16-
def __init__(self, name: str, description: str, ctype_to_extension_dict: dict, vec_type: List[str], parameters: List[str], return_type: str) -> None:
16+
def __init__(self, name: str, functor_name: str, description: str, ctype_to_extension_dict: dict, vec_type: List[str], parameters: List[str], return_type: str) -> None:
1717
self.name = name
18+
self.functor_name = functor_name
1819
self.description = description
1920
self.ctype_to_extension_dict = ctype_to_extension_dict
2021
self.vec_type = vec_type
@@ -26,9 +27,14 @@ def __repr__(self) -> str:
2627
return f"{self.name}: {self.ctype_to_extension_dict}"
2728

2829
def to_html(self, considered_types: list, considered_exts: list) -> str:
30+
button_label = ""
31+
if self.name != self.functor_name:
32+
button_label = f"{self.name} <div style='font-size: small;'><i>overload</i> (calls tsl::functors::{self.functor_name})</div>"
33+
else:
34+
button_label = self.name
2935
call_signature = f"{self.name}&lt;{', '.join(self.vec_type)}&gt;({', '.join(self.parameters)}) -> {self.return_type}"
30-
primitive_button = f"""<div class="primitiveContainer"><div class="primitive"><button id ="{self.name}_link" onclick="togglePrimitive(event, '{self.name}')">{self.name}</button></div>"""
31-
primitive_table_start = f"""<div id="{self.name}" class="primitiveinfo"><p>Brief: <span class="description">{self.description}</span></p><p>Call signature: <br><code>{call_signature}</code></p><center><table border=1 cellpadding=10 cellspacing=0>"""
36+
primitive_button = f"""<div class="primitiveContainer"><div class="primitive"><button id ="{self.functor_name}_link" onclick="togglePrimitive(event, '{self.functor_name}')">{button_label}</button></div>"""
37+
primitive_table_start = f"""<div id="{self.functor_name}" class="primitiveinfo"><p>Brief: <span class="description">{self.description}</span></p><p>Call signature: <br><code>{call_signature}</code></p><center><table border=1 cellpadding=10 cellspacing=0>"""
3238
primitive_table_end = """</table></center></div><br/></div>"""
3339

3440
top_left_corner = """<td style="border-top:0;border-left:0;"></td>"""
@@ -115,7 +121,8 @@ def create_primitive_index_html(tsl_lib: TSLLib) -> None:
115121
checkbox_html += add_checkbox(primitive_class.name)
116122
primitive_html += f"""<div class="primitiveCategory" id="{primitive_class.name}">"""
117123
for primitive in primitive_class:
118-
name = primitive.declaration.functor_name
124+
name = primitive.declaration.name
125+
functor_name = primitive.declaration.functor_name
119126
brief_description = primitive.declaration.data["brief_description"]
120127
detailed_description = primitive.declaration.data["detailed_description"]
121128
ctype_ext_dict = copy.deepcopy(raw_primitive_dict)
@@ -127,7 +134,7 @@ def create_primitive_index_html(tsl_lib: TSLLib) -> None:
127134
vnames = [primitive.declaration.data["vector_name"]]
128135
if primitive.declaration.data["additional_simd_template_parameter"]['name'] != "":
129136
vnames.append(primitive.declaration.data["additional_simd_template_parameter"]['name'])
130-
pP = PrintablePrimitive(name, brief_description, ctype_ext_dict,
137+
pP = PrintablePrimitive(name, functor_name, brief_description, ctype_ext_dict,
131138
vnames,
132139
[f"{html.escape(param['ctype'])} {param['name']}" for param in primitive.declaration.data["parameters"]],
133140
html.escape(primitive.declaration.data['returns']['ctype']))

primitive_data/primitives/mask.yaml

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -798,42 +798,60 @@ definitions:
798798
implementation: "return ~0;"
799799
#INTEL - AVX2
800800
- target_extension: "avx2"
801-
ctype: [ "uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t", "float", "double" ]
801+
ctype: [ "uint8_t", "uint16_t", "uint32_t", "int8_t", "int16_t", "int32_t", "float"]
802802
lscpu_flags: [ "avx2" ]
803803
includes: ["<type_traits>"]
804804
implementation: |
805-
if constexpr(Vec::vector_element_count() < 8) {
806-
return ((static_cast<typename Vec::imask_type>(1)<<Vec::vector_element_count()) - 1);
807-
} else {
808-
return ~0;
809-
}
805+
return static_cast<typename Vec::imask_type>(~0);
806+
- target_extension: "avx2"
807+
ctype: [ "uint64_t", "int64_t", "double" ]
808+
lscpu_flags: [ "avx2" ]
809+
includes: ["<type_traits>"]
810+
implementation: |
811+
return static_cast<typename Vec::imask_type>(0b1111);
810812
#INTEL - SSE
811813
- target_extension: "sse"
812-
ctype: [ "uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t", "float", "double" ]
814+
ctype: [ "uint8_t", "uint16_t", "int8_t", "int16_t"]
813815
lscpu_flags: [ "sse" ]
814816
includes: ["<type_traits>"]
815817
implementation: |
816-
if constexpr(Vec::vector_element_count() < 8) {
817-
return ((static_cast<typename Vec::imask_type>(1)<<Vec::vector_element_count()) - 1);
818-
} else {
819-
return ~0;
820-
}
818+
return static_cast<typename Vec::imask_type>(~0);
819+
- target_extension: "sse"
820+
ctype: [ "uint32_t", "int32_t", "float" ]
821+
lscpu_flags: [ "sse" ]
822+
includes: ["<type_traits>"]
823+
implementation: |
824+
return static_cast<typename Vec::imask_type>(0b1111);
825+
- target_extension: "sse"
826+
ctype: [ "uint64_t", "int64_t", "double" ]
827+
lscpu_flags: [ "sse" ]
828+
includes: ["<type_traits>"]
829+
implementation: |
830+
return static_cast<typename Vec::imask_type>(0b11);
821831
#SCALAR
822832
- target_extension: "scalar"
823833
ctype: [ "uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t", "float", "double" ]
824834
lscpu_flags: []
825835
implementation: return true;
826836
#ARM - NEON
827837
- target_extension: "neon"
828-
ctype: [ "uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t", "float", "double" ]
838+
ctype: [ "uint8_t", "uint16_t", "int8_t", "int16_t"]
829839
lscpu_flags: [ "neon" ]
830840
includes: ["<type_traits>"]
831841
implementation: |
832-
if constexpr(Vec::vector_element_count() < 8) {
833-
return ((static_cast<typename Vec::imask_type>(1)<<Vec::vector_element_count()) - 1);
834-
} else {
835-
return ~0;
836-
}
842+
return static_cast<typename Vec::imask_type>(~0);
843+
- target_extension: "neon"
844+
ctype: [ "uint32_t", "int32_t", "float" ]
845+
lscpu_flags: [ "neon" ]
846+
includes: ["<type_traits>"]
847+
implementation: |
848+
return static_cast<typename Vec::imask_type>(0b1111);
849+
- target_extension: "neon"
850+
ctype: [ "uint64_t", "int64_t", "double" ]
851+
lscpu_flags: [ "neon" ]
852+
includes: ["<type_traits>"]
853+
implementation: |
854+
return static_cast<typename Vec::imask_type>(0b11);
837855
...
838856
---
839857
primitive_name: "integral_all_false"
@@ -957,19 +975,19 @@ definitions:
957975
- target_extension: "avx512"
958976
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t", "float", "double"]
959977
lscpu_flags: ["avx512f"]
960-
implementation: "return (mask >> position) & ((1ULL << Vec::vector_element_count()) - 1);"
978+
implementation: "return (mask >> position) & (static_cast<typename Vec::imask_type>(~0));"
961979
- target_extension: "avx2"
962980
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t", "float", "double"]
963981
lscpu_flags: ["avx"]
964-
implementation: "return (mask >> position) & ((1ULL << Vec::vector_element_count()) - 1);"
982+
implementation: "return (mask >> position) & (static_cast<typename Vec::imask_type>(~0));"
965983
- target_extension: "sse"
966984
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t", "float", "double"]
967985
lscpu_flags: ["sse"]
968-
implementation: "return (mask >> position) & ((1ULL << Vec::vector_element_count()) - 1);"
986+
implementation: "return (mask >> position) & (static_cast<typename Vec::imask_type>(~0));"
969987
- target_extension: "scalar"
970988
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t", "float", "double"]
971989
lscpu_flags: []
972-
implementation: "return (mask >> position) & ((1ULL << Vec::vector_element_count()) - 1);"
990+
implementation: "return (mask >> position) & (static_cast<typename Vec::imask_type>(~0));"
973991
...
974992
#---
975993
#primitive_name: "mask_reduce"

supplementary/runtime/cpu/include/tslCPUrt.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@ namespace tsl {
1919
}
2020
class cpu {
2121
public:
22-
template<typename T, int Par>
23-
using simd_ext_t = typename std::conditional_t<
24-
(Par==1), scalar, typename details::simd_ext_helper_t<sizeof(T)*CHAR_BIT*Par>::extension_t
25-
>;
22+
template <typename T, int Par>
23+
struct simd_ext_by_par_t {
24+
using type = typename details::simd_ext_helper_t<sizeof(T) * CHAR_BIT * Par>::extension_t;
25+
};
26+
template <typename T>
27+
struct simd_ext_by_par_t<T, 1> {
28+
using type = scalar;
29+
};
2630

2731
using max_width_extension_t = {{ avail_extension_types_dict[avail_extension_types_dict.keys()|max] }};
2832
using min_width_extension_t = {{ avail_extension_types_dict[avail_extension_types_dict.keys()|min] }};

0 commit comments

Comments
 (0)