Skip to content

Implement torchvision.ops.roi_align in torchxla2 #8288

Open
@qihqi

Description

🚀 Feature

https://pytorch.org/vision/stable/generated/torchvision.ops.roi_align.html?highlight=roi_align#torchvision.ops.roi_align

Few ideas:

  1. Use torch decomposition in here: https://github.com/pytorch/vision/blob/main/torchvision/ops/roi_align.py#L115 ; tried this and found out jax OOMs pointing here: https://github.com/pytorch/vision/blob/main/torchvision/ops/roi_align.py#L74 so the issue seems that the advanced indexing used here creates large intermediaries. Torch side needed a "loop-less" impl to help with inductor, we could actually rewrite it using jax.vmap and jax.lax.fori_loop.
  2. Start from this jax implementation: https://github.com/google-research/scenic/blob/74225e8e71ba27a76abd62e6bc56e8a64c4cc19e/scenic/projects/baselines/centernet/modeling/roi_align.py#L103 but this one takes output_size as int instead of tuple of int (i.e. it assumes width and height is the same) so it will need some modification.

Motivation

Pitch

Alternatives

Additional context

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions