Skip to content

Commit bd662ce

Browse files
committed
Implemented parallel execution of PredictExpressions
1 parent 9b5815e commit bd662ce

File tree

5 files changed

+115
-25
lines changed

5 files changed

+115
-25
lines changed

src/execution/expression_executor/execute_conjunction.cpp

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
#include "duckdb/planner/expression/bound_conjunction_expression.hpp"
44
#include "duckdb/execution/adaptive_filter.hpp"
55

6+
#include <iostream>
67
#include <random>
8+
#include <future>
9+
10+
#define EXECUTE_PARALLEL 0
711

812
namespace duckdb {
913

@@ -56,6 +60,10 @@ void ExpressionExecutor::Execute(const BoundConjunctionExpression &expr, Express
5660
idx_t ExpressionExecutor::Select(const BoundConjunctionExpression &expr, ExpressionState *state_p,
5761
const SelectionVector *sel, idx_t count, SelectionVector *true_sel,
5862
SelectionVector *false_sel) {
63+
#if EXECUTE_PARALLEL == 1
64+
return SelectParallel(expr, state_p, sel, count, true_sel, false_sel);
65+
#endif
66+
5967
auto &state = state_p->Cast<ConjunctionState>();
6068

6169
if (expr.GetExpressionType() == ExpressionType::CONJUNCTION_AND) {
@@ -140,4 +148,104 @@ idx_t ExpressionExecutor::Select(const BoundConjunctionExpression &expr, Express
140148
}
141149
}
142150

151+
idx_t ExpressionExecutor::AsyncSelect(const Expression &expr, ExpressionState *state, const SelectionVector *sel,
152+
idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) {
153+
idx_t out = Select(expr, state, sel, count, true_sel, false_sel);
154+
// for (idx_t j = 0; j < out; j++) {
155+
// std::cout << "AsyncSelect- " << " idx- " << j << ", mapping- " << true_sel->get_index(j) << std::endl;
156+
// }
157+
return out;
158+
}
159+
160+
idx_t ExpressionExecutor::SelectParallel(const BoundConjunctionExpression &expr, ExpressionState *state_p,
161+
const SelectionVector *sel, idx_t count, SelectionVector *true_sel,
162+
SelectionVector *false_sel) {
163+
auto &state = state_p->Cast<ConjunctionState>();
164+
165+
if (expr.GetExpressionType() == ExpressionType::CONJUNCTION_AND) {
166+
// get runtime statistics
167+
auto filter_state = state.adaptive_filter->BeginFilter();
168+
const SelectionVector *current_sel = sel;
169+
idx_t false_count = 0;
170+
171+
unique_ptr<SelectionVector> temp_true, temp_false;
172+
if (false_sel) {
173+
temp_false = make_uniq<SelectionVector>(STANDARD_VECTOR_SIZE);
174+
}
175+
if (!true_sel) {
176+
temp_true = make_uniq<SelectionVector>(STANDARD_VECTOR_SIZE);
177+
true_sel = temp_true.get();
178+
}
179+
180+
vector<std::future<idx_t>> futures;
181+
// vector<idx_t> futures_i;
182+
vector<unique_ptr<SelectionVector>> children_true;
183+
for (idx_t i = 0; i < expr.children.size(); i++) {
184+
children_true.push_back(make_uniq<SelectionVector>(STANDARD_VECTOR_SIZE));
185+
}
186+
187+
vector<SelectionVector> children_false{expr.children.size(), SelectionVector{STANDARD_VECTOR_SIZE}};
188+
vector<idx_t> true_counts{expr.children.size(), 0};
189+
190+
for (idx_t i = 0; i < expr.children.size(); i++) {
191+
// futures_i.push_back(AsyncSelect(*expr.children[state.adaptive_filter->permutation[i]],
192+
// state.child_states[state.adaptive_filter->permutation[i]].get(),
193+
// current_sel, count, children_true[i].get(), &children_false[i]));
194+
195+
// for (idx_t j = 0; j < futures_i[i]; j++) {
196+
// std::cout << "InLoopFilter- " << i << " , idx- " << j << ", mapping- " << children_true[i]->get_index(j) << std::endl;
197+
// }
198+
futures.push_back(std::async(std::launch::async, &ExpressionExecutor::AsyncSelect,
199+
this, std::cref(*expr.children[state.adaptive_filter->permutation[i]]),
200+
state.child_states[state.adaptive_filter->permutation[i]].get(),
201+
std::cref(current_sel), count, children_true[i].get(), &children_false[i]));
202+
}
203+
204+
for (idx_t i = 0; i < expr.children.size(); i++) {
205+
// idx_t tcount = futures_i[i];
206+
idx_t tcount = futures[i].get();
207+
idx_t fcount = count - tcount;
208+
if (fcount > 0 && false_sel) {
209+
// move failing tuples into the false_sel
210+
// tuples passed, move them into the actual result vector
211+
for (idx_t j = 0; j < fcount; j++) {
212+
false_sel->set_index(false_count++, children_false[i].get_index(j));
213+
}
214+
}
215+
true_counts[i] = tcount;
216+
// std::cout << "Filtered: " << tcount << std::endl;
217+
}
218+
idx_t true_count = 0;
219+
idx_t lp = 0;
220+
idx_t rp = 0;
221+
while (lp < true_counts[0] && rp < true_counts[1]) {
222+
idx_t lidx = children_true[0]->get_index(lp);
223+
idx_t ridx = children_true[1]->get_index(rp);
224+
if (lidx == ridx) {
225+
// std::cout << "Match: " << lidx << std::endl;
226+
true_sel->set_index(true_count, lidx);
227+
true_count++;
228+
lp++;
229+
rp++;
230+
} else if (lidx < ridx) {
231+
lp++;
232+
} else if (ridx < lidx) {
233+
rp++;
234+
}
235+
}
236+
// for (idx_t i = 0; i < children_true.size(); i++) {
237+
// for (idx_t j = 0; j < true_counts[i]; j++) {
238+
// std::cout << "Filter- " << i << " , idx- " << j << ", mapping- " << children_true[i]->get_index(j) << std::endl;
239+
// true_sel->set_index(true_count, children_true[i]->get_index(j));
240+
// true_count++;
241+
// }
242+
// }
243+
// adapt runtime statistics
244+
state.adaptive_filter->EndFilter(filter_state);
245+
return true_count;
246+
} else {
247+
throw std::runtime_error("Expression executor does not support select_parallel() with !AND");
248+
}
249+
}
250+
143251
} // namespace duckdb

src/include/duckdb/execution/expression_executor.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ class ExpressionExecutor {
165165
SelectionVector *true_sel, SelectionVector *false_sel);
166166
idx_t Select(const BoundConjunctionExpression &expr, ExpressionState *state, const SelectionVector *sel,
167167
idx_t count, SelectionVector *true_sel, SelectionVector *false_sel);
168+
idx_t SelectParallel(const BoundConjunctionExpression &expr, ExpressionState *state, const SelectionVector *sel,
169+
idx_t count, SelectionVector *true_sel, SelectionVector *false_sel);
170+
idx_t AsyncSelect(const Expression &expr, ExpressionState *state, const SelectionVector *sel,
171+
idx_t count, SelectionVector *true_sel, SelectionVector *false_sel);
168172

169173
//! Verify that the output of a step in the ExpressionExecutor is correct
170174
void Verify(const Expression &expr, Vector &result, idx_t count);

third_party/predictors/common/common.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class PromptUtil {
1515
static std::string extract_json(const std::string &text) {
1616
const size_t start = text.find_first_of("{[");
1717
if (start == std::string::npos) {
18-
throw std::runtime_error("No JSON start found");
18+
throw std::runtime_error("No JSON start found: " + text);
1919
}
2020

2121
const char open = text[start];

third_party/predictors/llm_api/duckdb_llm_api.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ std::unique_ptr<BatchResult> LlmApiPredictor::PredictBatch(OpenAI &api, const ve
202202
request["model"] = this->model_path;
203203
request["messages"] = {{{"content", GenerateSystemMessage(true)}, {"role", "system"}},
204204
{{"content", rewritten}, {"role", "user"}}};
205+
request["temperature"] = 0.2;
205206
#if IS_SCHEMA
206207
std::stringstream sch;
207208
sch << "{\"type\":\"json_schema\",\"json_schema\":{\"name\":\"json_response\",\"strict\":true,";
@@ -224,6 +225,7 @@ std::unique_ptr<BatchResult> LlmApiPredictor::PredictBatch(OpenAI &api, const ve
224225
LLM_LOG( "Too much requests!\n");
225226
}
226227
} else {
228+
LLM_LOG("LLM Output: " + completion.dump());
227229
tokens += completion["usage"]["total_tokens"].get<int>();
228230
in_tokens += completion["usage"]["prompt_tokens"].get<int>();
229231
out_tokens += completion["usage"]["completion_tokens"].get<int>();

third_party/predictors/llm_api/https_client.cpp

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -147,30 +147,6 @@ unique_ptr<HTTPParams> HTTPSUtil::InitializeParameters(optional_ptr<FileOpener>
147147
info);
148148
FileOpener::TryGetCurrentSetting(opener, "ca_cert_file", result->ca_cert_file, info);
149149

150-
// HTTP Secret lookups
151-
KeyValueSecretReader settings_reader(*opener, info, "http");
152-
153-
if (string proxy_setting;
154-
settings_reader.TryGetSecretKey<string>("http_proxy", proxy_setting) && !proxy_setting.empty()) {
155-
idx_t port;
156-
string host;
157-
ParseHTTPProxyHost(proxy_setting, host, port);
158-
result->http_proxy = host;
159-
result->http_proxy_port = port;
160-
}
161-
settings_reader.TryGetSecretKey<string>("http_proxy_username", result->http_proxy_username);
162-
settings_reader.TryGetSecretKey<string>("http_proxy_password", result->http_proxy_password);
163-
settings_reader.TryGetSecretKey<string>("bearer_token", result->bearer_token);
164-
165-
if (Value extra_headers; settings_reader.TryGetSecretKey("extra_http_headers", extra_headers)) {
166-
auto children = MapValue::GetChildren(extra_headers);
167-
for (const auto &child : children) {
168-
auto kv = StructValue::GetChildren(child);
169-
D_ASSERT(kv.size() == 2);
170-
result->extra_headers[kv[0].GetValue<string>()] = kv[1].GetValue<string>();
171-
}
172-
}
173-
174150
return std::move(result);
175151
}
176152

0 commit comments

Comments
 (0)