Propagate manual axes to called functions during shard map import.#1227
Closed
copybara-service[bot] wants to merge 1 commit intomainfrom
Closed
Propagate manual axes to called functions during shard map import.#1227copybara-service[bot] wants to merge 1 commit intomainfrom
copybara-service[bot] wants to merge 1 commit intomainfrom
Conversation
It is to prepare pushing shardy unflattener up past shard map export. Before this change, shard map export is the one that attaches manual axes to funcs and then the shardy unflattener use them to (not) deduplicate funcs with different manual axes. Instead, in this change, the manual axes are attached to funcs during the shard map import. Those manual axes are kept through the pipeline until the shardy unflattener and now the shardy unflattener use those manaul axes attached during the shard map import to (not) deduplicate funcs with different manual axes. The shard map import pass now sets a new `sdy.func_manual_axes` attribute on functions called within a context where manual axes are active. This attribute stores a flattened list of manual axes from the call stack. The export and unflatten passes are updated to read and remove this new attribute. PiperOrigin-RevId: 905501102
64a2704 to
4d8d5fd
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Propagate manual axes to called functions during shard map import.
It is to prepare pushing shardy unflattener up past shard map export. Before this change, shard map export is the one that attaches manual axes to funcs and then the shardy unflattener use them to (not) deduplicate funcs with different manual axes.
Instead, in this change, the manual axes are attached to funcs during the shard map import. Those manual axes are kept through the pipeline until the shardy unflattener and now the shardy unflattener use those manaul axes attached during the shard map import to (not) deduplicate funcs with different manual axes.
The shard map import pass now sets a new
sdy.func_manual_axesattribute on functions called within a context where manual axes are active. This attribute stores a flattened list of manual axes from the call stack. The export and unflatten passes are updated to read and remove this new attribute.