Skip to content

Integration of RAFT Optical Flow model to SG. #1945

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

Yael-Baron
Copy link
Contributor

No description provided.

def forward(self, x, **kwargs):
"""Estimate optical flow between pairs of frames"""

image1 = x[:, 0]
Copy link
Contributor

@BloodAxe BloodAxe Apr 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The way of passing input image pair as [B,2,C,H,W] should be reflected in docstring (Order of prev/curr or curr/prev as well)
  2. No kwargs

# run update block network
flow_predictions, flow_up = self.flow_iterative_block(coords0, coords1, net, inp, fmap1, fmap2)

if not self.training:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need this?
I suggest the model to have same output format regardless whether it is in training/eval mode.
If you want to omit some outputs when exporting model to onnx you can do if torch.jit.is_tracing() but I suggest not to make output of the model be state-dependent.

Models.RAFT_L,
]

# def test_infer_input_image_shape_from_model(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove irrelevant commented code pieces

import_pytorch_quantization_or_install()


class TestOpticalFlowModelExport(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A test case should be registered in the unit test runner

@BloodAxe
Copy link
Contributor

BloodAxe commented Apr 2, 2024

I have a few questions regarding this PR:

  • Are we going to include pretrained weights for RAFT? (Official or ours)?
  • What about dataset & training recipes & loss? Without it it sort of does not make any sense to include a model.

from .raft_base import Encoder, ContextEncoder, FlowIterativeBlock


class RAFT(ExportableOpticalFlowModel, SgModule):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please include a docstring with a link to original paper/implementation

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And add docstrings to the rest of the methods of this class.

"FP16 quantization is done by calling model.half() so you don't need to pass calibration_loader, as it will be ignored."
)

if engine in {ExportTargetBackend.ONNXRUNTIME, ExportTargetBackend.TENSORRT}:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we didnt check that TRT conversion works.
We should not add code that has not been tested.

# update the quantization_mode to INT8, so that we can correctly export the model.
quantization_mode = ExportQuantizationMode.INT8

from super_gradients.training.models.conversion import ConvertableCompletePipelineModel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the import here ?


self.flow_iterative_block = FlowIterativeBlock(encoder_params, encoder_params.update_block.hidden_dim, flow_params, corr_params.alternate_corr)

def freeze_bn(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the use of this ?

from .raft_base import Encoder, ContextEncoder, FlowIterativeBlock


class RAFT(ExportableOpticalFlowModel, SgModule):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And add docstrings to the rest of the methods of this class.

@Yael-Baron
Copy link
Contributor Author

@BloodAxe

  1. We're planing to add pre-trained weights for RAFT_L and RAFT_S.
  2. There will be a PR for dataset, transforms, loos+metric separately.

@Yael-Baron
Copy link
Contributor Author

See PR Feature/dle 599 optical flow integration #1984.

@Yael-Baron Yael-Baron closed this May 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants