-
Notifications
You must be signed in to change notification settings - Fork 543
Integration of RAFT Optical Flow model to SG. #1945
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
def forward(self, x, **kwargs): | ||
"""Estimate optical flow between pairs of frames""" | ||
|
||
image1 = x[:, 0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- The way of passing input image pair as [B,2,C,H,W] should be reflected in docstring (Order of prev/curr or curr/prev as well)
- No kwargs
# run update block network | ||
flow_predictions, flow_up = self.flow_iterative_block(coords0, coords1, net, inp, fmap1, fmap2) | ||
|
||
if not self.training: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we need this?
I suggest the model to have same output format regardless whether it is in training/eval mode.
If you want to omit some outputs when exporting model to onnx you can do if torch.jit.is_tracing()
but I suggest not to make output of the model be state-dependent.
Models.RAFT_L, | ||
] | ||
|
||
# def test_infer_input_image_shape_from_model(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove irrelevant commented code pieces
import_pytorch_quantization_or_install() | ||
|
||
|
||
class TestOpticalFlowModelExport(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A test case should be registered in the unit test runner
I have a few questions regarding this PR:
|
from .raft_base import Encoder, ContextEncoder, FlowIterativeBlock | ||
|
||
|
||
class RAFT(ExportableOpticalFlowModel, SgModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please include a docstring with a link to original paper/implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And add docstrings to the rest of the methods of this class.
"FP16 quantization is done by calling model.half() so you don't need to pass calibration_loader, as it will be ignored." | ||
) | ||
|
||
if engine in {ExportTargetBackend.ONNXRUNTIME, ExportTargetBackend.TENSORRT}: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we didnt check that TRT conversion works.
We should not add code that has not been tested.
# update the quantization_mode to INT8, so that we can correctly export the model. | ||
quantization_mode = ExportQuantizationMode.INT8 | ||
|
||
from super_gradients.training.models.conversion import ConvertableCompletePipelineModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the import here ?
|
||
self.flow_iterative_block = FlowIterativeBlock(encoder_params, encoder_params.update_block.hidden_dim, flow_params, corr_params.alternate_corr) | ||
|
||
def freeze_bn(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the use of this ?
from .raft_base import Encoder, ContextEncoder, FlowIterativeBlock | ||
|
||
|
||
class RAFT(ExportableOpticalFlowModel, SgModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And add docstrings to the rest of the methods of this class.
|
See PR Feature/dle 599 optical flow integration #1984. |
No description provided.