diff --git a/keras_hub/src/models/segformer/segformer_image_segmenter.py b/keras_hub/src/models/segformer/segformer_image_segmenter.py index 7c7dc73d17..f61564c04e 100644 --- a/keras_hub/src/models/segformer/segformer_image_segmenter.py +++ b/keras_hub/src/models/segformer/segformer_image_segmenter.py @@ -40,6 +40,8 @@ class SegFormerImageSegmenter(ImageSegmenter): projection_filters: int, number of filters in the convolution layer projecting the concatenated features into a segmentation map. Defaults to 256`. + dropout_rate: float. The dropout rate to apply before the + segmentation head. Defaults to `0.1`. Example: @@ -121,6 +123,7 @@ def __init__( backbone, num_classes, preprocessor=None, + dropout_rate=0.1, **kwargs, ): if not isinstance(backbone, keras.layers.Layer) or not isinstance( @@ -137,7 +140,7 @@ def __init__( self.backbone = backbone self.preprocessor = preprocessor - self.dropout = keras.layers.Dropout(0.1) + self.dropout = keras.layers.Dropout(dropout_rate) self.output_segmentation_head = keras.layers.Conv2D( filters=num_classes, kernel_size=1, strides=1 ) @@ -162,6 +165,7 @@ def __init__( # === Config === self.num_classes = num_classes self.backbone = backbone + self.dropout_rate = dropout_rate def get_config(self): config = super().get_config() @@ -169,6 +173,7 @@ def get_config(self): { "num_classes": self.num_classes, "backbone": keras.saving.serialize_keras_object(self.backbone), + "dropout_rate": self.dropout_rate, } ) return config