Skip to content

In SubModuleXD, shapes are duplicated within output_shape and _to_xd #81

@ghost

Description

For example, looking at SubModule2D:

    def output_shape(self, dim=None):
        if dim == 1:
            f, l = self._output_shape
            return (f * l,)
        ...
    ...
    def _to_1d(self, submodule_output):
        """
        :param submodule_output: torch.Tensor (Batch + 2D)
        :return: torch.Tensor (Batch + 1D)
        """
        n, f, l = submodule_output.size()
        return submodule_output.view(n, f * l)

The fact that the 2D -> 1D conversion goes from (F, S) -> (F * S) is indicated in two places within the SubModule2D class. The same is true in general for mD -> nD. It may be worth eliminating this duplication.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions