-
Notifications
You must be signed in to change notification settings - Fork 15
Load Weights Once #1015
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Load Weights Once #1015
Changes from all commits
9880665
8ed2ef7
7c07c0d
5ba66f5
7480673
f3bbcc4
02dce32
4eeab9f
2d57080
f3708f2
cb48ec6
a966ea0
de6f1ee
5d711fc
d7f40d3
dd01d91
9e341a8
9cabe65
f7ff7f2
5cfdd8a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import torch | ||
from torch.fx.passes.infra.pass_base import PassBase, PassResult | ||
from enum import Enum | ||
|
||
|
||
class ModelType(Enum): | ||
"""Enumeration of model types differentiating between inference and training, forward and backward. | ||
|
||
:param INFERENCE: Model with this tag is for the forward pass of inference. | ||
:param TRAIN_FORWARD: Model wih this tag is for the forward pass of a training run. | ||
:param TRAIN_BACKWARD: Model with this tag is for the backward pass of a training run. | ||
""" | ||
|
||
INFERENCE = 1 | ||
TRAIN_FORWARD = 2 | ||
TRAIN_BACKWARD = 3 | ||
|
||
|
||
# this list seems small, but is based off all the backward calls from the Core Aten IR: | ||
# https://docs.pytorch.org/docs/stable/torch.compiler_ir.html | ||
aten_backward_ops = { | ||
torch.ops.aten._adaptive_avg_pool2d_backward.default, | ||
torch.ops.aten.avg_pool2d_backward.default, | ||
torch.ops.aten.convolution_backward.default, | ||
torch.ops.aten.embedding_dense_backward.default, | ||
torch.ops.aten.max_pool2d_with_indices_backward.default, | ||
torch.ops.aten.native_group_norm_backward.default, | ||
torch.ops.aten.native_layer_norm_backward.default, | ||
} | ||
|
||
|
||
def is_train_backward(gm): | ||
node_list = list(gm.graph.nodes) | ||
for node in node_list: | ||
if node.op == "call_function" and node.target in aten_backward_ops: | ||
return True | ||
|
||
return False | ||
|
||
|
||
def is_train_forward(gm): | ||
philei-tt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Assume training forward calls are differentiated by directly returning one of the inputs to be used for calculating gradients | ||
# If this assumption fails in the future, we will need to update this function | ||
outputs = [node for node in gm.graph.nodes if node.op == "output"] | ||
for node in outputs: | ||
placeholder_args = [arg for arg in node.args[0] if arg.op == "placeholder"] | ||
if len(placeholder_args) > 0: | ||
return True | ||
|
||
return False | ||
|
||
|
||
class GraphModuleAnalysisPass(PassBase): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def call(self, gm: torch.fx.GraphModule): | ||
"""Marks the GraphModule as either training forward, training backward, or inference (forward). | ||
|
||
This relies on commonalities between training forward, backward, and inference graphs. Namely, backward passes call backward versions of the forward functions to calculate gradients. Training forward passes return inputs unchanged. Inference forward functions do neither of these. It would be cleaner if we could just use something like `torch.is_grad_enabled()` or `gm.training` instead, but these appear to be inaccurate by the time the GraphModule is passed to our backend. | ||
|
||
:param gm: Graph module for the function being compiled. | ||
:return: Pass result with the updated graph module with metadata indicating the type of graph being compiled. | ||
:rtype: PassResult[torch.fx.GraphModule, bool] | ||
""" | ||
|
||
modified = False | ||
|
||
# check nodes for backward function call | ||
if is_train_backward(gm): | ||
gm.meta["graph_type"] = ModelType.TRAIN_BACKWARD | ||
elif is_train_forward(gm): | ||
gm.meta["graph_type"] = ModelType.TRAIN_FORWARD | ||
else: | ||
gm.meta["graph_type"] = ModelType.INFERENCE | ||
|
||
return PassResult(gm, modified) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
import logging | ||
import torch | ||
import ttnn | ||
from torch_ttnn.passes.analysis.graph_module_analysis_pass import ModelType | ||
from torch_ttnn.passes.analysis.input_analysis_pass import PrimalTag | ||
from torch_ttnn.utils import ( | ||
GraphCleanup, | ||
|
@@ -269,7 +270,12 @@ class NodeInputAligner: | |
def __init__(self, graph, device): | ||
self.graph = graph | ||
self.device = device | ||
# aligned_node_dict maps DataMoveSpec to aligned version of the node to prevent calling the same data movement twice | ||
self.aligned_node_dict = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please leave a comment on what this member is about? |
||
# marshaled_node_dict maps DataMoveSpec to index in the load_weights function that runs once. This is consumed once when adding data movement, and the DataMoveSpec will | ||
# then be populated in the aligned_node_dict for further usage | ||
self.marshaled_node_dict = {} | ||
self.input_idx = 0 | ||
|
||
# fields for data parallel | ||
self.shard_to_mesh = dict() | ||
|
@@ -417,7 +423,7 @@ def _change_layout(self, spec): | |
|
||
return input_node | ||
|
||
def _create_aligned_node(self, spec): | ||
def _extract_args_kwargs_from_spec(self, spec): | ||
if isinstance(spec, self.AlignSpecFromTorch): | ||
kwargs = {} | ||
args = (spec.input_node,) | ||
|
@@ -450,7 +456,7 @@ def _create_aligned_node(self, spec): | |
args = spec.input_node.args | ||
kwargs["mesh_mapper"] = mesh_mapper | ||
|
||
return self.graph.call_function(ttnn.from_torch, args, kwargs) | ||
return args, kwargs | ||
|
||
elif isinstance(spec, self.AlignSpecToTorch): | ||
kwargs = {"dtype": spec.dtype} | ||
|
@@ -470,6 +476,21 @@ def _create_aligned_node(self, spec): | |
kwargs["mesh_composer"] = composer | ||
args = (actual_inp_node,) | ||
|
||
return args, kwargs | ||
|
||
elif isinstance(spec, self.AlignSpecInTtnn): | ||
return (), {} | ||
|
||
else: | ||
raise RuntimeError(f"Cannot create aligned node for unknown spec ({spec})") | ||
|
||
def _create_aligned_node(self, spec): | ||
args, kwargs = self._extract_args_kwargs_from_spec(spec) | ||
|
||
if isinstance(spec, self.AlignSpecFromTorch): | ||
return self.graph.call_function(ttnn.from_torch, args, kwargs) | ||
|
||
elif isinstance(spec, self.AlignSpecToTorch): | ||
return self.graph.call_function(ttnn.to_torch, args, kwargs) | ||
|
||
elif isinstance(spec, self.AlignSpecInTtnn): | ||
|
@@ -500,7 +521,38 @@ def _connect_aligned_node(self, node, aligned_node, input_site, input_site_type: | |
new_arg[tuple_idx] = aligned_node | ||
node.update_kwarg(key, tuple(new_arg)) | ||
|
||
def align(self, node, input_node, input_site, input_site_type: InputSiteType, first_node): | ||
def marshal_params(self, node, input_node, input_site, input_site_type, first_node): | ||
if not isinstance(input_node, torch.fx.node.Node): | ||
return 0 | ||
|
||
# examine first arg for multi device case | ||
check_constant = input_node | ||
if input_node.op == "call_function" and input_node.target in [ | ||
target_wrappers.shard_tensor, | ||
target_wrappers.replicate_tensor, | ||
]: | ||
check_constant = input_node.args[0] | ||
is_constant_for_inference = (check_constant.op == "placeholder") and ( | ||
check_constant.meta.get("primal_tag") in [PrimalTag.PARAMETER, PrimalTag.BUFFER] | ||
) | ||
if not is_constant_for_inference: | ||
return 0 | ||
|
||
data_move_spec = self._get_align_spec(node, input_node, input_site, input_site_type) | ||
if not isinstance(data_move_spec, self.AlignSpecFromTorch): | ||
# No need to align input_node | ||
return 0 | ||
|
||
if data_move_spec in self.marshaled_node_dict: | ||
# already handled | ||
return 0 | ||
|
||
self.marshaled_node_dict[data_move_spec] = self.input_idx | ||
self.input_idx += 1 | ||
|
||
return 1 | ||
|
||
def align(self, node, input_node, input_site, input_site_type: InputSiteType, first_node, ttnn_inputs): | ||
# assert input_site_type in ["args", "kwargs", "args_tuple", "kwargs_tuple"] | ||
data_move_spec = self._get_align_spec(node, input_node, input_site, input_site_type) | ||
if data_move_spec is None: | ||
|
@@ -509,11 +561,26 @@ def align(self, node, input_node, input_site, input_site_type: InputSiteType, fi | |
|
||
if data_move_spec in self.aligned_node_dict: | ||
aligned_node = self.aligned_node_dict[data_move_spec] | ||
elif ttnn_inputs is not None and (input_idx := self.marshaled_node_dict.get(data_move_spec)) is not None: | ||
with self.graph.inserting_before(node): | ||
aligned_node = self.graph.call_function(getitem, (ttnn_inputs, input_idx)) | ||
# mark node cached so it isn't deallocated by DeallocationPass | ||
aligned_node.meta["is_cached"] = True | ||
# update aligned_node_dict | ||
self.aligned_node_dict[data_move_spec] = aligned_node | ||
else: | ||
# push from_torch calls to top of forward function if they are due to a placeholder | ||
maybe_forward_input = input_node | ||
# We have to test the first arg of shard_tensor and replicate_tensor calls instead | ||
if input_node.op == "call_function" and input_node.target in [ | ||
target_wrappers.shard_tensor, | ||
target_wrappers.replicate_tensor, | ||
]: | ||
maybe_forward_input = input_node.args[0] | ||
if ( | ||
isinstance(data_move_spec, self.AlignSpecFromTorch) | ||
and input_node.op == "placeholder" | ||
and input_node.meta.get("primal_tag") != PrimalTag.ARGUMENT | ||
and maybe_forward_input.op == "placeholder" | ||
and maybe_forward_input.meta.get("primal_tag") != PrimalTag.ARGUMENT | ||
): | ||
# This will push all from_torch calls to the top of the forward function. This shouldn't impact performance, but it may impact memory usage since variables will be | ||
# live longer than they would if from_torch calls occurred right before usage. If we start running out of DRAM or need to be more careful about memory usage, this | ||
|
@@ -530,6 +597,46 @@ def align(self, node, input_node, input_site, input_site_type: InputSiteType, fi | |
return 1 | ||
|
||
|
||
def insert_load_params_once(gm, first_node, nodes, node_input_aligner): | ||
SiteType = NodeInputAligner.InputSiteType | ||
modifications_count = 0 | ||
|
||
for node in nodes: | ||
args = node.args | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know args and kwargs are iterated differently, but the rest of the code seems like it could be combined. Let me know if it's more difficult than it looks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not entirely sure what change you're asking for here. Can you clarify? |
||
for idx, arg in enumerate(args): | ||
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)): | ||
for tuple_idx, tuple_arg in enumerate(arg): | ||
modifications_count += node_input_aligner.marshal_params( | ||
node, tuple_arg, [idx, tuple_idx], SiteType.ARGS_TUPLE, first_node | ||
) | ||
else: | ||
modifications_count += node_input_aligner.marshal_params(node, arg, idx, SiteType.ARGS, first_node) | ||
|
||
kwargs = node.kwargs | ||
for key, arg in kwargs.items(): | ||
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)): | ||
for tuple_idx, tuple_arg in enumerate(arg): | ||
modifications_count += node_input_aligner.marshal_params( | ||
node, tuple_arg, [key, tuple_idx], SiteType.KWARGS_TUPLE, first_node | ||
) | ||
else: | ||
modifications_count += node_input_aligner.marshal_params(node, arg, key, SiteType.KWARGS, first_node) | ||
|
||
# reset run_once_count so recompilation triggers loading weights | ||
with gm.graph.inserting_before(first_node): | ||
ttnn_inputs = gm.graph.call_function( | ||
target_wrappers.run_once, | ||
tuple( | ||
[ | ||
node_input_aligner._extract_args_kwargs_from_spec(spec) | ||
for spec in node_input_aligner.marshaled_node_dict.keys() | ||
] | ||
), | ||
) | ||
|
||
return modifications_count, ttnn_inputs | ||
|
||
|
||
class AddDataMovePass(PassBase): | ||
"""Pass that adds instructions to move data between host and device and align tensor dtype and layout. | ||
|
||
|
@@ -538,39 +645,52 @@ class AddDataMovePass(PassBase): | |
:param device: The device on which a workload will run (either a MeshDevice or Device). | ||
""" | ||
|
||
def __init__(self, device): | ||
def __init__(self, device, is_end_to_end): | ||
self.device = device | ||
self.is_end_to_end = is_end_to_end | ||
|
||
def call(self, gm: torch.fx.GraphModule): | ||
SiteType = NodeInputAligner.InputSiteType | ||
|
||
i = 0 | ||
modifications_count = 0 | ||
node_input_aligner = NodeInputAligner(gm.graph, self.device) | ||
nodes = list(gm.graph.nodes) | ||
|
||
first_node = [node for node in nodes if node.op != "placeholder"][0] | ||
|
||
# first load weights | ||
ttnn_inputs = None | ||
if gm.meta.get("graph_type") == ModelType.INFERENCE and self.is_end_to_end: | ||
global run_once_count | ||
target_wrappers.run_once_count = 0 | ||
modifications_count, ttnn_inputs = insert_load_params_once(gm, first_node, nodes, node_input_aligner) | ||
|
||
# then handle rest of the args and kwargs | ||
for node in nodes: | ||
args = node.args | ||
for idx, arg in enumerate(args): | ||
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)): | ||
for tuple_idx, tuple_arg in enumerate(arg): | ||
i += node_input_aligner.align( | ||
node, tuple_arg, [idx, tuple_idx], SiteType.ARGS_TUPLE, first_node | ||
modifications_count += node_input_aligner.align( | ||
node, tuple_arg, [idx, tuple_idx], SiteType.ARGS_TUPLE, first_node, ttnn_inputs | ||
) | ||
else: | ||
i += node_input_aligner.align(node, arg, idx, SiteType.ARGS, first_node) | ||
modifications_count += node_input_aligner.align( | ||
node, arg, idx, SiteType.ARGS, first_node, ttnn_inputs | ||
) | ||
|
||
kwargs = node.kwargs | ||
for key, arg in kwargs.items(): | ||
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)): | ||
for tuple_idx, tuple_arg in enumerate(arg): | ||
i += node_input_aligner.align( | ||
node, tuple_arg, [key, tuple_idx], SiteType.KWARGS_TUPLE, first_node | ||
modifications_count += node_input_aligner.align( | ||
node, tuple_arg, [key, tuple_idx], SiteType.KWARGS_TUPLE, first_node, ttnn_inputs | ||
) | ||
else: | ||
i += node_input_aligner.align(node, arg, key, SiteType.KWARGS, first_node) | ||
modifications_count += node_input_aligner.align( | ||
node, arg, key, SiteType.KWARGS, first_node, ttnn_inputs | ||
) | ||
|
||
modified = i > 0 | ||
modified = modifications_count > 0 | ||
GraphCleanup(gm) | ||
return PassResult(gm, modified) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure this list will always determine that graph is train backward? Feels too small)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that it seems too small, but it's based on the list of backward ops from the aten core IR. I will add a comment with a link
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be honest, I'm not convinced that all models will have aten._backward. calls in backward pass. Here is an example that produces forward+backward aten IR for small linear model, that doesn't have any functions from the list:
Forward + backward generated python code:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, looks like that graph contains
torch.ops.aten.threshold_backward.default
, which isn't in the list. I will go through and check the models that we support and make sure each of the backward passes would be caught in practice