1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414from dataclasses import dataclass
15- from typing import Any , Callable , Dict , Tuple , Union
16-
17- import torch
15+ from typing import Any , Callable , Dict , Tuple
1816
1917from flash .core .data .io .input import DataKeys
2018from flash .core .data .io .input_transform import InputTransform
21- from flash .core .data .transforms import ApplyToKeys , kornia_collate , KorniaParallelTransforms
22- from flash .core .utilities .imports import _KORNIA_AVAILABLE , _TORCHVISION_AVAILABLE , requires
19+ from flash .core .data .transforms import AlbumentationsAdapter , ApplyToKeys
20+ from flash .core .utilities .imports import _ALBUMENTATIONS_AVAILABLE , _TORCHVISION_AVAILABLE , requires
2321
24- if _KORNIA_AVAILABLE :
25- import kornia as K
22+ if _ALBUMENTATIONS_AVAILABLE :
23+ import albumentations as alb
24+ else :
25+ alb = None
2626
2727if _TORCHVISION_AVAILABLE :
2828 from torchvision import transforms as T
3131def prepare_target (batch : Dict [str , Any ]) -> Dict [str , Any ]:
3232 """Convert the target mask to long and remove the channel dimension."""
3333 if DataKeys .TARGET in batch :
34- batch [DataKeys .TARGET ] = batch [DataKeys .TARGET ].long ().squeeze ( 1 )
34+ batch [DataKeys .TARGET ] = batch [DataKeys .TARGET ].squeeze ().long ( )
3535 return batch
3636
3737
38- def target_as_tensor (sample : Dict [str , Any ]) -> Dict [str , Any ]:
38+ def permute_target (sample : Dict [str , Any ]) -> Dict [str , Any ]:
3939 if DataKeys .TARGET in sample :
4040 target = sample [DataKeys .TARGET ]
4141 if target .ndim == 2 :
42- target = target [: , :, None ]
43- sample [DataKeys .TARGET ] = torch . from_numpy ( target .transpose ((2 , 0 , 1 ))). contiguous (). squeeze (). float ( )
42+ target = target [None , :, : ]
43+ sample [DataKeys .TARGET ] = target .transpose ((1 , 2 , 0 ) )
4444 return sample
4545
4646
@@ -53,62 +53,48 @@ def remove_extra_dimensions(batch: Dict[str, Any]):
5353
5454@dataclass
5555class SemanticSegmentationInputTransform (InputTransform ):
56+ # https://albumentations.ai/docs/examples/pytorch_semantic_segmentation
5657
5758 image_size : Tuple [int , int ] = (128 , 128 )
58- mean : Union [ float , Tuple [float , float , float ] ] = (0.485 , 0.456 , 0.406 )
59- std : Union [ float , Tuple [float , float , float ] ] = (0.229 , 0.224 , 0.225 )
59+ mean : Tuple [float , float , float ] = (0.485 , 0.456 , 0.406 )
60+ std : Tuple [float , float , float ] = (0.229 , 0.224 , 0.225 )
6061
6162 @requires ("image" )
6263 def train_per_sample_transform (self ) -> Callable :
6364 return T .Compose (
6465 [
66+ permute_target ,
67+ AlbumentationsAdapter (
68+ [
69+ alb .Resize (* self .image_size ),
70+ alb .HorizontalFlip (p = 0.5 ),
71+ alb .Normalize (mean = self .mean , std = self .std ),
72+ ]
73+ ),
6574 ApplyToKeys (
6675 DataKeys .INPUT ,
6776 T .ToTensor (),
6877 ),
69- target_as_tensor ,
70- ApplyToKeys (
71- [DataKeys .INPUT , DataKeys .TARGET ],
72- KorniaParallelTransforms (
73- K .geometry .Resize (self .image_size , interpolation = "nearest" ),
74- K .augmentation .RandomHorizontalFlip (p = 0.5 ),
75- ),
76- ),
77- ApplyToKeys ([DataKeys .INPUT ], K .augmentation .Normalize (mean = self .mean , std = self .std )),
7878 ]
7979 )
8080
8181 @requires ("image" )
8282 def per_sample_transform (self ) -> Callable :
8383 return T .Compose (
8484 [
85+ permute_target ,
86+ AlbumentationsAdapter (
87+ [
88+ alb .Resize (* self .image_size ),
89+ alb .Normalize (mean = self .mean , std = self .std ),
90+ ]
91+ ),
8592 ApplyToKeys (
8693 DataKeys .INPUT ,
8794 T .ToTensor (),
8895 ),
89- target_as_tensor ,
90- ApplyToKeys (
91- [DataKeys .INPUT , DataKeys .TARGET ],
92- KorniaParallelTransforms (K .geometry .Resize (self .image_size , interpolation = "nearest" )),
93- ),
94- ApplyToKeys ([DataKeys .INPUT ], K .augmentation .Normalize (mean = self .mean , std = self .std )),
9596 ]
9697 )
9798
98- @requires ("image" )
99- def predict_per_sample_transform (self ) -> Callable :
100- return ApplyToKeys (
101- DataKeys .INPUT ,
102- T .ToTensor (),
103- K .geometry .Resize (
104- self .image_size ,
105- interpolation = "nearest" ,
106- ),
107- K .augmentation .Normalize (mean = self .mean , std = self .std ),
108- )
109-
110- def collate (self ) -> Callable :
111- return kornia_collate
112-
11399 def per_batch_transform (self ) -> Callable :
114100 return T .Compose ([prepare_target , remove_extra_dimensions ])
0 commit comments