Skip to content

Load Weights Once #1015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open

Load Weights Once #1015

wants to merge 18 commits into from

Conversation

jmalone-tt
Copy link
Collaborator

@jmalone-tt jmalone-tt commented May 14, 2025

This adds changes to only load weights once by adding a wrapper function with logic to load weights on device and use cached tensors for future iterations.

Cached tensors should be deallocated when the calling function ends. Since Native Device should have the same behavior, this is ok for now. If we start running into issues with running out of DRAM, we may need to be more strategic about which tensors are cached on device.

Note that this PR does not enable cached model parameters for training runs, only inference. This is accomplished through an additional GraphModuleAnalysisPass. This pass looks for characteristics that are typical of backward computation graphs (calls to backward aten ops), forward training graphs (outputting one of the inputs unchanged), and marks everything else as forward inference. This seems fragile, but any of the normal torch methods of determining whether a model is in training or eval mode appear to not work by the time the GraphModule gets to our backend.

This PR also fixes an issue where inputs for data parallel were not hoisted to the top of the forward function.

This PR also opens up the opportunity for preprocessing model weights. This is needed for convolution ops to be able to run metal trace, and will help speed up several other functions once implemented. Preprocessing will be addressed in a separate PR.

Local tests indicate the following performance improvements:
MNIST Batch 1, 5th iteration changes 5.29 ms -> 3.32 ms
BERT Batch 8, 5th iteration now hits 23.3 sen/sec (batch 1 on main today hits ~1.7 sen/sec, not sure about batch 8)

Function correctly runs once
Created graphmodule loads weights properly
TODO: get wrapper function to successfully call generated GraphModule
TODO: clean up before PR
@jmalone-tt
Copy link
Collaborator Author

Based on a quick test, this lowers 5th iteration MNIST batch 1 from 5.29 ms to 3.32 ms. I have not tested BERT, data parallel, or other models.

@jmalone-tt
Copy link
Collaborator Author

With this change, BERT batch 8 hits 23.3 sen/sec locally

@jmalone-tt jmalone-tt marked this pull request as ready for review May 14, 2025 20:40
@jmalone-tt
Copy link
Collaborator Author

I'm sure the code in here could use some additional cleaning, but at this point I'm a bit too deep in the weeds, so I'm sure I'm missing things. Please point out areas where the code could be more readable

Copy link
Contributor

@dgomezTT dgomezTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me, but I would wait for Kevin or Artem to check it up!

Update comment for graph module analysis pass
Copy link
Contributor

@kevinwuTT kevinwuTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I'm glad the analysis pass for determining training/eval works well enough.

i = 0

for node in nodes:
args = node.args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know args and kwargs are iterated differently, but the rest of the code seems like it could be combined. Let me know if it's more difficult than it looks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure what change you're asking for here. Can you clarify?

@jmalone-tt
Copy link
Collaborator Author

@@ -270,6 +271,8 @@ def __init__(self, graph, device):
self.graph = graph
self.device = device
self.aligned_node_dict = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please leave a comment on what this member is about?

@jmalone-tt jmalone-tt added this pull request to the merge queue May 17, 2025
@jmalone-tt
Copy link
Collaborator Author

This passes consistently in the Run Tests workflow, and passes locally, but fails consistently in Before Merge 😕. Need to investigate further.

@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks May 17, 2025
@jmalone-tt jmalone-tt added this pull request to the merge queue May 19, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks May 19, 2025
Comment on lines 9 to 21
class ModelType(Enum):
"""Enumeration of compiled model inputs.

It is expected that PARAMETER and BUFFER tensors do not change between inference runs, but ARGUMENT tensors may change.

:param PARAMETER: Tensor tracked by optimizer during training. Represents weights and biases.
:param BUFFER: Tensor not updated during training but still part of model state. Normally used for non-trainable data like fixed weights or running statistics in batch norm.
:param ARGUMENT: Tensors not tracked by model and not persistent between calls. Can be input data or configuration options to module's methods.
"""

INFERENCE = 1
TRAIN_FORWARD = 2
TRAIN_BACKWARD = 3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think docs are from other enum? From PrimalTag?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this!

Comment on lines +24 to +32
aten_backward_ops = {
torch.ops.aten._adaptive_avg_pool2d_backward.default,
torch.ops.aten.avg_pool2d_backward.default,
torch.ops.aten.convolution_backward.default,
torch.ops.aten.embedding_dense_backward.default,
torch.ops.aten.max_pool2d_with_indices_backward.default,
torch.ops.aten.native_group_norm_backward.default,
torch.ops.aten.native_layer_norm_backward.default,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this list will always determine that graph is train backward? Feels too small)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it seems too small, but it's based on the list of backward ops from the aten core IR. I will add a comment with a link

return False


def is_train_forward(gm):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this documented in somewhere in torch docs that train forward will return inputs? If no, this might change in the future

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it's documented anywhere, but inputs are used in backward pass, so this should be reliable. Can add a comment that notes this assumption though

Comment on lines 570 to 572
with self.graph.inserting_before(first_node):
self.marshaled_node_dict[data_move_spec] = self.input_idx
self.input_idx += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this block wrapped with with self.graph.inserting_before? You aren't inserting node here, or am I missing something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, earlier in the change I was adding the nodes here, but then removed that code. Will update

Comment on lines +9 to +14
run_once_count = 0
run_once_ans = tuple()


@torch.fx.wrap
def run_once(*args):
Copy link
Contributor

@philei-tt philei-tt May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) using closure or @lru_cache might be a bit better than global variable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the lru_cache here, but it hits an error with the kwargs inputs being unhashable since it's a dict. The closure path seems less readable to me in this case. I agree that the global variable isn't the best, but I'm leaning towards keeping it unless you feel strongly

)
else:
i += node_input_aligner.align(node, arg, key, SiteType.KWARGS, first_node)
i += node_input_aligner.align(node, arg, key, SiteType.KWARGS, first_node, ttnn_inputs)

modified = i > 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) maybe use modified = modified or node_input_aligner.align(...) for readability purposes? Not sure it's better, but variable "i" doesn't tell much about it's purpose

Copy link
Contributor

@philei-tt philei-tt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just a few small comments.

@jmalone-tt jmalone-tt enabled auto-merge May 20, 2025 20:22
@jmalone-tt jmalone-tt added this pull request to the merge queue May 20, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks May 20, 2025
@jmalone-tt jmalone-tt added this pull request to the merge queue May 20, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks May 20, 2025
@jmalone-tt jmalone-tt added this pull request to the merge queue May 20, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks May 20, 2025
@jmalone-tt jmalone-tt added this pull request to the merge queue May 20, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks May 20, 2025
@jmalone-tt jmalone-tt enabled auto-merge May 20, 2025 22:56
@jmalone-tt jmalone-tt added this pull request to the merge queue May 20, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks May 20, 2025
@jmalone-tt jmalone-tt added this pull request to the merge queue May 20, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks May 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants