Skip to content

Load Weights Once #1015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/models/falcon/test_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _load_inputs(self):

@pytest.mark.parametrize(
"mode",
["eval"],
[pytest.param("eval", marks=pytest.mark.compilation_xfail(reason="Before Merge fails, Run Tests passes"))],
)
def test_falcon(record_property, mode):
model_name = "Falcon"
Expand Down
2 changes: 2 additions & 0 deletions torch_ttnn/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def aten_backend(

# Run analysis passes to help with ttnn ops
from torch.fx.passes.infra.pass_manager import PassManager
from torch_ttnn.passes.analysis.graph_module_analysis_pass import GraphModuleAnalysisPass
from torch_ttnn.passes.analysis.input_analysis_pass import InputAnalysisPass
from torch_ttnn.passes.analysis.multi_device_shard_analysis_pass import MultiDeviceShardAnalysisPass

Expand All @@ -157,6 +158,7 @@ def aten_backend(
from torch_ttnn.passes.deallocation_pass import DeallocationPass

passes = [
GraphModuleAnalysisPass(),
InputAnalysisPass(option._n_parameters, option._n_buffers, option._n_arguments),
MultiDeviceShardAnalysisPass(option.device),
ConstantFoldingPass(),
Expand Down
80 changes: 80 additions & 0 deletions torch_ttnn/passes/analysis/graph_module_analysis_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import torch
from torch.fx.passes.infra.pass_base import PassBase, PassResult
from enum import Enum


class ModelType(Enum):
"""Enumeration of model types differentiating between inference and training, forward and backward.

:param INFERENCE: Model with this tag is for the forward pass of inference.
:param TRAIN_FORWARD: Model wih this tag is for the forward pass of a training run.
:param TRAIN_BACKWARD: Model with this tag is for the backward pass of a training run.
"""

INFERENCE = 1
TRAIN_FORWARD = 2
TRAIN_BACKWARD = 3


# this list seems small, but is based off all the backward calls from the Core Aten IR:
# https://docs.pytorch.org/docs/stable/torch.compiler_ir.html
aten_backward_ops = {
torch.ops.aten._adaptive_avg_pool2d_backward.default,
torch.ops.aten.avg_pool2d_backward.default,
torch.ops.aten.convolution_backward.default,
torch.ops.aten.embedding_dense_backward.default,
torch.ops.aten.max_pool2d_with_indices_backward.default,
torch.ops.aten.native_group_norm_backward.default,
torch.ops.aten.native_layer_norm_backward.default,
}
Comment on lines +24 to +32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this list will always determine that graph is train backward? Feels too small)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it seems too small, but it's based on the list of backward ops from the aten core IR. I will add a comment with a link



def is_train_backward(gm):
node_list = list(gm.graph.nodes)
for node in node_list:
if node.op == "call_function" and node.target in aten_backward_ops:
return True

return False


def is_train_forward(gm):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this documented in somewhere in torch docs that train forward will return inputs? If no, this might change in the future

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it's documented anywhere, but inputs are used in backward pass, so this should be reliable. Can add a comment that notes this assumption though

# Assume training forward calls are differentiated by directly returning one of the inputs to be used for calculating gradients
# If this assumption fails in the future, we will need to update this function
outputs = [node for node in gm.graph.nodes if node.op == "output"]
for node in outputs:
placeholder_args = [arg for arg in node.args[0] if arg.op == "placeholder"]
if len(placeholder_args) > 0:
return True

return False


class GraphModuleAnalysisPass(PassBase):
def __init__(self):
super().__init__()

def call(self, gm: torch.fx.GraphModule):
"""Marks the GraphModule as either training forward, training backward, or inference (forward).

This relies on commonalities between training forward, backward, and inference graphs. Namely, backward passes call backward versions of the forward functions to calculate gradients. Training forward passes return inputs unchanged. Inference forward functions do neither of these. It would be cleaner if we could just use something like `torch.is_grad_enabled()` or `gm.training` instead, but these appear to be inaccurate by the time the GraphModule is passed to our backend.

:param gm: Graph module for the function being compiled.
:return: Pass result with the updated graph module with metadata indicating the type of graph being compiled.
:rtype: PassResult[torch.fx.GraphModule, bool]
"""

modified = False

# check nodes for backward function call
if is_train_backward(gm):
gm.meta["graph_type"] = ModelType.TRAIN_BACKWARD
elif is_train_forward(gm):
gm.meta["graph_type"] = ModelType.TRAIN_FORWARD
else:
gm.meta["graph_type"] = ModelType.INFERENCE

return PassResult(gm, modified)
3 changes: 3 additions & 0 deletions torch_ttnn/passes/deallocation_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def call(self, gm: torch.fx.GraphModule):
# We don't want to delete these too early
if node.target in [target_wrappers.pack_to_tuple, operator.getitem]:
continue
# Skip nodes that are cached between runs
elif n.meta.get("is_cached", False):
continue
with graph.inserting_after(node):
new_node = graph.call_function(deallocate, args=(n,))
modified = True
Expand Down
145 changes: 132 additions & 13 deletions torch_ttnn/passes/lowering/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import torch
import ttnn
from torch_ttnn.passes.analysis.graph_module_analysis_pass import ModelType
from torch_ttnn.passes.analysis.input_analysis_pass import PrimalTag
from torch_ttnn.utils import (
GraphCleanup,
Expand Down Expand Up @@ -269,7 +270,12 @@ class NodeInputAligner:
def __init__(self, graph, device):
self.graph = graph
self.device = device
# aligned_node_dict maps DataMoveSpec to aligned version of the node to prevent calling the same data movement twice
self.aligned_node_dict = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please leave a comment on what this member is about?

# marshaled_node_dict maps DataMoveSpec to index in the load_weights function that runs once. This is consumed once when adding data movement, and the DataMoveSpec will
# then be populated in the aligned_node_dict for further usage
self.marshaled_node_dict = {}
self.input_idx = 0

# fields for data parallel
self.shard_to_mesh = dict()
Expand Down Expand Up @@ -417,7 +423,7 @@ def _change_layout(self, spec):

return input_node

def _create_aligned_node(self, spec):
def _extract_args_kwargs_from_spec(self, spec):
if isinstance(spec, self.AlignSpecFromTorch):
kwargs = {}
args = (spec.input_node,)
Expand Down Expand Up @@ -450,7 +456,7 @@ def _create_aligned_node(self, spec):
args = spec.input_node.args
kwargs["mesh_mapper"] = mesh_mapper

return self.graph.call_function(ttnn.from_torch, args, kwargs)
return args, kwargs

elif isinstance(spec, self.AlignSpecToTorch):
kwargs = {"dtype": spec.dtype}
Expand All @@ -470,6 +476,21 @@ def _create_aligned_node(self, spec):
kwargs["mesh_composer"] = composer
args = (actual_inp_node,)

return args, kwargs

elif isinstance(spec, self.AlignSpecInTtnn):
return (), {}

else:
raise RuntimeError(f"Cannot create aligned node for unknown spec ({spec})")

def _create_aligned_node(self, spec):
args, kwargs = self._extract_args_kwargs_from_spec(spec)

if isinstance(spec, self.AlignSpecFromTorch):
return self.graph.call_function(ttnn.from_torch, args, kwargs)

elif isinstance(spec, self.AlignSpecToTorch):
return self.graph.call_function(ttnn.to_torch, args, kwargs)

elif isinstance(spec, self.AlignSpecInTtnn):
Expand Down Expand Up @@ -500,7 +521,38 @@ def _connect_aligned_node(self, node, aligned_node, input_site, input_site_type:
new_arg[tuple_idx] = aligned_node
node.update_kwarg(key, tuple(new_arg))

def align(self, node, input_node, input_site, input_site_type: InputSiteType, first_node):
def marshal_params(self, node, input_node, input_site, input_site_type, first_node):
if not isinstance(input_node, torch.fx.node.Node):
return 0

# examine first arg for multi device case
check_constant = input_node
if input_node.op == "call_function" and input_node.target in [
target_wrappers.shard_tensor,
target_wrappers.replicate_tensor,
]:
check_constant = input_node.args[0]
is_constant_for_inference = (check_constant.op == "placeholder") and (
check_constant.meta.get("primal_tag") in [PrimalTag.PARAMETER, PrimalTag.BUFFER]
)
if not is_constant_for_inference:
return 0

data_move_spec = self._get_align_spec(node, input_node, input_site, input_site_type)
if not isinstance(data_move_spec, self.AlignSpecFromTorch):
# No need to align input_node
return 0

if data_move_spec in self.marshaled_node_dict:
# already handled
return 0

self.marshaled_node_dict[data_move_spec] = self.input_idx
self.input_idx += 1

return 1

def align(self, node, input_node, input_site, input_site_type: InputSiteType, first_node, ttnn_inputs):
# assert input_site_type in ["args", "kwargs", "args_tuple", "kwargs_tuple"]
data_move_spec = self._get_align_spec(node, input_node, input_site, input_site_type)
if data_move_spec is None:
Expand All @@ -509,11 +561,26 @@ def align(self, node, input_node, input_site, input_site_type: InputSiteType, fi

if data_move_spec in self.aligned_node_dict:
aligned_node = self.aligned_node_dict[data_move_spec]
elif ttnn_inputs is not None and (input_idx := self.marshaled_node_dict.get(data_move_spec)) is not None:
with self.graph.inserting_before(node):
aligned_node = self.graph.call_function(getitem, (ttnn_inputs, input_idx))
# mark node cached so it isn't deallocated by DeallocationPass
aligned_node.meta["is_cached"] = True
# update aligned_node_dict
self.aligned_node_dict[data_move_spec] = aligned_node
else:
# push from_torch calls to top of forward function if they are due to a placeholder
maybe_forward_input = input_node
# We have to test the first arg of shard_tensor and replicate_tensor calls instead
if input_node.op == "call_function" and input_node.target in [
target_wrappers.shard_tensor,
target_wrappers.replicate_tensor,
]:
maybe_forward_input = input_node.args[0]
if (
isinstance(data_move_spec, self.AlignSpecFromTorch)
and input_node.op == "placeholder"
and input_node.meta.get("primal_tag") != PrimalTag.ARGUMENT
and maybe_forward_input.op == "placeholder"
and maybe_forward_input.meta.get("primal_tag") != PrimalTag.ARGUMENT
):
# This will push all from_torch calls to the top of the forward function. This shouldn't impact performance, but it may impact memory usage since variables will be
# live longer than they would if from_torch calls occurred right before usage. If we start running out of DRAM or need to be more careful about memory usage, this
Expand All @@ -530,6 +597,46 @@ def align(self, node, input_node, input_site, input_site_type: InputSiteType, fi
return 1


def insert_load_params_once(gm, first_node, nodes, node_input_aligner):
SiteType = NodeInputAligner.InputSiteType
modifications_count = 0

for node in nodes:
args = node.args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know args and kwargs are iterated differently, but the rest of the code seems like it could be combined. Let me know if it's more difficult than it looks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure what change you're asking for here. Can you clarify?

for idx, arg in enumerate(args):
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)):
for tuple_idx, tuple_arg in enumerate(arg):
modifications_count += node_input_aligner.marshal_params(
node, tuple_arg, [idx, tuple_idx], SiteType.ARGS_TUPLE, first_node
)
else:
modifications_count += node_input_aligner.marshal_params(node, arg, idx, SiteType.ARGS, first_node)

kwargs = node.kwargs
for key, arg in kwargs.items():
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)):
for tuple_idx, tuple_arg in enumerate(arg):
modifications_count += node_input_aligner.marshal_params(
node, tuple_arg, [key, tuple_idx], SiteType.KWARGS_TUPLE, first_node
)
else:
modifications_count += node_input_aligner.marshal_params(node, arg, key, SiteType.KWARGS, first_node)

# reset run_once_count so recompilation triggers loading weights
with gm.graph.inserting_before(first_node):
ttnn_inputs = gm.graph.call_function(
target_wrappers.run_once,
tuple(
[
node_input_aligner._extract_args_kwargs_from_spec(spec)
for spec in node_input_aligner.marshaled_node_dict.keys()
]
),
)

return modifications_count, ttnn_inputs


class AddDataMovePass(PassBase):
"""Pass that adds instructions to move data between host and device and align tensor dtype and layout.

Expand All @@ -544,33 +651,45 @@ def __init__(self, device):
def call(self, gm: torch.fx.GraphModule):
SiteType = NodeInputAligner.InputSiteType

i = 0
modifications_count = 0
node_input_aligner = NodeInputAligner(gm.graph, self.device)
nodes = list(gm.graph.nodes)

first_node = [node for node in nodes if node.op != "placeholder"][0]

# first load weights
ttnn_inputs = None
if gm.meta.get("graph_type") == ModelType.INFERENCE:
global run_once_count
target_wrappers.run_once_count = 0
modifications_count, ttnn_inputs = insert_load_params_once(gm, first_node, nodes, node_input_aligner)

# then handle rest of the args and kwargs
for node in nodes:
args = node.args
for idx, arg in enumerate(args):
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)):
for tuple_idx, tuple_arg in enumerate(arg):
i += node_input_aligner.align(
node, tuple_arg, [idx, tuple_idx], SiteType.ARGS_TUPLE, first_node
modifications_count += node_input_aligner.align(
node, tuple_arg, [idx, tuple_idx], SiteType.ARGS_TUPLE, first_node, ttnn_inputs
)
else:
i += node_input_aligner.align(node, arg, idx, SiteType.ARGS, first_node)
modifications_count += node_input_aligner.align(
node, arg, idx, SiteType.ARGS, first_node, ttnn_inputs
)

kwargs = node.kwargs
for key, arg in kwargs.items():
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)):
for tuple_idx, tuple_arg in enumerate(arg):
i += node_input_aligner.align(
node, tuple_arg, [key, tuple_idx], SiteType.KWARGS_TUPLE, first_node
modifications_count += node_input_aligner.align(
node, tuple_arg, [key, tuple_idx], SiteType.KWARGS_TUPLE, first_node, ttnn_inputs
)
else:
i += node_input_aligner.align(node, arg, key, SiteType.KWARGS, first_node)
modifications_count += node_input_aligner.align(
node, arg, key, SiteType.KWARGS, first_node, ttnn_inputs
)

modified = i > 0
modified = modifications_count > 0
GraphCleanup(gm)
return PassResult(gm, modified)
19 changes: 19 additions & 0 deletions torch_ttnn/passes/lowering/target_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,25 @@

from torch_ttnn.utils import TtnnDevice

run_once_count = 0
run_once_ans = tuple()


@torch.fx.wrap
def run_once(*args):
Comment on lines +9 to +14
Copy link
Contributor

@philei-tt philei-tt May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) using closure or @lru_cache might be a bit better than global variable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the lru_cache here, but it hits an error with the kwargs inputs being unhashable since it's a dict. The closure path seems less readable to me in this case. I agree that the global variable isn't the best, but I'm leaning towards keeping it unless you feel strongly

global run_once_count
global run_once_ans

if run_once_count == 0:

def convert_input(spec):
return ttnn.from_torch(*spec[0], **spec[1])

run_once_ans = tuple([convert_input(arg) for arg in args])
run_once_count += 1

return run_once_ans


@torch.fx.wrap
def clone(t):
Expand Down
6 changes: 6 additions & 0 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ def __init__(self, module, device, use_less_ttnn_op_types):
self.device = device
self.use_less_ttnn_op_types = use_less_ttnn_op_types

def transform(self):
old_meta = self.module.meta
result = super().transform()
result.meta = old_meta
return result

def get_attr(self, target, args, kwargs):
# Restore original metadata for get_attr nodes
proxy = super().get_attr(target, args, kwargs)
Expand Down
Loading
Loading