-
Notifications
You must be signed in to change notification settings - Fork 44
enable pseudo-inverse behavior in eigendecomposition-based amortized computation #120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This pull request was exported from Phabricator. Differential Revision: D72204068 |
…computation (facebookresearch#120) Summary: Port changes from D71908578 to OSS Shampoo. Differential Revision: D72204068
27f7106
to
f34500f
Compare
This pull request was exported from Phabricator. Differential Revision: D72204068 |
f34500f
to
ac951a0
Compare
…computation (facebookresearch#120) Summary: Port changes from D71908578 to OSS Shampoo. Differential Revision: D72204068
This pull request was exported from Phabricator. Differential Revision: D72204068 |
matrix_functions_types.py
Outdated
|
||
|
||
@dataclass(kw_only=True) | ||
class RegularizationConfig(NonInvertibleHandlingConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer a different name since regularization is very general and might be confusing in the context of an optimizer (this config is unrelated to weight decay or l2 regularization).
Some suggestions: TikhonovRegularizationConfig
, DampingConfig
, or even EpsilonConfig
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We ended up deciding on PerturbationConfig
for now, but let me know if you have any further suggestions. (I'm not particularly happy with this name, but so far it's the best compromise.)
matrix_functions_types.py
Outdated
@@ -14,14 +14,60 @@ | |||
from commons import AbstractDataclass | |||
|
|||
|
|||
@dataclass(init=False) | |||
class NonInvertibleHandlingConfig(AbstractDataclass): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this name is fine, but one downside is that it focuses on invertibility while this config also applies to methods that don't directly require an inverse. However, since being non-invertible and rank deficient are equivalent properties for square matrices, the name is still technically correct.
matrix_functions.py
Outdated
# Only do it when add_epsilon_before_computation is True (root_inv_config must be a RegularizationConfig) | ||
if getattr(eigendecomposition_config, "add_epsilon_before_computation", False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it be eigendecomposition_config.noninvertible_handling_config
? I.e.:
# Only do it when add_epsilon_before_computation is True (root_inv_config must be a RegularizationConfig) | |
if getattr(eigendecomposition_config, "add_epsilon_before_computation", False): | |
# Only do it when add_epsilon_before_computation is True (eigendecomposition_config.noninvertible_handling_config must be a RegularizationConfig) | |
if getattr(eigendecomposition_config.noninvertible_handling_config, "add_epsilon_before_computation", False): |
If this fix is correct it means that before getattr
always returned False
and there is no test that catches this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I'll definitely need to improve testing, because this is exactly the kind of situation everyone is worried about...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use replace those getattr()
with operator.attrgetter()
if the original getattr()
is used mainly for preventing the linter complains.
matrix_functions.py
Outdated
# Only do it when add_epsilon_before_computation is True (root_inv_config must be a RegularizationConfig) | ||
if getattr(root_inv_config, "add_epsilon_before_computation", False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above:
# Only do it when add_epsilon_before_computation is True (root_inv_config must be a RegularizationConfig) | |
if getattr(root_inv_config, "add_epsilon_before_computation", False): | |
# Only do it when add_epsilon_before_computation is True (root_inv_config.noninvertible_handling_config must be a RegularizationConfig) | |
if getattr(root_inv_config.noninvertible_handling_config, "add_epsilon_before_computation", False): |
This pull request was exported from Phabricator. Differential Revision: D72204068 |
ac951a0
to
9b2644f
Compare
…computation (facebookresearch#120) Summary: Pull Request resolved: facebookresearch#120 Port changes from D71908578, D72674113, D72676653, D72685827 to OSS Shampoo with fbsource/fbcode/scripts/zong/oss_sync.sh. Differential Revision: D72204068
…computation (facebookresearch#120) Summary: Pull Request resolved: facebookresearch#120 Port changes from D71908578, D72674113, D72676653, D72685827 to OSS Shampoo with fbsource/fbcode/scripts/zong/oss_sync.sh. Reviewed By: tsunghsienlee Differential Revision: D72204068
This pull request was exported from Phabricator. Differential Revision: D72204068 |
9b2644f
to
b2be636
Compare
This pull request has been merged in 41138ab. |
Enable pseudo-inverse behavior in eigendecomposition-based computation with a unique config
PseudoInverseConfig
inheriting fromRankDeficientHandlingConfig
. (Pseudo-inverse implementation adapted from #11.) Refactorenhance_stability
flag as an argument forPerturbationConfig
, anotherRankDeficientHandlingConfig
that perturbs the matrix with a given epsilon, i.e. the default behavior.Note that the inverse root/exponent computation has been refactored out into
scale_and_pow_eigenvalues()
, which fixes a bug introduced inEigendecomposedShampooPreconditionerList
where the eigenvalues were not being adjusted by the minimum eigenvalue (if it was negative) before being scaled by epsilon.In the future, we intend to generalize the pseudo-inverse to all amortized computation methods, so
rank_deficient_handling_config
should eventually be moved underMatrixFunctionConfig
.Changes ported from D71908578 (internal) to OSS Shampoo.