Skip to content

Commit 5cdbc00

Browse files
authored
Load Weights Once (#1015)
* WIP loading weights once * WIP loading weights once Function correctly runs once Created graphmodule loads weights properly TODO: get wrapper function to successfully call generated GraphModule * Adds necessary files to get load_weights working (fixup) * Loading weights once works!!! TODO: clean up before PR * Clean up for PR part 1 * Add analysis of when to load_weights_once TODO: clean up unused code * Training flows now work by not loading weights first TODO: address any feedback * fixup * Get Data Parallel working again * Reset run_once count every time compilation occurs to fix multiple tests run in one command * Remove dead code in backend Update comment for graph module analysis pass * Address PR feedback, improves readability * Updated comments * Skip Falcon-7b test for now since it fails in Before Merge workflow but passes in Run Tests * Only load_once on end-to-end converted models to solve input caching issue
1 parent 35ce5c7 commit 5cdbc00

File tree

7 files changed

+258
-21
lines changed

7 files changed

+258
-21
lines changed

torch_ttnn/backend.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def __init__(
6767
self._n_buffers = None
6868
self._n_arguments = None
6969

70+
# Used for pre-loading model params
71+
self._is_end_to_end = False
72+
7073
def reset_containers(self):
7174
self._out_fx_graphs = list()
7275
self.original_schema_list = list()
@@ -140,6 +143,7 @@ def aten_backend(
140143

141144
# Run analysis passes to help with ttnn ops
142145
from torch.fx.passes.infra.pass_manager import PassManager
146+
from torch_ttnn.passes.analysis.graph_module_analysis_pass import GraphModuleAnalysisPass
143147
from torch_ttnn.passes.analysis.input_analysis_pass import InputAnalysisPass
144148
from torch_ttnn.passes.analysis.multi_device_shard_analysis_pass import MultiDeviceShardAnalysisPass
145149

@@ -157,13 +161,14 @@ def aten_backend(
157161
from torch_ttnn.passes.deallocation_pass import DeallocationPass
158162

159163
passes = [
164+
GraphModuleAnalysisPass(),
160165
InputAnalysisPass(option._n_parameters, option._n_buffers, option._n_arguments),
161166
MultiDeviceShardAnalysisPass(option.device),
162167
ConstantFoldingPass(),
163168
MultiDevicePass(option.device, example_inputs),
164169
ToTtPass(option.device, option.use_less_ttnn_op_types),
165170
FusionPass(),
166-
AddDataMovePass(option.device),
171+
AddDataMovePass(option.device, option._is_end_to_end),
167172
EliminateCoreopsPass(),
168173
CSEPass(),
169174
PermuteReshapeTuple(),
@@ -277,6 +282,9 @@ def ttnn_backend(
277282
options._n_buffers = len(list(gm.buffers()))
278283
options._n_arguments = len(example_inputs)
279284

285+
# Currently, we only support preprocessing weights for end-to-end converted models
286+
options._is_end_to_end = gm.compile_subgraph_reason.graph_break == False
287+
280288
tracer_option = options.tracer_option
281289
if tracer_option is not None:
282290
from ..tracer import Tracer
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
import torch
5+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
6+
from enum import Enum
7+
8+
9+
class ModelType(Enum):
10+
"""Enumeration of model types differentiating between inference and training, forward and backward.
11+
12+
:param INFERENCE: Model with this tag is for the forward pass of inference.
13+
:param TRAIN_FORWARD: Model wih this tag is for the forward pass of a training run.
14+
:param TRAIN_BACKWARD: Model with this tag is for the backward pass of a training run.
15+
"""
16+
17+
INFERENCE = 1
18+
TRAIN_FORWARD = 2
19+
TRAIN_BACKWARD = 3
20+
21+
22+
# this list seems small, but is based off all the backward calls from the Core Aten IR:
23+
# https://docs.pytorch.org/docs/stable/torch.compiler_ir.html
24+
aten_backward_ops = {
25+
torch.ops.aten._adaptive_avg_pool2d_backward.default,
26+
torch.ops.aten.avg_pool2d_backward.default,
27+
torch.ops.aten.convolution_backward.default,
28+
torch.ops.aten.embedding_dense_backward.default,
29+
torch.ops.aten.max_pool2d_with_indices_backward.default,
30+
torch.ops.aten.native_group_norm_backward.default,
31+
torch.ops.aten.native_layer_norm_backward.default,
32+
}
33+
34+
35+
def is_train_backward(gm):
36+
node_list = list(gm.graph.nodes)
37+
for node in node_list:
38+
if node.op == "call_function" and node.target in aten_backward_ops:
39+
return True
40+
41+
return False
42+
43+
44+
def is_train_forward(gm):
45+
# Assume training forward calls are differentiated by directly returning one of the inputs to be used for calculating gradients
46+
# If this assumption fails in the future, we will need to update this function
47+
outputs = [node for node in gm.graph.nodes if node.op == "output"]
48+
for node in outputs:
49+
placeholder_args = [arg for arg in node.args[0] if arg.op == "placeholder"]
50+
if len(placeholder_args) > 0:
51+
return True
52+
53+
return False
54+
55+
56+
class GraphModuleAnalysisPass(PassBase):
57+
def __init__(self):
58+
super().__init__()
59+
60+
def call(self, gm: torch.fx.GraphModule):
61+
"""Marks the GraphModule as either training forward, training backward, or inference (forward).
62+
63+
This relies on commonalities between training forward, backward, and inference graphs. Namely, backward passes call backward versions of the forward functions to calculate gradients. Training forward passes return inputs unchanged. Inference forward functions do neither of these. It would be cleaner if we could just use something like `torch.is_grad_enabled()` or `gm.training` instead, but these appear to be inaccurate by the time the GraphModule is passed to our backend.
64+
65+
:param gm: Graph module for the function being compiled.
66+
:return: Pass result with the updated graph module with metadata indicating the type of graph being compiled.
67+
:rtype: PassResult[torch.fx.GraphModule, bool]
68+
"""
69+
70+
modified = False
71+
72+
# check nodes for backward function call
73+
if is_train_backward(gm):
74+
gm.meta["graph_type"] = ModelType.TRAIN_BACKWARD
75+
elif is_train_forward(gm):
76+
gm.meta["graph_type"] = ModelType.TRAIN_FORWARD
77+
else:
78+
gm.meta["graph_type"] = ModelType.INFERENCE
79+
80+
return PassResult(gm, modified)

torch_ttnn/passes/deallocation_pass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def call(self, gm: torch.fx.GraphModule):
9797
# We don't want to delete these too early
9898
if node.target in [target_wrappers.pack_to_tuple, operator.getitem]:
9999
continue
100+
# Skip nodes that are cached between runs
101+
elif n.meta.get("is_cached", False):
102+
continue
100103
with graph.inserting_after(node):
101104
new_node = graph.call_function(deallocate, args=(n,))
102105
modified = True

torch_ttnn/passes/lowering/add_data_move_pass.py

Lines changed: 134 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import torch
66
import ttnn
7+
from torch_ttnn.passes.analysis.graph_module_analysis_pass import ModelType
78
from torch_ttnn.passes.analysis.input_analysis_pass import PrimalTag
89
from torch_ttnn.utils import (
910
GraphCleanup,
@@ -269,7 +270,12 @@ class NodeInputAligner:
269270
def __init__(self, graph, device):
270271
self.graph = graph
271272
self.device = device
273+
# aligned_node_dict maps DataMoveSpec to aligned version of the node to prevent calling the same data movement twice
272274
self.aligned_node_dict = {}
275+
# marshaled_node_dict maps DataMoveSpec to index in the load_weights function that runs once. This is consumed once when adding data movement, and the DataMoveSpec will
276+
# then be populated in the aligned_node_dict for further usage
277+
self.marshaled_node_dict = {}
278+
self.input_idx = 0
273279

274280
# fields for data parallel
275281
self.shard_to_mesh = dict()
@@ -417,7 +423,7 @@ def _change_layout(self, spec):
417423

418424
return input_node
419425

420-
def _create_aligned_node(self, spec):
426+
def _extract_args_kwargs_from_spec(self, spec):
421427
if isinstance(spec, self.AlignSpecFromTorch):
422428
kwargs = {}
423429
args = (spec.input_node,)
@@ -450,7 +456,7 @@ def _create_aligned_node(self, spec):
450456
args = spec.input_node.args
451457
kwargs["mesh_mapper"] = mesh_mapper
452458

453-
return self.graph.call_function(ttnn.from_torch, args, kwargs)
459+
return args, kwargs
454460

455461
elif isinstance(spec, self.AlignSpecToTorch):
456462
kwargs = {"dtype": spec.dtype}
@@ -470,6 +476,21 @@ def _create_aligned_node(self, spec):
470476
kwargs["mesh_composer"] = composer
471477
args = (actual_inp_node,)
472478

479+
return args, kwargs
480+
481+
elif isinstance(spec, self.AlignSpecInTtnn):
482+
return (), {}
483+
484+
else:
485+
raise RuntimeError(f"Cannot create aligned node for unknown spec ({spec})")
486+
487+
def _create_aligned_node(self, spec):
488+
args, kwargs = self._extract_args_kwargs_from_spec(spec)
489+
490+
if isinstance(spec, self.AlignSpecFromTorch):
491+
return self.graph.call_function(ttnn.from_torch, args, kwargs)
492+
493+
elif isinstance(spec, self.AlignSpecToTorch):
473494
return self.graph.call_function(ttnn.to_torch, args, kwargs)
474495

475496
elif isinstance(spec, self.AlignSpecInTtnn):
@@ -500,7 +521,38 @@ def _connect_aligned_node(self, node, aligned_node, input_site, input_site_type:
500521
new_arg[tuple_idx] = aligned_node
501522
node.update_kwarg(key, tuple(new_arg))
502523

503-
def align(self, node, input_node, input_site, input_site_type: InputSiteType, first_node):
524+
def marshal_params(self, node, input_node, input_site, input_site_type, first_node):
525+
if not isinstance(input_node, torch.fx.node.Node):
526+
return 0
527+
528+
# examine first arg for multi device case
529+
check_constant = input_node
530+
if input_node.op == "call_function" and input_node.target in [
531+
target_wrappers.shard_tensor,
532+
target_wrappers.replicate_tensor,
533+
]:
534+
check_constant = input_node.args[0]
535+
is_constant_for_inference = (check_constant.op == "placeholder") and (
536+
check_constant.meta.get("primal_tag") in [PrimalTag.PARAMETER, PrimalTag.BUFFER]
537+
)
538+
if not is_constant_for_inference:
539+
return 0
540+
541+
data_move_spec = self._get_align_spec(node, input_node, input_site, input_site_type)
542+
if not isinstance(data_move_spec, self.AlignSpecFromTorch):
543+
# No need to align input_node
544+
return 0
545+
546+
if data_move_spec in self.marshaled_node_dict:
547+
# already handled
548+
return 0
549+
550+
self.marshaled_node_dict[data_move_spec] = self.input_idx
551+
self.input_idx += 1
552+
553+
return 1
554+
555+
def align(self, node, input_node, input_site, input_site_type: InputSiteType, first_node, ttnn_inputs):
504556
# assert input_site_type in ["args", "kwargs", "args_tuple", "kwargs_tuple"]
505557
data_move_spec = self._get_align_spec(node, input_node, input_site, input_site_type)
506558
if data_move_spec is None:
@@ -509,11 +561,26 @@ def align(self, node, input_node, input_site, input_site_type: InputSiteType, fi
509561

510562
if data_move_spec in self.aligned_node_dict:
511563
aligned_node = self.aligned_node_dict[data_move_spec]
564+
elif ttnn_inputs is not None and (input_idx := self.marshaled_node_dict.get(data_move_spec)) is not None:
565+
with self.graph.inserting_before(node):
566+
aligned_node = self.graph.call_function(getitem, (ttnn_inputs, input_idx))
567+
# mark node cached so it isn't deallocated by DeallocationPass
568+
aligned_node.meta["is_cached"] = True
569+
# update aligned_node_dict
570+
self.aligned_node_dict[data_move_spec] = aligned_node
512571
else:
572+
# push from_torch calls to top of forward function if they are due to a placeholder
573+
maybe_forward_input = input_node
574+
# We have to test the first arg of shard_tensor and replicate_tensor calls instead
575+
if input_node.op == "call_function" and input_node.target in [
576+
target_wrappers.shard_tensor,
577+
target_wrappers.replicate_tensor,
578+
]:
579+
maybe_forward_input = input_node.args[0]
513580
if (
514581
isinstance(data_move_spec, self.AlignSpecFromTorch)
515-
and input_node.op == "placeholder"
516-
and input_node.meta.get("primal_tag") != PrimalTag.ARGUMENT
582+
and maybe_forward_input.op == "placeholder"
583+
and maybe_forward_input.meta.get("primal_tag") != PrimalTag.ARGUMENT
517584
):
518585
# This will push all from_torch calls to the top of the forward function. This shouldn't impact performance, but it may impact memory usage since variables will be
519586
# live longer than they would if from_torch calls occurred right before usage. If we start running out of DRAM or need to be more careful about memory usage, this
@@ -530,6 +597,46 @@ def align(self, node, input_node, input_site, input_site_type: InputSiteType, fi
530597
return 1
531598

532599

600+
def insert_load_params_once(gm, first_node, nodes, node_input_aligner):
601+
SiteType = NodeInputAligner.InputSiteType
602+
modifications_count = 0
603+
604+
for node in nodes:
605+
args = node.args
606+
for idx, arg in enumerate(args):
607+
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)):
608+
for tuple_idx, tuple_arg in enumerate(arg):
609+
modifications_count += node_input_aligner.marshal_params(
610+
node, tuple_arg, [idx, tuple_idx], SiteType.ARGS_TUPLE, first_node
611+
)
612+
else:
613+
modifications_count += node_input_aligner.marshal_params(node, arg, idx, SiteType.ARGS, first_node)
614+
615+
kwargs = node.kwargs
616+
for key, arg in kwargs.items():
617+
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)):
618+
for tuple_idx, tuple_arg in enumerate(arg):
619+
modifications_count += node_input_aligner.marshal_params(
620+
node, tuple_arg, [key, tuple_idx], SiteType.KWARGS_TUPLE, first_node
621+
)
622+
else:
623+
modifications_count += node_input_aligner.marshal_params(node, arg, key, SiteType.KWARGS, first_node)
624+
625+
# reset run_once_count so recompilation triggers loading weights
626+
with gm.graph.inserting_before(first_node):
627+
ttnn_inputs = gm.graph.call_function(
628+
target_wrappers.run_once,
629+
tuple(
630+
[
631+
node_input_aligner._extract_args_kwargs_from_spec(spec)
632+
for spec in node_input_aligner.marshaled_node_dict.keys()
633+
]
634+
),
635+
)
636+
637+
return modifications_count, ttnn_inputs
638+
639+
533640
class AddDataMovePass(PassBase):
534641
"""Pass that adds instructions to move data between host and device and align tensor dtype and layout.
535642
@@ -538,39 +645,52 @@ class AddDataMovePass(PassBase):
538645
:param device: The device on which a workload will run (either a MeshDevice or Device).
539646
"""
540647

541-
def __init__(self, device):
648+
def __init__(self, device, is_end_to_end):
542649
self.device = device
650+
self.is_end_to_end = is_end_to_end
543651

544652
def call(self, gm: torch.fx.GraphModule):
545653
SiteType = NodeInputAligner.InputSiteType
546654

547-
i = 0
655+
modifications_count = 0
548656
node_input_aligner = NodeInputAligner(gm.graph, self.device)
549657
nodes = list(gm.graph.nodes)
550658

551659
first_node = [node for node in nodes if node.op != "placeholder"][0]
552660

661+
# first load weights
662+
ttnn_inputs = None
663+
if gm.meta.get("graph_type") == ModelType.INFERENCE and self.is_end_to_end:
664+
global run_once_count
665+
target_wrappers.run_once_count = 0
666+
modifications_count, ttnn_inputs = insert_load_params_once(gm, first_node, nodes, node_input_aligner)
667+
668+
# then handle rest of the args and kwargs
553669
for node in nodes:
554670
args = node.args
555671
for idx, arg in enumerate(args):
556672
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)):
557673
for tuple_idx, tuple_arg in enumerate(arg):
558-
i += node_input_aligner.align(
559-
node, tuple_arg, [idx, tuple_idx], SiteType.ARGS_TUPLE, first_node
674+
modifications_count += node_input_aligner.align(
675+
node, tuple_arg, [idx, tuple_idx], SiteType.ARGS_TUPLE, first_node, ttnn_inputs
560676
)
561677
else:
562-
i += node_input_aligner.align(node, arg, idx, SiteType.ARGS, first_node)
678+
modifications_count += node_input_aligner.align(
679+
node, arg, idx, SiteType.ARGS, first_node, ttnn_inputs
680+
)
563681

564682
kwargs = node.kwargs
565683
for key, arg in kwargs.items():
566684
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)):
567685
for tuple_idx, tuple_arg in enumerate(arg):
568-
i += node_input_aligner.align(
569-
node, tuple_arg, [key, tuple_idx], SiteType.KWARGS_TUPLE, first_node
686+
modifications_count += node_input_aligner.align(
687+
node, tuple_arg, [key, tuple_idx], SiteType.KWARGS_TUPLE, first_node, ttnn_inputs
570688
)
571689
else:
572-
i += node_input_aligner.align(node, arg, key, SiteType.KWARGS, first_node)
690+
modifications_count += node_input_aligner.align(
691+
node, arg, key, SiteType.KWARGS, first_node, ttnn_inputs
692+
)
573693

574-
modified = i > 0
694+
modified = modifications_count > 0
575695
GraphCleanup(gm)
576696
return PassResult(gm, modified)

torch_ttnn/passes/lowering/target_wrappers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,25 @@
66

77
from torch_ttnn.utils import TtnnDevice
88

9+
run_once_count = 0
10+
run_once_ans = tuple()
11+
12+
13+
@torch.fx.wrap
14+
def run_once(*args):
15+
global run_once_count
16+
global run_once_ans
17+
18+
if run_once_count == 0:
19+
20+
def convert_input(spec):
21+
return ttnn.from_torch(*spec[0], **spec[1])
22+
23+
run_once_ans = tuple([convert_input(arg) for arg in args])
24+
run_once_count += 1
25+
26+
return run_once_ans
27+
928

1029
@torch.fx.wrap
1130
def clone(t):

0 commit comments

Comments
 (0)