-
Notifications
You must be signed in to change notification settings - Fork 583
Adding support for load_state_dict with assign=True for priors of Transformed distributions #2691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
gpytorch/priors/prior.py
Outdated
| tensor_value = ( | ||
| value if isinstance(value, torch.Tensor) else torch.as_tensor(value) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be a noop if already a tensor...
| tensor_value = ( | |
| value if isinstance(value, torch.Tensor) else torch.as_tensor(value) | |
| ) | |
| tensor_value = torch.as_tensor(value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
| # Update the base attribute in the base distribution | ||
| self.base_dist.__setattr__(base_attr_name, tensor_value) | ||
| # Update the transformed attribute as well | ||
| super().__setattr__(name, tensor_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't we have to save the untransformed value for the attribute of the base distribution here? Seems odd that we assign the same value to both base dist and dist...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Balandat _transformed_ is just a buffer copy of the base attribute, indicating that it comes from a torch.distributions.TransformedDistribution, not a transformed version. I could change the name to something like _buffered_ instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recall that the issue is that the base attributes (loc, scale etc.) are a @property of the base distribution, e.g. torch.distributions.Normal so we can't bufferize these on the LogNormalPrior. Thus, we need to bufferize an attribute containing the same info, that we can then use to set the loc and scale on the base distribution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Balandat Added a note on this in a new version:
# Note: "_transformed_" is just an indicator that this attribute belongs to a
# TransformedDistribution, the value itself is not transformed.
6ce060b to
5c174a1
Compare
| buffered_attrs = [attr for attr in dir(module) if buffered_str in attr] | ||
| for buffered_attr in buffered_attrs: | ||
| base_attr_name = buffered_attr.replace(buffered_str, "") | ||
| setattr(module.base_dist, base_attr_name, getattr(module, buffered_attr)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to traverse twice here
| buffered_attrs = [attr for attr in dir(module) if buffered_str in attr] | |
| for buffered_attr in buffered_attrs: | |
| base_attr_name = buffered_attr.replace(buffered_str, "") | |
| setattr(module.base_dist, base_attr_name, getattr(module, buffered_attr)) | |
| for attr in dir(module): | |
| if buffered_str in attr: | |
| base_attr_name = attr.replace(buffered_str, "") | |
| setattr(module.base_dist, base_attr_name, getattr(module, attr)) |
| # TransformedDistribution, NOT that the value itself is transformed. | ||
| # The _buffered_ buffer is simply a copy of the base_dist attribute, | ||
| # so we assign the same value to both. | ||
| if hasattr(self, name) and "_buffered_" in name: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's make "_buffered_" a constant so that there is a single source of truth
Required changes in GPytorch to unblock meta-pytorch/botorch#3080. When
load_state_dictis called withassign=True,setattris called on_transformedattributes of the prior at the pytorch level.This was not the intended use of the
_transformedattribute, but it seems like we have to enable its modification directly.