Skip to content

Commit 8b38e48

Browse files
committed
Implement SIMD-based search function
1 parent 2319ff4 commit 8b38e48

File tree

2 files changed

+72
-4
lines changed

2 files changed

+72
-4
lines changed

include/simd_search.h

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// NOLINTBEGIN(*)
2+
3+
#ifndef SIMD_SEARCH_H_
4+
#define SIMD_SEARCH_H_
5+
6+
#include <cstddef>
7+
#include <cstdint>
8+
#include <string_view>
9+
10+
#ifdef __ARM_NEON
11+
#include <arm_neon.h>
12+
#endif
13+
14+
#ifdef __ARM_NEON
15+
inline bool simd_search4(std::string_view haystack, std::string_view needle) {
16+
if (haystack.size() < 4) {
17+
return false;
18+
}
19+
// Prepare the needle value and replicate it into a SIMD register.
20+
const uint32_t needle_val = *reinterpret_cast<const uint32_t*>(needle.data());
21+
const uint32x4_t needle_vec = vdupq_n_u32(needle_val);
22+
// Set up pointers.
23+
const char* data = haystack.data();
24+
const size_t haystack_len = haystack.size();
25+
const auto* curr = reinterpret_cast<const uint8_t*>(data);
26+
const uint8_t* end = curr + haystack_len - 3;
27+
while (curr + 16 <= end) {
28+
uint8x16_t data0 = vld1q_u8(curr + 0);
29+
uint8x16_t data1 = vld1q_u8(curr + 1);
30+
uint8x16_t data2 = vld1q_u8(curr + 2);
31+
uint8x16_t data3 = vld1q_u8(curr + 3);
32+
uint32x4_t seq0 = vreinterpretq_u32_u8(data0);
33+
uint32x4_t seq1 = vreinterpretq_u32_u8(data1);
34+
uint32x4_t seq2 = vreinterpretq_u32_u8(data2);
35+
uint32x4_t seq3 = vreinterpretq_u32_u8(data3);
36+
uint32x4_t eq0 = vceqq_u32(seq0, needle_vec);
37+
uint32x4_t eq1 = vceqq_u32(seq1, needle_vec);
38+
uint32x4_t eq2 = vceqq_u32(seq2, needle_vec);
39+
uint32x4_t eq3 = vceqq_u32(seq3, needle_vec);
40+
uint32x4_t combined = vorrq_u32(vorrq_u32(eq0, eq1), vorrq_u32(eq2, eq3));
41+
if (vmaxvq_u32(combined)) return true;
42+
curr += 13;
43+
}
44+
while (curr <= end) {
45+
if (*reinterpret_cast<const uint32_t*>(curr) == needle_val) {
46+
return true;
47+
}
48+
++curr;
49+
}
50+
return false;
51+
}
52+
#else
53+
#include <string>
54+
inline bool simd_search4(std::string_view haystack, std::string_view needle) {
55+
// Fallback: use std::string_view::find.
56+
return haystack.find(needle) != std::string_view::npos;
57+
}
58+
#endif
59+
60+
#endif // SIMD_SEARCH_H_
61+
62+
// NOLINTEND(*)

src/sparser/sparser.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "node.h"
2323
#include "raw_filter.h"
2424
#include "rdtsc.h"
25+
#include "simd_search.h"
2526

2627
EstimationResult Sparser::Calibrate(const std::vector<std::string_view>& input, const JsonQuery& json_query,
2728
const RawFilterData& rf_data) {
@@ -46,14 +47,14 @@ EstimationResult Sparser::Calibrate(const std::vector<std::string_view>& input,
4647
auto idx = GetFlatIdx(conj_idx, pred_idx, rf_idx);
4748

4849
auto grepStart = rdtsc();
49-
auto find_result = json_row.find(rf);
50+
auto find_result = simd_search4(json_row, rf);
5051
auto grepEnd = rdtsc();
5152

5253
const auto rf_runtime = static_cast<double>(grepEnd - grepStart);
5354
total_rf_time += rf_runtime;
5455
rf_count++;
5556

56-
if (find_result != std::string_view::npos) {
57+
if (find_result) {
5758
#ifndef NDEBUG
5859
std::cout << "Found: " << rf << "\n";
5960
#endif
@@ -91,10 +92,15 @@ void Sparser::Run(const std::string& input_path, const JsonQuery& json_query) {
9192

9293
const auto& disjunction = json_query.GetDisjunction();
9394
auto rf_data = RawFilterQueryGenerator::GenerateRawFilters(disjunction);
95+
96+
auto calibrate_time_start = benchmark_start();
9497
auto estimation_result = Calibrate(sparser_input, json_query, rf_data);
98+
auto calibrate_time = benchmark_stop(calibrate_time_start);
99+
std::cout << "Calibration time: " << calibrate_time << '\n';
95100

96101
auto cascade_builder = CascadeBuilder(disjunction, rf_data);
97102
auto valid_cascades = cascade_builder.GenerateValidCascades();
103+
std::cout << "Generated " << valid_cascades.size() << " valid cascades\n";
98104

99105
auto cascade_evaluator = CascadeEvaluator(estimation_result);
100106
double min_cost = std::numeric_limits<double>::max();
@@ -141,8 +147,8 @@ SparserSearchStats Sparser::SearchCascade(const std::vector<std::string_view>& i
141147
while (root->type == NodeType::INTER) {
142148
auto rf = rf_data.data.at(root->conjunction_idx).at(root->predicate_idx).at(root->raw_filter_idx);
143149

144-
auto find_result = record.find(rf);
145-
if (find_result != std::string_view::npos) {
150+
auto find_result = simd_search4(record, rf);
151+
if (find_result) {
146152
root = root->right;
147153
} else {
148154
root = root->left;

0 commit comments

Comments
 (0)