-
Notifications
You must be signed in to change notification settings - Fork 7
Load Weights Once #1015
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Load Weights Once #1015
Changes from all commits
9880665
8ed2ef7
7c07c0d
5ba66f5
7480673
f3bbcc4
02dce32
4eeab9f
2d57080
f3708f2
cb48ec6
a966ea0
de6f1ee
5d711fc
d7f40d3
dd01d91
9e341a8
9cabe65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import torch | ||
from torch.fx.passes.infra.pass_base import PassBase, PassResult | ||
from enum import Enum | ||
|
||
|
||
class ModelType(Enum): | ||
"""Enumeration of model types differentiating between inference and training, forward and backward. | ||
|
||
:param INFERENCE: Model with this tag is for the forward pass of inference. | ||
:param TRAIN_FORWARD: Model wih this tag is for the forward pass of a training run. | ||
:param TRAIN_BACKWARD: Model with this tag is for the backward pass of a training run. | ||
""" | ||
|
||
INFERENCE = 1 | ||
TRAIN_FORWARD = 2 | ||
TRAIN_BACKWARD = 3 | ||
|
||
|
||
# this list seems small, but is based off all the backward calls from the Core Aten IR: | ||
# https://docs.pytorch.org/docs/stable/torch.compiler_ir.html | ||
aten_backward_ops = { | ||
torch.ops.aten._adaptive_avg_pool2d_backward.default, | ||
torch.ops.aten.avg_pool2d_backward.default, | ||
torch.ops.aten.convolution_backward.default, | ||
torch.ops.aten.embedding_dense_backward.default, | ||
torch.ops.aten.max_pool2d_with_indices_backward.default, | ||
torch.ops.aten.native_group_norm_backward.default, | ||
torch.ops.aten.native_layer_norm_backward.default, | ||
} | ||
|
||
|
||
def is_train_backward(gm): | ||
node_list = list(gm.graph.nodes) | ||
for node in node_list: | ||
if node.op == "call_function" and node.target in aten_backward_ops: | ||
return True | ||
|
||
return False | ||
|
||
|
||
def is_train_forward(gm): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this documented in somewhere in torch docs that train forward will return inputs? If no, this might change in the future There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know if it's documented anywhere, but inputs are used in backward pass, so this should be reliable. Can add a comment that notes this assumption though |
||
# Assume training forward calls are differentiated by directly returning one of the inputs to be used for calculating gradients | ||
# If this assumption fails in the future, we will need to update this function | ||
outputs = [node for node in gm.graph.nodes if node.op == "output"] | ||
for node in outputs: | ||
placeholder_args = [arg for arg in node.args[0] if arg.op == "placeholder"] | ||
if len(placeholder_args) > 0: | ||
return True | ||
|
||
return False | ||
|
||
|
||
class GraphModuleAnalysisPass(PassBase): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def call(self, gm: torch.fx.GraphModule): | ||
"""Marks the GraphModule as either training forward, training backward, or inference (forward). | ||
|
||
This relies on commonalities between training forward, backward, and inference graphs. Namely, backward passes call backward versions of the forward functions to calculate gradients. Training forward passes return inputs unchanged. Inference forward functions do neither of these. It would be cleaner if we could just use something like `torch.is_grad_enabled()` or `gm.training` instead, but these appear to be inaccurate by the time the GraphModule is passed to our backend. | ||
|
||
:param gm: Graph module for the function being compiled. | ||
:return: Pass result with the updated graph module with metadata indicating the type of graph being compiled. | ||
:rtype: PassResult[torch.fx.GraphModule, bool] | ||
""" | ||
|
||
modified = False | ||
|
||
# check nodes for backward function call | ||
if is_train_backward(gm): | ||
gm.meta["graph_type"] = ModelType.TRAIN_BACKWARD | ||
elif is_train_forward(gm): | ||
gm.meta["graph_type"] = ModelType.TRAIN_FORWARD | ||
else: | ||
gm.meta["graph_type"] = ModelType.INFERENCE | ||
|
||
return PassResult(gm, modified) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
import logging | ||
import torch | ||
import ttnn | ||
from torch_ttnn.passes.analysis.graph_module_analysis_pass import ModelType | ||
from torch_ttnn.passes.analysis.input_analysis_pass import PrimalTag | ||
from torch_ttnn.utils import ( | ||
GraphCleanup, | ||
|
@@ -269,7 +270,12 @@ class NodeInputAligner: | |
def __init__(self, graph, device): | ||
self.graph = graph | ||
self.device = device | ||
# aligned_node_dict maps DataMoveSpec to aligned version of the node to prevent calling the same data movement twice | ||
self.aligned_node_dict = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please leave a comment on what this member is about? |
||
# marshaled_node_dict maps DataMoveSpec to index in the load_weights function that runs once. This is consumed once when adding data movement, and the DataMoveSpec will | ||
# then be populated in the aligned_node_dict for further usage | ||
self.marshaled_node_dict = {} | ||
self.input_idx = 0 | ||
|
||
# fields for data parallel | ||
self.shard_to_mesh = dict() | ||
|
@@ -417,7 +423,7 @@ def _change_layout(self, spec): | |
|
||
return input_node | ||
|
||
def _create_aligned_node(self, spec): | ||
def _extract_args_kwargs_from_spec(self, spec): | ||
if isinstance(spec, self.AlignSpecFromTorch): | ||
kwargs = {} | ||
args = (spec.input_node,) | ||
|
@@ -450,7 +456,7 @@ def _create_aligned_node(self, spec): | |
args = spec.input_node.args | ||
kwargs["mesh_mapper"] = mesh_mapper | ||
|
||
return self.graph.call_function(ttnn.from_torch, args, kwargs) | ||
return args, kwargs | ||
|
||
elif isinstance(spec, self.AlignSpecToTorch): | ||
kwargs = {"dtype": spec.dtype} | ||
|
@@ -470,6 +476,21 @@ def _create_aligned_node(self, spec): | |
kwargs["mesh_composer"] = composer | ||
args = (actual_inp_node,) | ||
|
||
return args, kwargs | ||
|
||
elif isinstance(spec, self.AlignSpecInTtnn): | ||
return (), {} | ||
|
||
else: | ||
raise RuntimeError(f"Cannot create aligned node for unknown spec ({spec})") | ||
|
||
def _create_aligned_node(self, spec): | ||
args, kwargs = self._extract_args_kwargs_from_spec(spec) | ||
|
||
if isinstance(spec, self.AlignSpecFromTorch): | ||
return self.graph.call_function(ttnn.from_torch, args, kwargs) | ||
|
||
elif isinstance(spec, self.AlignSpecToTorch): | ||
return self.graph.call_function(ttnn.to_torch, args, kwargs) | ||
|
||
elif isinstance(spec, self.AlignSpecInTtnn): | ||
|
@@ -500,7 +521,38 @@ def _connect_aligned_node(self, node, aligned_node, input_site, input_site_type: | |
new_arg[tuple_idx] = aligned_node | ||
node.update_kwarg(key, tuple(new_arg)) | ||
|
||
def align(self, node, input_node, input_site, input_site_type: InputSiteType, first_node): | ||
def marshal_params(self, node, input_node, input_site, input_site_type, first_node): | ||
if not isinstance(input_node, torch.fx.node.Node): | ||
return 0 | ||
|
||
# examine first arg for multi device case | ||
check_constant = input_node | ||
if input_node.op == "call_function" and input_node.target in [ | ||
target_wrappers.shard_tensor, | ||
target_wrappers.replicate_tensor, | ||
]: | ||
check_constant = input_node.args[0] | ||
is_constant_for_inference = (check_constant.op == "placeholder") and ( | ||
check_constant.meta.get("primal_tag") in [PrimalTag.PARAMETER, PrimalTag.BUFFER] | ||
) | ||
if not is_constant_for_inference: | ||
return 0 | ||
|
||
data_move_spec = self._get_align_spec(node, input_node, input_site, input_site_type) | ||
if not isinstance(data_move_spec, self.AlignSpecFromTorch): | ||
# No need to align input_node | ||
return 0 | ||
|
||
if data_move_spec in self.marshaled_node_dict: | ||
# already handled | ||
return 0 | ||
|
||
self.marshaled_node_dict[data_move_spec] = self.input_idx | ||
self.input_idx += 1 | ||
|
||
return 1 | ||
|
||
def align(self, node, input_node, input_site, input_site_type: InputSiteType, first_node, ttnn_inputs): | ||
# assert input_site_type in ["args", "kwargs", "args_tuple", "kwargs_tuple"] | ||
data_move_spec = self._get_align_spec(node, input_node, input_site, input_site_type) | ||
if data_move_spec is None: | ||
|
@@ -509,11 +561,26 @@ def align(self, node, input_node, input_site, input_site_type: InputSiteType, fi | |
|
||
if data_move_spec in self.aligned_node_dict: | ||
aligned_node = self.aligned_node_dict[data_move_spec] | ||
elif ttnn_inputs is not None and (input_idx := self.marshaled_node_dict.get(data_move_spec)) is not None: | ||
with self.graph.inserting_before(node): | ||
aligned_node = self.graph.call_function(getitem, (ttnn_inputs, input_idx)) | ||
# mark node cached so it isn't deallocated by DeallocationPass | ||
aligned_node.meta["is_cached"] = True | ||
# update aligned_node_dict | ||
self.aligned_node_dict[data_move_spec] = aligned_node | ||
else: | ||
# push from_torch calls to top of forward function if they are due to a placeholder | ||
maybe_forward_input = input_node | ||
# We have to test the first arg of shard_tensor and replicate_tensor calls instead | ||
if input_node.op == "call_function" and input_node.target in [ | ||
target_wrappers.shard_tensor, | ||
target_wrappers.replicate_tensor, | ||
]: | ||
maybe_forward_input = input_node.args[0] | ||
if ( | ||
isinstance(data_move_spec, self.AlignSpecFromTorch) | ||
and input_node.op == "placeholder" | ||
and input_node.meta.get("primal_tag") != PrimalTag.ARGUMENT | ||
and maybe_forward_input.op == "placeholder" | ||
and maybe_forward_input.meta.get("primal_tag") != PrimalTag.ARGUMENT | ||
): | ||
# This will push all from_torch calls to the top of the forward function. This shouldn't impact performance, but it may impact memory usage since variables will be | ||
# live longer than they would if from_torch calls occurred right before usage. If we start running out of DRAM or need to be more careful about memory usage, this | ||
|
@@ -530,6 +597,46 @@ def align(self, node, input_node, input_site, input_site_type: InputSiteType, fi | |
return 1 | ||
|
||
|
||
def insert_load_params_once(gm, first_node, nodes, node_input_aligner): | ||
SiteType = NodeInputAligner.InputSiteType | ||
modifications_count = 0 | ||
|
||
for node in nodes: | ||
args = node.args | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know args and kwargs are iterated differently, but the rest of the code seems like it could be combined. Let me know if it's more difficult than it looks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not entirely sure what change you're asking for here. Can you clarify? |
||
for idx, arg in enumerate(args): | ||
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)): | ||
for tuple_idx, tuple_arg in enumerate(arg): | ||
modifications_count += node_input_aligner.marshal_params( | ||
node, tuple_arg, [idx, tuple_idx], SiteType.ARGS_TUPLE, first_node | ||
) | ||
else: | ||
modifications_count += node_input_aligner.marshal_params(node, arg, idx, SiteType.ARGS, first_node) | ||
|
||
kwargs = node.kwargs | ||
for key, arg in kwargs.items(): | ||
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)): | ||
for tuple_idx, tuple_arg in enumerate(arg): | ||
modifications_count += node_input_aligner.marshal_params( | ||
node, tuple_arg, [key, tuple_idx], SiteType.KWARGS_TUPLE, first_node | ||
) | ||
else: | ||
modifications_count += node_input_aligner.marshal_params(node, arg, key, SiteType.KWARGS, first_node) | ||
|
||
# reset run_once_count so recompilation triggers loading weights | ||
with gm.graph.inserting_before(first_node): | ||
ttnn_inputs = gm.graph.call_function( | ||
target_wrappers.run_once, | ||
tuple( | ||
[ | ||
node_input_aligner._extract_args_kwargs_from_spec(spec) | ||
for spec in node_input_aligner.marshaled_node_dict.keys() | ||
] | ||
), | ||
) | ||
|
||
return modifications_count, ttnn_inputs | ||
|
||
|
||
class AddDataMovePass(PassBase): | ||
"""Pass that adds instructions to move data between host and device and align tensor dtype and layout. | ||
|
||
|
@@ -544,33 +651,45 @@ def __init__(self, device): | |
def call(self, gm: torch.fx.GraphModule): | ||
SiteType = NodeInputAligner.InputSiteType | ||
|
||
i = 0 | ||
modifications_count = 0 | ||
node_input_aligner = NodeInputAligner(gm.graph, self.device) | ||
nodes = list(gm.graph.nodes) | ||
|
||
first_node = [node for node in nodes if node.op != "placeholder"][0] | ||
|
||
# first load weights | ||
ttnn_inputs = None | ||
if gm.meta.get("graph_type") == ModelType.INFERENCE: | ||
global run_once_count | ||
target_wrappers.run_once_count = 0 | ||
modifications_count, ttnn_inputs = insert_load_params_once(gm, first_node, nodes, node_input_aligner) | ||
|
||
# then handle rest of the args and kwargs | ||
for node in nodes: | ||
args = node.args | ||
for idx, arg in enumerate(args): | ||
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)): | ||
for tuple_idx, tuple_arg in enumerate(arg): | ||
i += node_input_aligner.align( | ||
node, tuple_arg, [idx, tuple_idx], SiteType.ARGS_TUPLE, first_node | ||
modifications_count += node_input_aligner.align( | ||
node, tuple_arg, [idx, tuple_idx], SiteType.ARGS_TUPLE, first_node, ttnn_inputs | ||
) | ||
else: | ||
i += node_input_aligner.align(node, arg, idx, SiteType.ARGS, first_node) | ||
modifications_count += node_input_aligner.align( | ||
node, arg, idx, SiteType.ARGS, first_node, ttnn_inputs | ||
) | ||
|
||
kwargs = node.kwargs | ||
for key, arg in kwargs.items(): | ||
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)): | ||
for tuple_idx, tuple_arg in enumerate(arg): | ||
i += node_input_aligner.align( | ||
node, tuple_arg, [key, tuple_idx], SiteType.KWARGS_TUPLE, first_node | ||
modifications_count += node_input_aligner.align( | ||
node, tuple_arg, [key, tuple_idx], SiteType.KWARGS_TUPLE, first_node, ttnn_inputs | ||
) | ||
else: | ||
i += node_input_aligner.align(node, arg, key, SiteType.KWARGS, first_node) | ||
modifications_count += node_input_aligner.align( | ||
node, arg, key, SiteType.KWARGS, first_node, ttnn_inputs | ||
) | ||
|
||
modified = i > 0 | ||
modified = modifications_count > 0 | ||
GraphCleanup(gm) | ||
return PassResult(gm, modified) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,25 @@ | |
|
||
from torch_ttnn.utils import TtnnDevice | ||
|
||
run_once_count = 0 | ||
run_once_ans = tuple() | ||
|
||
|
||
@torch.fx.wrap | ||
def run_once(*args): | ||
Comment on lines
+9
to
+14
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) using closure or @lru_cache might be a bit better than global variable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I really like the lru_cache here, but it hits an error with the kwargs inputs being unhashable since it's a dict. The closure path seems less readable to me in this case. I agree that the global variable isn't the best, but I'm leaning towards keeping it unless you feel strongly |
||
global run_once_count | ||
global run_once_ans | ||
|
||
if run_once_count == 0: | ||
|
||
def convert_input(spec): | ||
return ttnn.from_torch(*spec[0], **spec[1]) | ||
|
||
run_once_ans = tuple([convert_input(arg) for arg in args]) | ||
run_once_count += 1 | ||
|
||
return run_once_ans | ||
|
||
|
||
@torch.fx.wrap | ||
def clone(t): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure this list will always determine that graph is train backward? Feels too small)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that it seems too small, but it's based on the list of backward ops from the aten core IR. I will add a comment with a link