Skip to content

[RFC] U-Net framework #6610

Open
Open
@TeodorPoncu

Description

@TeodorPoncu

🚀 The feature

A module-based approach of building U-Nets inside torchvision, similar to torchmultimodals sub-network approach. Mostly a food for though experiment given the similar nature of other popular vision Architectural Frameworks (namely DETR, Deformable-DETR, Mask2Former or Mask DINO which are just a composition of sub-networks that can be adapted to most vision tasks)

Motivation, pitch

U-Nets are a good example given the rising popularity of Diffusion Models in which the U-Net paradigm is used (layers or merge strategies being the main difference between most implementations).

Unlike DETR or Mask2Former, which can be broken done quite simply into a module with 2-4 sub-modules followed by a task head, the U-Net framework presents some more intricate challenges at configuration specification level given that we have to sync cross-level encoder and decoder configurations.

Shifting towards a more Framework / Block based approach for larger architectures (think of nn.Modules like experience but for present-time vision architectures) would be beneficial for users when it comes to sharing code, improvements or simply swapping out backbones or different components. For instance, if someone would want to grab a Mask2Former they would have to go and integrate themselves into Detectron2.

Similarly, if someone would want to jump in into doing diffusion, they would first have to find or make their own U-Net implementation even though what they most-likely want to do is to add a bottleneck with attention or some residual connection somewhere throughout the network or simply a different normalization layer in comparison to the original paper.

These classic paradigms should be easy to configure or specify (the same way torch.nn.Transformer handles the transformer), and if more severe changes are wanted a user can have access to a code-base which they can copy-paste and apply minimal modification to (similar to how DETR handles positional embeddings in the decoder).

Even if they opt for modifying to code-base and later on share their work, there is the bonus of familiarity when others might want to work on top of their code since it's not entirely different than the base version with which they are already familiar with.

Supporting some of these architectures or frame-works might attract users that are working on tasks that are currently not supported by torchvision (for instance Monocular Depth Estimation SotA makes use of Mask2Former) which might provide us with valuable insight about the needs of the larger vision community.

Alternatives

Currently Lucidrains has been leading these kinds of efforts for Attention Operations and Transformers and more recently for Diffusion models.

Additional context

No response

cc @datumbox

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions