Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions forge/forge/transpiler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Forge Transpiler Package

A multi-frontend transpiler for converting ML framework models to Forge intermediate representation.
Supports ONNX, with PaddlePaddle and TensorFlow coming soon.
"""

# Configure logging
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Import operations to register them (must be imported first)
from .ir.operations import *

Check warning

Code scanning / flake8

'.ir.operations.*' imported but unused Warning

'.ir.operations.*' imported but unused

Check warning

Code scanning / flake8

'from .ir.operations import *' used; unable to detect undefined names Warning

'from .ir.operations import *' used; unable to detect undefined names

# Public API - IR (common across all frontends)
from .ir.types import TensorInfo, onnx_dtype_to_torch_dtype
from .ir.nodes import TIRNode

# Public API - Core
from .core.graph import TIRGraph

# Public API - Code Generation
from .codegen import generate_forge_module

# Public API - ONNX Frontend
from .frontends.onnx import ONNXToForgeTranspiler
from .frontends.onnx.converters import (
extract_attributes,
extract_attr_value,
AutoPad,
remove_initializers_from_input,
get_inputs_names,
get_outputs_names,
)
from .frontends.onnx.debug import debug_node_output, get_activation_value

# Public API - Common Utils
from .utils import (
is_constant,
is_symmetric_padding,
extract_padding_for_conv,
get_selection,
)

__all__ = [
# Types (IR)
'TensorInfo',
'onnx_dtype_to_torch_dtype',
# Nodes (IR)
'TIRNode',
# Graph (Core)
'TIRGraph',
# Code Generation
'generate_forge_module',
# ONNX Frontend
'ONNXToForgeTranspiler',
# ONNX Converters
'extract_attributes',
'extract_attr_value',
'AutoPad',
'remove_initializers_from_input',
'get_inputs_names',
'get_outputs_names',
# ONNX Debug
'debug_node_output',
'get_activation_value',
# Common Utils
'is_constant',
'is_symmetric_padding',
'extract_padding_for_conv',
'get_selection',
]
7 changes: 7 additions & 0 deletions forge/forge/transpiler/codegen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""
Code generation module for converting TIRGraph to Python code.
"""
from .generator import generate_forge_module

__all__ = ['generate_forge_module']

Check warning

Code scanning / flake8

blank line at end of file Warning

blank line at end of file
71 changes: 71 additions & 0 deletions forge/forge/transpiler/codegen/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
Code generation for converting TIRGraph to Python code.
Framework-agnostic - works for all frontends.
"""
from ..core.graph import TIRGraph


def generate_forge_module(graph: TIRGraph, class_name="GeneratedForgeModule") -> str:
"""
Generates a Python string for the Forge module by traversing the graph.
"""
lines = []
lines.append("import torch")
lines.append("import forge")
lines.append("")
lines.append(f"class {class_name}(forge.Module):")
lines.append(f" def __init__(self, name='{graph.name}'):")
lines.append(f" super().__init__(name=name)")

Check warning

Code scanning / flake8

f-string is missing placeholders Warning

f-string is missing placeholders

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
# 1. Add Parameters (Initializers)
for name, tensor in graph.initializers.items():
shape_str = str(tuple(tensor.shape))
lines.append(f" self.add_parameter('{name}', forge.Parameter(shape={shape_str}))")
lines.append("")

# 2. Forward Method
forward_args = [inp for inp in graph.inputs if inp not in graph.initializers]
args_str = ", ".join(forward_args)
lines.append(f" def forward(self, {args_str}):")

# 3. Operations
sorted_nodes = graph.get_topological_sort()
for node in sorted_nodes:
op_info = node.emit()

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
# Format Inputs
inputs_str = ", ".join(op_info['inputs'])

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
# Format Attributes
attrs = op_info.get('attrs', {})
attr_strs = []
for k, v in attrs.items():
if isinstance(v, str):
attr_strs.append(f"{k}='{v}'")
else:
attr_strs.append(f"{k}={v}")
attrs_str = ", ".join(attr_strs)

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
call_args = inputs_str
if attrs_str:
call_args += f", {attrs_str}"

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
# Format Outputs
outputs = op_info['outputs']
if len(outputs) == 1:
lhs = outputs[0]
else:
lhs = ", ".join(outputs)

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
lines.append(f" # {node.op_type} -> {op_info['op_name']}")
lines.append(f" {lhs} = {op_info['forge_func']}({call_args})")

# 4. Return
if len(graph.outputs) == 1:
lines.append(f" return {graph.outputs[0]}")
else:
out_str = ", ".join(graph.outputs)
lines.append(f" return {out_str}")

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
return "\n".join(lines)

Check warning

Code scanning / flake8

blank line at end of file Warning

blank line at end of file
9 changes: 9 additions & 0 deletions forge/forge/transpiler/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
Core transpiler functionality - framework-agnostic.
"""
from .graph import TIRGraph

__all__ = [
'TIRGraph',
]

Check warning

Code scanning / flake8

blank line at end of file Warning

blank line at end of file
185 changes: 185 additions & 0 deletions forge/forge/transpiler/core/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""
Graph representation for the transpiler intermediate representation.
Framework-agnostic - works for all frontends.
"""
import torch
import logging
from typing import Dict, List, Optional, Any
from collections import deque, defaultdict
from copy import deepcopy

from ..ir.nodes import TIRNode

logger = logging.getLogger("ForgeTranspiler")


class TIRGraph:
"""Represents a computational graph in Transpiler Intermediate Representation (TIR)."""
def __init__(self, name: str, frontend_model=None, debug_mode: bool = False):
self.name = name
self.nodes: List[TIRNode] = []
self.inputs: List[str] = []

Check warning

Code scanning / flake8

trailing whitespace Warning

trailing whitespace
self.outputs: List[str] = []
self.initializers: Dict[str, torch.Tensor] = {}

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
# Topology info
self.producer_map: Dict[str, str] = {}

Check warning

Code scanning / flake8

trailing whitespace Warning

trailing whitespace
self.consumer_map: Dict[str, List[str]] = {}

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
# Activation memory management
self.needed_by: Optional[Dict[str, set]] = None

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
# Debug mode
self.debug_mode: bool = debug_mode
self.frontend_model = frontend_model # Store original model for debug comparisons
self.node_proto_map: Dict[str, Any] = {} # Map node names to frontend node protos

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
def add_node(self, node: TIRNode):
"""Add a node to the graph and update topology maps."""
self.nodes.append(node)
for out_name in node.outputs:
self.producer_map[out_name] = node.name
for in_name in node.inputs:
if in_name not in self.consumer_map:
self.consumer_map[in_name] = []
self.consumer_map[in_name].append(node.name)

def get_node_by_name(self, name: str) -> Optional[TIRNode]:
"""Get a node by its name."""
for node in self.nodes:
if node.name == name:
return node
return None

def get_topological_sort(self) -> List[TIRNode]:
"""Get nodes in topological order for execution."""
in_degree = {node.name: 0 for node in self.nodes}

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
# Calculate in-degree based on internal graph dependencies
for node in self.nodes:
for input_name in node.inputs:
# Check if input is produced by another node (i.e., not a model input/initializer)
if input_name in self.producer_map and self.producer_map[input_name] != node.name:
in_degree[node.name] += 1

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
# Initialize queue with nodes that have zero in-degree
queue = deque([node for node in self.nodes if in_degree[node.name] == 0])
sorted_nodes = []

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
while queue:
node = queue.popleft()
sorted_nodes.append(node)

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
# Decrease in-degree of consumers
for output_name in node.outputs:
if output_name in self.consumer_map:
for consumer_name in self.consumer_map[output_name]:
if consumer_name in in_degree:
in_degree[consumer_name] -= 1
if in_degree[consumer_name] == 0:
consumer_node = self.get_node_by_name(consumer_name)
if consumer_node:
queue.append(consumer_node)
return sorted_nodes

def compute_activation_dependencies(self):
"""
Compute activation dependencies - which nodes need which activations.
Used for memory management (garbage collection of unused activations).
"""
needed_by = defaultdict (set)

Check failure

Code scanning / flake8

whitespace before '(' Error

whitespace before '('

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
for node in self.nodes:
out_op_id = node.outputs[0] if node.outputs else None
if out_op_id:
for in_op_id in node.inputs:
needed_by[in_op_id].add(out_op_id)

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
needed_by.default_factory = None
self.needed_by = dict(needed_by)
return self.needed_by

def run(self, inputs: Dict[str, torch.Tensor], enable_gc: bool = True) -> Dict[str, torch.Tensor]:
"""
Execute the graph with given inputs.

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
Args:
inputs: Input tensors dictionary
enable_gc: Enable activation garbage collection (memory optimization)
"""
logger.info(f"Executing Graph: {self.name}")
tensor_memory = {}
tensor_memory.update(self.initializers)
tensor_memory.update(inputs)

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
# Compute dependencies if not already computed and GC is enabled
if enable_gc and self.needed_by is None:
self.compute_activation_dependencies()

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
still_needed_by = deepcopy(self.needed_by) if enable_gc and self.needed_by else None

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
# Prepare inputs for debug mode (convert to numpy)
debug_inputs = None
if self.debug_mode and self.frontend_model is not None:
import numpy as np
debug_inputs = []
for input_name in self.inputs:
if input_name in inputs:
tensor = inputs[input_name]
if isinstance(tensor, torch.Tensor):
debug_inputs.append(tensor.detach().cpu().numpy())
else:
debug_inputs.append(np.array(tensor))

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
execution_plan = self.get_topological_sort()

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
for node in execution_plan:
node_inputs = {}
ready = True
for inp in node.inputs:
if inp in tensor_memory:
node_inputs[inp] = tensor_memory[inp]
else:
logger.error(f"Node {node.name} missing input: {inp}")
ready = False

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
if ready:
outputs = node.eval(node_inputs)
tensor_memory.update(outputs)

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
# Debug mode: compare outputs with frontend runtime (frontend-specific)
if self.debug_mode and self.frontend_model is not None and debug_inputs is not None:
frontend_node = self.node_proto_map.get(node.name)
if frontend_node:
# Import debug function from frontend
try:
from ..frontends.onnx.debug.validator import debug_node_output
debug_node_output(
self.frontend_model,
debug_inputs,
outputs,
frontend_node
)
except Exception as e:
logger.warning(f"Debug comparison failed for node {node.name}: {e}")

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
# Garbage collection: remove activations no longer needed
if enable_gc and still_needed_by is not None:
out_op_id = node.outputs[0] if node.outputs else None
if out_op_id:
for in_op_id in node.inputs:
if in_op_id in still_needed_by:
still_needed_by[in_op_id].discard(out_op_id)
if len(still_needed_by[in_op_id]) == 0:
if in_op_id in tensor_memory and in_op_id not in self.initializers:
del tensor_memory[in_op_id]
logger.debug(f"GC: Deleted activation {in_op_id}")

Check warning

Code scanning / flake8

blank line contains whitespace Warning

blank line contains whitespace
result = {}
for out_name in self.outputs:
if out_name in tensor_memory:
result[out_name] = tensor_memory[out_name]
else:
logger.error(f"Graph output {out_name} was not produced.")
return result

Check warning

Code scanning / flake8

blank line at end of file Warning

blank line at end of file
8 changes: 8 additions & 0 deletions forge/forge/transpiler/frontends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
Frontend modules for different ML frameworks.
Each frontend handles framework-specific model parsing and conversion.
"""
# Frontends are imported on-demand to avoid circular dependencies

__all__ = []

Check warning

Code scanning / flake8

blank line at end of file Warning

blank line at end of file
7 changes: 7 additions & 0 deletions forge/forge/transpiler/frontends/onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""
ONNX frontend for transpiler.
"""
from .engine import ONNXToForgeTranspiler

__all__ = ['ONNXToForgeTranspiler']

Check warning

Code scanning / flake8

blank line at end of file Warning

blank line at end of file
20 changes: 20 additions & 0 deletions forge/forge/transpiler/frontends/onnx/converters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
ONNX conversion utilities.
"""
from .attributes import extract_attributes, extract_attr_value
from .autopad import AutoPad
from .utils import (
remove_initializers_from_input,
get_inputs_names,
get_outputs_names,
)

__all__ = [
'extract_attributes',
'extract_attr_value',
'AutoPad',
'remove_initializers_from_input',
'get_inputs_names',
'get_outputs_names',
]

Check warning

Code scanning / flake8

blank line at end of file Warning

blank line at end of file
Loading
Loading