@@ -160,7 +160,7 @@ void Qwen3p5Model::post_init() {
160160 throw std::runtime_error (" Qwen3p5Model computed zero cache row dim" );
161161 }
162162
163- conv_cache_.init (config_.num_layers , deltanet_cache_row_dim_, 1 , Precision::FP16 );
163+ conv_cache_.init (config_.num_layers , deltanet_cache_row_dim_, 1 , Precision::FP32 );
164164 last_forward_used_cache_ = false ;
165165 deltanet_total_seq_len_ = 0 ;
166166}
@@ -451,6 +451,7 @@ size_t Qwen3p5Model::build_gated_deltanet(CactusGraph* gb, size_t normalized_inp
451451 size_t history_2d = 0 ;
452452 if (prev_conv_flat != 0 ) {
453453 history_2d = gb->reshape (prev_conv_flat, {deltanet_conv_history_len_, mixed_proj_dim});
454+ history_2d = gb->precision_cast (history_2d, Precision::FP16 );
454455 } else {
455456 history_2d = gb->input ({deltanet_conv_history_len_, mixed_proj_dim}, Precision::FP16 );
456457 std::vector<__fp16> zeros (deltanet_conv_flat_dim_, static_cast <__fp16>(0 .0f ));
@@ -532,11 +533,11 @@ size_t Qwen3p5Model::build_gated_deltanet(CactusGraph* gb, size_t normalized_inp
532533
533534 size_t deltanet_out;
534535 if (use_cache && seq_len == 1 ) {
535- deltanet_out = gb->gated_deltanet_decode (q_4d, k_4d, v_4d, gate_3d, beta_3d, initial_state, 0 .0f , num_k_heads );
536+ deltanet_out = gb->gated_deltanet_decode (q_4d, k_4d, v_4d, gate_3d, beta_3d, initial_state, 0 .0f );
536537 } else {
537538 const size_t chunk_for_op = std::min<size_t >(64 , std::max<size_t >(1 , seq_len));
538539 deltanet_out = gb->gated_deltanet_prefill (
539- q_4d, k_4d, v_4d, gate_3d, beta_3d, initial_state, chunk_for_op, 0 .0f , num_k_heads );
540+ q_4d, k_4d, v_4d, gate_3d, beta_3d, initial_state, chunk_for_op, 0 .0f );
540541 }
541542
542543 size_t y_4d = gb->slice (deltanet_out, 1 , 0 , seq_len);
@@ -560,6 +561,7 @@ size_t Qwen3p5Model::build_gated_deltanet(CactusGraph* gb, size_t normalized_inp
560561 }
561562 } else if (prev_conv_flat != 0 ) {
562563 history_2d = gb->reshape (prev_conv_flat, {deltanet_conv_history_len_, mixed_proj_dim});
564+ history_2d = gb->precision_cast (history_2d, Precision::FP16 );
563565 } else {
564566 history_2d = gb->input ({deltanet_conv_history_len_, mixed_proj_dim}, Precision::FP16 );
565567 std::vector<__fp16> zeros (deltanet_conv_flat_dim_, static_cast <__fp16>(0 .0f ));
@@ -569,6 +571,7 @@ size_t Qwen3p5Model::build_gated_deltanet(CactusGraph* gb, size_t normalized_inp
569571 size_t history_flat = gb->reshape (history_2d, {1 , deltanet_conv_flat_dim_});
570572 packed_cache = gb->concat (packed_cache, history_flat, 1 );
571573 }
574+ packed_cache = gb->precision_cast (packed_cache, conv_cache_.precision );
572575 conv_cache_state_nodes_[layer_idx] = packed_cache;
573576 cache_k_output_nodes_[layer_idx] = 0 ;
574577 cache_v_output_nodes_[layer_idx] = 0 ;
0 commit comments