Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions cactus-engine/src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1072,26 +1072,31 @@ Model::ChunkedPrefillResult Model::run_chunked_prefill(const std::vector<uint32_
size_t effective_chunk = chunk_size > 0 ? std::min(chunk_size, component_tokens) : component_tokens;
if (effective_chunk != component_tokens) effective_chunk = component_tokens;
size_t whole_chunks_end = (tokens.size() / effective_chunk) * effective_chunk;
const bool has_recurrent_state = [&]() {
auto any_cache_node = [&](auto predicate) {
if (!decoder_prefill_->graph) return false;
for (const auto& state : decoder_prefill_->cache_states) {
for (int node_id : {state.key_node_id, state.value_node_id}) {
if (node_id < 0) continue;
if (decoder_prefill_->graph->get_node_op_type(static_cast<size_t>(node_id))
== OpType::RECURRENT_CACHE_STATE) {
return true;
}
if (predicate(static_cast<size_t>(node_id))) return true;
}
}
return false;
}();
};
const bool has_recurrent_state = any_cache_node([&](size_t id) {
return decoder_prefill_->graph->get_node_op_type(id) == OpType::RECURRENT_CACHE_STATE;
});
if (has_recurrent_state && whole_chunks_end > effective_chunk) {
whole_chunks_end = effective_chunk;
}
const bool has_sliding_window_cache = any_cache_node([&](size_t id) {
return decoder_prefill_->graph->get_node_op_type(id) == OpType::KV_CACHE_STATE
&& decoder_prefill_->graph->get_node_window_size(id) > 0;
});
const size_t tail_tokens = tokens.size() - whole_chunks_end;
const size_t padding_cutoff = std::max<size_t>(1, effective_chunk / 16);
const bool pad_tail = family_ != "lfm2_vl"
&& !has_recurrent_state
&& !has_sliding_window_cache
&& tail_tokens >= padding_cutoff;
const size_t executable_tokens = whole_chunks_end + (pad_tail ? effective_chunk : 0);
if (executable_tokens == 0) {
Expand Down
1 change: 1 addition & 0 deletions cactus-graph/cactus_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,7 @@ class CactusGraph {
const std::vector<size_t>& output_shape, const OpParams& params = {});
const BufferDesc& get_output_buffer(size_t node_id) const;
OpType get_node_op_type(size_t node_id) const;
size_t get_node_window_size(size_t node_id) const;
void allocate_buffers();
size_t get_node_count() const;

Expand Down
4 changes: 4 additions & 0 deletions cactus-graph/src/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1257,6 +1257,10 @@ OpType CactusGraph::get_node_op_type(size_t node_id) const {
return nodes_[node_index_map_.at(node_id)]->op_type;
}

size_t CactusGraph::get_node_window_size(size_t node_id) const {
return nodes_[node_index_map_.at(node_id)]->params.window_size;
}

size_t CactusGraph::persistent(size_t source_node) {
const auto& source_buffer = get_output_buffer(source_node);
OpParams params;
Expand Down
121 changes: 121 additions & 0 deletions python/cactus/transpile/optimize_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from cactus.transpile.fusion import match_rope
from cactus.transpile.fusion import match_self_attention_block
from cactus.transpile.fusion.common import producer
from cactus.transpile.fusion.common import strip_layout_passthrough
from cactus.transpile.fusion.common import strip_passthrough
from cactus.transpile.fusion.linear import match_linear
from cactus.transpile.fusion.rope import _find_constant_ancestor
from cactus.transpile.graph_ir import IRGraph
from cactus.transpile.graph_ir import IRNode
from cactus.transpile.graph_ir import IRValue
Expand Down Expand Up @@ -106,6 +108,9 @@ def optimize_graph(graph: IRGraph, *, max_passes: int = 8, config: FusionConfig
break
canonicalize_exported_graph(graph)

if precompute_rope_tables(graph):
canonicalize_exported_graph(graph)

annotate_gold_patterns(graph)
_prune_unused_inputs(graph)
verify_ir(graph)
Expand Down Expand Up @@ -361,6 +366,122 @@ def fuse_rope(graph: IRGraph) -> bool:
return changed


def precompute_rope_tables(graph: IRGraph) -> bool:
"""Replace the runtime ``cos/sin(position * inv_freq)`` rope angle with an
fp64-precomputed fp16 table gathered by position id (the fp16 angle matmul
cannot represent positions > 2048, randomising cos/sin past long context)."""

changed = False
for node_id in list(graph.order):
node = graph.nodes.get(node_id)
if node is None or node.op not in {"scalar_cos", "scalar_sin"} or len(node.inputs) != 1:
continue
# Only the text decoder's rope. Vision/audio encoders carry cos/sin with
# the same shape but an axial/2D position this scalar table cannot model.
if node.meta.get("component") != "decoder":
continue

output_id = node.outputs[0]
output_value = graph.values.get(output_id)
if output_value is None or output_value.shape is None:
continue
index_shape = tuple(int(dim) for dim in output_value.shape[:-1])

match = _match_rope_angle_table_source(graph, node.inputs[0], index_shape)
if match is None:
continue
inv_freq_value_id, position_value_id = match

inv_freq = graph.constants[inv_freq_value_id].detach().cpu().to(torch.float64).reshape(-1)
head_dim = int(inv_freq.numel()) * 2
max_seq = _rope_table_max_seq(graph)
positions = torch.arange(max_seq, dtype=torch.float64).reshape(max_seq, 1)
freqs = positions * inv_freq.reshape(1, -1)
emb = torch.cat((freqs, freqs), dim=-1)
table = (torch.cos(emb) if node.op == "scalar_cos" else torch.sin(emb)).to(torch.float16)
assert table.shape[-1] == head_dim

table_value_id = _materialize_rope_table_constant(graph, node.id, table)
output_value.dtype = "fp16"
node.op = "embedding"
node.inputs = [table_value_id, position_value_id]
node.attrs = {}
node.kind = "generic"
changed = True

if changed:
rebuild_graph(graph)
return changed


def _match_rope_angle_table_source(graph: IRGraph, value_id: str, index_shape: tuple[int, ...]) -> tuple[str, str] | None:
cat_node = producer(graph, strip_passthrough(graph, value_id))
if cat_node is None or cat_node.op != "cat" or len(cat_node.inputs) != 2:
return None
if strip_layout_passthrough(graph, cat_node.inputs[0]) != strip_layout_passthrough(graph, cat_node.inputs[1]):
return None
angle_node = producer(graph, strip_layout_passthrough(graph, cat_node.inputs[0]))
if angle_node is not None and angle_node.op == "permute":
angle_node = producer(graph, strip_layout_passthrough(graph, angle_node.inputs[0]))
if angle_node is None or angle_node.op != "matmul" or len(angle_node.inputs) != 2:
return None

inv_freq_value_id: str | None = None
position_value_id: str | None = None
for input_id in angle_node.inputs:
const_id = _find_constant_ancestor(graph, input_id)
if const_id is not None:
inv_freq_value_id = const_id
continue
position_value_id = _trace_integer_position_source(graph, input_id, index_shape)

if inv_freq_value_id is None or position_value_id is None:
return None
return inv_freq_value_id, position_value_id


def _trace_integer_position_source(graph: IRGraph, value_id: str, index_shape: tuple[int, ...]) -> str | None:
current = value_id
visited: set[str] = set()
while current not in visited:
visited.add(current)
value = graph.values.get(current)
if (
value is not None
and str(value.dtype) in {"int32", "int64", "i32", "i64"}
and value.shape is not None
and tuple(int(dim) for dim in value.shape) == index_shape
):
return current
node = producer(graph, current)
if node is None or node.op not in {"precision_cast", "view", "reshape", "unsqueeze", "squeeze", "contiguous"} or len(node.inputs) != 1:
return None
current = node.inputs[0]
return None


def _rope_table_max_seq(graph: IRGraph) -> int:
max_seq = _coerce_optional_int(graph.meta.get("max_cache_seq_len"))
if max_seq is None or max_seq <= 0:
max_seq = _coerce_optional_int(graph.meta.get("max_position_embeddings"))
if max_seq is None or max_seq <= 0:
raise ValueError("rope table precompute requires max_cache_seq_len or max_position_embeddings in graph.meta")
return int(max_seq)


def _materialize_rope_table_constant(graph: IRGraph, node_id: str, table: torch.Tensor) -> str:
value_id = f"c_rope_table_{node_id}"
graph.constants[value_id] = table
graph.values[value_id] = IRValue(
id=value_id,
shape=tuple(int(dim) for dim in table.shape),
dtype=dtype_to_ir(table.dtype),
producer=None,
users=[],
)
return value_id


def fuse_attention(graph: IRGraph) -> bool:
changed = False
for node_id in list(graph.order):
Expand Down
Loading
Loading