Skip to content

Commit 6182793

Browse files
committed
slower but more stable now with fp32 accumulation
Signed-off-by: Karen Mosoyan <karen.mossoyan@gmail.com>
1 parent 859a1ac commit 6182793

5 files changed

Lines changed: 15 additions & 21 deletions

File tree

cactus/engine/engine_tokenizer.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,7 @@ std::string Tokenizer::format_qwen_style(const std::vector<ChatMessage>& message
157157
}
158158

159159
if (add_generation_prompt) {
160-
const bool template_has_think = !chat_template_.empty() && chat_template_.find("<think>") != std::string::npos;
161-
if (template_has_think) {
162-
result += "<|im_start|>assistant\n<think>\n";
163-
} else if (!tools_json.empty()) {
164-
result += "<|im_start|>assistant\n<think>\n</think>\n\n";
165-
} else {
166-
result += "<|im_start|>assistant\n";
167-
}
160+
result += "<|im_start|>assistant\n<think>\n\n</think>\n\n";
168161
}
169162

170163
return result;

cactus/graph/graph.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -549,10 +549,9 @@ class CactusGraph {
549549

550550
size_t lstm_cell(size_t input, size_t h_prev, size_t c_prev, size_t weight_ih, size_t weight_hh, size_t bias_ih, size_t bias_hh);
551551
size_t gated_deltanet_decode(size_t query, size_t key, size_t value, size_t gate_log, size_t beta,
552-
size_t initial_state, float scale = 0.0f, size_t num_qk_heads = 0);
552+
size_t initial_state, float scale = 0.0f);
553553
size_t gated_deltanet_prefill(size_t query, size_t key, size_t value, size_t gate_log, size_t beta,
554-
size_t initial_state, size_t chunk_size = 64, float scale = 0.0f,
555-
size_t num_qk_heads = 0);
554+
size_t initial_state, size_t chunk_size = 64, float scale = 0.0f);
556555
size_t stft(size_t input, size_t weight, size_t stride, size_t num_fft_bins);
557556

558557
size_t sample(size_t logits, float temperature = 0.6f, float top_p = 0.95f, size_t top_k = 20,

cactus/graph/graph_builder.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,7 @@ size_t CactusGraph::lstm_cell(size_t input, size_t h_prev, size_t c_prev, size_t
957957
}
958958

959959
size_t CactusGraph::gated_deltanet_decode(size_t query, size_t key, size_t value, size_t gate_log, size_t beta,
960-
size_t initial_state, float scale, size_t num_qk_heads) {
960+
size_t initial_state, float scale) {
961961
const auto& q = get_output_buffer(query);
962962
const auto& k = get_output_buffer(key);
963963
const auto& v = get_output_buffer(value);
@@ -981,7 +981,9 @@ size_t CactusGraph::gated_deltanet_decode(size_t query, size_t key, size_t value
981981
const size_t K = q.shape[3];
982982
const size_t Hv = v.shape[2];
983983
const size_t V = v.shape[3];
984-
984+
if (T != 1) {
985+
throw std::runtime_error("gated_deltanet_decode expects sequence length T=1");
986+
}
985987
auto is_supported_precision = [](Precision p) {
986988
return p == Precision::FP16 || p == Precision::FP32;
987989
};
@@ -1007,8 +1009,7 @@ size_t CactusGraph::gated_deltanet_decode(size_t query, size_t key, size_t value
10071009
}
10081010

10091011
size_t CactusGraph::gated_deltanet_prefill(size_t query, size_t key, size_t value, size_t gate_log, size_t beta,
1010-
size_t initial_state, size_t chunk_size, float scale,
1011-
size_t num_qk_heads) {
1012+
size_t initial_state, size_t chunk_size, float scale) {
10121013
const auto& q = get_output_buffer(query);
10131014
const auto& k = get_output_buffer(key);
10141015
const auto& v = get_output_buffer(value);
@@ -1032,7 +1033,6 @@ size_t CactusGraph::gated_deltanet_prefill(size_t query, size_t key, size_t valu
10321033
const size_t K = q.shape[3];
10331034
const size_t Hv = v.shape[2];
10341035
const size_t V = v.shape[3];
1035-
10361036
auto is_supported_precision = [](Precision p) {
10371037
return p == Precision::FP16 || p == Precision::FP32;
10381038
};

cactus/graph/graph_ops_nn.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2078,7 +2078,7 @@ void compute_groupnorm_node(GraphNode& node, const std::vector<std::unique_ptr<G
20782078
}
20792079
}
20802080
}
2081-
2081+
}
20822082

20832083
void compute_lstm_cell_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map) {
20842084
const auto& input_buffer = nodes[node_index_map.at(node.input_ids[0])]->output_buffer;
@@ -2137,4 +2137,3 @@ void compute_lstm_cell_node(GraphNode& node, const std::vector<std::unique_ptr<G
21372137
}
21382138
}
21392139
}
2140-
}

