-
Notifications
You must be signed in to change notification settings - Fork 7
Load Weights Once #1015
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Load Weights Once #1015
Conversation
Function correctly runs once Created graphmodule loads weights properly TODO: get wrapper function to successfully call generated GraphModule
TODO: clean up before PR
Based on a quick test, this lowers 5th iteration MNIST batch 1 from 5.29 ms to 3.32 ms. I have not tested BERT, data parallel, or other models. |
With this change, BERT batch 8 hits 23.3 sen/sec locally |
TODO: clean up unused code
TODO: address any feedback
I'm sure the code in here could use some additional cleaning, but at this point I'm a bit too deep in the weeds, so I'm sure I'm missing things. Please point out areas where the code could be more readable |
run in one command
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me, but I would wait for Kevin or Artem to check it up!
Update comment for graph module analysis pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I'm glad the analysis pass for determining training/eval works well enough.
i = 0 | ||
|
||
for node in nodes: | ||
args = node.args |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know args and kwargs are iterated differently, but the rest of the code seems like it could be combined. Let me know if it's more difficult than it looks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not entirely sure what change you're asking for here. Can you clarify?
@@ -270,6 +271,8 @@ def __init__(self, graph, device): | |||
self.graph = graph | |||
self.device = device | |||
self.aligned_node_dict = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please leave a comment on what this member is about?
This passes consistently in the |
class ModelType(Enum): | ||
"""Enumeration of compiled model inputs. | ||
|
||
It is expected that PARAMETER and BUFFER tensors do not change between inference runs, but ARGUMENT tensors may change. | ||
|
||
:param PARAMETER: Tensor tracked by optimizer during training. Represents weights and biases. | ||
:param BUFFER: Tensor not updated during training but still part of model state. Normally used for non-trainable data like fixed weights or running statistics in batch norm. | ||
:param ARGUMENT: Tensors not tracked by model and not persistent between calls. Can be input data or configuration options to module's methods. | ||
""" | ||
|
||
INFERENCE = 1 | ||
TRAIN_FORWARD = 2 | ||
TRAIN_BACKWARD = 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think docs are from other enum? From PrimalTag?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching this!
aten_backward_ops = { | ||
torch.ops.aten._adaptive_avg_pool2d_backward.default, | ||
torch.ops.aten.avg_pool2d_backward.default, | ||
torch.ops.aten.convolution_backward.default, | ||
torch.ops.aten.embedding_dense_backward.default, | ||
torch.ops.aten.max_pool2d_with_indices_backward.default, | ||
torch.ops.aten.native_group_norm_backward.default, | ||
torch.ops.aten.native_layer_norm_backward.default, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure this list will always determine that graph is train backward? Feels too small)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that it seems too small, but it's based on the list of backward ops from the aten core IR. I will add a comment with a link
return False | ||
|
||
|
||
def is_train_forward(gm): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this documented in somewhere in torch docs that train forward will return inputs? If no, this might change in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if it's documented anywhere, but inputs are used in backward pass, so this should be reliable. Can add a comment that notes this assumption though
with self.graph.inserting_before(first_node): | ||
self.marshaled_node_dict[data_move_spec] = self.input_idx | ||
self.input_idx += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this block wrapped with with self.graph.inserting_before? You aren't inserting node here, or am I missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, earlier in the change I was adding the nodes here, but then removed that code. Will update
run_once_count = 0 | ||
run_once_ans = tuple() | ||
|
||
|
||
@torch.fx.wrap | ||
def run_once(*args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) using closure or @lru_cache might be a bit better than global variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like the lru_cache here, but it hits an error with the kwargs inputs being unhashable since it's a dict. The closure path seems less readable to me in this case. I agree that the global variable isn't the best, but I'm leaning towards keeping it unless you feel strongly
) | ||
else: | ||
i += node_input_aligner.align(node, arg, key, SiteType.KWARGS, first_node) | ||
i += node_input_aligner.align(node, arg, key, SiteType.KWARGS, first_node, ttnn_inputs) | ||
|
||
modified = i > 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) maybe use modified = modified or node_input_aligner.align(...) for readability purposes? Not sure it's better, but variable "i" doesn't tell much about it's purpose
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just a few small comments.
passes in Run Tests
This adds changes to only load weights once by adding a wrapper function with logic to load weights on device and use cached tensors for future iterations.
Cached tensors should be deallocated when the calling function ends. Since Native Device should have the same behavior, this is ok for now. If we start running into issues with running out of DRAM, we may need to be more strategic about which tensors are cached on device.
Note that this PR does not enable cached model parameters for training runs, only inference. This is accomplished through an additional GraphModuleAnalysisPass. This pass looks for characteristics that are typical of backward computation graphs (calls to backward aten ops), forward training graphs (outputting one of the inputs unchanged), and marks everything else as forward inference. This seems fragile, but any of the normal torch methods of determining whether a model is in training or eval mode appear to not work by the time the GraphModule gets to our backend.
This PR also fixes an issue where inputs for data parallel were not hoisted to the top of the forward function.
This PR also opens up the opportunity for preprocessing model weights. This is needed for convolution ops to be able to run metal trace, and will help speed up several other functions once implemented. Preprocessing will be addressed in a separate PR.
Local tests indicate the following performance improvements:
MNIST Batch 1, 5th iteration changes 5.29 ms -> 3.32 ms
BERT Batch 8, 5th iteration now hits 23.3 sen/sec (batch 1 on main today hits ~1.7 sen/sec, not sure about batch 8)