Skip to content

Remove redundant flatten/unflatten of PyTrees #1648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

rauletorresc
Copy link
Contributor

Context: When calling get_abstract_signature from within get_decomposed_signature we are flattening/unflattening the PyTree one more time without need. In this sense get_abstract_signature is just a wrapper for the shaped_abstractify function that we don't need here.

Description of the Change: Avoid the wrapper function call and use shaped_abstractify directly.

Benefits: Avoid code that does nothing.

@rauletorresc rauletorresc requested a review from a team April 14, 2025 23:13
@rauletorresc rauletorresc self-assigned this Apr 14, 2025
Copy link
Contributor

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md on your branch with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@rauletorresc rauletorresc force-pushed the raultorres/redundant_flatten_unflatten branch from bbda3a7 to 8e43f76 Compare April 14, 2025 23:14
Copy link

codecov bot commented Apr 14, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 96.88%. Comparing base (1a6df92) to head (8e43f76).

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1648   +/-   ##
=======================================
  Coverage   96.88%   96.88%           
=======================================
  Files          80       80           
  Lines        8861     8861           
  Branches      841      841           
=======================================
  Hits         8585     8585           
  Misses        222      222           
  Partials       54       54           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@dime10
Copy link
Contributor

dime10 commented Apr 15, 2025

This change looks good, but if we want to resolve the core issue we have to tackle it here:

dynamic_sig = get_abstract_signature(dynamic_args)

  • We generate an abstract signature re-wrapped into pytrees.
  • The question is whether we really need the abstract values at some point, and if so whether we need them in pytrees. We can modify either of those conditions to become compliant.

The abstract signature is used for example to generate jaxpr:

jaxpr, out_type, treedef, plugins = trace_to_jaxpr(

Then it is used again during promotion:

args = promote_arguments(self.c_sig, dynamic_args)

  • Here we can probably just do the promotion on a flattened argument list instead.

The last use case, in the compilation cache, is already addressed in this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants