Skip to content

Why do we use MaskedLinear for the condition #67

@yangysc

Description

@yangysc

Description

It seems that in zuko we just cat features and condition and then use a MaskedLinear to handle them together (please correct me if I missed something). What if we use MaskedLinear only for features and a plain Linear layer for handling the condition?

def meta(self, c: Tensor, x: Tensor) -> Transform:
if c is not None:
x = torch.cat(broadcast(x, c, ignore=1), dim=-1)
phi = self.hyper(x)
phi = phi.unflatten(-1, (-1, self.total))
phi = unpack(phi, self.shapes)
return DependentTransform(self.univariate(*phi), 1)
def forward(self, c: Tensor = None) -> Transform:
return AutoregressiveTransform(partial(self.meta, c), self.passes)

where the hyper net is

self.hyper = MaskedMLP(adjacency, **kwargs)

Implementation

The implementation would be like MaskedLinear(cat(features, condition)) -> cat(MaskedLinear(features), Linear(condition))

Thanks in advance

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions