Skip to content

enable pseudo-inverse behavior in eigendecomposition-based amortized computation #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

anana10c
Copy link
Contributor

@anana10c anana10c commented Mar 31, 2025

Enable pseudo-inverse behavior in eigendecomposition-based computation with a unique config PseudoInverseConfig inheriting from RankDeficientHandlingConfig. (Pseudo-inverse implementation adapted from #11.) Refactor enhance_stability flag as an argument for PerturbationConfig, another RankDeficientHandlingConfig that perturbs the matrix with a given epsilon, i.e. the default behavior.

Note that the inverse root/exponent computation has been refactored out into scale_and_pow_eigenvalues(), which fixes a bug introduced in EigendecomposedShampooPreconditionerList where the eigenvalues were not being adjusted by the minimum eigenvalue (if it was negative) before being scaled by epsilon.

In the future, we intend to generalize the pseudo-inverse to all amortized computation methods, so rank_deficient_handling_config should eventually be moved under MatrixFunctionConfig.

Changes ported from D71908578 (internal) to OSS Shampoo.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 31, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72204068

anana10c added a commit to anana10c/optimizers that referenced this pull request Apr 1, 2025
…computation (facebookresearch#120)

Summary:

Port changes from D71908578 to OSS Shampoo.

Differential Revision: D72204068
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72204068

anana10c added a commit to anana10c/optimizers that referenced this pull request Apr 1, 2025
…computation (facebookresearch#120)

Summary:

Port changes from D71908578 to OSS Shampoo.

Differential Revision: D72204068
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72204068



@dataclass(kw_only=True)
class RegularizationConfig(NonInvertibleHandlingConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer a different name since regularization is very general and might be confusing in the context of an optimizer (this config is unrelated to weight decay or l2 regularization).

Some suggestions: TikhonovRegularizationConfig, DampingConfig, or even EpsilonConfig.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We ended up deciding on PerturbationConfig for now, but let me know if you have any further suggestions. (I'm not particularly happy with this name, but so far it's the best compromise.)

@@ -14,14 +14,60 @@
from commons import AbstractDataclass


@dataclass(init=False)
class NonInvertibleHandlingConfig(AbstractDataclass):
Copy link
Contributor

@runame runame Apr 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this name is fine, but one downside is that it focuses on invertibility while this config also applies to methods that don't directly require an inverse. However, since being non-invertible and rank deficient are equivalent properties for square matrices, the name is still technically correct.

Comment on lines 218 to 219
# Only do it when add_epsilon_before_computation is True (root_inv_config must be a RegularizationConfig)
if getattr(eigendecomposition_config, "add_epsilon_before_computation", False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be eigendecomposition_config.noninvertible_handling_config? I.e.:

Suggested change
# Only do it when add_epsilon_before_computation is True (root_inv_config must be a RegularizationConfig)
if getattr(eigendecomposition_config, "add_epsilon_before_computation", False):
# Only do it when add_epsilon_before_computation is True (eigendecomposition_config.noninvertible_handling_config must be a RegularizationConfig)
if getattr(eigendecomposition_config.noninvertible_handling_config, "add_epsilon_before_computation", False):

If this fix is correct it means that before getattr always returned False and there is no test that catches this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I'll definitely need to improve testing, because this is exactly the kind of situation everyone is worried about...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use replace those getattr() with operator.attrgetter() if the original getattr() is used mainly for preventing the linter complains.

Comment on lines 407 to 408
# Only do it when add_epsilon_before_computation is True (root_inv_config must be a RegularizationConfig)
if getattr(root_inv_config, "add_epsilon_before_computation", False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above:

Suggested change
# Only do it when add_epsilon_before_computation is True (root_inv_config must be a RegularizationConfig)
if getattr(root_inv_config, "add_epsilon_before_computation", False):
# Only do it when add_epsilon_before_computation is True (root_inv_config.noninvertible_handling_config must be a RegularizationConfig)
if getattr(root_inv_config.noninvertible_handling_config, "add_epsilon_before_computation", False):

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72204068

anana10c added a commit to anana10c/optimizers that referenced this pull request Apr 28, 2025
…computation (facebookresearch#120)

Summary:
Pull Request resolved: facebookresearch#120

Port changes from D71908578, D72674113, D72676653, D72685827 to OSS Shampoo with fbsource/fbcode/scripts/zong/oss_sync.sh.

Differential Revision: D72204068
@anana10c
Copy link
Contributor Author

Proposed generalization of RankDeficientStabilityConfig:
Screenshot 2025-04-28 at 4 23 39 PM

…computation (facebookresearch#120)

Summary:
Pull Request resolved: facebookresearch#120

Port changes from D71908578, D72674113, D72676653, D72685827 to OSS Shampoo with fbsource/fbcode/scripts/zong/oss_sync.sh.

Reviewed By: tsunghsienlee

Differential Revision: D72204068
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72204068

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 41138ab.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants