From c0df8c2ed078e3d980eb0b43d820dc8af102b452 Mon Sep 17 00:00:00 2001 From: Viraj Shah Date: Sat, 28 Jun 2025 13:58:42 +0530 Subject: [PATCH] Refactor `BasicBlockGraphBuilder::AddBasicBlockFromInstructions`. * Factor out the per-instruction add to graph logic. * This change is in preparation for further changes to the graph construction logic. * Specifically, this should make the trace graph construction logic cleaner. --- gematria/granite/graph_builder.cc | 104 +++++++++++++---------- gematria/granite/graph_builder.h | 4 + gematria/granite/python/graph_builder.cc | 2 - 3 files changed, 61 insertions(+), 49 deletions(-) diff --git a/gematria/granite/graph_builder.cc b/gematria/granite/graph_builder.cc index 30f6c4ab..9d2f7e09 100644 --- a/gematria/granite/graph_builder.cc +++ b/gematria/granite/graph_builder.cc @@ -189,6 +189,61 @@ BasicBlockGraphBuilder::BasicBlockGraphBuilder( } } +BasicBlockGraphBuilder::NodeIndex BasicBlockGraphBuilder::AddInstruction( + const Instruction& instruction, NodeIndex previous_instruction_node) { + // Add the instruction node. + const NodeIndex instruction_node = + AddNode(NodeType::kInstruction, instruction.mnemonic); + if (instruction_node == kInvalidNode) { + return kInvalidNode; + } + + // Store the annotations for later use (inclusion in embeddings), using -1 + // as a default value wherever annotations are missing. + std::vector row = std::vector(annotation_names_.size(), -1); + for (const auto& [name, value] : instruction.instruction_annotations) { + const auto annotation_index = annotation_name_to_idx_.find(name); + if (annotation_index == annotation_name_to_idx_.end()) continue; + row[annotation_index->second] = value; + } + instruction_annotations_.push_back(row); + + // Add nodes for prefixes of the instruction. + for (const std::string& prefix : instruction.prefixes) { + const NodeIndex prefix_node = AddNode(NodeType::kPrefix, prefix); + if (prefix_node == kInvalidNode) { + return kInvalidNode; + } + AddEdge(EdgeType::kInstructionPrefix, prefix_node, instruction_node); + } + + // Add a structural dependency edge from the previous instruction. + if (previous_instruction_node >= 0) { + AddEdge(EdgeType::kStructuralDependency, previous_instruction_node, + instruction_node); + } + + // Add edges for input operands. And nodes too, if necessary. + for (const InstructionOperand& operand : instruction.input_operands) { + if (!AddInputOperand(instruction_node, operand)) return kInvalidNode; + } + for (const InstructionOperand& operand : + instruction.implicit_input_operands) { + if (!AddInputOperand(instruction_node, operand)) return kInvalidNode; + } + + // Add edges and nodes for output operands. + for (const InstructionOperand& operand : instruction.output_operands) { + if (!AddOutputOperand(instruction_node, operand)) return kInvalidNode; + } + for (const InstructionOperand& operand : + instruction.implicit_output_operands) { + if (!AddOutputOperand(instruction_node, operand)) return kInvalidNode; + } + + return instruction_node; +} + bool BasicBlockGraphBuilder::AddBasicBlockFromInstructions( const std::vector& instructions) { if (instructions.empty()) return false; @@ -203,56 +258,11 @@ bool BasicBlockGraphBuilder::AddBasicBlockFromInstructions( NodeIndex previous_instruction_node = kInvalidNode; for (const Instruction& instruction : instructions) { - // Add the instruction node. - const NodeIndex instruction_node = - AddNode(NodeType::kInstruction, instruction.mnemonic); + NodeIndex instruction_node = + AddInstruction(instruction, previous_instruction_node); if (instruction_node == kInvalidNode) { return false; } - - // Store the annotations for later use (inclusion in embeddings), using -1 - // as a default value wherever annotations are missing. - std::vector row = std::vector(annotation_names_.size(), -1); - for (const auto& [name, value] : instruction.instruction_annotations) { - const auto annotation_index = annotation_name_to_idx_.find(name); - if (annotation_index == annotation_name_to_idx_.end()) continue; - row[annotation_index->second] = value; - } - instruction_annotations_.push_back(row); - - // Add nodes for prefixes of the instruction. - for (const std::string& prefix : instruction.prefixes) { - const NodeIndex prefix_node = AddNode(NodeType::kPrefix, prefix); - if (prefix_node == kInvalidNode) { - return false; - } - AddEdge(EdgeType::kInstructionPrefix, prefix_node, instruction_node); - } - - // Add a structural dependency edge from the previous instruction. - if (previous_instruction_node >= 0) { - AddEdge(EdgeType::kStructuralDependency, previous_instruction_node, - instruction_node); - } - - // Add edges for input operands. And nodes too, if necessary. - for (const InstructionOperand& operand : instruction.input_operands) { - if (!AddInputOperand(instruction_node, operand)) return false; - } - for (const InstructionOperand& operand : - instruction.implicit_input_operands) { - if (!AddInputOperand(instruction_node, operand)) return false; - } - - // Add edges and nodes for output operands. - for (const InstructionOperand& operand : instruction.output_operands) { - if (!AddOutputOperand(instruction_node, operand)) return false; - } - for (const InstructionOperand& operand : - instruction.implicit_output_operands) { - if (!AddOutputOperand(instruction_node, operand)) return false; - } - previous_instruction_node = instruction_node; } diff --git a/gematria/granite/graph_builder.h b/gematria/granite/graph_builder.h index c4684eea..27eac6f8 100644 --- a/gematria/granite/graph_builder.h +++ b/gematria/granite/graph_builder.h @@ -360,6 +360,10 @@ class BasicBlockGraphBuilder { size_t prev_global_features_size_; }; + // Adds nodes and edges for a single instruction of a basic block. + NodeIndex AddInstruction(const Instruction& instruction, + NodeIndex previous_instruction_node); + // Adds nodes and edges for a single input operand of an instruction. bool AddInputOperand(NodeIndex instruction_node, const InstructionOperand& operand); diff --git a/gematria/granite/python/graph_builder.cc b/gematria/granite/python/graph_builder.cc index c238e2d3..6949cdc6 100644 --- a/gematria/granite/python/graph_builder.cc +++ b/gematria/granite/python/graph_builder.cc @@ -14,13 +14,11 @@ #include "gematria/granite/graph_builder.h" -#include #include #include #include "absl/strings/string_view.h" #include "gematria/model/oov_token_behavior.h" -#include "gematria/proto/canonicalized_instruction.pb.h" #include "pybind11/cast.h" #include "pybind11/detail/common.h" #include "pybind11/pybind11.h"