-
Notifications
You must be signed in to change notification settings - Fork 34
Closed
Labels
enhancementNew feature or requestNew feature or request
Description
Description
It seems that in zuko we just cat features and condition and then use a MaskedLinear to handle them together (please correct me if I missed something). What if we use MaskedLinear only for features and a plain Linear layer for handling the condition?
zuko/zuko/flows/autoregressive.py
Lines 207 to 218 in 25fefe2
| def meta(self, c: Tensor, x: Tensor) -> Transform: | |
| if c is not None: | |
| x = torch.cat(broadcast(x, c, ignore=1), dim=-1) | |
| phi = self.hyper(x) | |
| phi = phi.unflatten(-1, (-1, self.total)) | |
| phi = unpack(phi, self.shapes) | |
| return DependentTransform(self.univariate(*phi), 1) | |
| def forward(self, c: Tensor = None) -> Transform: | |
| return AutoregressiveTransform(partial(self.meta, c), self.passes) |
where the hyper net is
zuko/zuko/flows/autoregressive.py
Line 152 in 25fefe2
| self.hyper = MaskedMLP(adjacency, **kwargs) |
Implementation
The implementation would be like MaskedLinear(cat(features, condition)) -> cat(MaskedLinear(features), Linear(condition))
Thanks in advance
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request