Split lowering rules out of jax_primitives.py#2753
Conversation
|
Hello. You may have forgotten to update the changelog!
|
| return ctx.module, ctx.context | ||
|
|
||
|
|
||
| def get_mlir_attribute_from_pyval(value): |
There was a problem hiding this comment.
Moved to a more appropriate place (jax_primitives_utils.py).
I wrote this helper function a while ago, and placed it here in lowering.py somewhat arbitrarily. This lowering.py file needs the CUSTOM_LOWERING_RULES registry, so I was moving it to break out of circular imports.
| # pylint: disable=unused-argument,too-many-lines,too-many-statements,protected-access | ||
|
|
||
|
|
||
| CUSTOM_LOWERING_RULES = () |
There was a problem hiding this comment.
What would you think of having this be a dictionary and then just converting it via tuple(CUSTOM_LOWERING_RULES.items()) when needed? It would make it so we easily fetch the corresponding lowering rule when we need it for a given primitive.
There was a problem hiding this comment.
Ah! Good point, I kept it as a tuple because the connection with jax, jax.interpreters.mlir.LoweringParameters(override_lowering_rules=CUSTOM_LOWERING_RULES)
But it completely didn't occur to me to just convert 😅 I'll add it
There was a problem hiding this comment.
Could also get rid of the need for global keyword since the registry is mutable now.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2753 +/- ##
=======================================
Coverage 96.99% 97.00%
=======================================
Files 165 166 +1
Lines 18460 18517 +57
Branches 1783 1781 -2
=======================================
+ Hits 17906 17962 +56
Misses 398 398
- Partials 156 157 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
dime10
left a comment
There was a problem hiding this comment.
Thanks @paul0403! This is actually a great opportunity to move these files into the tracing module, which was intended to be sub-module for everything jax/tracing related.
The root-level files stem from when we first started the repo and didn't have any subdirectories. After a while we refactored things into proper sub-modules like api-extensions, passes, autograph, etc., but didn't get to refactor the jax code. As a result, it's split across a growing number of files/folders:
- jax_tracer/primitives/primitives_utils.py
- jax_extras (for patches/functionality considered "core jax")
- the started but never completed tracing module
- and maybe even some things in utils/ ?
If we're able move all of that into one sub-module that would be amazing in terms of code org 😍
As usual, a general reminder to avoid "utils"-style files and folders, in favour of meaningfully grouped code by purpose/functionality.
Ah, nice! Them being at root-level and not having their own subdirectory (and hence not having their own I'll see what I can come up with 👍 |
|
Note that I can do the file reorganization after finishing reference semantics. In the meantime, this PR (which just splits |
kipawaa
left a comment
There was a problem hiding this comment.
This is so great, thanks for doing this!
Context:
The
jax_primitives.pyfile was getting too heavy.Description of the Change:
Split the lowering rules out into their own file,
primitive_lowering_rules.py.Remove all
NotImplementedError()on the primitives' unuseddef_impl.Slightly improve import structure.
Delete a unused function
catalyst.Pass.get_options().Benefits:
Better organization.