cactus/models/model_qwen3p5.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ void Qwen3p5Model::post_init() {
160160
throw std::runtime_error("Qwen3p5Model computed zero cache row dim");
161161
}
162162

163-
conv_cache_.init(config_.num_layers, deltanet_cache_row_dim_, 1, Precision::FP16);
163+
conv_cache_.init(config_.num_layers, deltanet_cache_row_dim_, 1, Precision::FP32);
164164
last_forward_used_cache_ = false;
165165
deltanet_total_seq_len_ = 0;
166166
}
@@ -451,6 +451,7 @@ size_t Qwen3p5Model::build_gated_deltanet(CactusGraph* gb, size_t normalized_inp
451451
size_t history_2d = 0;
452452
if (prev_conv_flat != 0) {
453453
history_2d = gb->reshape(prev_conv_flat, {deltanet_conv_history_len_, mixed_proj_dim});
454+
history_2d = gb->precision_cast(history_2d, Precision::FP16);
454455
} else {
455456
history_2d = gb->input({deltanet_conv_history_len_, mixed_proj_dim}, Precision::FP16);
456457
std::vector<__fp16> zeros(deltanet_conv_flat_dim_, static_cast<__fp16>(0.0f));
@@ -532,11 +533,11 @@ size_t Qwen3p5Model::build_gated_deltanet(CactusGraph* gb, size_t normalized_inp
532533

533534
size_t deltanet_out;
534535
if (use_cache && seq_len == 1) {
535-
deltanet_out = gb->gated_deltanet_decode(q_4d, k_4d, v_4d, gate_3d, beta_3d, initial_state, 0.0f, num_k_heads);
536+
deltanet_out = gb->gated_deltanet_decode(q_4d, k_4d, v_4d, gate_3d, beta_3d, initial_state, 0.0f);
536537
} else {
537538
const size_t chunk_for_op = std::min<size_t>(64, std::max<size_t>(1, seq_len));
538539
deltanet_out = gb->gated_deltanet_prefill(
539-
q_4d, k_4d, v_4d, gate_3d, beta_3d, initial_state, chunk_for_op, 0.0f, num_k_heads);
540+
q_4d, k_4d, v_4d, gate_3d, beta_3d, initial_state, chunk_for_op, 0.0f);
540541
}
541542

542543
size_t y_4d = gb->slice(deltanet_out, 1, 0, seq_len);
@@ -560,6 +561,7 @@ size_t Qwen3p5Model::build_gated_deltanet(CactusGraph* gb, size_t normalized_inp
560561
}
561562
} else if (prev_conv_flat != 0) {
562563
history_2d = gb->reshape(prev_conv_flat, {deltanet_conv_history_len_, mixed_proj_dim});
564+
history_2d = gb->precision_cast(history_2d, Precision::FP16);
563565
} else {
564566
history_2d = gb->input({deltanet_conv_history_len_, mixed_proj_dim}, Precision::FP16);
565567
std::vector<__fp16> zeros(deltanet_conv_flat_dim_, static_cast<__fp16>(0.0f));
@@ -569,6 +571,7 @@ size_t Qwen3p5Model::build_gated_deltanet(CactusGraph* gb, size_t normalized_inp
569571
size_t history_flat = gb->reshape(history_2d, {1, deltanet_conv_flat_dim_});
570572
packed_cache = gb->concat(packed_cache, history_flat, 1);
571573
}
574+
packed_cache = gb->precision_cast(packed_cache, conv_cache_.precision);
572575
conv_cache_state_nodes_[layer_idx] = packed_cache;
573576
cache_k_output_nodes_[layer_idx] = 0;
574577
cache_v_output_nodes_[layer_idx] = 0;

0 commit comments

Comments
 (0)