diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py index 4c9ef72c76f..97808d12857 100644 --- a/mmdet/datasets/pipelines/transforms.py +++ b/mmdet/datasets/pipelines/transforms.py @@ -1484,8 +1484,14 @@ def __call__(self, results): else: results['masks'] = [mask for mask in results['masks'].masks] + # Convert to RGB since Albumentations works with RGB images + results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_BGR2RGB) + results = self.aug(**results) + # Convert back to BGR + results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_RGB2BGR) + if 'bboxes' in results: if isinstance(results['bboxes'], list): results['bboxes'] = np.array( diff --git a/tests/test_data/test_pipelines/test_transform/test_transform.py b/tests/test_data/test_pipelines/test_transform/test_transform.py index 1ebc4f36914..eaf3f93ca39 100644 --- a/tests/test_data/test_pipelines/test_transform/test_transform.py +++ b/tests/test_data/test_pipelines/test_transform/test_transform.py @@ -496,6 +496,42 @@ def test_albu_transform(): assert results['img'].dtype == np.float32 +def test_albu_channel_order(): + results = dict( + img_prefix=osp.join(osp.dirname(__file__), '../../../data'), + img_info=dict(filename='color.jpg')) + + # Define simple pipeline + load = dict(type='LoadImageFromFile') + load = build_from_cfg(load, PIPELINES) + + # Transform is modifying B channel + albu_transform = dict( + type='Albu', + transforms=[ + dict( + type='RGBShift', + r_shift_limit=0, + g_shift_limit=0, + b_shift_limit=200, + p=1) + ]) + albu_transform = build_from_cfg(albu_transform, PIPELINES) + + # Execute transforms + results_load = load(results) + results_albu = albu_transform(results_load) + + # assert only Green and Red channel are not modified + np.testing.assert_array_equal(results_albu['img'][..., 1:], + results_load['img'][..., 1:]) + + # assert Blue channel is modified + with pytest.raises(AssertionError): + np.testing.assert_array_equal(results_albu['img'][..., 0], + results_load['img'][..., 0]) + + def test_random_center_crop_pad(): # test assertion for invalid crop_size while test_mode=False with pytest.raises(AssertionError):