diff --git a/requirements.txt b/requirements.txt index b67d6f022c..5eb6424745 100755 --- a/requirements.txt +++ b/requirements.txt @@ -35,3 +35,4 @@ albumentations~=1.3 fonttools>=4.43.0 # not directly required, pinned by Snyk to avoid a vulnerability werkzeug>=2.3.8 +imagesize~=1.4.1 diff --git a/src/super_gradients/common/object_names.py b/src/super_gradients/common/object_names.py index db78546bad..c4d90c056b 100644 --- a/src/super_gradients/common/object_names.py +++ b/src/super_gradients/common/object_names.py @@ -17,6 +17,7 @@ class Losses: DEKR_LOSS = "DEKRLoss" RESCORING_LOSS = "RescoringLoss" YOLONAS_POSE_LOSS = "YoloNASPoseLoss" + OPTICAL_FLOW_LOSS = "OpticalFlowLoss" class Metrics: @@ -44,6 +45,7 @@ class Metrics: DepthMSE = "DepthMSE" DepthRMSE = "DepthRMSE" DepthMSLE = "DepthMSLE" + EPE = "EPE" class Transforms: @@ -82,6 +84,14 @@ class Transforms: DetectionNormalize = "DetectionNormalize" DetectionPadIfNeeded = "DetectionPadIfNeeded" DetectionLongestMaxSize = "DetectionLongestMaxSize" + # Optical flow transforms + OpticalFlowColorJitter = "OpticalFlowColorJitter" + OpticalFlowOcclusion = "OpticalFlowOcclusion" + OpticalFlowRandomRescale = "OpticalFlowRandomRescale" + OpticalFlowRandomFlip = "OpticalFlowRandomFlip" + OpticalFlowCrop = "OpticalFlowCrop" + OpticalFlowInputPadder = "OpticalFlowInputPadder" + OpticalFlowNormalize = "OpticalFlowNormalize" # RandomResizedCropAndInterpolation = "RandomResizedCropAndInterpolation" RandAugmentTransform = "RandAugmentTransform" @@ -338,6 +348,9 @@ class Models: YOLO_NAS_POSE_M = "yolo_nas_pose_m" YOLO_NAS_POSE_L = "yolo_nas_pose_l" + RAFT_S = "raft_s" + RAFT_L = "raft_l" + class ConcatenatedTensorFormats: XYXY_LABEL = "XYXY_LABEL" @@ -417,6 +430,8 @@ class Dataloaders: PASCAL_VOC_DETECTION_VAL = "pascal_voc_detection_val" ROBOFLOW_TRAIN_BASE = "roboflow_train_yolox" ROBOFLOW_VAL_BASE = "roboflow_val_yolox" + KITTI2015_OPTICAL_FLOW_TRAIN = "kitti2015_optical_flow_train" + KITTI2015_OPTICAL_FLOW_VAL = "kitti2015_optical_flow_val" class Datasets: @@ -438,6 +453,7 @@ class Datasets: COCO_KEY_POINTS_DATASET = "COCOKeypointsDataset" COCO_POSE_ESTIMATION_DATASET = "COCOPoseEstimationDataset" NYUV2_DEPTH_ESTIMATION_DATASET = "NYUv2DepthEstimationDataset" + KITTI_OPTICAL_FLOW_DATASET = "KITTIOpticalFlowDataset" class Processings: diff --git a/src/super_gradients/module_interfaces/__init__.py b/src/super_gradients/module_interfaces/__init__.py index f9871c3825..9838aab106 100644 --- a/src/super_gradients/module_interfaces/__init__.py +++ b/src/super_gradients/module_interfaces/__init__.py @@ -12,6 +12,7 @@ SemanticSegmentationDecodingModule, BinarySegmentationDecodingModule, ) +from .exportable_optical_flow import ExportableOpticalFlowModel, OpticalFlowModelExportResult __all__ = [ "HasPredict", @@ -35,4 +36,6 @@ "AbstractSegmentationDecodingModule", "SemanticSegmentationDecodingModule", "BinarySegmentationDecodingModule", + "ExportableOpticalFlowModel", + "OpticalFlowModelExportResult", ] diff --git a/src/super_gradients/module_interfaces/exportable_optical_flow.py b/src/super_gradients/module_interfaces/exportable_optical_flow.py new file mode 100644 index 0000000000..79b5c0ec0f --- /dev/null +++ b/src/super_gradients/module_interfaces/exportable_optical_flow.py @@ -0,0 +1,293 @@ +import copy +import dataclasses +import gc +from typing import Union, Optional, List, Tuple + +import numpy as np +import onnx +import onnxsim +import torch +from super_gradients.common.abstractions.abstract_logger import get_logger +from super_gradients.conversion import ExportTargetBackend, ExportQuantizationMode +from super_gradients.conversion.conversion_utils import find_compatible_model_device_for_dtype +from super_gradients.conversion.gs_utils import import_onnx_graphsurgeon_or_install +from super_gradients.import_utils import import_pytorch_quantization_or_install +from super_gradients.module_interfaces.supports_input_shape_check import SupportsInputShapeCheck +from super_gradients.training.utils.export_utils import ( + infer_image_shape_from_model, + infer_image_input_channels, +) +from super_gradients.training.utils.utils import infer_model_device, check_model_contains_quantized_modules +from super_gradients.conversion.onnx.export_to_onnx import export_to_onnx +from torch import nn +from torch.utils.data import DataLoader + +logger = get_logger(__name__) + +__all__ = ["ExportableOpticalFlowModel", "OpticalFlowModelExportResult"] + + +@dataclasses.dataclass +class OpticalFlowModelExportResult: + """ + A dataclass that holds the result of model export. + """ + + input_image_channels: int + input_image_dtype: torch.dtype + input_image_shape: Tuple[int, int] + + engine: ExportTargetBackend + quantization_mode: Optional[ExportQuantizationMode] + + output: str + + usage_instructions: str = "" + + def __repr__(self): + return self.usage_instructions + + +class ExportableOpticalFlowModel: + """ + A mixin class that adds export functionality to the optical flow models. + Classes that inherit from this mixin must implement the following methods: + - get_decoding_module() + - get_preprocessing_callback() + Providing these methods are implemented correctly, the model can be exported to ONNX or TensorRT formats + using model.export(...) method. + """ + + def export( + self, + output: str, + quantization_mode: Optional[ExportQuantizationMode] = None, + selective_quantizer: Optional["SelectiveQuantizer"] = None, # noqa + calibration_loader: Optional[DataLoader] = None, + calibration_method: str = "percentile", + calibration_batches: int = 16, + calibration_percentile: float = 99.99, + batch_size: int = 1, + input_image_shape: Optional[Tuple[int, int]] = None, + input_image_channels: Optional[int] = None, + input_image_dtype: Optional[torch.dtype] = None, + onnx_export_kwargs: Optional[dict] = None, + onnx_simplify: bool = True, + device: Optional[Union[torch.device, str]] = None, + ): + """ + Export the model to one of supported formats. Format is inferred from the output file extension or can be + explicitly specified via `format` argument. + + :param output: Output file name of the exported model. + :param quantization_mode: (QuantizationMode) Sets the quantization mode for the exported model. + If None, the model is exported as-is without any changes to mode weights. + If QuantizationMode.FP16, the model is exported with weights converted to half precision. + If QuantizationMode.INT8, the model is exported with weights quantized to INT8 (Using PTQ). + For this mode you can use calibration_loader to specify a data loader for calibrating the model. + :param selective_quantizer: (SelectiveQuantizer) An optional quantizer for selectively quantizing model weights. + :param calibration_loader: (torch.utils.data.DataLoader) An optional data loader for calibrating a quantized model. + :param calibration_method: (str) Calibration method for quantized models. See QuantizationCalibrator for details. + :param calibration_batches: (int) Number of batches to use for calibration. Default is 16. + :param calibration_percentile: (float) Percentile for percentile calibration method. Default is 99.99. + :param batch_size: (int) Batch size for the exported model. + :param input_image_shape: (tuple) Input image shape (height, width) for the exported model. + If None, the function will infer the image shape from the model's preprocessing params. + :param input_image_channels: (int) Number of input image channels for the exported model. + If None, the function will infer the number of channels from the model itself + (No implemented now, will use hard-coded value of 3 for now). + :param input_image_dtype: (torch.dtype) Type of the input image for the exported model. + If None, the function will infer the dtype from the model's preprocessing and other parameters. + If preprocessing is True, dtype will default to torch.uint8. + If preprocessing is False and requested quantization mode is FP16 a torch.float16 will be used, + otherwise a default torch.float32 dtype will be used. + :param device: (torch.device) Device to use for exporting the model. If not specified, the device is inferred from the model itself. + :param onnx_export_kwargs: (dict) Optional keyword arguments for torch.onnx.export() function. + :param onnx_simplify: (bool) If True, apply onnx-simplifier to the exported model. + :return: + """ + + # Do imports here to avoid raising error of missing onnx_graphsurgeon package if it is not needed. + import_onnx_graphsurgeon_or_install() + if ExportQuantizationMode.INT8 == quantization_mode: + import_pytorch_quantization_or_install() + from super_gradients.conversion.conversion_utils import torch_dtype_to_numpy_dtype + + usage_instructions = [] + + # Hard-coded for now + # Will be made a parameter if we decide to support CoreML/OpenVino/TRT export in the future + engine = ExportTargetBackend.ONNXRUNTIME + + if not isinstance(self, nn.Module): + raise TypeError(f"Export is only supported for torch.nn.Module. Got type {type(self)}") + + device: torch.device = device or infer_model_device(self) + if device is None: + raise ValueError( + "Device is not specified and cannot be inferred from the model. " + "Please specify the device explicitly: model.export(..., device=torch.device(...))" + ) + + # The following is a trick to infer the exact device index in order to make sure the model using right device. + # User may pass device="cuda", which is not explicitly specifying device index. + # Using this trick, we can infer the correct device (cuda:3 for instance) and use it later for checking + # whether model places all it's parameters on the right device. + device = torch.zeros(1).to(device).device + + logger.debug(f"Using device: {device} for exporting model {self.__class__.__name__}") + + model: nn.Module = copy.deepcopy(self).eval() + + # Infer the input image shape from the model + if input_image_shape is None: + input_image_shape = infer_image_shape_from_model(model) + logger.debug(f"Inferred input image shape: {input_image_shape} from model {model.__class__.__name__}") + + if input_image_shape is None: + raise ValueError( + "Image shape is not specified and cannot be inferred from the model. " + "Please specify the image shape explicitly: model.export(..., input_image_shape=(height, width))" + ) + + try: + rows, cols = input_image_shape + except ValueError: + raise ValueError(f"Image shape must be a tuple of two integers (height, width), got {input_image_shape} instead") + + # Infer the number of input channels from the model + if input_image_channels is None: + input_image_channels = infer_image_input_channels(model) + logger.debug(f"Inferred input image channels: {input_image_channels} from model {model.__class__.__name__}") + + if input_image_channels is None: + raise ValueError( + "Number of input channels is not specified and cannot be inferred from the model. " + "Please specify the number of input channels explicitly: model.export(..., input_image_channels=NUM_CHANNELS_YOUR_MODEL_TAKES)" + ) + + input_shape = (batch_size, 2, input_image_channels, rows, cols) + + if isinstance(model, SupportsInputShapeCheck): + model.validate_input_shape(input_shape) + + prep_model_for_conversion_kwargs = { + "input_size": input_shape, + } + + model_type = torch.half if quantization_mode == ExportQuantizationMode.FP16 else torch.float32 + device = find_compatible_model_device_for_dtype(device, model_type) + + # This variable holds the output names of the model. + # If postprocessing is enabled, it will be set to the output names of the postprocessing module. + output_names: Optional[List[str]] = None + + if hasattr(model, "prep_model_for_conversion"): + model.prep_model_for_conversion(**prep_model_for_conversion_kwargs) + + contains_quantized_modules = check_model_contains_quantized_modules(model) + + if quantization_mode == ExportQuantizationMode.INT8: + from super_gradients.training.utils.quantization import ptq + + model = ptq( + model, + selective_quantizer=selective_quantizer, + calibration_loader=calibration_loader, + calibration_method=calibration_method, + calibration_batches=calibration_batches, + calibration_percentile=calibration_percentile, + ) + elif quantization_mode == ExportQuantizationMode.FP16: + if contains_quantized_modules: + raise RuntimeError("Model contains quantized modules for INT8 mode. " "FP16 quantization is not supported for such models.") + elif quantization_mode is None and contains_quantized_modules: + # If quantization_mode is None, but we have quantized modules in the model, we need to + # update the quantization_mode to INT8, so that we can correctly export the model. + quantization_mode = ExportQuantizationMode.INT8 + + from super_gradients.training.models.conversion import ConvertableCompletePipelineModel + + # The model.prep_model_for_conversion will be called inside ConvertableCompletePipelineModel once more, + # but as long as implementation of prep_model_for_conversion is idempotent, it should be fine. + complete_model = ( + ConvertableCompletePipelineModel(model=model, pre_process=None, post_process=None, **prep_model_for_conversion_kwargs).to(device).eval() + ) + + if quantization_mode == ExportQuantizationMode.FP16: + # For FP16 quantization, we simply can to convert the whole model to half precision + complete_model = complete_model.half() + + if calibration_loader is not None: + logger.warning( + "It seems you've passed calibration_loader to export function, but quantization_mode is set to FP16. " + "FP16 quantization is done by calling model.half() so you don't need to pass calibration_loader, as it will be ignored." + ) + + if engine in {ExportTargetBackend.ONNXRUNTIME}: + + onnx_export_kwargs = onnx_export_kwargs or {} + onnx_input = torch.randn(input_shape).to(device=device, dtype=input_image_dtype) + + export_to_onnx( + model=complete_model, + model_input=onnx_input, + onnx_filename=output, + input_names=["input"], + output_names=output_names, + onnx_opset=onnx_export_kwargs.get("opset_version", None), + do_constant_folding=onnx_export_kwargs.get("do_constant_folding", True), + dynamic_axes=onnx_export_kwargs.get("dynamic_axes", None), + keep_initializers_as_inputs=onnx_export_kwargs.get("keep_initializers_as_inputs", False), + verbose=onnx_export_kwargs.get("verbose", False), + ) + + if onnx_simplify: + model_opt, simplify_successful = onnxsim.simplify(output) + if not simplify_successful: + raise RuntimeError(f"Failed to simplify ONNX model {output} with onnxsim. Please check the logs for details.") + onnx.save(model_opt, output) + + logger.debug(f"Ran onnxsim.simplify on {output}") + else: + raise ValueError(f"Unsupported export format: {engine}. Supported formats: onnxruntime, tensorrt") + + # Cleanup memory, not sure whether it is necessary but just in case + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Add usage instructions + usage_instructions.append(f"Model exported successfully to {output}") + usage_instructions.append( + f"Model expects input image of shape [{batch_size}, {2}, {input_image_channels}, {input_image_shape[0]}, {input_image_shape[1]}]" + ) + usage_instructions.append(f"Input image dtype is {input_image_dtype}") + + usage_instructions.append("Exported model is in ONNX format and can be used with ONNXRuntime") + usage_instructions.append("To run inference with ONNXRuntime, please use the following code snippet:") + usage_instructions.append("") + usage_instructions.append(" import onnxruntime") + usage_instructions.append(" import numpy as np") + + usage_instructions.append(f' session = onnxruntime.InferenceSession("{output}", providers=["CUDAExecutionProvider", "CPUExecutionProvider"])') + usage_instructions.append(" inputs = [o.name for o in session.get_inputs()]") + usage_instructions.append(" outputs = [o.name for o in session.get_outputs()]") + + dtype_name = np.dtype(torch_dtype_to_numpy_dtype(input_image_dtype)).name + usage_instructions.append( + f" example_input_batch = np.zeros(({batch_size}, {2}, {input_image_channels}, {input_image_shape[0]}, {input_image_shape[1]})).astype(np.{dtype_name})" # noqa + ) + + usage_instructions.append(" flow_prediction = session.run(outputs, {inputs[0]: example_input_batch})") + usage_instructions.append("") + + return OpticalFlowModelExportResult( + input_image_channels=input_image_channels, + input_image_dtype=input_image_dtype, + input_image_shape=input_image_shape, + engine=engine, + quantization_mode=quantization_mode, + output=output, + usage_instructions="\n".join(usage_instructions), + ) diff --git a/src/super_gradients/recipes/arch_params/raft_l_arch_params.yaml b/src/super_gradients/recipes/arch_params/raft_l_arch_params.yaml new file mode 100644 index 0000000000..17f20ac8ac --- /dev/null +++ b/src/super_gradients/recipes/arch_params/raft_l_arch_params.yaml @@ -0,0 +1,42 @@ +in_channels: 3 # the number of in_channels to fnet and cnet. 1 for greyscale image. + +encoder_params: + in_planes: 64 + hidden_dim: 128 + context_dim: 128 + corr_levels: 4 + corr_radius: 4 + dropout: 0 + fnet: # Feature encoder + block: ResidualBlock + output_dim: 256 + norm_fn: 'instance' # 'instance' - for instance normalization + cnet: # Context encoder + block: ResidualBlock + norm_fn: 'batch' # 'batch' - for batch normalization + output_dim: 256 # context_dim + hidden_dim + update_block: + hidden_dim: ${..hidden_dim} + use_mask: True + motion_encoder: + num_corr_conv: 2 # limited to max 2 conv layers + convc1_output_dim: 256 + convc2_output_dim: 192 + convf1_output_dim: 128 + convf2_output_dim: 64 + conv_output_dim: 126 + gru: + block: SepConvGRU + hidden_dim: ${...hidden_dim} + input_dim: 256 + flow_head: + hidden_dim: 256 + input_dim: ${arch_params.encoder_params.update_block.gru.hidden_dim} + +corr_params: + alternate_corr: False + +flow_params: + training_iters: 12 # the number of iterations during training the optimization loop will run to refine the optical flow predictions over multiple iterations + validation_iters: 24 # the number of iterations during validating + upsample_mode: convex # if none, then using a predefined upsample function (upflow8) OR convex with large model only diff --git a/src/super_gradients/recipes/arch_params/raft_s_arch_params.yaml b/src/super_gradients/recipes/arch_params/raft_s_arch_params.yaml new file mode 100644 index 0000000000..95e5b5cfe8 --- /dev/null +++ b/src/super_gradients/recipes/arch_params/raft_s_arch_params.yaml @@ -0,0 +1,42 @@ +in_channels: 3 # the number of in_channels to fnet and cnet. 1 for greyscale image. + +encoder_params: + in_planes: 32 + hidden_dim: 96 + context_dim: 64 + corr_levels: 4 + corr_radius: 3 + dropout: 0 + fnet: # Feature encoder + block: BottleneckBlock + output_dim: 128 + norm_fn: 'instance' # 'instance' - for instance normalization + cnet: # Context encoder + block: BottleneckBlock + norm_fn: 'none' # 'batch' - for batch normalization + output_dim: 160 # context_dim + hidden_dim + update_block: + hidden_dim: ${..hidden_dim} + use_mask: False + motion_encoder: + num_corr_conv: 1 # limited to max 2 conv layers + convc1_output_dim: 96 + convc2_output_dim: # convc2 layer is set only for raft_l + convf1_output_dim: 64 + convf2_output_dim: 32 + conv_output_dim: 80 + gru: + block: ConvGRU + hidden_dim: ${...hidden_dim} + input_dim: 146 + flow_head: + hidden_dim: 128 + input_dim: ${arch_params.encoder_params.update_block.gru.hidden_dim} + +corr_params: + alternate_corr: False + +flow_params: + training_iters: 12 # the number of iterations during training the optimization loop will run to refine the optical flow predictions over multiple iterations + validation_iters: 24 # the number of iterations during validating + upsample_mode: # if none, then using a predefined upsample function (upflow8) OR convex with large model only diff --git a/src/super_gradients/recipes/dataset_params/kitti_optical_flow_dataset_params.yaml b/src/super_gradients/recipes/dataset_params/kitti_optical_flow_dataset_params.yaml new file mode 100644 index 0000000000..4ea3a3f1d7 --- /dev/null +++ b/src/super_gradients/recipes/dataset_params/kitti_optical_flow_dataset_params.yaml @@ -0,0 +1,47 @@ +num_workers: 8 +image_size: [288, 960] +batch_size: 3 + +train_dataset_params: + root: 'https://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow' + transforms: + - OpticalFlowColorJitter: + brightness: 0.3 + contrast: 0.3 + saturation: 0.3 + hue: 0.096 + prob: 0.2 + - OpticalFlowOcclusion: + bounds: [50, 100] + - OpticalFlowRandomRescale + - OpticalFlowRandomFlip + - OpticalFlowCrop: + mode: random + crop_size: ${.....image_size} + - OpticalFlowNormalize + + +val_dataset_params: + root: 'https://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow' + transforms: + - OpticalFlowCrop: + crop_size: [376, 1248] + - OpticalFlowNormalize + +train_dataloader_params: + dataset: KITTIOpticalFlowDataset + shuffle: True + batch_size: ${..batch_size} + num_workers: ${..num_workers} + drop_last: True + collate_fn: OpticalFlowCollateFN + +val_dataloader_params: + dataset: KITTIOpticalFlowDataset + shuffle: False + batch_size: ${..batch_size} + num_workers: ${..num_workers} + drop_last: False + collate_fn: OpticalFlowCollateFN + +_convert_: all diff --git a/src/super_gradients/recipes/kitti_raft_l.yaml b/src/super_gradients/recipes/kitti_raft_l.yaml new file mode 100644 index 0000000000..0495e119c5 --- /dev/null +++ b/src/super_gradients/recipes/kitti_raft_l.yaml @@ -0,0 +1,62 @@ +defaults: + - training_hyperparams: default_train_params + - checkpoint_params: default_checkpoint_params + - dataset_params: kitti_optical_flow_dataset_params + - arch_params: raft_l_arch_params + - _self_ + +architecture: RAFT_L +num_classes: 1 + +multi_gpu: DDP +num_gpus: 2 + +dataset_params: + batch_size: 3 + +experiment_name: ${architecture} +ckpt_root_dir: ./checkpoints +resume: False + +arch_params: + num_classes: ${num_classes} + +checkpoint_params: + pretrained_weights: things + +training_hyperparams: + resume: ${resume} + + max_epochs: 1000 + + initial_lr: 1e-4 + lr_mode: cosine + cosine_final_lr_ratio: 0.1 + + optimizer: AdamW + optimizer_params: + weight_decay: 0.00001 + + loss: + OpticalFlowLoss: + gamma: 0.85 + max_flow: 400 + + train_metrics_list: + - EPE: + max_flow: 400 + + valid_metrics_list: + - EPE + + metric_to_watch: epe + greater_metric_to_watch_is_better: False + + +hydra: + searchpath: + - pkg://super_gradients.recipes + run: + dir: ${hydra_output_dir:${ckpt_root_dir}, ${experiment_name}} + job: + chdir: False diff --git a/src/super_gradients/recipes/kitti_raft_s.yaml b/src/super_gradients/recipes/kitti_raft_s.yaml new file mode 100644 index 0000000000..8cbf0dc6dc --- /dev/null +++ b/src/super_gradients/recipes/kitti_raft_s.yaml @@ -0,0 +1,63 @@ +defaults: + - training_hyperparams: default_train_params + - checkpoint_params: default_checkpoint_params + - dataset_params: kitti_optical_flow_dataset_params + - arch_params: raft_s_arch_params + - _self_ + +architecture: RAFT_S +num_classes: 1 + +multi_gpu: DDP +num_gpus: 2 + +dataset_params: + batch_size: 3 + +experiment_name: ${architecture} +ckpt_root_dir: ./checkpoints +resume: False + +arch_params: + num_classes: ${num_classes} + +checkpoint_params: + pretrained_weights: things + +training_hyperparams: + resume: ${resume} + + max_epochs: 1000 + + initial_lr: 1e-4 + lr_mode: cosine + cosine_final_lr_ratio: 0.1 + + optimizer: AdamW + optimizer_params: + weight_decay: 0.00001 + + loss: + OpticalFlowLoss: + gamma: 0.85 + max_flow: 400 + + + train_metrics_list: + - EPE: + max_flow: 400 + + valid_metrics_list: + - EPE + + metric_to_watch: epe + greater_metric_to_watch_is_better: False + + +hydra: + searchpath: + - pkg://super_gradients.recipes + run: + dir: ${hydra_output_dir:${ckpt_root_dir}, ${experiment_name}} + job: + chdir: False diff --git a/src/super_gradients/training/dataloaders/dataloaders.py b/src/super_gradients/training/dataloaders/dataloaders.py index 7b701b7830..532deb7dc6 100644 --- a/src/super_gradients/training/dataloaders/dataloaders.py +++ b/src/super_gradients/training/dataloaders/dataloaders.py @@ -14,7 +14,7 @@ from super_gradients.common.factories.datasets_factory import DatasetsFactory from super_gradients.common.factories.samplers_factory import SamplersFactory from super_gradients.common.object_names import Dataloaders -from super_gradients.training.datasets import ImageNetDataset +from super_gradients.training.datasets import ImageNetDataset, KITTIOpticalFlowDataset from super_gradients.training.datasets.classification_datasets.cifar import ( Cifar10, Cifar100, @@ -886,6 +886,28 @@ def coco2017_rescoring_val(dataset_params: Dict = None, dataloader_params: Dict ) +@register_dataloader(Dataloaders.KITTI2015_OPTICAL_FLOW_TRAIN) +def kitti2015_optical_flow_train(dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader: + return get_data_loader( + config_name="kitti_optical_flow_dataset_params", + dataset_cls=KITTIOpticalFlowDataset, + train=True, + dataset_params=dataset_params, + dataloader_params=dataloader_params, + ) + + +@register_dataloader(Dataloaders.KITTI2015_OPTICAL_FLOW_VAL) +def kitti2015_optical_flow_val(dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader: + return get_data_loader( + config_name="kitti_optical_flow_dataset_params", + dataset_cls=KITTIOpticalFlowDataset, + train=False, + dataset_params=dataset_params, + dataloader_params=dataloader_params, + ) + + def get(name: str = None, dataset_params: Dict = None, dataloader_params: Dict = None, dataset: torch.utils.data.Dataset = None) -> DataLoader: """ Get DataLoader of the recipe-configured dataset defined by name in ALL_DATALOADERS. diff --git a/src/super_gradients/training/datasets/__init__.py b/src/super_gradients/training/datasets/__init__.py index 754fd09a79..d1e44f2c63 100755 --- a/src/super_gradients/training/datasets/__init__.py +++ b/src/super_gradients/training/datasets/__init__.py @@ -25,6 +25,7 @@ BaseKeypointsDataset, COCOPoseEstimationDataset, ) +from super_gradients.training.datasets.optical_flow_datasets.kitti_dataset import KITTIOpticalFlowDataset __all__ = [ @@ -50,6 +51,7 @@ "SuperviselyPersonsDataset", "COCOKeypointsDataset", "COCOPoseEstimationDataset", + "KITTIOpticalFlowDataset", ] cv2.setNumThreads(0) diff --git a/src/super_gradients/training/datasets/optical_flow_datasets/__init__.py b/src/super_gradients/training/datasets/optical_flow_datasets/__init__.py new file mode 100644 index 0000000000..30ef8faf9b --- /dev/null +++ b/src/super_gradients/training/datasets/optical_flow_datasets/__init__.py @@ -0,0 +1,4 @@ +from super_gradients.training.datasets.optical_flow_datasets.abstract_optical_flow_dataset import AbstractOpticalFlowDataset +from super_gradients.training.datasets.optical_flow_datasets.kitti_dataset import KITTIOpticalFlowDataset + +__all__ = ["AbstractOpticalFlowDataset", "KITTIOpticalFlowDataset"] diff --git a/src/super_gradients/training/datasets/optical_flow_datasets/abstract_optical_flow_dataset.py b/src/super_gradients/training/datasets/optical_flow_datasets/abstract_optical_flow_dataset.py new file mode 100644 index 0000000000..6d8f97b21d --- /dev/null +++ b/src/super_gradients/training/datasets/optical_flow_datasets/abstract_optical_flow_dataset.py @@ -0,0 +1,124 @@ +import abc +from typing import List, Tuple + +import random + +import numpy as np +from data_gradients.common.decorators import resolve_param +from matplotlib import pyplot as plt +from torch.utils.data.dataloader import Dataset + +from super_gradients.common.factories.list_factory import ListFactory +from super_gradients.common.factories.transforms_factory import TransformsFactory +from super_gradients.training.samples import OpticalFlowSample +from super_gradients.training.transforms.optical_flow import AbstractOpticalFlowTransform +from super_gradients.training.utils.visualization.optical_flow import FlowVisualization + + +class AbstractOpticalFlowDataset(Dataset): + """ + Abstract class for datasets for optical flow task. + + Attempting to follow principles provided in pose_etimation_dataset. + """ + + @resolve_param("transforms", ListFactory(TransformsFactory())) + def __init__(self, transforms: List[AbstractOpticalFlowTransform] = None): + super().__init__() + self.transforms = transforms or [] + + @abc.abstractmethod + def load_sample(self, index: int) -> OpticalFlowSample: + """ + Load an optical flow sample from the dataset. + + :param index: Index of the sample to load. + :return: Instance of OpticalFlowSample. + + """ + raise NotImplementedError() + + def load_random_sample(self) -> OpticalFlowSample: + """ + Return a random sample from the dataset + + :return: Instance of OpticalFlowSample + """ + num_samples = len(self) + random_index = random.randrange(0, num_samples) + return self.load_sample(random_index) + + def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]: + """ + Get a transformed optical flow sample from the dataset. + + :param index: Index of the sample to retrieve. + :return: Tuple containing the transformed images and flow map as np.ndarrays. + + After applying the transforms pipeline, the image is expected to be in 2HWC format, and the flow map should be + a 3D array (e.g., 2 x Height x Width). + + Before returning the images and flow map, the image's channels are moved to 2CHW format and the flow_map's channels are moved to CHW format. + """ + sample = self.load_sample(index) + for transform in self.transforms: + sample = transform(sample) + + images = np.transpose(sample.images, (0, 3, 1, 2)).astype(np.float32) + flow_map = np.transpose(sample.flow_map, (2, 0, 1)).astype(np.float32) + valid = sample.valid.astype(np.float32) + return images, (flow_map, valid) + + def plot( + self, + max_samples_per_plot: int = 8, + n_plots: int = 1, + plot_transformed_data: bool = True, + ): + """ + Combine samples of images with flow maps into plots and display the result. + + :param max_samples_per_plot: Maximum number of samples (image with depth map) to be displayed per plot. + :param n_plots: Number of plots to display. + :param plot_transformed_data: If True, the plot will be over samples after applying transforms (i.e., on __getitem__). + If False, the plot will be over the raw samples (i.e., on load_sample). + + :return: None + """ + plot_counter = 0 + + for plot_i in range(n_plots): + fig, axes = plt.subplots(3, max_samples_per_plot, figsize=(20, 7)) + for sample_i in range(max_samples_per_plot): + index = sample_i + plot_i * max_samples_per_plot + if plot_transformed_data: + images, (flow_map, valid) = self[index] + + # Transpose to HWC format for visualization + images = images.transpose(0, 2, 3, 1) + flow_map = flow_map.squeeze() # Remove dummy dimension + else: + sample = self.load_sample(index) + images, flow_map, _ = sample.images, sample.flow_map.sample.valid + + # Plot the image + axes[0, sample_i].imshow(images[0].astype(np.uint8)) + axes[0, sample_i].axis("off") + axes[0, sample_i].set_title(f"Sample {index} image1") + + axes[1, sample_i].imshow(images[1].astype(np.uint8)) + axes[1, sample_i].axis("off") + axes[1, sample_i].set_title(f"Sample {index} image2") + + # Plot the depth map side by side with the selected color scheme + flow_map = FlowVisualization.process_flow_map_for_visualization(flow_map) + axes[2, sample_i].imshow(flow_map) + axes[2, sample_i].axis("off") + axes[2, sample_i].set_title(f"Flow Map {index}") + + plt.show() + plt.close() + + plot_counter += 1 + if plot_counter == n_plots: + return diff --git a/src/super_gradients/training/datasets/optical_flow_datasets/kitti_dataset.py b/src/super_gradients/training/datasets/optical_flow_datasets/kitti_dataset.py new file mode 100644 index 0000000000..f57d701f81 --- /dev/null +++ b/src/super_gradients/training/datasets/optical_flow_datasets/kitti_dataset.py @@ -0,0 +1,127 @@ +import warnings + +import numpy as np + +from super_gradients.common.object_names import Datasets +from super_gradients.common.registry import register_dataset +from super_gradients.training.datasets.optical_flow_datasets import kitti_utils +from super_gradients.training.datasets.optical_flow_datasets.abstract_optical_flow_dataset import AbstractOpticalFlowDataset + +from super_gradients.training.samples import OpticalFlowSample +from glob import glob +import os + + +@register_dataset(Datasets.KITTI_OPTICAL_FLOW_DATASET) +class KITTIOpticalFlowDataset(AbstractOpticalFlowDataset): + """ + Dataset class for KITTI 2015 dataset for optical flow. + + :param root: Root directory containing the dataset. + :param transforms: Transforms to be applied to the samples. + + To use the KITTIOpticalFlowDataset class, ensure that your dataset directory is organized as follows: + + - Root directory (specified as 'root' when initializing the dataset) + - training + - image_2 + - 000000_10.png + - 000000_11.png + - 000001_10.png + - 000001_11.png + - ... + - flow_occ + - 000000_10.png + - 000001_10.png + - ... + + Data can be obtained at https://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow + ... + """ + + def __init__(self, root: str, transforms=None): + """ + Initialize KITTIDataset. + + :param root: Root directory containing the dataset. + :param df_path: Path to the CSV file containing image and depth map file paths. + :param transforms: Transforms to be applied to the samples. + """ + super(KITTIOpticalFlowDataset, self).__init__(transforms=transforms) + + images_list = [] + + data_root = os.path.join(root, "training") + + images1 = sorted(glob(os.path.join(data_root, "image_2/*_10.png"))) + images2 = sorted(glob(os.path.join(data_root, "image_2/*_11.png"))) + + for img1, img2 in zip(images1, images2): + images_list += [[img1, img2]] + + flow_list = sorted(glob(os.path.join(data_root, "flow_occ/*_10.png"))) + + self.files_list = [(elem1[0], elem1[1], elem2) for elem1, elem2 in zip(images_list, flow_list)] + + self._check_paths_exist() + + def load_sample(self, index: int) -> OpticalFlowSample: + """ + Load an optical flow estimation sample at the specified index. + + :param index: Index of the sample. + + :return: Loaded optical flow estimation sample. + """ + flow_map, valid = kitti_utils.read_flow_kitti(self.files_list[index][2]) + + image1 = kitti_utils.read_gen(self.files_list[index][0]) + image2 = kitti_utils.read_gen(self.files_list[index][1]) + + flow_map = np.array(flow_map).astype(np.float32) + image1 = np.array(image1).astype(np.uint8) + image2 = np.array(image2).astype(np.uint8) + + # grayscale images + if len(image1.shape) == 2: + image1 = np.tile(image1[..., None], (1, 1, 3)) + image2 = np.tile(image2[..., None], (1, 1, 3)) + else: + image1 = image1[..., :3] + image2 = image2[..., :3] + + images = np.stack([image1, image2]) + + if valid is not None: + valid = valid + else: + valid = (np.abs(flow_map[:, :, 0]) < 1000) & (np.abs(flow_map[:, :, 1]) < 1000) + + return OpticalFlowSample(images=images, flow_map=flow_map, valid=valid) + + def __len__(self): + """ + Get the number of samples in the dataset. + + :return: Number of samples in the dataset. + """ + return len(self.files_list) + + def _check_paths_exist(self): + """ + Check if the paths in self.train_list and self.val_list exist. Remove lines with missing paths and print information about removed paths. + Raise an error if all lines are removed. + """ + valid_paths = [] + + for idx in range(len(self.files_list)): + paths_exist = all(os.path.exists(path) for path in self.files_list[idx]) + if paths_exist: + valid_paths.append(self.files_list[idx]) + else: + warnings.warn(f"Warning: Removed the following line as one or more paths do not exist: {self.files_list[idx]}") + + if not valid_paths: + raise FileNotFoundError("All lines in the dataset have been removed as some paths do not exist. " "Please check the paths and dataset structure.") + + self.files_list = valid_paths diff --git a/src/super_gradients/training/datasets/optical_flow_datasets/kitti_utils.py b/src/super_gradients/training/datasets/optical_flow_datasets/kitti_utils.py new file mode 100644 index 0000000000..bd8c6eefa8 --- /dev/null +++ b/src/super_gradients/training/datasets/optical_flow_datasets/kitti_utils.py @@ -0,0 +1,91 @@ +import numpy as np +from PIL import Image +import os +import re + +import cv2 + + +def readFlow(fn): + """Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(fn, "rb") as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print("Magic number incorrect. Invalid .flo file") + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + + +def readPFM(file): + file = open(file, "rb") + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b"PF": + color = True + elif header == b"Pf": + color = False + else: + raise Exception("Not a PFM file.") + + dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = "<" + scale = -scale + else: + endian = ">" # big-endian + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + + +def read_gen(file_name, pil=False): + ext = os.path.splitext(file_name)[-1] + if ext == ".png" or ext == ".jpeg" or ext == ".ppm" or ext == ".jpg": + return Image.open(file_name) + elif ext == ".bin" or ext == ".raw": + return np.load(file_name) + elif ext == ".flo": + return readFlow(file_name).astype(np.float32) + elif ext == ".pfm": + flow = readPFM(file_name).astype(np.float32) + if len(flow.shape) == 2: + return flow + else: + return flow[:, :, :-1] + return [] + + +def read_flow_kitti(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) + flow = flow[:, :, ::-1].astype(np.float32) + flow, valid = flow[:, :, :2], flow[:, :, 2] + flow = (flow - 2**15) / 64.0 + return flow, valid diff --git a/src/super_gradients/training/losses/__init__.py b/src/super_gradients/training/losses/__init__.py index b70ff82bbc..5a82f52126 100755 --- a/src/super_gradients/training/losses/__init__.py +++ b/src/super_gradients/training/losses/__init__.py @@ -13,6 +13,7 @@ from super_gradients.training.losses.stdc_loss import STDCLoss from super_gradients.training.losses.rescoring_loss import RescoringLoss from super_gradients.training.losses.yolo_nas_pose_loss import YoloNASPoseLoss +from super_gradients.training.losses.optical_flow_loss import OpticalFlowLoss from super_gradients.common.object_names import Losses from super_gradients.common.registry.registry import LOSSES @@ -35,4 +36,5 @@ "STDCLoss", "RescoringLoss", "YoloNASPoseLoss", + "OpticalFlowLoss", ] diff --git a/src/super_gradients/training/losses/optical_flow_loss.py b/src/super_gradients/training/losses/optical_flow_loss.py new file mode 100644 index 0000000000..516714b3a1 --- /dev/null +++ b/src/super_gradients/training/losses/optical_flow_loss.py @@ -0,0 +1,53 @@ +from typing import Union + +import torch +from torch import Tensor +from torch.nn.modules.loss import _Loss + +from super_gradients.common.registry import register_loss +from super_gradients.common.decorators.factory_decorator import resolve_param +from super_gradients.common.factories.losses_factory import LossesFactory +from super_gradients.training.losses.loss_utils import apply_reduce, LossReduction + + +@register_loss() +class OpticalFlowLoss(_Loss): + @resolve_param("criterion", LossesFactory()) + def __init__(self, gamma: float, max_flow: int = 400, reduction: Union[LossReduction, str] = "mean"): + """ + Loss function defined over sequence of flow predictions + + :param gamma: Loss weights factor + :param max_flow: The maximum flow displacement allowed. Flow values above it will be excluded from metric calculation. + :param reduction: Specifies the reduction to apply to the output: `none` | `mean` | `sum`. + `none`: no reduction will be applied. + `mean`: the sum of the output will be divided by the number of elements in the output. + `sum`: the output will be summed. + Default: `mean` + """ + super().__init__() + + self.gamma = gamma + self.max_flow = max_flow + self.reduction = reduction + + def forward(self, preds: Tensor, target: Tensor): + flow_loss = 0.0 + + flow_gt, valid = target + + if torch.is_tensor(preds): + preds = [preds] + + # exclude invalid pixels and extremely large displacements + mag = torch.sum(flow_gt**2, dim=1).sqrt() + valid = (valid >= 0.5) & (mag < self.max_flow) + + n_predictions = len(preds) + + for i in range(n_predictions): + i_weight = self.gamma ** (n_predictions - i - 1) + i_loss = i_weight * (valid[:, None] * (preds[i] - flow_gt).abs()) + flow_loss += apply_reduce(i_loss, self.reduction) + + return flow_loss diff --git a/src/super_gradients/training/metrics/__init__.py b/src/super_gradients/training/metrics/__init__.py index 7916b13cd4..850ba05a20 100755 --- a/src/super_gradients/training/metrics/__init__.py +++ b/src/super_gradients/training/metrics/__init__.py @@ -17,6 +17,7 @@ DepthRMSE, DepthMSLE, ) +from super_gradients.training.metrics.optical_flow_metric import EPE __all__ = [ "METRICS", @@ -45,4 +46,5 @@ "DepthMSE", "DepthRMSE", "DepthMSLE", + "EPE", ] diff --git a/src/super_gradients/training/metrics/optical_flow_metric.py b/src/super_gradients/training/metrics/optical_flow_metric.py new file mode 100644 index 0000000000..db146340dc --- /dev/null +++ b/src/super_gradients/training/metrics/optical_flow_metric.py @@ -0,0 +1,140 @@ +import collections +from abc import ABC, abstractmethod +from typing import Union, Dict, List, Tuple +import torch +from torchmetrics import Metric + +import super_gradients +from super_gradients.common.abstractions.abstract_logger import get_logger +from super_gradients.common.environment.ddp_utils import get_world_size +from super_gradients.common.registry import register_metric +from super_gradients.common.object_names import Metrics + +logger = get_logger(__name__) + +__all__ = ["EPE"] + + +class AbstractMetricsArgsPrepFn(ABC): + """ + Abstract preprocess metrics arguments class. + """ + + @abstractmethod + def __call__(self, preds, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + All base classes must implement this function and return a tuple of torch tensors (predictions, target). + """ + raise NotImplementedError() + + +class PreprocessOpticalFlowMetricsArgs(AbstractMetricsArgsPrepFn): + """ + Default optical flow inputs preprocess function before updating optical flow metrics, handles multiple inputs. + """ + + def __init__(self, pad_factor: int = 8, apply_unpad: bool = False): + """ + :param pad_factor: The factor by which the input images were padded. By default, set to 8. + :param apply_unpad: Whether to apply unpading on predictions list. By default, set to False. + """ + self.pad_factor = pad_factor + self.apply_unpad = apply_unpad + + def __call__(self, preds, target: torch.Tensor) -> List[torch.Tensor]: + # WHEN DEALING WITH MULTIPLE OUTPUTS- OUTPUTS[-1] IS THE MAIN FLOW MAP + ht, wd = target.shape[-2:] + pad_ht = (((ht // self.pad_factor) + 1) * self.pad_factor - ht) % self.pad_factor + pad_wd = (((wd // self.pad_factor) + 1) * self.pad_factor - wd) % self.pad_factor + pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] + + if isinstance(preds, (tuple, list)): + preds = preds[-1] + if self.apply_unpad: + ht, wd = preds.shape[-2:] + c = [pad[2], ht - pad[3], pad[0], wd - pad[1]] + preds = preds[..., c[0] : c[1], c[2] : c[3]] + + return [preds] + + +@register_metric(Metrics.EPE) +class EPE(Metric): + """ + End-Point-Error metric for optical flow. + + :param max_flow: The maximum flow displacement allowed. Flow values above it will be excluded from metric calculation. + :param apply_unpad: Bool, if to apply unpad to the predicted flow map. By default, set to False. + """ + + def __init__(self, max_flow: int = None, apply_unpad: bool = False): + super().__init__() + + greater_component_is_better = [ + ("epe", False), + ] + + self.max_flow = max_flow + self.metrics_args_prep_fn = PreprocessOpticalFlowMetricsArgs(apply_unpad=apply_unpad) + self.greater_component_is_better = collections.OrderedDict(greater_component_is_better) + self.component_names = list(self.greater_component_is_better.keys()) + self.components = len(self.component_names) + self.world_size = None + self.rank = None + self.is_distributed = super_gradients.is_distributed() + + self.add_state("epe", default=[], dist_reduce_fx="cat") + + def update(self, preds: List[torch.Tensor], target: torch.Tensor): + flow_gt, valid = target + + if torch.is_tensor(preds): + preds = [preds] + + preds = self.metrics_args_prep_fn(preds, flow_gt) + + # exclude invalid pixels and extremely large displacements + mag = torch.sum(flow_gt**2, dim=1).sqrt() + epe = torch.sum((preds[-1] - flow_gt) ** 2, dim=1).sqrt() + + epe = epe.view(-1) + mag = mag.view(-1) + valid = valid.view(-1) + + if self.max_flow is None: + valid = valid >= 0.5 + else: + valid = (valid >= 0.5) & (mag < self.max_flow) + + epe = epe[valid].mean().item() + + self.epe.append(torch.tensor(epe, dtype=torch.float32)) + + def compute(self) -> Dict[str, Union[float, torch.Tensor]]: + return dict(epe=torch.tensor(self.epe).mean().item()) + + def _sync_dist(self, dist_sync_fn=None, process_group=None): + """ + When in distributed mode, stats are aggregated after each forward pass to the metric state. Since these have all + different sizes we override the synchronization function since it works only for tensors (and use + all_gather_object) + """ + if self.world_size is None: + self.world_size = get_world_size() if self.is_distributed else -1 + if self.rank is None: + self.rank = torch.distributed.get_rank() if self.is_distributed else -1 + + if self.is_distributed: + local_state_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()} + gathered_state_dicts = [None] * self.world_size + torch.distributed.barrier() + torch.distributed.all_gather_object(gathered_state_dicts, local_state_dict) + metric_keys = {"epe": []} + + for state_dict in gathered_state_dicts: + for key in state_dict.keys(): + if len(state_dict[key]) > 0: + metric_keys[key].extend(state_dict[key]) + + for key in metric_keys.keys(): + setattr(self, key, metric_keys[key]) diff --git a/src/super_gradients/training/models/__init__.py b/src/super_gradients/training/models/__init__.py index abc6bcb349..2855b2f06f 100755 --- a/src/super_gradients/training/models/__init__.py +++ b/src/super_gradients/training/models/__init__.py @@ -125,6 +125,7 @@ from super_gradients.training.models.arch_params_factory import get_arch_params from super_gradients.training.models.conversion import convert_to_coreml, convert_to_onnx, convert_from_config +from super_gradients.training.models.optical_flow_models.raft.raft_variants import RAFT_S, RAFT_L from super_gradients.common.object_names import Models from super_gradients.common.registry.registry import ARCHITECTURES @@ -295,4 +296,6 @@ "SegFormerB5", "DDRNet39Backbone", "BasicResNetBlock", + "RAFT_S", + "RAFT_L", ] diff --git a/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_head.py b/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_head.py index 82bcbcdee0..feb64b0608 100644 --- a/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_head.py +++ b/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_head.py @@ -224,7 +224,12 @@ def forward_eval(self, feats: Tuple[Tensor, ...]) -> Tuple[Tuple[Tensor, Tensor] reg_distri_list.append(torch.permute(reg_distri.flatten(2), [0, 2, 1])) reg_dist_reduced = torch.permute(reg_distri.reshape([-1, 4, self.reg_max + 1, height_mul_width]), [0, 2, 3, 1]) - reg_dist_reduced = torch.nn.functional.conv2d(torch.nn.functional.softmax(reg_dist_reduced, dim=1), weight=self.proj_conv).squeeze(1) + + # OpenVINO cannot handle this: + # reg_dist_reduced = torch.nn.functional.conv2d(torch.nn.functional.softmax(reg_dist_reduced, dim=1), weight=self.proj_conv).squeeze(1) + # So we do it with multiplication instead + reg_dist_reduced = torch.nn.functional.softmax(reg_dist_reduced, dim=1) * self.proj_conv + reg_dist_reduced = reg_dist_reduced.sum(dim=1, keepdim=False) # cls and reg cls_score_list.append(cls_logit.reshape([b, self.num_classes, height_mul_width])) diff --git a/src/super_gradients/training/models/detection_models/yolo_nas/dfl_heads.py b/src/super_gradients/training/models/detection_models/yolo_nas/dfl_heads.py index 9d8ca0e8ef..e8bba7e4a1 100644 --- a/src/super_gradients/training/models/detection_models/yolo_nas/dfl_heads.py +++ b/src/super_gradients/training/models/detection_models/yolo_nas/dfl_heads.py @@ -207,7 +207,12 @@ def forward(self, feats: Tuple[Tensor, ...]) -> Tuple[Tuple[Tensor, Tensor], Tup reg_distri_list.append(torch.permute(reg_distri.flatten(2), [0, 2, 1])) reg_dist_reduced = torch.permute(reg_distri.reshape([-1, 4, self.reg_max + 1, height_mul_width]), [0, 2, 3, 1]) - reg_dist_reduced = torch.nn.functional.conv2d(torch.nn.functional.softmax(reg_dist_reduced, dim=1), weight=self.proj_conv).squeeze(1) + + # OpenVINO cannot handle this: + # reg_dist_reduced = torch.nn.functional.conv2d(torch.nn.functional.softmax(reg_dist_reduced, dim=1), weight=self.proj_conv).squeeze(1) + # So we do it with multiplication instead + reg_dist_reduced = torch.nn.functional.softmax(reg_dist_reduced, dim=1) * self.proj_conv + reg_dist_reduced = reg_dist_reduced.sum(dim=1, keepdim=False) # cls and reg cls_score_list.append(cls_logit.reshape([b, self.num_classes, height_mul_width])) diff --git a/src/super_gradients/training/models/optical_flow_models/__init__.py b/src/super_gradients/training/models/optical_flow_models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/super_gradients/training/models/optical_flow_models/raft/__init__.py b/src/super_gradients/training/models/optical_flow_models/raft/__init__.py new file mode 100644 index 0000000000..a15c35a510 --- /dev/null +++ b/src/super_gradients/training/models/optical_flow_models/raft/__init__.py @@ -0,0 +1,31 @@ +from super_gradients.training.models.optical_flow_models.raft.raft_base import ( + BottleneckBlock, + Encoder, + ContextEncoder, + FlowHead, + SepConvGRU, + ConvGRU, + MotionEncoder, + UpdateBlock, + CorrBlock, + AlternateCorrBlock, + FlowIterativeBlock, +) + +from super_gradients.training.models.optical_flow_models.raft.raft_variants import RAFT_S, RAFT_L + +__all__ = [ + "BottleneckBlock", + "Encoder", + "ContextEncoder", + "FlowHead", + "SepConvGRU", + "ConvGRU", + "MotionEncoder", + "UpdateBlock", + "CorrBlock", + "AlternateCorrBlock", + "FlowIterativeBlock", + "RAFT_S", + "RAFT_L", +] diff --git a/src/super_gradients/training/models/optical_flow_models/raft/raft_base.py b/src/super_gradients/training/models/optical_flow_models/raft/raft_base.py new file mode 100644 index 0000000000..a859700e8b --- /dev/null +++ b/src/super_gradients/training/models/optical_flow_models/raft/raft_base.py @@ -0,0 +1,565 @@ +from typing import Optional, Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from super_gradients.module_interfaces import SupportsReplaceInputChannels + + +__all__ = [ + "BottleneckBlock", + "ResidualBlock", + "Encoder", + "ContextEncoder", + "FlowHead", + "SepConvGRU", + "ConvGRU", + "MotionEncoder", + "UpdateBlock", + "CorrBlock", + "AlternateCorrBlock", + "FlowIterativeBlock", +] + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes: int, planes: int, norm_fn: str = "group", stride: int = 1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes // 4) + self.norm2 = nn.BatchNorm2d(planes // 4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes // 4) + self.norm2 = nn.InstanceNorm2d(planes // 4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class Encoder(nn.Module, SupportsReplaceInputChannels): + def __init__( + self, + in_channels: int, + in_planes: int, + block: str, + output_dim: int = 128, + norm_fn: str = "batch", + dropout: float = 0.0, + ): + super(Encoder, self).__init__() + self.norm_fn = norm_fn + self.in_planes = in_planes + self.block = block + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(self.in_planes) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(self.in_planes) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=self.in_planes, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + dim = int(self.in_planes / 32) + self.layer1 = self._make_layer(dim * 32, stride=1) + self.layer2 = self._make_layer((dim + 1) * 32, stride=2) + self.layer3 = self._make_layer((dim + 2) * 32, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d((dim + 2) * 32, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim: int, stride: int = 1): + if self.block == "BottleneckBlock": + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + elif self.block == "ResidualBlock": + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + + # # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None): + from super_gradients.modules.weight_replacement_utils import replace_conv2d_input_channels + + self.conv1 = replace_conv2d_input_channels(conv=self.conv1, in_channels=in_channels, fn=compute_new_weights_fn) + + def get_input_channels(self) -> int: + return self.conv1.in_channels + + +class ContextEncoder(nn.Module): + def __init__( + self, + in_channels: int, + in_planes: int, + block: str, + hidden_dim: int, + context_dim: int, + output_dim: int = 128, + norm_fn: str = "batch", + dropout: float = 0.0, + ): + super(ContextEncoder, self).__init__() + + self.cnet = Encoder(in_channels=in_channels, in_planes=in_planes, block=block, output_dim=output_dim, norm_fn=norm_fn, dropout=dropout) + + self.hidden_dim = hidden_dim + self.context_dim = context_dim + + def forward(self, x): + out = self.cnet(x) + net, inp = torch.split(out, [self.hidden_dim, self.context_dim], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + return net, inp + + +class FlowHead(nn.Module): + def __init__(self, input_dim: int = 128, hidden_dim: int = 256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim: int = 128, input_dim: int = 192 + 128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) + self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) + self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) + + self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) + self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) + self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + return h + + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim: int = 128, input_dim: int = 192 + 128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) + + h = (1 - z) * h + z * q + return h + + +class MotionEncoder(nn.Module): + def __init__( + self, + corr_levels: int, + corr_radius: int, + num_corr_conv: int, + convc1_output_dim: int, + convc2_output_dim: int, + convf1_output_dim: int, + convf2_output_dim: int, + conv_output_dim: int, + ): + super(MotionEncoder, self).__init__() + self.num_corr_conv = num_corr_conv + + cor_planes = corr_levels * (2 * corr_radius + 1) ** 2 + self.convc1 = nn.Conv2d(cor_planes, convc1_output_dim, 1, padding=0) + if self.num_corr_conv == 2: + self.convc2 = nn.Conv2d(convc1_output_dim, convc2_output_dim, 3, padding=1) + conv_input_dim = convf2_output_dim + convc2_output_dim + else: + conv_input_dim = convf2_output_dim + convc1_output_dim + + self.convf1 = nn.Conv2d(2, convf1_output_dim, 7, padding=3) + self.convf2 = nn.Conv2d(convf1_output_dim, convf2_output_dim, 3, padding=1) + self.conv = nn.Conv2d(conv_input_dim, conv_output_dim, 3, padding=1) + + def forward(self, flow: Tensor, corr: Tensor): + cor = F.relu(self.convc1(corr)) + if self.num_corr_conv == 2: + cor = F.relu(self.convc2(cor)) + + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class UpdateBlock(nn.Module): + def __init__(self, encoder_params, hidden_dim: int = 128): + super(UpdateBlock, self).__init__() + self.use_mask = encoder_params.update_block.use_mask + + self.encoder = MotionEncoder( + encoder_params.corr_levels, + encoder_params.corr_radius, + encoder_params.update_block.motion_encoder.num_corr_conv, + encoder_params.update_block.motion_encoder.convc1_output_dim, + encoder_params.update_block.motion_encoder.convc2_output_dim, + encoder_params.update_block.motion_encoder.convf1_output_dim, + encoder_params.update_block.motion_encoder.convf2_output_dim, + encoder_params.update_block.motion_encoder.conv_output_dim, + ) + + if encoder_params.update_block.gru.block == "ConvGRU": + self.gru = ConvGRU(hidden_dim=encoder_params.update_block.gru.hidden_dim, input_dim=encoder_params.update_block.gru.input_dim) + elif encoder_params.update_block.gru.block == "SepConvGRU": + self.gru = SepConvGRU(hidden_dim=encoder_params.update_block.gru.hidden_dim, input_dim=encoder_params.update_block.gru.input_dim) + + self.flow_head = FlowHead(hidden_dim, hidden_dim=encoder_params.update_block.flow_head.hidden_dim) + + if self.use_mask: + self.mask = nn.Sequential(nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 64 * 9, 1, padding=0)) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + if self.use_mask: + # scale mask to balance gradients + mask = 0.25 * self.mask(net) + return net, mask, delta_flow + else: + return net, None, delta_flow + + +class CorrBlock: + def __init__(self, num_levels: int = 4, radius: int = 4): + self.num_levels = num_levels + self.radius = radius + + def __call__(self, coords, fmap1, fmap2): + corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch * h1 * w1, dim, h2, w2) + + corr_pyramid.append(corr) + + for i in range(self.num_levels - 1): + corr = F.avg_pool2d(corr, 2, stride=2) + corr_pyramid.append(corr) + + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = corr_pyramid[i] + dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device) + dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) + + centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = self.bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + @staticmethod + def bilinear_sampler(img, coords, mask: bool = False): + """Wrapper for grid_sample, uses pixel coordinates""" + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1, 1], dim=-1) + xgrid = 2 * xgrid / (W - 1) - 1 + ygrid = 2 * ygrid / (H - 1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +class AlternateCorrBlock: + def __init__(self, num_levels: int = 4, radius: int = 4): + self.num_levels = num_levels + self.radius = radius + + def __call__( + self, + coords, + fmap1, + fmap2, + ): + import alt_cuda_corr + + self.pyramid = [(fmap1, fmap2)] + for i in range(self.num_levels): + fmap1 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2 = F.avg_pool2d(fmap2, 2, stride=2) + self.pyramid.append((fmap1, fmap2)) + + coords = coords.permute(0, 2, 3, 1) + B, H, W, _ = coords.shape + dim = self.pyramid[0][0].shape[1] + + corr_list = [] + for i in range(self.num_levels): + r = self.radius + fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() + fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() + + coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() + (corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) + corr_list.append(corr.squeeze(1)) + + corr = torch.stack(corr_list, dim=1) + corr = corr.reshape(B, -1, H, W) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class FlowIterativeBlock(nn.Module): + def __init__(self, encoder_params, hidden_dim, flow_params, alternate_corr): + super(FlowIterativeBlock, self).__init__() + self.update_block = UpdateBlock(encoder_params, hidden_dim) + self.upsample_mode = flow_params.upsample_mode + self.training_iters = flow_params.training_iters + self.validation_iters = flow_params.validation_iters + + if alternate_corr: + self.corr_fn = AlternateCorrBlock(radius=encoder_params.corr_radius) + else: + self.corr_fn = CorrBlock(radius=encoder_params.corr_radius) + + @staticmethod + def upsample_flow(flow, mask): + """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8 * H, 8 * W) + + @staticmethod + def upflow8(flow, mode="bilinear"): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) + + def forward(self, coords0, coords1, net, inp, fmap1, fmap2): + + flow_predictions = [] + + if self.training: + iters = self.training_iters + else: + iters = self.validation_iters + + for itr in range(iters): + coords1 = coords1.detach() + corr = self.corr_fn(coords1, fmap1, fmap2) # index correlation volume + + flow = coords1 - coords0 + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # F(t+1) = F(t) + \Delta(t) + # update the coordinates based on the flow change + coords1 = coords1 + delta_flow + + # upsample flow predictions + if self.upsample_mode is None: + flow_up = self.upflow8(coords1 - coords0) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + flow_predictions.append(flow_up) + + return flow_predictions, flow_up diff --git a/src/super_gradients/training/models/optical_flow_models/raft/raft_variants.py b/src/super_gradients/training/models/optical_flow_models/raft/raft_variants.py new file mode 100644 index 0000000000..9173400468 --- /dev/null +++ b/src/super_gradients/training/models/optical_flow_models/raft/raft_variants.py @@ -0,0 +1,161 @@ +import copy +from typing import Union, Optional, Callable + +import torch +import torch.nn as nn +from omegaconf import DictConfig + +from super_gradients.common.object_names import Models + +from super_gradients.common.registry import register_model +from super_gradients.module_interfaces import SupportsReplaceInputChannels +from super_gradients.module_interfaces.exportable_optical_flow import ExportableOpticalFlowModel +from super_gradients.training.models import get_arch_params, SgModule +from super_gradients.training.utils.utils import HpmStruct, get_param + +from .raft_base import Encoder, ContextEncoder, FlowIterativeBlock + +""" +paper: RAFT: Recurrent All-Pairs Field Transforms for Optical Flow + ( https://arxiv.org/pdf/2003.12039 ) + +Code and KITTI pre-trained weights adopted from GitHub repo: +https://github.com/princeton-vl/RAFT/tree/3fa0bb0a9c633ea0a9bb8a79c576b6785d4e6a02 +""" + + +class RAFT(ExportableOpticalFlowModel, SgModule): + def __init__(self, in_channels, encoder_params, corr_params, flow_params, num_classes): + super().__init__() + + self.in_channels = in_channels + + self.feature_encoder = Encoder( + in_channels=self.in_channels, + in_planes=encoder_params.in_planes, + block=encoder_params.fnet.block, + output_dim=encoder_params.fnet.output_dim, + norm_fn=encoder_params.fnet.norm_fn, + dropout=encoder_params.dropout, + ) + + self.context_encoder = ContextEncoder( + in_channels=self.in_channels, + in_planes=encoder_params.in_planes, + block=encoder_params.cnet.block, + hidden_dim=encoder_params.hidden_dim, + context_dim=encoder_params.context_dim, + output_dim=encoder_params.cnet.output_dim, + norm_fn=encoder_params.cnet.norm_fn, + dropout=encoder_params.dropout, + ) + + self.flow_iterative_block = FlowIterativeBlock(encoder_params, encoder_params.update_block.hidden_dim, flow_params, corr_params.alternate_corr) + + @staticmethod + def coords_grid(batch, ht, wd, device): + coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + def initialize_flow(self, img): + """Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, C, H, W = img.shape + coords0 = self.coords_grid(N, H // 8, W // 8, device=img.device) + coords1 = self.coords_grid(N, H // 8, W // 8, device=img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def forward(self, x): + """ + Estimate optical flow between pairs of frames. + + :param x: Input image pair of shape [B,2,C,H,W], order prev/curr. + :return: Flow predictions for each input image pair. During training, the predicted flow will be a list of flow + predictions of length equal to the number of iterations in the FlowIterativeBlock. + """ + + image1 = x[:, 0] + image2 = x[:, 1] + + # run the feature network + fmap1, fmap2 = self.feature_encoder([image1, image2]) + + # run the context network + net, inp = self.context_encoder(image1) + + # initialize flow + coords0, coords1 = self.initialize_flow(image1) + + # run update block network + flow_predictions, flow_up = self.flow_iterative_block(coords0, coords1, net, inp, fmap1, fmap2) + + if torch.jit.is_tracing(): + return flow_up # removed 1st coords1 - coords0, + + return flow_predictions + + def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None): + if isinstance(self.feature_encoder, SupportsReplaceInputChannels) and isinstance(self.context_encoder, SupportsReplaceInputChannels): + self.feature_encoder.replace_input_channels(in_channels=in_channels, compute_new_weights_fn=compute_new_weights_fn) + self.context_encoder.replace_input_channels(in_channels=in_channels, compute_new_weights_fn=compute_new_weights_fn) + + self.in_channels = self.get_input_channels() + else: + raise NotImplementedError( + f"`{self.feature_encoder.__class__.__name__}` and `{self.context_encoder.__class__.__name__}` do not support `replace_input_channels`" + ) + + def get_input_channels(self) -> int: + if isinstance(self.feature_encoder, SupportsReplaceInputChannels) and isinstance(self.context_encoder, SupportsReplaceInputChannels): + return self.feature_encoder.get_input_channels() + else: + raise NotImplementedError( + f"`{self.feature_encoder.__class__.__name__}` and `{self.context_encoder.__class__.__name__}` do not support `replace_input_channels`" + ) + + def prep_model_for_conversion(self, input_size: Optional[Union[tuple, list]] = None, **kwargs): + for module in self.modules(): + if module != self and hasattr(module, "prep_model_for_conversion"): + module.prep_model_for_conversion(input_size, **kwargs) + + +@register_model(Models.RAFT_S) +class RAFT_S(RAFT): + def __init__(self, arch_params: Union[HpmStruct, DictConfig]): + """ + RAFT S architecture + :param arch_params: architecture parameters + """ + + default_arch_params = get_arch_params("raft_s_arch_params") + merged_arch_params = HpmStruct(**copy.deepcopy(default_arch_params)) + merged_arch_params.override(**arch_params.to_dict()) + super().__init__( + in_channels=merged_arch_params.in_channels, + encoder_params=merged_arch_params.encoder_params, + corr_params=merged_arch_params.corr_params, + flow_params=merged_arch_params.flow_params, + num_classes=get_param(merged_arch_params, "num_classes", None), + ) + + +@register_model(Models.RAFT_L) +class RAFT_L(RAFT): + def __init__(self, arch_params: Union[HpmStruct, DictConfig]): + """ + RAFT L architecture + :param arch_params: architecture parameters + """ + + default_arch_params = get_arch_params("raft_l_arch_params") + merged_arch_params = HpmStruct(**copy.deepcopy(default_arch_params)) + merged_arch_params.override(**arch_params.to_dict()) + super().__init__( + in_channels=merged_arch_params.in_channels, + encoder_params=merged_arch_params.encoder_params, + corr_params=merged_arch_params.corr_params, + flow_params=merged_arch_params.flow_params, + num_classes=get_param(merged_arch_params, "num_classes", None), + ) diff --git a/src/super_gradients/training/models/pose_estimation_models/yolo_nas_pose/yolo_nas_pose_ndfl_heads.py b/src/super_gradients/training/models/pose_estimation_models/yolo_nas_pose/yolo_nas_pose_ndfl_heads.py index 06b09e6ca2..631fe91de3 100644 --- a/src/super_gradients/training/models/pose_estimation_models/yolo_nas_pose/yolo_nas_pose_ndfl_heads.py +++ b/src/super_gradients/training/models/pose_estimation_models/yolo_nas_pose/yolo_nas_pose_ndfl_heads.py @@ -150,7 +150,10 @@ def forward(self, feats: Tuple[Tensor, ...]) -> Union[YoloNasPoseDecodedPredicti reg_distri_list.append(torch.permute(reg_distri.flatten(2), [0, 2, 1])) reg_dist_reduced = torch.permute(reg_distri.reshape([-1, 4, self.reg_max + 1, height_mul_width]), [0, 2, 3, 1]) - reg_dist_reduced = torch.nn.functional.conv2d(torch.nn.functional.softmax(reg_dist_reduced, dim=1), weight=self.proj_conv).squeeze(1) + # OpenVINO cannot handle this: + # reg_dist_reduced = torch.nn.functional.conv2d(torch.nn.functional.softmax(reg_dist_reduced, dim=1), weight=self.proj_conv).squeeze(1) + # So we do it with multiplication instead + reg_dist_reduced = torch.nn.functional.softmax(reg_dist_reduced, dim=1).mul(self.proj_conv).sum(1) # cls and reg cls_score_list.append(cls_logit.reshape([b, -1, height_mul_width])) diff --git a/src/super_gradients/training/pretrained_models.py b/src/super_gradients/training/pretrained_models.py index 9c938b4a4d..521a78dbbb 100644 --- a/src/super_gradients/training/pretrained_models.py +++ b/src/super_gradients/training/pretrained_models.py @@ -59,6 +59,8 @@ "yolo_nas_pose_s_coco_pose": "https://sghub.deci.ai/models/yolo_nas_pose_s_coco_pose.pth", "yolo_nas_pose_m_coco_pose": "https://sghub.deci.ai/models/yolo_nas_pose_m_coco_pose.pth", "yolo_nas_pose_l_coco_pose": "https://sghub.deci.ai/models/yolo_nas_pose_l_coco_pose.pth", + "raft_s_flying_things": "s3://yael-tmp/raft-small.pth", + "raft_l_flying_things": "s3://yael-tmp/raft-sintel.pth", } PRETRAINED_NUM_CLASSES = { @@ -69,6 +71,7 @@ "coco": 80, "coco_pose": 17, "cifar10": 10, + "flying_things": 1, } DATASET_LICENSES = { diff --git a/src/super_gradients/training/samples/__init__.py b/src/super_gradients/training/samples/__init__.py index 93f00253ae..b40fdfffe7 100644 --- a/src/super_gradients/training/samples/__init__.py +++ b/src/super_gradients/training/samples/__init__.py @@ -2,5 +2,6 @@ from .pose_estimation_sample import PoseEstimationSample from .detection_sample import DetectionSample from .segmentation_sample import SegmentationSample +from .optical_flow_sample import OpticalFlowSample -__all__ = ["PoseEstimationSample", "DetectionSample", "SegmentationSample", "DepthEstimationSample"] +__all__ = ["PoseEstimationSample", "DetectionSample", "SegmentationSample", "DepthEstimationSample", "OpticalFlowSample"] diff --git a/src/super_gradients/training/samples/optical_flow_sample.py b/src/super_gradients/training/samples/optical_flow_sample.py new file mode 100644 index 0000000000..98b03c684c --- /dev/null +++ b/src/super_gradients/training/samples/optical_flow_sample.py @@ -0,0 +1,38 @@ +import dataclasses + +import numpy as np + + +__all__ = ["OpticalFlowSample"] + + +@dataclasses.dataclass +class OpticalFlowSample: + """ + A dataclass representing a single optical flow sample. + Contains input images and flow map. + + :param images: np.ndarray, Image of [2, H, W, C] shape. + :param flow_map: Depth map of [H, W, 2] shape. + :param valid: Valid map of [H, W] shape. + """ + + __slots__ = ["images", "flow_map", "valid"] + + images: np.ndarray + flow_map: np.ndarray + valid: np.ndarray + + def __init__(self, images: np.ndarray, flow_map: np.ndarray, valid: np.ndarray = None): + # small sanity check + dm_shape = flow_map.shape + + if len(dm_shape) == 4: + if dm_shape[-1] == 1: + flow_map = np.squeeze(flow_map, axis=-1) + else: + raise RuntimeError(f"Flow map should contain only H and W dimensions for both u and v axises, got {len(dm_shape)} dimensions instead.") + + self.images = images + self.flow_map = flow_map + self.valid = valid diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 5f856da732..7ba08998ae 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -1421,6 +1421,14 @@ def get_finetune_lr_dict(self, lr: float) -> Dict[str, float]: self.ckpt_best_name = self.training_params.ckpt_best_name + if self.training_params.average_best_models and not self.training_params.save_model: + logger.warning( + "'training_params.average_best_models' is enabled, but 'training_params.save_model' is disabled. \n" + "Model averaging requires saving snapshot checkpoints to function properly. As a result, " + "'training_params.average_best_models' will be disabled. " + ) + self.training_params.average_best_models = False + self.max_train_batches = self.training_params.max_train_batches self.max_valid_batches = self.training_params.max_valid_batches diff --git a/src/super_gradients/training/transforms/__init__.py b/src/super_gradients/training/transforms/__init__.py index 1f52079464..647526801c 100644 --- a/src/super_gradients/training/transforms/__init__.py +++ b/src/super_gradients/training/transforms/__init__.py @@ -11,6 +11,13 @@ DetectionTargetsFormatTransform, Standardize, DetectionTransform, + OpticalFlowColorJitter, + OpticalFlowOcclusion, + OpticalFlowRandomRescale, + OpticalFlowRandomFlip, + OpticalFlowCrop, + OpticalFlowInputPadder, + OpticalFlowNormalize, ) from super_gradients.training.transforms.keypoints import ( AbstractKeypointTransform, @@ -37,6 +44,7 @@ from super_gradients.common.registry.registry import TRANSFORMS from super_gradients.common.registry.albumentation import ALBUMENTATIONS_TRANSFORMS, ALBUMENTATIONS_COMP_TRANSFORMS, imported_albumentations_failure from super_gradients.training.transforms.detection import AbstractDetectionTransform, DetectionPadIfNeeded, DetectionLongestMaxSize +from super_gradients.training.transforms.optical_flow import AbstractOpticalFlowTransform __all__ = [ "TRANSFORMS", @@ -76,6 +84,14 @@ "DetectionPadIfNeeded", "DetectionLongestMaxSize", "AbstractDetectionTransform", + "AbstractOpticalFlowTransform", + "OpticalFlowColorJitter", + "OpticalFlowOcclusion", + "OpticalFlowRandomRescale", + "OpticalFlowRandomFlip", + "OpticalFlowCrop", + "OpticalFlowInputPadder", + "OpticalFlowNormalize", ] cv2.setNumThreads(0) diff --git a/src/super_gradients/training/transforms/optical_flow/__init__.py b/src/super_gradients/training/transforms/optical_flow/__init__.py new file mode 100644 index 0000000000..6eb48abab3 --- /dev/null +++ b/src/super_gradients/training/transforms/optical_flow/__init__.py @@ -0,0 +1,3 @@ +from .abstract_optical_flow_transform import AbstractOpticalFlowTransform + +__all__ = ["AbstractOpticalFlowTransform"] diff --git a/src/super_gradients/training/transforms/optical_flow/abstract_optical_flow_transform.py b/src/super_gradients/training/transforms/optical_flow/abstract_optical_flow_transform.py new file mode 100644 index 0000000000..aadd467c47 --- /dev/null +++ b/src/super_gradients/training/transforms/optical_flow/abstract_optical_flow_transform.py @@ -0,0 +1,22 @@ +import abc + +from super_gradients.training.samples import OpticalFlowSample + + +class AbstractOpticalFlowTransform(abc.ABC): + """ + Base class for all transforms for optical flow sample augmentation. + """ + + @abc.abstractmethod + def __call__(self, sample: OpticalFlowSample) -> OpticalFlowSample: + """ + Apply transformation to given optical flow sample. + Important note - function call may return new object, may modify it in-place. + This is implementation dependent and if you need to keep original sample intact it + is recommended to make a copy of it BEFORE passing it to transform. + + :param sample: Input sample to transform. + :return: Modified sample (It can be the same instance as input or a new object). + """ + raise NotImplementedError() diff --git a/src/super_gradients/training/transforms/transforms.py b/src/super_gradients/training/transforms/transforms.py index f18b2b9028..0accd6ce03 100644 --- a/src/super_gradients/training/transforms/transforms.py +++ b/src/super_gradients/training/transforms/transforms.py @@ -21,8 +21,9 @@ from super_gradients.training.datasets.data_formats.bbox_formats.xywh import xyxy_to_xywh from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL, LABEL_CXCYWH from super_gradients.training.datasets.data_formats.formats import filter_on_bboxes, ConcatenatedTensorFormat -from super_gradients.training.samples import DetectionSample, SegmentationSample +from super_gradients.training.samples import DetectionSample, SegmentationSample, OpticalFlowSample from super_gradients.training.transforms.detection import DetectionPadIfNeeded, AbstractDetectionTransform, LegacyDetectionTransformMixin +from super_gradients.training.transforms.optical_flow import AbstractOpticalFlowTransform from super_gradients.training.transforms.segmentation.abstract_segmentation_transform import AbstractSegmentationTransform from super_gradients.training.transforms.segmentation.legacy_segmentation_transform_mixin import LegacySegmentationTransformMixin from super_gradients.training.transforms.utils import ( @@ -1658,3 +1659,279 @@ def _max_targets_deprication(max_targets: Optional[int] = None): "If you are using collate_fn provided by SG, it is safe to simply drop this argument.", DeprecationWarning, ) + + +@register_transform(Transforms.OpticalFlowColorJitter) +class OpticalFlowColorJitter(AbstractOpticalFlowTransform): + """ + Apply color jitter transformation to the input images with a certain probability. + + :param brightness: (float) Brightness factor for color jitter. + :param contrast: (float) Contrast factor for color jitter + :param saturation: (float) Saturation factor for color jitter + :param hue: (float) Hue factor for color jitter + :param prob: (float) Probability of applying color jitter transformation. + """ + + def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.16, prob=0.5): + self.color_jitter = _transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) + self.prob = prob + + def __call__(self, sample: OpticalFlowSample) -> OpticalFlowSample: + img1, img2 = sample.images[0], sample.images[1] + + if np.random.rand() < self.prob: + img1 = self.color_jitter(Image.fromarray(img1)) + img2 = self.color_jitter(Image.fromarray(img2)) + + images = np.stack([img1, img2]) + + return OpticalFlowSample(images=images, flow_map=sample.flow_map, valid=sample.valid) + + +@register_transform(Transforms.OpticalFlowOcclusion) +class OpticalFlowOcclusion(AbstractOpticalFlowTransform): + """ + Apply occlusion augmentation to optical flow images. + + :param prob: Probability of applying occlusion. + :param bounds: Bounds for occlusion size. + """ + + def __init__(self, prob=0.5, bounds=(10, 30)): + self.prob = prob + self.bounds = bounds + + def __call__(self, sample: OpticalFlowSample) -> OpticalFlowSample: + img1, img2 = sample.images[0], sample.images[1] + ht, wd = img1.shape[:2] + + if np.random.rand() < self.prob: + mean_color = np.mean(img2) + + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(self.bounds[0], self.bounds[1]) + dy = np.random.randint(self.bounds[0], self.bounds[1]) + img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color + + images = np.stack([img1, img2]) + + return OpticalFlowSample(images=images, flow_map=sample.flow_map, valid=sample.valid) + + +@register_transform(Transforms.OpticalFlowRandomRescale) +class OpticalFlowRandomRescale(AbstractOpticalFlowTransform): + """ + Apply random rescaling to optical flow images. + + :param min_scale: Minimum scaling factor. + :param max_scale: Maximum scaling factor. + :param prob: Probability of applying random rescale. + """ + + def __init__(self, min_scale=0.9, max_scale=1.2, prob=0.5): + self.scale = np.random.uniform(min_scale, max_scale) + self.prob = prob + + def __call__(self, sample: OpticalFlowSample) -> OpticalFlowSample: + img1, img2 = sample.images[0], sample.images[1] + flow_map = sample.flow_map + valid = sample.valid + + if np.random.rand() < self.prob: + img1 = cv2.resize(img1, None, fx=self.scale, fy=self.scale, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=self.scale, fy=self.scale, interpolation=cv2.INTER_LINEAR) + flow_map, valid = self.resize_sparse_flow_map(flow_map, valid, fx=self.scale, fy=self.scale) + + images = np.stack([img1, img2]) + + return OpticalFlowSample(images=images, flow_map=flow_map, valid=valid) + + @staticmethod + def resize_sparse_flow_map(flow, valid, fx=1.0, fy=1.0): + ht, wd = flow.shape[:2] + coords = np.meshgrid(np.arange(wd), np.arange(ht)) + coords = np.stack(coords, axis=-1) + + coords = coords.reshape(-1, 2).astype(np.float32) + flow = flow.reshape(-1, 2).astype(np.float32) + valid = valid.reshape(-1).astype(np.float32) + + coords0 = coords[valid >= 1] + flow0 = flow[valid >= 1] + + ht1 = int(round(ht * fy)) + wd1 = int(round(wd * fx)) + + coords1 = coords0 * [fx, fy] + flow1 = flow0 * [fx, fy] + + xx = np.round(coords1[:, 0]).astype(np.int32) + yy = np.round(coords1[:, 1]).astype(np.int32) + + v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) + valid_img = np.zeros([ht1, wd1], dtype=np.int32) + + flow_img[yy, xx] = flow1 + valid_img[yy, xx] = 1 + + return flow_img, valid_img + + +@register_transform(Transforms.OpticalFlowRandomFlip) +class OpticalFlowRandomFlip(AbstractOpticalFlowTransform): + """ + Apply random flipping to optical flow images. + + :param h_flip_prob: Probability of horizontal flipping. + :param v_flip_prob: Probability of vertical flipping. + """ + + def __init__(self, h_flip_prob=0.5, v_flip_prob=0.1): + self.h_flip_prob = h_flip_prob + self.v_flip_prob = v_flip_prob + + def __call__(self, sample: OpticalFlowSample) -> OpticalFlowSample: + img1, img2 = sample.images[0], sample.images[1] + flow_map = sample.flow_map + valid = sample.valid + + if np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow_map = flow_map[:, ::-1] * [-1.0, 1.0] + valid = valid[:, ::-1] + + if np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow_map = flow_map[::-1, :] * [1.0, -1.0] + valid = valid[::-1, :] + + images = np.stack([img1, img2]) + + return OpticalFlowSample(images=images, flow_map=flow_map, valid=valid) + + +@register_transform(Transforms.OpticalFlowCrop) +class OpticalFlowCrop(AbstractOpticalFlowTransform): + """ + Crop optical flow images. + + :param crop_size: Size of the crop in the format (height, width). + :param mode: Crop mode ('center' or 'random'). + """ + + def __init__(self, crop_size: Union[Tuple, List], mode: str = "center"): + self.crop_size = crop_size + self.mode = mode + + def __call__(self, sample: OpticalFlowSample) -> OpticalFlowSample: + img1, img2 = sample.images[0], sample.images[1] + flow_map = sample.flow_map + valid = sample.valid + + # Get the height and width of the image + img_height, img_width = img1.shape[:2] + + # Extract the desired crop size + crop_height, crop_width = self.crop_size + + # Check if cropping is necessary + if img_height > crop_height: + if self.mode == "center": + y_start = max(0, (img_height - crop_height) // 2) + elif self.mode == "random": + y_start = np.random.randint(0, img_height - crop_height + 1) + else: + raise ValueError("Invalid crop mode. Supported modes are 'center' and 'random'.") + img1 = img1[y_start : y_start + crop_height, :] + img2 = img2[y_start : y_start + crop_height, :] + flow_map = flow_map[y_start : y_start + crop_height, :] + valid = valid[y_start : y_start + crop_height, :] + + if img_width > crop_width: + if self.mode == "center": + x_start = max(0, (img_width - crop_width) // 2) + elif self.mode == "random": + x_start = np.random.randint(0, img_width - crop_width + 1) + else: + raise ValueError("Invalid crop mode. Supported modes are 'center' and 'random'.") + img1 = img1[:, x_start : x_start + crop_width] + img2 = img2[:, x_start : x_start + crop_width] + flow_map = flow_map[:, x_start : x_start + crop_width] + valid = valid[:, x_start : x_start + crop_width] + + # Calculate padding amounts + pad_height = max(0, crop_height - img1.shape[0]) + pad_width = max(0, crop_width - img1.shape[1]) + + # Pad the cropped image using the bottom-right method + img1 = np.pad(img1, ((0, pad_height), (0, pad_width), (0, 0)), mode="symmetric") + img2 = np.pad(img2, ((0, pad_height), (0, pad_width), (0, 0)), mode="symmetric") + flow_map = np.pad(flow_map, ((0, pad_height), (0, pad_width), (0, 0)), mode="symmetric") + valid = np.pad(valid, ((0, pad_height), (0, pad_width)), mode="constant", constant_values=0) + + images = np.stack([img1, img2]) + + return OpticalFlowSample(images=images, flow_map=flow_map, valid=valid) + + +@register_transform(Transforms.OpticalFlowInputPadder) +class OpticalFlowInputPadder(AbstractOpticalFlowTransform): + """ + Pads images such that dimensions are divisible by pad_factor. + + :param dataset_mode: The padding method is determined by the dataset. + :param pad_factor: The factor which the padded image should be divisible by. + """ + + def __init__(self, dataset_mode: str, pad_factor: int = 8): + self.dataset_mode = dataset_mode + self.pad_factor = pad_factor + + def __call__(self, sample: OpticalFlowSample) -> OpticalFlowSample: + img1, img2 = sample.images[0], sample.images[1] + + # Get the input image dims + ht, wd = img1.shape[:2] + + # Calculate padding amounts + pad_ht = (((ht // self.pad_factor) + 1) * self.pad_factor - ht) % self.pad_factor + pad_wd = (((wd // self.pad_factor) + 1) * self.pad_factor - wd) % self.pad_factor + if self.dataset_mode == "sintel": + pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2] + else: + pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] + + # Pad the cropped image using the bottom-right method + img1 = np.pad(img1, ((pad[2], pad[3]), (pad[0], pad[1]), (0, 0)), mode="symmetric") + img2 = np.pad(img2, ((pad[2], pad[3]), (pad[0], pad[1]), (0, 0)), mode="symmetric") + + images = np.stack([img1, img2]) + + return OpticalFlowSample(images=images, flow_map=sample.flow_map, valid=sample.valid) + + +@register_transform(Transforms.OpticalFlowNormalize) +class OpticalFlowNormalize(AbstractOpticalFlowTransform): + """ + Normalize OpticalFlowSample images from [0, 255] range to [0, 1] range. + """ + + def __init__(self): + super().__init__() + + def __call__(self, sample: OpticalFlowSample) -> OpticalFlowSample: + images = np.ascontiguousarray(sample.images) / 255 + flow_map = np.ascontiguousarray(sample.flow_map) + valid = np.ascontiguousarray(sample.valid) + + return OpticalFlowSample(images=images, flow_map=flow_map, valid=valid) diff --git a/src/super_gradients/training/utils/collate_fn/__init__.py b/src/super_gradients/training/utils/collate_fn/__init__.py index 338182e9f9..284347f8be 100644 --- a/src/super_gradients/training/utils/collate_fn/__init__.py +++ b/src/super_gradients/training/utils/collate_fn/__init__.py @@ -2,5 +2,6 @@ from .ppyoloe_collate_fn import PPYoloECollateFN from .crowd_detection_collate_fn import CrowdDetectionCollateFN from .crowd_detection_ppyoloe_collate_fn import CrowdDetectionPPYoloECollateFN +from .optical_flow_collate_fn import OpticalFlowCollateFN -__all__ = ["DetectionCollateFN", "PPYoloECollateFN", "CrowdDetectionCollateFN", "CrowdDetectionPPYoloECollateFN"] +__all__ = ["DetectionCollateFN", "PPYoloECollateFN", "CrowdDetectionCollateFN", "CrowdDetectionPPYoloECollateFN", "OpticalFlowCollateFN"] diff --git a/src/super_gradients/training/utils/collate_fn/optical_flow_collate_fn.py b/src/super_gradients/training/utils/collate_fn/optical_flow_collate_fn.py new file mode 100644 index 0000000000..1e3f7fcb2f --- /dev/null +++ b/src/super_gradients/training/utils/collate_fn/optical_flow_collate_fn.py @@ -0,0 +1,45 @@ +from typing import Tuple, List, Union + +import numpy as np +import torch + +from super_gradients.common.registry import register_collate_function +from super_gradients.common.exceptions.dataset_exceptions import DatasetItemsException + + +@register_collate_function() +class OpticalFlowCollateFN: + """ + Collate function for optical flow training + """ + + def __init__(self): + self.expected_item_names = ("images", "targets") + + def __call__(self, data) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + try: + images_batch, labels_batch = list(zip(*data)) + except (ValueError, TypeError): + raise DatasetItemsException(data_sample=data[0], collate_type=type(self), expected_item_names=self.expected_item_names) + + return self._format_images(images_batch), self._format_labels(labels_batch) + + @staticmethod + def _format_images(images_batch: List[Union[torch.Tensor, np.array]]) -> torch.Tensor: + images_batch = [torch.tensor(img) for img in images_batch] + images_batch_stack = torch.stack(images_batch, 0) + return images_batch_stack + + @staticmethod + def _format_labels(labels_batch: List[Union[torch.Tensor, np.array]]) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Split labels to flow maps and valid tensors + :param labels_batch: a list of targets per image + :return: a tuple of two tensors of targets of all images, where one tensor is the flow map of shape [2, H, W] + and another tensor is the valid map of shape [H, W] + """ + flow_labels_batch = [torch.tensor(flow) for flow, _ in labels_batch] + valid_labels_batch = [torch.tensor(valid) for _, valid in labels_batch] + flow_labels_batch_stack = torch.stack(flow_labels_batch, 0) + valid_labels_batch_stack = torch.stack(valid_labels_batch, 0) + return flow_labels_batch_stack, valid_labels_batch_stack diff --git a/src/super_gradients/training/utils/distributed_training_utils.py b/src/super_gradients/training/utils/distributed_training_utils.py index c4dbd55e80..6d587466a1 100755 --- a/src/super_gradients/training/utils/distributed_training_utils.py +++ b/src/super_gradients/training/utils/distributed_training_utils.py @@ -10,7 +10,6 @@ from torch import distributed as dist from torch.cuda.amp import autocast from torch.distributed import get_rank, all_gather_object -from torch.distributed.elastic.multiprocessing import Std from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.launcher.api import LaunchConfig, elastic_launch @@ -345,9 +344,6 @@ def restart_script_with_ddp(num_gpus: int = None): max_restarts=0, monitor_interval=5, start_method="spawn", - log_dir=None, - redirects=Std.NONE, - tee=Std.NONE, metrics_cfg={}, ) diff --git a/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py b/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py index 2019a5ba17..345c7d15ee 100644 --- a/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py +++ b/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py @@ -119,7 +119,14 @@ def save( :param show_confidence: Whether to show confidence scores on the image. :param box_thickness: (Optional) Thickness of bounding boxes. If None, will adapt to the box size. """ - image = self.draw(box_thickness=box_thickness, show_confidence=show_confidence) + image = self.draw( + edge_colors=edge_colors, + joint_thickness=joint_thickness, + keypoint_colors=keypoint_colors, + keypoint_radius=keypoint_radius, + box_thickness=box_thickness, + show_confidence=show_confidence, + ) save_image(image=image, path=output_path) diff --git a/src/super_gradients/training/utils/visualization/optical_flow.py b/src/super_gradients/training/utils/visualization/optical_flow.py new file mode 100644 index 0000000000..ae62667bc4 --- /dev/null +++ b/src/super_gradients/training/utils/visualization/optical_flow.py @@ -0,0 +1,28 @@ +import torch +import numpy as np +from torchvision.utils import flow_to_image + + +class FlowVisualization: + @staticmethod + def process_flow_map_for_visualization( + flow_map: np.ndarray, + ) -> np.ndarray: + """ + Process a flow map for visualization. + + :param flow_map: Input depth map as a NumPy array. + + :return: Processed colormap of the flow map for visualization. + """ + + # Convert to Torch tensor + flow_map = torch.tensor(flow_map) + + # Convert flow map to an image + flow_map_img = flow_to_image(flow_map) + + # Convert to NumPy array + flow_map_img = flow_map_img.permute(1, 2, 0).numpy() + + return flow_map_img diff --git a/tests/data/kitti_2015/training/flow_occ/000000_10.png b/tests/data/kitti_2015/training/flow_occ/000000_10.png new file mode 100644 index 0000000000..f8984eb33e Binary files /dev/null and b/tests/data/kitti_2015/training/flow_occ/000000_10.png differ diff --git a/tests/data/kitti_2015/training/flow_occ/000001_10.png b/tests/data/kitti_2015/training/flow_occ/000001_10.png new file mode 100644 index 0000000000..3a295c7cfc Binary files /dev/null and b/tests/data/kitti_2015/training/flow_occ/000001_10.png differ diff --git a/tests/data/kitti_2015/training/flow_occ/000002_10.png b/tests/data/kitti_2015/training/flow_occ/000002_10.png new file mode 100644 index 0000000000..4545a7f9fa Binary files /dev/null and b/tests/data/kitti_2015/training/flow_occ/000002_10.png differ diff --git a/tests/data/kitti_2015/training/flow_occ/000003_10.png b/tests/data/kitti_2015/training/flow_occ/000003_10.png new file mode 100644 index 0000000000..3df98ded1f Binary files /dev/null and b/tests/data/kitti_2015/training/flow_occ/000003_10.png differ diff --git a/tests/data/kitti_2015/training/flow_occ/000004_10.png b/tests/data/kitti_2015/training/flow_occ/000004_10.png new file mode 100644 index 0000000000..49943cec5d Binary files /dev/null and b/tests/data/kitti_2015/training/flow_occ/000004_10.png differ diff --git a/tests/data/kitti_2015/training/flow_occ/000005_10.png b/tests/data/kitti_2015/training/flow_occ/000005_10.png new file mode 100755 index 0000000000..3ef888cad1 Binary files /dev/null and b/tests/data/kitti_2015/training/flow_occ/000005_10.png differ diff --git a/tests/data/kitti_2015/training/flow_occ/000006_10.png b/tests/data/kitti_2015/training/flow_occ/000006_10.png new file mode 100644 index 0000000000..05bb560bd3 Binary files /dev/null and b/tests/data/kitti_2015/training/flow_occ/000006_10.png differ diff --git a/tests/data/kitti_2015/training/flow_occ/000007_10.png b/tests/data/kitti_2015/training/flow_occ/000007_10.png new file mode 100644 index 0000000000..69556fff3c Binary files /dev/null and b/tests/data/kitti_2015/training/flow_occ/000007_10.png differ diff --git a/tests/data/kitti_2015/training/flow_occ/000008_10.png b/tests/data/kitti_2015/training/flow_occ/000008_10.png new file mode 100644 index 0000000000..138ed2ab94 Binary files /dev/null and b/tests/data/kitti_2015/training/flow_occ/000008_10.png differ diff --git a/tests/data/kitti_2015/training/flow_occ/000009_10.png b/tests/data/kitti_2015/training/flow_occ/000009_10.png new file mode 100644 index 0000000000..9b080a493d Binary files /dev/null and b/tests/data/kitti_2015/training/flow_occ/000009_10.png differ diff --git a/tests/data/kitti_2015/training/image_2/000000_10.png b/tests/data/kitti_2015/training/image_2/000000_10.png new file mode 100755 index 0000000000..cf07f6d277 Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000000_10.png differ diff --git a/tests/data/kitti_2015/training/image_2/000000_11.png b/tests/data/kitti_2015/training/image_2/000000_11.png new file mode 100755 index 0000000000..608902288a Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000000_11.png differ diff --git a/tests/data/kitti_2015/training/image_2/000001_10.png b/tests/data/kitti_2015/training/image_2/000001_10.png new file mode 100755 index 0000000000..dd33de122a Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000001_10.png differ diff --git a/tests/data/kitti_2015/training/image_2/000001_11.png b/tests/data/kitti_2015/training/image_2/000001_11.png new file mode 100755 index 0000000000..7a1e16107d Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000001_11.png differ diff --git a/tests/data/kitti_2015/training/image_2/000002_10.png b/tests/data/kitti_2015/training/image_2/000002_10.png new file mode 100755 index 0000000000..112f573318 Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000002_10.png differ diff --git a/tests/data/kitti_2015/training/image_2/000002_11.png b/tests/data/kitti_2015/training/image_2/000002_11.png new file mode 100755 index 0000000000..62994f279c Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000002_11.png differ diff --git a/tests/data/kitti_2015/training/image_2/000003_10.png b/tests/data/kitti_2015/training/image_2/000003_10.png new file mode 100755 index 0000000000..7d94c44448 Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000003_10.png differ diff --git a/tests/data/kitti_2015/training/image_2/000003_11.png b/tests/data/kitti_2015/training/image_2/000003_11.png new file mode 100755 index 0000000000..b3b9968b72 Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000003_11.png differ diff --git a/tests/data/kitti_2015/training/image_2/000004_10.png b/tests/data/kitti_2015/training/image_2/000004_10.png new file mode 100755 index 0000000000..5a93261767 Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000004_10.png differ diff --git a/tests/data/kitti_2015/training/image_2/000004_11.png b/tests/data/kitti_2015/training/image_2/000004_11.png new file mode 100755 index 0000000000..204ae2fb35 Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000004_11.png differ diff --git a/tests/data/kitti_2015/training/image_2/000005_10.png b/tests/data/kitti_2015/training/image_2/000005_10.png new file mode 100755 index 0000000000..c2dbc99620 Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000005_10.png differ diff --git a/tests/data/kitti_2015/training/image_2/000005_11.png b/tests/data/kitti_2015/training/image_2/000005_11.png new file mode 100755 index 0000000000..5b35618996 Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000005_11.png differ diff --git a/tests/data/kitti_2015/training/image_2/000006_10.png b/tests/data/kitti_2015/training/image_2/000006_10.png new file mode 100755 index 0000000000..da2fce1358 Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000006_10.png differ diff --git a/tests/data/kitti_2015/training/image_2/000006_11.png b/tests/data/kitti_2015/training/image_2/000006_11.png new file mode 100755 index 0000000000..353f36db02 Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000006_11.png differ diff --git a/tests/data/kitti_2015/training/image_2/000007_10.png b/tests/data/kitti_2015/training/image_2/000007_10.png new file mode 100755 index 0000000000..e609b9c3b2 Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000007_10.png differ diff --git a/tests/data/kitti_2015/training/image_2/000007_11.png b/tests/data/kitti_2015/training/image_2/000007_11.png new file mode 100755 index 0000000000..0980452b4b Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000007_11.png differ diff --git a/tests/data/kitti_2015/training/image_2/000008_10.png b/tests/data/kitti_2015/training/image_2/000008_10.png new file mode 100755 index 0000000000..1da683178e Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000008_10.png differ diff --git a/tests/data/kitti_2015/training/image_2/000008_11.png b/tests/data/kitti_2015/training/image_2/000008_11.png new file mode 100755 index 0000000000..a109024e43 Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000008_11.png differ diff --git a/tests/data/kitti_2015/training/image_2/000009_10.png b/tests/data/kitti_2015/training/image_2/000009_10.png new file mode 100755 index 0000000000..29210d8194 Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000009_10.png differ diff --git a/tests/data/kitti_2015/training/image_2/000009_11.png b/tests/data/kitti_2015/training/image_2/000009_11.png new file mode 100755 index 0000000000..df93cec4e8 Binary files /dev/null and b/tests/data/kitti_2015/training/image_2/000009_11.png differ diff --git a/tests/deci_core_unit_test_suite_runner.py b/tests/deci_core_unit_test_suite_runner.py index 27a92f1079..726138337c 100644 --- a/tests/deci_core_unit_test_suite_runner.py +++ b/tests/deci_core_unit_test_suite_runner.py @@ -90,6 +90,11 @@ from tests.unit_tests.detection_metrics_distance_based_test import TestDetectionMetricsDistanceBased from tests.unit_tests.class_balancer_test import ClassBalancerTest from tests.unit_tests.class_balanced_sampler_test import ClassBalancedSamplerTest +from tests.unit_tests.optical_flow_dataset_test import OpticalFlowDatasetTest +from tests.unit_tests.optical_flow_transforms_test import OpticalFlowTransformsTest +from tests.unit_tests.optical_flow_loss_test import OpticalFlowLossTest +from tests.unit_tests.test_optical_flow_metric import TestOpticalFlowMetric +from tests.unit_tests.export_optical_flow_model_test import TestOpticalFlowModelExport class CoreUnitTestSuiteRunner: @@ -192,6 +197,11 @@ def _add_modules_to_unit_tests_suite(self): self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(ClassBalancerTest)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(ClassBalancedSamplerTest)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestSegmentationModelExport)) + self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(OpticalFlowDatasetTest)) + self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestOpticalFlowModelExport)) + self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(OpticalFlowTransformsTest)) + self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(OpticalFlowLossTest)) + self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestOpticalFlowMetric)) def _add_modules_to_end_to_end_tests_suite(self): """ diff --git a/tests/integration_tests/raft_integration_test.py b/tests/integration_tests/raft_integration_test.py new file mode 100644 index 0000000000..62cef63851 --- /dev/null +++ b/tests/integration_tests/raft_integration_test.py @@ -0,0 +1,35 @@ +import unittest +from super_gradients.training import models +from super_gradients.training import Trainer +from super_gradients.training.dataloaders.dataloaders import kitti2015_optical_flow_val +from super_gradients.training.metrics import EPE +from super_gradients.training.transforms import OpticalFlowInputPadder, OpticalFlowNormalize + + +class RAFTIntegrationTest(unittest.TestCase): + def setUp(self): + self.data_dir = "/home/yael.baron/data/kitti" + self.dl = kitti2015_optical_flow_val( + dataset_params=dict(root=self.data_dir, transforms=[OpticalFlowInputPadder(dataset_mode="kitti", pad_factor=8), OpticalFlowNormalize()]), + dataloader_params=dict(batch_size=1), + ) + + def test_raft_s_kitti(self): + trainer = Trainer("test_raft_s") + model = models.get("raft_s", num_classes=1, checkpoint_path="/home/yael.baron/checkpoints/RAFT_pretrained_weights/raft-small.pth") + # model = models.get("raft_s", num_classes=1, pretrained_weights="flying_things") + metric = EPE(apply_unpad=True) + metric_values = trainer.test(model=model, test_loader=self.dl, test_metrics_list=[metric]) + self.assertAlmostEqual(metric_values["epe"], 7.672, delta=0.1) + + def test_raft_l_kitti(self): + trainer = Trainer("test_raft_l") + model = models.get("raft_l", num_classes=1, checkpoint_path="/home/yael.baron/checkpoints/RAFT_pretrained_weights/raft-things.pth") + # model = models.get("raft_l", num_classes=1, pretrained_weights="flying_things") + metric = EPE(apply_unpad=True) + metric_values = trainer.test(model=model, test_loader=self.dl, test_metrics_list=[metric]) + self.assertAlmostEqual(metric_values["epe"], 5.044, delta=0.001) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/__init__.py b/tests/unit_tests/__init__.py index c43384006f..be2115b149 100644 --- a/tests/unit_tests/__init__.py +++ b/tests/unit_tests/__init__.py @@ -29,6 +29,10 @@ from tests.unit_tests.test_models_factory import DynamicModelTests from tests.unit_tests.test_mixed_precision_cpu import TestMixedPrecisionDisabled from tests.unit_tests.test_data_adapters import TestClassificationAdapter, TestDetectionAdapter, TestSegmentationAdapter +from tests.unit_tests.optical_flow_dataset_test import OpticalFlowDatasetTest +from tests.unit_tests.optical_flow_transforms_test import OpticalFlowTransformsTest +from tests.unit_tests.optical_flow_loss_test import OpticalFlowLossTest +from tests.unit_tests.test_optical_flow_metric import TestOpticalFlowMetric __all__ = [ "CrashTipTest", @@ -63,4 +67,8 @@ "TestClassificationAdapter", "TestDetectionAdapter", "TestSegmentationAdapter", + "OpticalFlowDatasetTest", + "OpticalFlowTransformsTest", + "OpticalFlowLossTest", + "TestOpticalFlowMetric", ] diff --git a/tests/unit_tests/export_optical_flow_model_test.py b/tests/unit_tests/export_optical_flow_model_test.py new file mode 100644 index 0000000000..a5748b491b --- /dev/null +++ b/tests/unit_tests/export_optical_flow_model_test.py @@ -0,0 +1,68 @@ +import logging +import os +import tempfile +import unittest + +import numpy as np +import onnxruntime +import torch +from super_gradients.common.object_names import Models +from super_gradients.conversion.gs_utils import import_onnx_graphsurgeon_or_install +from super_gradients.import_utils import import_pytorch_quantization_or_install +from super_gradients.module_interfaces import ExportableOpticalFlowModel, OpticalFlowModelExportResult +from super_gradients.training import models + + +gs = import_onnx_graphsurgeon_or_install() +import_pytorch_quantization_or_install() + + +class TestOpticalFlowModelExport(unittest.TestCase): + def setUp(self) -> None: + logging.getLogger().setLevel(logging.DEBUG) + + self.models_to_test = [ + Models.RAFT_S, + Models.RAFT_L, + ] + + def test_export_to_onnxruntime_and_run(self): + """ + Test export to ONNX + """ + + with tempfile.TemporaryDirectory() as tmpdirname: + for model_type in self.models_to_test: + with self.subTest(model_type=model_type): + model_name = str(model_type).lower().replace(".", "_") + out_path = os.path.join(tmpdirname, f"{model_name}_onnxruntime.onnx") + + model_arch: ExportableOpticalFlowModel = models.get(model_name, num_classes=1) + export_result = model_arch.export( + out_path, + input_image_shape=(640, 640), # Force .export() to infer image shape from the model itself + input_image_channels=3, + input_image_dtype=torch.float32, + onnx_export_kwargs={"opset_version": 16}, + ) + + [flow_prediction] = self._run_inference_with_onnx(export_result) + self.assertTrue(flow_prediction.shape[0] == 1) + self.assertTrue(flow_prediction.shape[1] == 2) + self.assertTrue(flow_prediction.shape[2] == 640) + self.assertTrue(flow_prediction.shape[3] == 640) + + @staticmethod + def _run_inference_with_onnx(export_result: OpticalFlowModelExportResult): + input = np.zeros((1, 2, 3, 640, 640)).astype(np.float32) + + session = onnxruntime.InferenceSession(export_result.output) + inputs = [o.name for o in session.get_inputs()] + outputs = [o.name for o in session.get_outputs()] + result = session.run(outputs, {inputs[0]: input}) + + return result + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/optical_flow_dataset_test.py b/tests/unit_tests/optical_flow_dataset_test.py new file mode 100644 index 0000000000..7b0ea30ede --- /dev/null +++ b/tests/unit_tests/optical_flow_dataset_test.py @@ -0,0 +1,28 @@ +import unittest +from pathlib import Path + +import numpy as np + +from super_gradients.training.datasets.optical_flow_datasets.kitti_dataset import KITTIOpticalFlowDataset + + +class OpticalFlowDatasetTest(unittest.TestCase): + def setUp(self) -> None: + self.kitti_2015_data_dir = str(Path(__file__).parent.parent / "data" / "kitti_2015") + + def test_kitti_creation(self): + dataset = KITTIOpticalFlowDataset(root=self.kitti_2015_data_dir) + for i, (images, target) in enumerate(dataset): + flow, valid = target + self.assertTrue(isinstance(images, np.ndarray)) + self.assertTrue(isinstance(flow, np.ndarray)) + self.assertTrue(isinstance(valid, np.ndarray)) + self.assertTrue(len(dataset) == 10 and i == 9) + + def test_optical_flow_plot(self): + dataset = KITTIOpticalFlowDataset(root=self.kitti_2015_data_dir) + dataset.plot(max_samples_per_plot=8) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/optical_flow_loss_test.py b/tests/unit_tests/optical_flow_loss_test.py new file mode 100644 index 0000000000..7c0f7c6e1b --- /dev/null +++ b/tests/unit_tests/optical_flow_loss_test.py @@ -0,0 +1,49 @@ +import torch +import unittest + +from super_gradients.training.losses import OpticalFlowLoss +from super_gradients.training.losses.loss_utils import apply_reduce + + +class OpticalFlowLossTest(unittest.TestCase): + def setUp(self) -> None: + self.img_size = 100 + self.gamma = 0.8 + self.max_flow = 400 + self.reduction = "mean" + self.batch_size = 1 + + def _get_default_predictions_tensor(self, n_predictions: int, fill_value: float): + return [torch.empty(self.batch_size, 2, self.img_size, self.img_size).fill_(fill_value) for _ in range(n_predictions)] + + def _get_default_target_tensor(self): + return (torch.zeros(self.batch_size, 2, self.img_size, self.img_size).long(), torch.ones(self.img_size, self.img_size)) + + def _assertion_flow_loss_torch_values(self, expected_value: torch.Tensor, found_value: torch.Tensor, rtol: float = 1e-5): + self.assertTrue(torch.allclose(found_value, expected_value, rtol=rtol), msg=f"Unequal flow loss: excepted: {expected_value}, found: {found_value}") + + def test_flow_loss_l1_criterion(self): + predictions = self._get_default_predictions_tensor(3, 2.5) + target, valid = self._get_default_target_tensor() + + loss_fn = OpticalFlowLoss(gamma=self.gamma, max_flow=self.max_flow, reduction=self.reduction) + + flow_loss = loss_fn(predictions, (target, valid)) + + # expected_flow_loss + expected_flow_loss = 0.0 + mag = torch.sum(target**2, dim=1).sqrt() + valid = (valid >= 0.5) & (mag < self.max_flow) + + n_predictions = len(predictions) + + for i in range(n_predictions): + i_weight = self.gamma ** (n_predictions - i - 1) + i_loss = i_weight * (valid[:, None] * (predictions[i] - target).abs()) # L1 dist + expected_flow_loss += apply_reduce(i_loss, self.reduction) + + self._assertion_flow_loss_torch_values(torch.tensor(expected_flow_loss), flow_loss) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/optical_flow_transforms_test.py b/tests/unit_tests/optical_flow_transforms_test.py new file mode 100644 index 0000000000..846d1298a4 --- /dev/null +++ b/tests/unit_tests/optical_flow_transforms_test.py @@ -0,0 +1,106 @@ +import unittest + +import numpy as np + +from super_gradients.training.samples import OpticalFlowSample +from super_gradients.training.transforms.transforms import ( + OpticalFlowColorJitter, + OpticalFlowOcclusion, + OpticalFlowRandomRescale, + OpticalFlowRandomFlip, + OpticalFlowCrop, + OpticalFlowNormalize, + OpticalFlowInputPadder, +) + + +class OpticalFlowTransformsTest(unittest.TestCase): + def setUp(self): + # Create an OpticalFlowSample + self.h, self.w = 400, 400 + img1 = np.random.randint(0, 255, size=(self.h, self.w, 3), dtype=np.uint8) + img2 = np.random.randint(0, 255, size=(self.h, self.w, 3), dtype=np.uint8) + flow_map = np.random.randn(self.h, self.w, 2) + valid = (np.abs(flow_map[:, :, 0]) < 1000) & (np.abs(flow_map[:, :, 1]) < 1000) + + self.sample = OpticalFlowSample(images=np.stack([img1, img2]), flow_map=flow_map, valid=valid) + + def test_OpticalFlowColorJitter(self): + transform = OpticalFlowColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.16, prob=1.0) + transformed_sample = transform(self.sample) + self.assertIsInstance(transformed_sample.images, np.ndarray) + self.assertEqual(transformed_sample.images.shape, self.sample.images.shape) + + def test_OpticalFlowOcclusion(self): + transform = OpticalFlowOcclusion(prob=1.0, bounds=(10, 30)) + transformed_sample = transform(self.sample) + self.assertIsInstance(transformed_sample.images, np.ndarray) + self.assertEqual(transformed_sample.images.shape, self.sample.images.shape) + + def test_OpticalFlowRandomRescale(self): + transform = OpticalFlowRandomRescale(min_scale=0.9, max_scale=1.2, prob=1.0) + transformed_sample = transform(self.sample) + + scale_factor = transform.scale + + expected_shape_images = (2, int(round(self.sample.images.shape[1] * scale_factor)), int(round(self.sample.images.shape[2] * scale_factor)), 3) + expected_shape_flow = (int(round(self.sample.flow_map.shape[0] * scale_factor)), int(round(self.sample.flow_map.shape[1] * scale_factor)), 2) + expected_shape_valid = (int(round(self.sample.valid.shape[0] * scale_factor)), int(round(self.sample.valid.shape[1] * scale_factor))) + + self.assertIsInstance(transformed_sample.images, np.ndarray) + self.assertEqual(transformed_sample.images.shape, expected_shape_images) + + self.assertIsInstance(transformed_sample.flow_map, np.ndarray) + self.assertEqual(transformed_sample.flow_map.shape, expected_shape_flow) + + self.assertIsInstance(transformed_sample.valid, np.ndarray) + self.assertEqual(transformed_sample.valid.shape, expected_shape_valid) + + def test_OpticalFlowRandomFlip(self): + transform = OpticalFlowRandomFlip(h_flip_prob=0.5, v_flip_prob=0.1) + transformed_sample = transform(self.sample) + + self.assertIsInstance(transformed_sample.images, np.ndarray) + self.assertEqual(transformed_sample.images.shape, self.sample.images.shape) + + self.assertIsInstance(transformed_sample.flow_map, np.ndarray) + self.assertEqual(transformed_sample.flow_map.shape, self.sample.flow_map.shape) + + self.assertIsInstance(transformed_sample.valid, np.ndarray) + self.assertEqual(transformed_sample.valid.shape, self.sample.valid.shape) + + def test_OpticalFlowCrop(self): + transform = OpticalFlowCrop(crop_size=(50, 50), mode="random") + transformed_sample = transform(self.sample) + self.assertIsInstance(transformed_sample.images, np.ndarray) + self.assertEqual(transformed_sample.images.shape, (2, 50, 50, 3)) + + def test_OpticalFlowInputPadder(self): + pad_factor = 8 + transformer = OpticalFlowInputPadder(dataset_mode="kitti", pad_factor=pad_factor) + + # Apply the transform + transformed_sample = transformer(self.sample) + + # Calculate the expected padded dimensions + expected_pad_ht = (((self.h // pad_factor) + 1) * pad_factor - self.h) % pad_factor + expected_pad_wd = (((self.w // pad_factor) + 1) * pad_factor - self.w) % pad_factor + expected_padded_height = self.h + expected_pad_ht + expected_padded_width = self.w + expected_pad_wd + + # Check if padding is applied correctly + self.assertEqual(transformed_sample.images.shape[1], expected_padded_height) + self.assertEqual(transformed_sample.images.shape[2], expected_padded_width) + + def test_OpticalFlowNormalize(self): + transform = OpticalFlowNormalize() + transformed_sample = transform(self.sample) + + # Check if normalization is applied correctly + self.assertTrue(np.allclose(transformed_sample.images, self.sample.images / 255.0)) + self.assertTrue(np.allclose(transformed_sample.flow_map, self.sample.flow_map)) + self.assertTrue(np.array_equal(transformed_sample.valid, self.sample.valid)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/raft_tests.py b/tests/unit_tests/raft_tests.py new file mode 100644 index 0000000000..dccb44d28e --- /dev/null +++ b/tests/unit_tests/raft_tests.py @@ -0,0 +1,38 @@ +import unittest + +import torch + +from super_gradients.common.object_names import Models +from super_gradients.training import models + + +class TestRAFT(unittest.TestCase): + def setUp(self): + self.models_to_test = [ + Models.RAFT_S, + Models.RAFT_L, + ] + + def test_raft_custom_in_channels(self): + """ + Validate that we can create a RAFT model with custom in_channels. + """ + for model_type in self.models_to_test: + with self.subTest(model_type=model_type): + model_name = str(model_type).lower().replace(".", "_") + model = models.get(model_name, arch_params=dict(in_channels=1), num_classes=1).eval() + model(torch.rand(1, 2, 1, 640, 640)) + + def test_raft_forward(self): + """ + Validate that we can create a RAFT model with custom in_channels. + """ + for model_type in self.models_to_test: + with self.subTest(model_type=model_type): + model_name = str(model_type).lower().replace(".", "_") + model = models.get(model_name, num_classes=1).eval() + model(torch.rand(1, 2, 3, 640, 640)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/test_optical_flow_metric.py b/tests/unit_tests/test_optical_flow_metric.py new file mode 100644 index 0000000000..7e7626774d --- /dev/null +++ b/tests/unit_tests/test_optical_flow_metric.py @@ -0,0 +1,34 @@ +import torch +import unittest + +from super_gradients.training.metrics.optical_flow_metric import EPE + + +class TestOpticalFlowMetric(unittest.TestCase): + def test_epe_metric(self): + # Specific example data + pred_flow = [torch.ones(1, 2, 100, 100)] + gt_flow = torch.zeros(1, 2, 100, 100) + valid = torch.ones(100, 100) + + # Create instances of delta metrics + max_flow = 400 + metric = EPE(max_flow=max_flow) + + # Update metrics with specific example data + metric.update(pred_flow, (gt_flow, valid)) + + # Expected metric + mag = torch.sum(gt_flow**2, dim=1).sqrt() + valid = (valid >= 0.5) & (mag < max_flow) + + expected_epe = torch.sum((pred_flow[-1] - gt_flow) ** 2, dim=1).sqrt() + expected_epe = expected_epe.view(-1)[valid.view(-1)] + expected_epe = expected_epe.mean().item() + + # Compute and assert the delta metrics + self.assertAlmostEqual(metric.compute()["epe"], expected_epe) + + +if __name__ == "__main__": + unittest.main()