Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cactus-engine/src/cloud.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ static std::string call_cloud_endpoint(const std::string& url,
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, curl_write_cb);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response_body);
curl_easy_setopt(curl, CURLOPT_TIMEOUT_MS, timeout_ms);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT_MS, std::min<long>(timeout_ms, 2000L));
curl_easy_setopt(curl, CURLOPT_NOSIGNAL, 1L);

if (!env_flag_enabled("CACTUS_CLOUD_STRICT_SSL")) {
curl_easy_setopt(curl, CURLOPT_SSL_VERIFYPEER, 0L);
Expand Down
199 changes: 138 additions & 61 deletions cactus-engine/src/complete.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "wav.h"
#include <algorithm>
#include <chrono>
#include <cmath>
#include <cstdlib>
#include <cstdint>
#include <cstring>
Expand Down Expand Up @@ -456,8 +457,14 @@ PreparedPrompt prepare_prompt(
prompt.model_type = handle->model->get_config().model_type;

if (prompt.options.confidence_threshold < 0.0f) {
float model_default = handle->model->get_config().default_cloud_handoff_threshold;
prompt.options.confidence_threshold = (model_default > 0.0f) ? model_default : 0.7f;
if (handle->model->has_handoff_probe()) {
// The Gemma4 probe returns p_wrong; confidence is 1 - p_wrong.
// Route when the probe is less than 50% confident in local output.
prompt.options.confidence_threshold = 0.50f;
} else {
float model_default = handle->model->get_config().default_cloud_handoff_threshold;
prompt.options.confidence_threshold = (model_default > 0.0f) ? model_default : 0.7f;
}
}

if (prompt.model_type == Config::ModelType::GEMMA4) {
Expand Down Expand Up @@ -694,6 +701,83 @@ int cactus_complete(

bool has_images = prompt.has_images();
bool has_audio = prompt.has_audio();
const bool cloud_disabled = env_flag_enabled("CACTUS_DISABLE_CLOUD_HANDOFF");
const bool cloud_eligible = !cloud_disabled &&
prompt.options.auto_handoff && (!has_images || prompt.options.handoff_with_images);
handle->model->reset_handoff_probe_rollout();
const bool defer_local_stream_until_probe = cloud_eligible && handle->model->has_handoff_probe();
bool pre_generation_cloud_attempted = false;

auto make_cloud_request = [&](const std::string& local_output_hint,
const std::vector<std::string>& local_calls_hint) {
CloudCompletionRequest request;
request.messages = prompt.messages;
request.tools = prompt.tools;
request.local_output = local_output_hint;
request.local_function_calls = local_calls_hint;
request.has_images = has_images;
request.has_audio = has_audio;
if (has_audio && pcm_buffer != nullptr && pcm_buffer_size > 0) {
request.audio_pcm.assign(pcm_buffer, pcm_buffer + pcm_buffer_size);
}
request.cloud_key = resolve_cloud_api_key(nullptr);
return request;
};

auto return_cloud_completion = [&](const CloudCompletionResult& cloud_result,
double ttft_ms,
double total_ms,
float confidence,
size_t prompt_token_count) {
std::string cloud_response = cloud_result.response;
std::vector<std::string> cloud_calls = cloud_result.function_calls;
if (callback && !cloud_response.empty()) {
callback(cloud_response.c_str(), 0, user_data);
}
std::string result = construct_response_json(cloud_response, cloud_calls, ttft_ms,
total_ms, 0.0, 0.0, prompt_token_count,
0, confidence, true, "");
if (result.length() >= buffer_size) {
handle_error_response("Response buffer too small", response_buffer, buffer_size);
return -1;
}
std::strcpy(response_buffer, result.c_str());

cactus::telemetry::CompletionMetrics metrics{};
metrics.success = true;
metrics.cloud_handoff = true;
metrics.ttft_ms = ttft_ms;
metrics.prefill_tps = 0.0;
metrics.decode_tps = 0.0;
metrics.response_time_ms = total_ms;
metrics.confidence = confidence;
metrics.ram_usage_mb = get_ram_usage_mb();
metrics.prefill_tokens = prompt_token_count;
metrics.decode_tokens = 0;
metrics.error_message = nullptr;
metrics.function_calls_json = nullptr;
cactus::telemetry::recordCompletion(handle->model_name.c_str(), metrics);
return static_cast<int>(result.length());
};

if (cloud_eligible && prompt.options.confidence_threshold >= 1.0f) {
pre_generation_cloud_attempted = true;
CACTUS_LOG_WARN("cloud_handoff", "Cloud handoff triggered before local generation; waiting up to "
<< prompt.options.cloud_timeout_ms << " ms before falling back");
auto cloud_result = cloud_complete_request(
make_cloud_request("", {}),
static_cast<long>(prompt.options.cloud_timeout_ms));
auto now = std::chrono::high_resolution_clock::now();
double elapsed_ms = std::chrono::duration_cast<std::chrono::microseconds>(now - start_time).count() / 1000.0;
if (cloud_result.ok && (!cloud_result.response.empty() || !cloud_result.function_calls.empty())) {
return return_cloud_completion(cloud_result, elapsed_ms, elapsed_ms, 0.0f, prompt.tokens.size());
}
std::string cloud_error = cloud_result.error.empty() ? "cloud completion failed" : cloud_result.error;
CACTUS_LOG_WARN("cloud_handoff", "Cloud completion failed before local generation: " << cloud_error);
handle_error_response(("cloud handoff failed before local generation: " + cloud_error).c_str(),
response_buffer, buffer_size);
return -1;
}

auto stop_token_sequences = build_stop_sequences(tokenizer, prompt.options.stop_sequences, prompt.model_type, !prompt.tools.empty());

Expand Down Expand Up @@ -737,40 +821,7 @@ int cactus_complete(
time_to_first_token = std::chrono::duration_cast<std::chrono::microseconds>(token_end - start_time).count() / 1000.0;

float confidence = 1.0f - first_token_entropy;
bool cloud_used = false;
std::string cloud_error;
std::future<CloudCompletionResult> cloud_future;
bool cloud_future_started = false;
const bool cloud_disabled = env_flag_enabled("CACTUS_DISABLE_CLOUD_HANDOFF");
const bool cloud_eligible = !cloud_disabled &&
prompt.options.auto_handoff && (!has_images || prompt.options.handoff_with_images);

auto maybe_start_cloud_handoff = [&](const std::string& local_output_hint,
const std::vector<std::string>& local_calls_hint) {
if (!cloud_eligible || cloud_future_started) {
return;
}
CloudCompletionRequest request;
request.messages = prompt.messages;
request.tools = prompt.tools;
request.local_output = local_output_hint;
request.local_function_calls = local_calls_hint;
request.has_images = has_images;
request.has_audio = has_audio;
if (has_audio && pcm_buffer != nullptr && pcm_buffer_size > 0) {
request.audio_pcm.assign(pcm_buffer, pcm_buffer + pcm_buffer_size);
}
request.cloud_key = resolve_cloud_api_key(nullptr);

cloud_future_started = true;
cloud_future = std::async(std::launch::async, [request, &prompt]() {
return cloud_complete_request(request, static_cast<long>(prompt.options.cloud_timeout_ms));
});
};

if (confidence < prompt.options.confidence_threshold) {
maybe_start_cloud_handoff("", {});
}

generated_tokens.push_back(next_token);
handle->processed_tokens.push_back(next_token);
Expand All @@ -787,7 +838,27 @@ int cactus_complete(
entropy.add(first_token_entropy);

if (!matches_stop_sequence(generated_tokens, stop_token_sequences)) {
if (callback) {
if (!defer_local_stream_until_probe
&& !pre_generation_cloud_attempted
&& confidence < prompt.options.confidence_threshold) {
CACTUS_LOG_WARN("cloud_handoff", "Cloud handoff triggered before local streaming; waiting up to "
<< prompt.options.cloud_timeout_ms << " ms before falling back");
CloudCompletionResult cloud_result = cloud_complete_request(
make_cloud_request("", {}),
static_cast<long>(prompt.options.cloud_timeout_ms));
auto now = std::chrono::high_resolution_clock::now();
double elapsed_ms = std::chrono::duration_cast<std::chrono::microseconds>(now - start_time).count() / 1000.0;
if (cloud_result.ok && (!cloud_result.response.empty() || !cloud_result.function_calls.empty())) {
if (prompt.options.force_tools && !prompt.tools.empty()) {
handle->model->clear_tool_constraints();
}
return return_cloud_completion(cloud_result, elapsed_ms, elapsed_ms, confidence, prompt_tokens);
}
cloud_error = cloud_result.error.empty() ? "cloud completion failed" : cloud_result.error;
CACTUS_LOG_WARN("cloud_handoff", "Cloud completion failed before local streaming, falling back to local output: " << cloud_error);
}

if (callback && !defer_local_stream_until_probe) {
std::string new_text = tokenizer->decode({next_token});
callback(new_text.c_str(), next_token, user_data);
}
Expand All @@ -813,7 +884,6 @@ int cactus_complete(

if (entropy.rolling_confidence() < prompt.options.confidence_threshold) {
entropy.spike_handoff = true;
maybe_start_cloud_handoff("", {});
}

if (prompt.options.force_tools && !prompt.tools.empty()) {
Expand All @@ -825,7 +895,7 @@ int cactus_complete(
break;
}

if (callback) {
if (callback && !defer_local_stream_until_probe) {
std::string new_text = tokenizer->decode({next_token});
callback(new_text.c_str(), next_token, user_data);
}
Expand All @@ -835,6 +905,14 @@ int cactus_complete(
}

confidence = entropy.mean_confidence();
if (defer_local_stream_until_probe && handle->model->has_handoff_probe_rollout()) {
float wrong_probability = handle->model->handoff_probe_wrong_probability();
if (std::isfinite(wrong_probability)) {
confidence = std::max(0.0f, std::min(1.0f, 1.0f - wrong_probability));
CACTUS_LOG_DEBUG("cloud_handoff", "Gemma4 handoff probe p_wrong="
<< wrong_probability << " confidence=" << confidence);
}
}

if (prompt.options.force_tools && !prompt.tools.empty()) {
handle->model->clear_tool_constraints();
Expand Down Expand Up @@ -868,40 +946,39 @@ int cactus_complete(
}
}

if (confidence < prompt.options.confidence_threshold) {
maybe_start_cloud_handoff(regular_response, function_calls);
}

std::string local_completion = regular_response;
if (local_completion.empty() && function_calls.empty()) {
local_completion = response_text;
}
std::string primary_response = local_completion;
std::vector<std::string> primary_function_calls = function_calls;

if (cloud_future_started) {
auto status = cloud_future.wait_for(std::chrono::milliseconds(prompt.options.cloud_timeout_ms));
if (status == std::future_status::ready) {
CloudCompletionResult cloud_result = cloud_future.get();
if (cloud_result.ok && (!cloud_result.response.empty() || !cloud_result.function_calls.empty())) {
cloud_used = true;
if (!cloud_result.response.empty()) {
primary_response = cloud_result.response;
}
if (!cloud_result.function_calls.empty()) {
primary_function_calls = cloud_result.function_calls;
}
} else {
cloud_error = cloud_result.error.empty() ? "cloud completion failed" : cloud_result.error;
CACTUS_LOG_WARN("cloud_handoff", "Cloud completion failed, falling back to local output: " << cloud_error);
bool handoff_succeeded = false;
if (defer_local_stream_until_probe && !pre_generation_cloud_attempted
&& confidence < prompt.options.confidence_threshold) {
CACTUS_LOG_WARN("cloud_handoff", "Cloud handoff triggered by Gemma4 probe: p_wrong="
<< (1.0f - confidence) << " confidence=" << confidence
<< "; waiting up to " << prompt.options.cloud_timeout_ms << " ms");
CloudCompletionResult cloud_result = cloud_complete_request(
make_cloud_request(local_completion, function_calls),
static_cast<long>(prompt.options.cloud_timeout_ms));
auto now = std::chrono::high_resolution_clock::now();
double elapsed_ms = std::chrono::duration_cast<std::chrono::microseconds>(now - start_time).count() / 1000.0;
if (cloud_result.ok && (!cloud_result.response.empty() || !cloud_result.function_calls.empty())) {
if (prompt.options.force_tools && !prompt.tools.empty()) {
handle->model->clear_tool_constraints();
}
} else {
cloud_error = "timeout";
CACTUS_LOG_WARN("cloud_handoff", "Cloud completion timed out, falling back to local output: " << cloud_error);
return return_cloud_completion(cloud_result, elapsed_ms, elapsed_ms, confidence, prompt_tokens);
}
cloud_error = cloud_result.error.empty() ? "cloud completion failed" : cloud_result.error;
CACTUS_LOG_WARN("cloud_handoff", "Cloud completion failed after probe handoff, falling back to local output: "
<< cloud_error);
}

if (callback && defer_local_stream_until_probe && !primary_response.empty()) {
callback(primary_response.c_str(), 0, user_data);
}

const bool handoff_succeeded = cloud_used;
std::string result = construct_response_json(primary_response, primary_function_calls, time_to_first_token,
total_time_ms, prefill_tps, decode_tps, prompt_tokens,
completion_tokens, confidence, handoff_succeeded,
Expand Down
31 changes: 28 additions & 3 deletions cactus-engine/src/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ bool is_npu_available();
} // namespace npu
namespace engine {


struct Config {
uint32_t vocab_size = 151936;
uint32_t bos_token_id = 151643;
Expand Down Expand Up @@ -631,6 +630,11 @@ class Model {
bool has_vocab_bias() const { return !vocab_bias_.empty(); }
const std::unordered_map<uint32_t, float>& get_vocab_bias() const { return vocab_bias_; }

bool has_handoff_probe() const { return handoff_probe_loaded_; }
void reset_handoff_probe_rollout() { handoff_probe_hidden_.clear(); }
bool has_handoff_probe_rollout() const;
float handoff_probe_wrong_probability() const;

private:
struct Binding {
int node_id = -1;
Expand Down Expand Up @@ -688,13 +692,16 @@ class Model {
ChunkedPrefillResult run_chunked_prefill(const std::vector<uint32_t>& tokens, size_t start_position,
size_t chunk_size, bool prepare_decode);
void run_full_context_text();
uint32_t argmax_component_logits(Component& comp, size_t logit_row = std::numeric_limits<size_t>::max());
uint32_t argmax_component_logits(Component& comp, size_t logit_row = std::numeric_limits<size_t>::max(),
float* out_uncertainty = nullptr);
void write_int_input(Component& comp, const std::string& name, int64_t value);
void write_int_input_at(Component& comp, const std::string& name, size_t index, int64_t value);
void write_bytes_input(Component& comp, const std::string& name, const void* data, size_t byte_size);
int input_index(const Component& comp, const std::string& name) const;
int output_index(const Component& comp, const std::string& name) const;
uint32_t argmax_last_logits();
uint32_t argmax_last_logits(float* out_uncertainty = nullptr);
bool load_handoff_probe();
void maybe_capture_handoff_probe_hidden(const Component& comp);
void run_vision_encoder(const std::string& image_path);
void run_audio_encoder(const std::vector<float>& audio_features);
void run_audio_encoder_messages(const std::vector<std::vector<float>>& audio_features_per_message);
Expand Down Expand Up @@ -754,6 +761,24 @@ class Model {
ToolCallConstrainer tool_constrainer_;
std::unordered_map<uint32_t, float> vocab_bias_;

bool handoff_probe_loaded_ = false;
uint32_t handoff_probe_feat_dim_ = 0;
uint32_t handoff_probe_t_h_ = 0;
uint32_t handoff_probe_h1_ = 0;
uint32_t handoff_probe_h2_ = 0;
std::vector<float> handoff_probe_norm_weight_;
std::vector<float> handoff_probe_norm_bias_;
std::vector<float> handoff_probe_proj_weight_;
std::vector<float> handoff_probe_proj_bias_;
std::vector<float> handoff_probe_attn_query_;
std::vector<float> handoff_probe_head0_weight_;
std::vector<float> handoff_probe_head0_bias_;
std::vector<float> handoff_probe_head2_weight_;
std::vector<float> handoff_probe_head2_bias_;
std::vector<float> handoff_probe_head4_weight_;
std::vector<float> handoff_probe_head4_bias_;
std::vector<float> handoff_probe_hidden_;

mutable std::vector<DebugNode> debug_nodes_;
};

Expand Down
Loading
Loading