Skip to content

[Refactor][Norm] Extract RowNormOp base class, migrate RmsNorm, add BatchNorm validation#637

Draft
lcy-seso wants to merge 6 commits intotile-ai:mainfrom
lcy-seso:refactor/norm/issue-628
Draft

[Refactor][Norm] Extract RowNormOp base class, migrate RmsNorm, add BatchNorm validation#637
lcy-seso wants to merge 6 commits intotile-ai:mainfrom
lcy-seso:refactor/norm/issue-628

Conversation

@lcy-seso
Copy link
Collaborator

Closes #628

Summary

  • Extract RowNormOp base class in tileops/ops/norm/base.py that encapsulates the shared validate/reshape/pad/kernel/trim/reshape pattern for row-wise normalization ops
  • Migrate RmsNormOp to inherit from RowNormOp, reducing its class body to ~15 lines while preserving the identical public API
  • Add input validation (CUDA device, dtype, shape) to BatchNormFwdOp.forward()
  • Add torch.library.custom_op registration and register_fake for BatchNormFwdOp and BatchNormBwdOp
  • Add comprehensive test suite in tests/ops/test_row_norm_base.py (15 tests covering base class, migration, validation, custom_op)

Test plan

  • pre-commit passed (all hooks green)
  • tests/ops/test_rms_norm.py — 16 passed (existing tests, unchanged)
  • tests/ops/test_row_norm_base.py — 15 passed (new: base class, migration, BatchNorm validation, custom_op smoke)
  • Smoke tests with @pytest.mark.smoke — 18 passed

Structural Readiness

  • [REQ] Op.forward validates input shape/numel before kernel launch
  • [REQ] Runtime validation uses ValueError/TypeError, never assert
  • [REQ] Op.forward owns pre/post-processing and delegates GPU work to self.kernel only
  • [REQ] @torch.library.custom_op + .register_fake for torch.compile (BatchNorm)
  • [REQ] Unit tests in tests/ops/ with reference or behavioral validation
  • [REC] init.py exports synchronized
  • SKIP: [REQ] Tests cover unsupported-dtype rejection — not added for RowNormOp/RmsNormOp in this refactor scope; BatchNorm dtype rejection tested via forward validation
  • SKIP: Benchmark items — this is a pure refactor with no new kernel; performance is unchanged

Regression

  • All 16 existing test_rms_norm.py tests pass without modification, confirming the refactor is behavior-preserving
  • RmsNormOp(M, N, dtype) constructor and op(x, weight) call signature remain identical

@lcy-seso lcy-seso added refactor Code restructuring without behavior change all-ai-powered Produced entirely by automated contributors labels Mar 23, 2026
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the normalization operations within the tileops library. It introduces a foundational RowNormOp base class to standardize the implementation of row-wise normalization layers, leading to cleaner and more maintainable code. Concurrently, it enhances the BatchNorm operations by adding crucial input validation and integrating them with torch.compile for improved performance and compatibility within the PyTorch ecosystem.

Highlights

  • RowNormOp Base Class Extraction: A new abstract base class, RowNormOp, was extracted to tileops/ops/norm/base.py. This class encapsulates the common validation, reshape, pad, kernel execution, trim, and reshape pattern shared by row-wise normalization operations like RmsNormOp and LayerNormOp.
  • RmsNormOp Migration: RmsNormOp was refactored to inherit from the new RowNormOp base class. This significantly reduced its class body to approximately 15 lines while preserving its public API and ensuring no behavioral regressions.
  • BatchNorm Input Validation: Input validation was added to BatchNormFwdOp.forward() to ensure inputs are on the correct CUDA device, have the expected data type, and match the required shape, improving robustness.
  • BatchNorm torch.library.custom_op Registration: torch.library.custom_op registration and register_fake implementations were added for both BatchNormFwdOp and BatchNormBwdOp, enabling better compatibility with torch.compile.
  • New Test Suite: A comprehensive test suite was introduced in tests/ops/test_row_norm_base.py to cover the RowNormOp base class functionality, the RmsNormOp migration, BatchNormFwdOp input validation, and the torch.library.custom_op smoke tests for BatchNorm.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@lcy-seso
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and valuable refactoring by extracting a RowNormOp base class, which greatly simplifies the RmsNormOp implementation. It also adds input validation to BatchNormFwdOp and introduces torch.compile support for both forward and backward batch normalization operations through torch.library.custom_op. The changes are well-tested. My review focuses on improving the clarity and maintainability of the new torch.compile integration code.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This is a great refactoring that significantly improves code structure and maintainability by introducing the RowNormOp base class and migrating RmsNormOp. The addition of input validation and torch.compile support for BatchNorm is also a valuable enhancement. The new tests are comprehensive and well-written. I've left a couple of minor suggestions in batch_norm.py to clean up some of the new torch.compile registration code by removing unused variables and simplifying a condition.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors row-wise normalization ops by introducing a shared RowNormOp base class, migrates RmsNormOp to use it, and improves BatchNorm by adding eager input validation plus torch.library.custom_op / register_fake plumbing for torch.compile support.

Changes:

  • Add RowNormOp base class with shared validation/reshape/pad/trim helpers for dim=-1 norm ops.
  • Migrate RmsNormOp to inherit from RowNormOp, reducing duplicated boilerplate.
  • Add BatchNorm eager input validation and torch.compile integration via custom_op wrappers + fakes, plus a new test suite.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tileops/ops/norm/base.py Introduces RowNormOp base class with shared constructor + helper methods for row-wise norm ops.
tileops/ops/norm/rms_norm.py Refactors RmsNormOp to inherit from RowNormOp and use its helpers.
tileops/ops/norm/batch_norm.py Adds _validate_inputs and implements BatchNorm custom_op wrappers + fake implementations for compile support.
tileops/ops/norm/__init__.py Exports RowNormOp from the norm package.
tests/ops/test_row_norm_base.py Adds tests for the new base class, the RmsNorm migration, and BatchNorm validation/custom_op behavior.

@lcy-seso lcy-seso marked this pull request as ready for review March 23, 2026 12:56
@lcy-seso lcy-seso requested a review from a team March 23, 2026 12:56
@lcy-seso lcy-seso marked this pull request as draft March 23, 2026 13:16
@lcy-seso lcy-seso force-pushed the refactor/norm/issue-628 branch from 891aba4 to 6bab9ba Compare March 24, 2026 07:35
lcy-seso and others added 3 commits March 24, 2026 16:56
…rm validation + custom_op

Create RowNormOp base class in tileops/ops/norm/base.py that encapsulates
the shared validate/reshape/pad/kernel/trim/reshape pattern for row-wise
normalization ops. Migrate RmsNormOp as the pilot (class body now ~11 lines).

Add input validation (CUDA device, dtype, shape) to BatchNormFwdOp.forward()
and register torch.library.custom_op for both BatchNormFwdOp and BatchNormBwdOp
to enable torch.compile support.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…custom_op

Remove dead _orig_fwd_forward and _orig_bwd_forward variables that were
assigned but never read.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
_batch_norm_fwd_wrapped always returns a tensor for rstd (never None),
so the `rstd is not None` check is unnecessary.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@lcy-seso lcy-seso force-pushed the refactor/norm/issue-628 branch from 6bab9ba to 4979ac2 Compare March 24, 2026 08:56
lcy-seso and others added 2 commits March 24, 2026 17:29
Remove TestRowNormOpExists (import/attribute checks), TestRmsNormMigration
(inheritance/line-count/signature checks), and custom_op attribute checks.
These were PR delivery verification, not correctness guards.

Retain BatchNorm validation tests (reject CPU/wrong dtype/wrong shape)
and torch.compile smoke tests (fwd/bwd).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- P2: relax "3+ ops" threshold to "multiple ops with substantial shared
  boilerplate" — avoids premature abstraction threshold
- P3: remove RmsNormOp hook example code (_init_kernel, _get_input_tensors,
  _call_kernel) — these hooks are unverified and specific to one family.
  Replace with neutral description of the declarative goal
- P4: custom_op registration → "per-op module or shared utility"
- P5: replace single kernel_cls with description of single-kernel vs
  multi-kernel dispatch patterns
- Hierarchy: mark FusedAdd*/AdaLayerNorm* as candidates needing design,
  not confirmed RowNormOp children

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@lcy-seso lcy-seso force-pushed the refactor/norm/issue-628 branch from 180ccba to 1e51d44 Compare March 24, 2026 10:35
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@lcy-seso lcy-seso force-pushed the refactor/norm/issue-628 branch from 1e51d44 to 853ef05 Compare March 24, 2026 10:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

all-ai-powered Produced entirely by automated contributors refactor Code restructuring without behavior change

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Refactor][Norm] Implement RowNormOp base class and migrate RmsNormOp

2 participants