Skip to content

Conversation

@hvarfner
Copy link
Contributor

Required changes in GPytorch to unblock meta-pytorch/botorch#3080. When load_state_dict is called with assign=True, setattr is called on _transformed attributes of the prior at the pytorch level.

This was not the intended use of the _transformed attribute, but it seems like we have to enable its modification directly.

@hvarfner
Copy link
Contributor Author

@Balandat

@hvarfner
Copy link
Contributor Author

@SebastianAment

Comment on lines 49 to 51
tensor_value = (
value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be a noop if already a tensor...

Suggested change
tensor_value = (
value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
)
tensor_value = torch.as_tensor(value)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines 52 to 51
# Update the base attribute in the base distribution
self.base_dist.__setattr__(base_attr_name, tensor_value)
# Update the transformed attribute as well
super().__setattr__(name, tensor_value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't we have to save the untransformed value for the attribute of the base distribution here? Seems odd that we assign the same value to both base dist and dist...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Balandat _transformed_ is just a buffer copy of the base attribute, indicating that it comes from a torch.distributions.TransformedDistribution, not a transformed version. I could change the name to something like _buffered_ instead.

Copy link
Contributor Author

@hvarfner hvarfner Dec 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recall that the issue is that the base attributes (loc, scale etc.) are a @property of the base distribution, e.g. torch.distributions.Normal so we can't bufferize these on the LogNormalPrior. Thus, we need to bufferize an attribute containing the same info, that we can then use to set the loc and scale on the base distribution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Balandat Added a note on this in a new version:

# Note: "_transformed_" is just an indicator that this attribute belongs to a
        # TransformedDistribution, the value itself is not transformed.

Comment on lines +30 to +33
buffered_attrs = [attr for attr in dir(module) if buffered_str in attr]
for buffered_attr in buffered_attrs:
base_attr_name = buffered_attr.replace(buffered_str, "")
setattr(module.base_dist, base_attr_name, getattr(module, buffered_attr))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to traverse twice here

Suggested change
buffered_attrs = [attr for attr in dir(module) if buffered_str in attr]
for buffered_attr in buffered_attrs:
base_attr_name = buffered_attr.replace(buffered_str, "")
setattr(module.base_dist, base_attr_name, getattr(module, buffered_attr))
for attr in dir(module):
if buffered_str in attr:
base_attr_name = attr.replace(buffered_str, "")
setattr(module.base_dist, base_attr_name, getattr(module, attr))

# TransformedDistribution, NOT that the value itself is transformed.
# The _buffered_ buffer is simply a copy of the base_dist attribute,
# so we assign the same value to both.
if hasattr(self, name) and "_buffered_" in name:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make "_buffered_" a constant so that there is a single source of truth

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants