Skip to content

Load Weights Once #1015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
9880665
WIP loading weights once
jmalone-tt May 6, 2025
8ed2ef7
WIP loading weights once
jmalone-tt May 8, 2025
7c07c0d
Adds necessary files to get load_weights working (fixup)
jmalone-tt May 8, 2025
5ba66f5
Loading weights once works!!!
jmalone-tt May 13, 2025
7480673
Clean up for PR part 1
jmalone-tt May 14, 2025
f3bbcc4
Add analysis of when to load_weights_once
jmalone-tt May 14, 2025
02dce32
Training flows now work by not loading weights first
jmalone-tt May 14, 2025
4eeab9f
fixup
jmalone-tt May 14, 2025
2d57080
Get Data Parallel working again
jmalone-tt May 14, 2025
f3708f2
Merge branch 'main' into jmalone/load_weights_once
jmalone-tt May 14, 2025
cb48ec6
Reset run_once count every time compilation occurs to fix multiple tests
jmalone-tt May 14, 2025
a966ea0
Remove dead code in backend
jmalone-tt May 15, 2025
de6f1ee
Address PR feedback, improves readability
jmalone-tt May 16, 2025
5d711fc
Merge branch 'main' into jmalone/load_weights_once
jmalone-tt May 17, 2025
d7f40d3
Updated comments
jmalone-tt May 20, 2025
dd01d91
Merge remote-tracking branch 'origin/main' into jmalone/load_weights_…
jmalone-tt May 20, 2025
9e341a8
Merge branch 'main' into jmalone/load_weights_once
jmalone-tt May 20, 2025
9cabe65
Skip Falcon-7b test for now since it fails in Before Merge workflow but
jmalone-tt May 20, 2025
f7ff7f2
Only load_once on end-to-end converted models to solve input caching
jmalone-tt May 21, 2025
5cfdd8a
Merge branch 'main' into jmalone/load_weights_once
jmalone-tt May 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion torch_ttnn/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def __init__(
self._n_buffers = None
self._n_arguments = None

# Used for pre-loading model params
self._is_end_to_end = False

def reset_containers(self):
self._out_fx_graphs = list()
self.original_schema_list = list()
Expand Down Expand Up @@ -140,6 +143,7 @@ def aten_backend(

# Run analysis passes to help with ttnn ops
from torch.fx.passes.infra.pass_manager import PassManager
from torch_ttnn.passes.analysis.graph_module_analysis_pass import GraphModuleAnalysisPass
from torch_ttnn.passes.analysis.input_analysis_pass import InputAnalysisPass
from torch_ttnn.passes.analysis.multi_device_shard_analysis_pass import MultiDeviceShardAnalysisPass

Expand All @@ -157,13 +161,14 @@ def aten_backend(
from torch_ttnn.passes.deallocation_pass import DeallocationPass

passes = [
GraphModuleAnalysisPass(),
InputAnalysisPass(option._n_parameters, option._n_buffers, option._n_arguments),
MultiDeviceShardAnalysisPass(option.device),
ConstantFoldingPass(),
MultiDevicePass(option.device, example_inputs),
ToTtPass(option.device, option.use_less_ttnn_op_types),
FusionPass(),
AddDataMovePass(option.device),
AddDataMovePass(option.device, option._is_end_to_end),
EliminateCoreopsPass(),
CSEPass(),
PermuteReshapeTuple(),
Expand Down Expand Up @@ -277,6 +282,9 @@ def ttnn_backend(
options._n_buffers = len(list(gm.buffers()))
options._n_arguments = len(example_inputs)

# Currently, we only support preprocessing weights for end-to-end converted models
options._is_end_to_end = gm.compile_subgraph_reason.graph_break == False

tracer_option = options.tracer_option
if tracer_option is not None:
from ..tracer import Tracer
Expand Down
80 changes: 80 additions & 0 deletions torch_ttnn/passes/analysis/graph_module_analysis_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import torch
from torch.fx.passes.infra.pass_base import PassBase, PassResult
from enum import Enum


class ModelType(Enum):
"""Enumeration of model types differentiating between inference and training, forward and backward.

:param INFERENCE: Model with this tag is for the forward pass of inference.
:param TRAIN_FORWARD: Model wih this tag is for the forward pass of a training run.
:param TRAIN_BACKWARD: Model with this tag is for the backward pass of a training run.
"""

INFERENCE = 1
TRAIN_FORWARD = 2
TRAIN_BACKWARD = 3


# this list seems small, but is based off all the backward calls from the Core Aten IR:
# https://docs.pytorch.org/docs/stable/torch.compiler_ir.html
aten_backward_ops = {
torch.ops.aten._adaptive_avg_pool2d_backward.default,
torch.ops.aten.avg_pool2d_backward.default,
torch.ops.aten.convolution_backward.default,
torch.ops.aten.embedding_dense_backward.default,
torch.ops.aten.max_pool2d_with_indices_backward.default,
torch.ops.aten.native_group_norm_backward.default,
torch.ops.aten.native_layer_norm_backward.default,
}
Comment on lines +24 to +32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this list will always determine that graph is train backward? Feels too small)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it seems too small, but it's based on the list of backward ops from the aten core IR. I will add a comment with a link

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I'm not convinced that all models will have aten._backward. calls in backward pass. Here is an example that produces forward+backward aten IR for small linear model, that doesn't have any functions from the list:

import torch
import torch.nn as nn
from torch.fx import symbolic_trace
from torch.fx.experimental.proxy_tensor import make_fx

# Define a simple model with 2 linear layers and ReLU
class SimpleModel(nn.Module):
    def __init__(self, in_features=10, hidden_features=5, out_features=2):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_features, out_features)

    def forward(self, x):
        print(f"Input shape: {x.shape}")
        x = self.fc1(x)
        print(f"After fc1: {x.shape}")
        x = self.relu(x)
        print(f"After relu: {x.shape}")
        x = self.fc2(x)
        print(f"After fc2: {x.shape}")
        return x

if __name__ == "__main__":
    # Instantiate model and dummy input
    model = SimpleModel()
    x = torch.randn(1, 10, requires_grad=True)

    # Forward pass with shape printing
    y = model(x)

    # Trace the forward graph with torch.fx
    traced = symbolic_trace(model)
    print("\n--- Forward FX Graph ---")
    print(traced.graph)
    print("\n--- Forward FX Graph (Python IR) ---")
    print(traced.code)

    # Define a function combining forward & backward
    def forward_and_backward(inp):
        out = model(inp)
        loss = out.sum()
        loss.backward()
        return out

    # Use make_fx to capture both forward and backward aten IR
    fx_g = make_fx(forward_and_backward)(x)
    print("\n--- Forward+Backward FX Graph (aten IR) ---")
    print(fx_g.graph)
    
    print("\n --- Forward+Backward FX Graph (Python IR) ---")
    print(fx_g.code)

Forward + backward generated python code:

 --- Forward+Backward FX Graph (Python IR) ---



def forward(self, inp_1):
    _param_constant0 = self._param_constant0
    t = torch.ops.aten.t.default(_param_constant0);  _param_constant0 = None
    _param_constant1 = self._param_constant1
    addmm = torch.ops.aten.addmm.default(_param_constant1, inp_1, t);  _param_constant1 = None
    relu = torch.ops.aten.relu.default(addmm);  addmm = None
    detach = torch.ops.aten.detach.default(relu)
    _param_constant2 = self._param_constant2
    t_1 = torch.ops.aten.t.default(_param_constant2);  _param_constant2 = None
    _param_constant3 = self._param_constant3
    addmm_1 = torch.ops.aten.addmm.default(_param_constant3, relu, t_1);  _param_constant3 = None
    sum_1 = torch.ops.aten.sum.default(addmm_1)
    ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format);  sum_1 = None
    expand = torch.ops.aten.expand.default(ones_like, [1, 2]);  ones_like = None
    t_2 = torch.ops.aten.t.default(t_1);  t_1 = None
    mm = torch.ops.aten.mm.default(expand, t_2);  t_2 = None
    t_3 = torch.ops.aten.t.default(expand)
    mm_1 = torch.ops.aten.mm.default(t_3, relu);  t_3 = relu = None
    t_4 = torch.ops.aten.t.default(mm_1);  mm_1 = None
    sum_2 = torch.ops.aten.sum.dim_IntList(expand, [0], True);  expand = None
    view = torch.ops.aten.view.default(sum_2, [2]);  sum_2 = None
    detach_1 = torch.ops.aten.detach.default(view);  view = None
    detach_2 = torch.ops.aten.detach.default(detach_1);  detach_1 = None
    t_5 = torch.ops.aten.t.default(t_4);  t_4 = None
    detach_3 = torch.ops.aten.detach.default(t_5);  t_5 = None
    detach_4 = torch.ops.aten.detach.default(detach_3);  detach_3 = None
    detach_5 = torch.ops.aten.detach.default(detach);  detach = None
    threshold_backward = torch.ops.aten.threshold_backward.default(mm, detach_5, 0);  mm = detach_5 = None
    t_6 = torch.ops.aten.t.default(t);  t = None
    mm_2 = torch.ops.aten.mm.default(threshold_backward, t_6);  t_6 = None
    t_7 = torch.ops.aten.t.default(threshold_backward)
    mm_3 = torch.ops.aten.mm.default(t_7, inp_1);  t_7 = inp_1 = None
    t_8 = torch.ops.aten.t.default(mm_3);  mm_3 = None
    sum_3 = torch.ops.aten.sum.dim_IntList(threshold_backward, [0], True);  threshold_backward = None
    view_1 = torch.ops.aten.view.default(sum_3, [5]);  sum_3 = None
    detach_6 = torch.ops.aten.detach.default(view_1);  view_1 = None
    detach_7 = torch.ops.aten.detach.default(detach_6);  detach_6 = None
    detach_8 = torch.ops.aten.detach.default(mm_2);  mm_2 = None
    detach_9 = torch.ops.aten.detach.default(detach_8);  detach_8 = None
    t_9 = torch.ops.aten.t.default(t_8);  t_8 = None
    detach_10 = torch.ops.aten.detach.default(t_9);  t_9 = None
    detach_11 = torch.ops.aten.detach.default(detach_10);  detach_10 = None
    return addmm_1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, looks like that graph contains torch.ops.aten.threshold_backward.default, which isn't in the list. I will go through and check the models that we support and make sure each of the backward passes would be caught in practice



def is_train_backward(gm):
node_list = list(gm.graph.nodes)
for node in node_list:
if node.op == "call_function" and node.target in aten_backward_ops:
return True

return False


def is_train_forward(gm):
# Assume training forward calls are differentiated by directly returning one of the inputs to be used for calculating gradients
# If this assumption fails in the future, we will need to update this function
outputs = [node for node in gm.graph.nodes if node.op == "output"]
for node in outputs:
placeholder_args = [arg for arg in node.args[0] if arg.op == "placeholder"]
if len(placeholder_args) > 0:
return True

return False


class GraphModuleAnalysisPass(PassBase):
def __init__(self):
super().__init__()

def call(self, gm: torch.fx.GraphModule):
"""Marks the GraphModule as either training forward, training backward, or inference (forward).

This relies on commonalities between training forward, backward, and inference graphs. Namely, backward passes call backward versions of the forward functions to calculate gradients. Training forward passes return inputs unchanged. Inference forward functions do neither of these. It would be cleaner if we could just use something like `torch.is_grad_enabled()` or `gm.training` instead, but these appear to be inaccurate by the time the GraphModule is passed to our backend.

:param gm: Graph module for the function being compiled.
:return: Pass result with the updated graph module with metadata indicating the type of graph being compiled.
:rtype: PassResult[torch.fx.GraphModule, bool]
"""

modified = False

# check nodes for backward function call
if is_train_backward(gm):
gm.meta["graph_type"] = ModelType.TRAIN_BACKWARD
elif is_train_forward(gm):
gm.meta["graph_type"] = ModelType.TRAIN_FORWARD
else:
gm.meta["graph_type"] = ModelType.INFERENCE

return PassResult(gm, modified)
3 changes: 3 additions & 0 deletions torch_ttnn/passes/deallocation_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def call(self, gm: torch.fx.GraphModule):
# We don't want to delete these too early
if node.target in [target_wrappers.pack_to_tuple, operator.getitem]:
continue
# Skip nodes that are cached between runs
elif n.meta.get("is_cached", False):
continue
with graph.inserting_after(node):
new_node = graph.call_function(deallocate, args=(n,))
modified = True
Expand Down
148 changes: 134 additions & 14 deletions torch_ttnn/passes/lowering/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import torch
import ttnn
from torch_ttnn.passes.analysis.graph_module_analysis_pass import ModelType
from torch_ttnn.passes.analysis.input_analysis_pass import PrimalTag
from torch_ttnn.utils import (
GraphCleanup,
Expand Down Expand Up @@ -269,7 +270,12 @@ class NodeInputAligner:
def __init__(self, graph, device):
self.graph = graph
self.device = device
# aligned_node_dict maps DataMoveSpec to aligned version of the node to prevent calling the same data movement twice
self.aligned_node_dict = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please leave a comment on what this member is about?

# marshaled_node_dict maps DataMoveSpec to index in the load_weights function that runs once. This is consumed once when adding data movement, and the DataMoveSpec will
# then be populated in the aligned_node_dict for further usage
self.marshaled_node_dict = {}
self.input_idx = 0

# fields for data parallel
self.shard_to_mesh = dict()
Expand Down Expand Up @@ -417,7 +423,7 @@ def _change_layout(self, spec):

return input_node

def _create_aligned_node(self, spec):
def _extract_args_kwargs_from_spec(self, spec):
if isinstance(spec, self.AlignSpecFromTorch):
kwargs = {}
args = (spec.input_node,)
Expand Down Expand Up @@ -450,7 +456,7 @@ def _create_aligned_node(self, spec):
args = spec.input_node.args
kwargs["mesh_mapper"] = mesh_mapper

return self.graph.call_function(ttnn.from_torch, args, kwargs)
return args, kwargs

elif isinstance(spec, self.AlignSpecToTorch):
kwargs = {"dtype": spec.dtype}
Expand All @@ -470,6 +476,21 @@ def _create_aligned_node(self, spec):
kwargs["mesh_composer"] = composer
args = (actual_inp_node,)

return args, kwargs

elif isinstance(spec, self.AlignSpecInTtnn):
return (), {}

else:
raise RuntimeError(f"Cannot create aligned node for unknown spec ({spec})")

def _create_aligned_node(self, spec):
args, kwargs = self._extract_args_kwargs_from_spec(spec)

if isinstance(spec, self.AlignSpecFromTorch):
return self.graph.call_function(ttnn.from_torch, args, kwargs)

elif isinstance(spec, self.AlignSpecToTorch):
return self.graph.call_function(ttnn.to_torch, args, kwargs)

elif isinstance(spec, self.AlignSpecInTtnn):
Expand Down Expand Up @@ -500,7 +521,38 @@ def _connect_aligned_node(self, node, aligned_node, input_site, input_site_type:
new_arg[tuple_idx] = aligned_node
node.update_kwarg(key, tuple(new_arg))

def align(self, node, input_node, input_site, input_site_type: InputSiteType, first_node):
def marshal_params(self, node, input_node, input_site, input_site_type, first_node):
if not isinstance(input_node, torch.fx.node.Node):
return 0

# examine first arg for multi device case
check_constant = input_node
if input_node.op == "call_function" and input_node.target in [
target_wrappers.shard_tensor,
target_wrappers.replicate_tensor,
]:
check_constant = input_node.args[0]
is_constant_for_inference = (check_constant.op == "placeholder") and (
check_constant.meta.get("primal_tag") in [PrimalTag.PARAMETER, PrimalTag.BUFFER]
)
if not is_constant_for_inference:
return 0

data_move_spec = self._get_align_spec(node, input_node, input_site, input_site_type)
if not isinstance(data_move_spec, self.AlignSpecFromTorch):
# No need to align input_node
return 0

if data_move_spec in self.marshaled_node_dict:
# already handled
return 0

self.marshaled_node_dict[data_move_spec] = self.input_idx
self.input_idx += 1

return 1

def align(self, node, input_node, input_site, input_site_type: InputSiteType, first_node, ttnn_inputs):
# assert input_site_type in ["args", "kwargs", "args_tuple", "kwargs_tuple"]
data_move_spec = self._get_align_spec(node, input_node, input_site, input_site_type)
if data_move_spec is None:
Expand All @@ -509,11 +561,26 @@ def align(self, node, input_node, input_site, input_site_type: InputSiteType, fi

if data_move_spec in self.aligned_node_dict:
aligned_node = self.aligned_node_dict[data_move_spec]
elif ttnn_inputs is not None and (input_idx := self.marshaled_node_dict.get(data_move_spec)) is not None:
with self.graph.inserting_before(node):
aligned_node = self.graph.call_function(getitem, (ttnn_inputs, input_idx))
# mark node cached so it isn't deallocated by DeallocationPass
aligned_node.meta["is_cached"] = True
# update aligned_node_dict
self.aligned_node_dict[data_move_spec] = aligned_node
else:
# push from_torch calls to top of forward function if they are due to a placeholder
maybe_forward_input = input_node
# We have to test the first arg of shard_tensor and replicate_tensor calls instead
if input_node.op == "call_function" and input_node.target in [
target_wrappers.shard_tensor,
target_wrappers.replicate_tensor,
]:
maybe_forward_input = input_node.args[0]
if (
isinstance(data_move_spec, self.AlignSpecFromTorch)
and input_node.op == "placeholder"
and input_node.meta.get("primal_tag") != PrimalTag.ARGUMENT
and maybe_forward_input.op == "placeholder"
and maybe_forward_input.meta.get("primal_tag") != PrimalTag.ARGUMENT
):
# This will push all from_torch calls to the top of the forward function. This shouldn't impact performance, but it may impact memory usage since variables will be
# live longer than they would if from_torch calls occurred right before usage. If we start running out of DRAM or need to be more careful about memory usage, this
Expand All @@ -530,6 +597,46 @@ def align(self, node, input_node, input_site, input_site_type: InputSiteType, fi
return 1


def insert_load_params_once(gm, first_node, nodes, node_input_aligner):
SiteType = NodeInputAligner.InputSiteType
modifications_count = 0

for node in nodes:
args = node.args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know args and kwargs are iterated differently, but the rest of the code seems like it could be combined. Let me know if it's more difficult than it looks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure what change you're asking for here. Can you clarify?

for idx, arg in enumerate(args):
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)):
for tuple_idx, tuple_arg in enumerate(arg):
modifications_count += node_input_aligner.marshal_params(
node, tuple_arg, [idx, tuple_idx], SiteType.ARGS_TUPLE, first_node
)
else:
modifications_count += node_input_aligner.marshal_params(node, arg, idx, SiteType.ARGS, first_node)

kwargs = node.kwargs
for key, arg in kwargs.items():
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)):
for tuple_idx, tuple_arg in enumerate(arg):
modifications_count += node_input_aligner.marshal_params(
node, tuple_arg, [key, tuple_idx], SiteType.KWARGS_TUPLE, first_node
)
else:
modifications_count += node_input_aligner.marshal_params(node, arg, key, SiteType.KWARGS, first_node)

# reset run_once_count so recompilation triggers loading weights
with gm.graph.inserting_before(first_node):
ttnn_inputs = gm.graph.call_function(
target_wrappers.run_once,
tuple(
[
node_input_aligner._extract_args_kwargs_from_spec(spec)
for spec in node_input_aligner.marshaled_node_dict.keys()
]
),
)

return modifications_count, ttnn_inputs


class AddDataMovePass(PassBase):
"""Pass that adds instructions to move data between host and device and align tensor dtype and layout.

Expand All @@ -538,39 +645,52 @@ class AddDataMovePass(PassBase):
:param device: The device on which a workload will run (either a MeshDevice or Device).
"""

def __init__(self, device):
def __init__(self, device, is_end_to_end):
self.device = device
self.is_end_to_end = is_end_to_end

def call(self, gm: torch.fx.GraphModule):
SiteType = NodeInputAligner.InputSiteType

i = 0
modifications_count = 0
node_input_aligner = NodeInputAligner(gm.graph, self.device)
nodes = list(gm.graph.nodes)

first_node = [node for node in nodes if node.op != "placeholder"][0]

# first load weights
ttnn_inputs = None
if gm.meta.get("graph_type") == ModelType.INFERENCE and self.is_end_to_end:
global run_once_count
target_wrappers.run_once_count = 0
modifications_count, ttnn_inputs = insert_load_params_once(gm, first_node, nodes, node_input_aligner)

# then handle rest of the args and kwargs
for node in nodes:
args = node.args
for idx, arg in enumerate(args):
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)):
for tuple_idx, tuple_arg in enumerate(arg):
i += node_input_aligner.align(
node, tuple_arg, [idx, tuple_idx], SiteType.ARGS_TUPLE, first_node
modifications_count += node_input_aligner.align(
node, tuple_arg, [idx, tuple_idx], SiteType.ARGS_TUPLE, first_node, ttnn_inputs
)
else:
i += node_input_aligner.align(node, arg, idx, SiteType.ARGS, first_node)
modifications_count += node_input_aligner.align(
node, arg, idx, SiteType.ARGS, first_node, ttnn_inputs
)

kwargs = node.kwargs
for key, arg in kwargs.items():
if isinstance(arg, (tuple, list, torch.fx.immutable_collections.immutable_list)):
for tuple_idx, tuple_arg in enumerate(arg):
i += node_input_aligner.align(
node, tuple_arg, [key, tuple_idx], SiteType.KWARGS_TUPLE, first_node
modifications_count += node_input_aligner.align(
node, tuple_arg, [key, tuple_idx], SiteType.KWARGS_TUPLE, first_node, ttnn_inputs
)
else:
i += node_input_aligner.align(node, arg, key, SiteType.KWARGS, first_node)
modifications_count += node_input_aligner.align(
node, arg, key, SiteType.KWARGS, first_node, ttnn_inputs
)

modified = i > 0
modified = modifications_count > 0
GraphCleanup(gm)
return PassResult(gm, modified)
19 changes: 19 additions & 0 deletions torch_ttnn/passes/lowering/target_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,25 @@

from torch_ttnn.utils import TtnnDevice

run_once_count = 0
run_once_ans = tuple()


@torch.fx.wrap
def run_once(*args):
global run_once_count
global run_once_ans

if run_once_count == 0:

def convert_input(spec):
return ttnn.from_torch(*spec[0], **spec[1])

run_once_ans = tuple([convert_input(arg) for arg in args])
run_once_count += 1

return run_once_ans


@torch.fx.wrap
def clone(t):
Expand Down
Loading