Skip to content

Propagate manual axes to called functions during shard map import.#1227

Closed
copybara-service[bot] wants to merge 1 commit intomainfrom
test_905501102
Closed

Propagate manual axes to called functions during shard map import.#1227
copybara-service[bot] wants to merge 1 commit intomainfrom
test_905501102

Conversation

@copybara-service
Copy link
Copy Markdown

Propagate manual axes to called functions during shard map import.

It is to prepare pushing shardy unflattener up past shard map export. Before this change, shard map export is the one that attaches manual axes to funcs and then the shardy unflattener use them to (not) deduplicate funcs with different manual axes.

Instead, in this change, the manual axes are attached to funcs during the shard map import. Those manual axes are kept through the pipeline until the shardy unflattener and now the shardy unflattener use those manaul axes attached during the shard map import to (not) deduplicate funcs with different manual axes.

The shard map import pass now sets a new sdy.func_manual_axes attribute on functions called within a context where manual axes are active. This attribute stores a flattened list of manual axes from the call stack. The export and unflatten passes are updated to read and remove this new attribute.

It is to prepare pushing shardy unflattener up past shard map export. Before this change, shard map export is the one that attaches manual axes to funcs and then the shardy unflattener use them to (not) deduplicate funcs with different manual axes.

Instead, in this change, the manual axes are attached to funcs during the shard map import. Those manual axes are kept through the pipeline until the shardy unflattener and now the shardy unflattener use those manaul axes attached during the shard map import to (not) deduplicate funcs with different manual axes.

The shard map import pass now sets a new `sdy.func_manual_axes` attribute on functions called within a context where manual axes are active. This attribute stores a flattened list of manual axes from the call stack. The export and unflatten passes are updated to read and remove this new attribute.

PiperOrigin-RevId: 905501102
@copybara-service copybara-service Bot closed this Apr 27, 2026
@copybara-service copybara-service Bot deleted the test_905501102 branch April 27, 2026 15:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant