[Refactor][Norm] Extract RowNormOp base class, migrate RmsNorm, add BatchNorm validation#637
[Refactor][Norm] Extract RowNormOp base class, migrate RmsNorm, add BatchNorm validation#637lcy-seso wants to merge 6 commits intotile-ai:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors the normalization operations within the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and valuable refactoring by extracting a RowNormOp base class, which greatly simplifies the RmsNormOp implementation. It also adds input validation to BatchNormFwdOp and introduces torch.compile support for both forward and backward batch normalization operations through torch.library.custom_op. The changes are well-tested. My review focuses on improving the clarity and maintainability of the new torch.compile integration code.
There was a problem hiding this comment.
Code Review
This is a great refactoring that significantly improves code structure and maintainability by introducing the RowNormOp base class and migrating RmsNormOp. The addition of input validation and torch.compile support for BatchNorm is also a valuable enhancement. The new tests are comprehensive and well-written. I've left a couple of minor suggestions in batch_norm.py to clean up some of the new torch.compile registration code by removing unused variables and simplifying a condition.
There was a problem hiding this comment.
Pull request overview
This PR refactors row-wise normalization ops by introducing a shared RowNormOp base class, migrates RmsNormOp to use it, and improves BatchNorm by adding eager input validation plus torch.library.custom_op / register_fake plumbing for torch.compile support.
Changes:
- Add
RowNormOpbase class with shared validation/reshape/pad/trim helpers for dim=-1 norm ops. - Migrate
RmsNormOpto inherit fromRowNormOp, reducing duplicated boilerplate. - Add BatchNorm eager input validation and
torch.compileintegration viacustom_opwrappers + fakes, plus a new test suite.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
tileops/ops/norm/base.py |
Introduces RowNormOp base class with shared constructor + helper methods for row-wise norm ops. |
tileops/ops/norm/rms_norm.py |
Refactors RmsNormOp to inherit from RowNormOp and use its helpers. |
tileops/ops/norm/batch_norm.py |
Adds _validate_inputs and implements BatchNorm custom_op wrappers + fake implementations for compile support. |
tileops/ops/norm/__init__.py |
Exports RowNormOp from the norm package. |
tests/ops/test_row_norm_base.py |
Adds tests for the new base class, the RmsNorm migration, and BatchNorm validation/custom_op behavior. |
891aba4 to
6bab9ba
Compare
…rm validation + custom_op Create RowNormOp base class in tileops/ops/norm/base.py that encapsulates the shared validate/reshape/pad/kernel/trim/reshape pattern for row-wise normalization ops. Migrate RmsNormOp as the pilot (class body now ~11 lines). Add input validation (CUDA device, dtype, shape) to BatchNormFwdOp.forward() and register torch.library.custom_op for both BatchNormFwdOp and BatchNormBwdOp to enable torch.compile support. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…custom_op Remove dead _orig_fwd_forward and _orig_bwd_forward variables that were assigned but never read. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
_batch_norm_fwd_wrapped always returns a tensor for rstd (never None), so the `rstd is not None` check is unnecessary. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
6bab9ba to
4979ac2
Compare
Remove TestRowNormOpExists (import/attribute checks), TestRmsNormMigration (inheritance/line-count/signature checks), and custom_op attribute checks. These were PR delivery verification, not correctness guards. Retain BatchNorm validation tests (reject CPU/wrong dtype/wrong shape) and torch.compile smoke tests (fwd/bwd). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- P2: relax "3+ ops" threshold to "multiple ops with substantial shared boilerplate" — avoids premature abstraction threshold - P3: remove RmsNormOp hook example code (_init_kernel, _get_input_tensors, _call_kernel) — these hooks are unverified and specific to one family. Replace with neutral description of the declarative goal - P4: custom_op registration → "per-op module or shared utility" - P5: replace single kernel_cls with description of single-kernel vs multi-kernel dispatch patterns - Hierarchy: mark FusedAdd*/AdaLayerNorm* as candidates needing design, not confirmed RowNormOp children Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
180ccba to
1e51d44
Compare
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1e51d44 to
853ef05
Compare
Closes #628
Summary
RowNormOpbase class intileops/ops/norm/base.pythat encapsulates the shared validate/reshape/pad/kernel/trim/reshape pattern for row-wise normalization opsRmsNormOpto inherit fromRowNormOp, reducing its class body to ~15 lines while preserving the identical public APIBatchNormFwdOp.forward()torch.library.custom_opregistration andregister_fakeforBatchNormFwdOpandBatchNormBwdOptests/ops/test_row_norm_base.py(15 tests covering base class, migration, validation, custom_op)Test plan
tests/ops/test_rms_norm.py— 16 passed (existing tests, unchanged)tests/ops/test_row_norm_base.py— 15 passed (new: base class, migration, BatchNorm validation, custom_op smoke)@pytest.mark.smoke— 18 passedStructural Readiness
Regression
test_rms_norm.pytests pass without modification, confirming the refactor is behavior-preservingRmsNormOp(M, N, dtype)constructor andop(x, weight)call signature remain identical