-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhubconf.py
More file actions
243 lines (197 loc) · 9.15 KB
/
hubconf.py
File metadata and controls
243 lines (197 loc) · 9.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
"""
PyTorch Hub configuration for U-Net implementation.
Based on U-Net: Convolutional Networks for Biomedical Image Segmentation
by Ronneberger et al. (2015): https://arxiv.org/abs/1505.04597
Author: Ole-Christian Galbo Engstrøm
Email: ocge@foss.dk
"""
dependencies = ["torch"]
# Import the U-Net implementation
from unet import UNet as _UNet
def unet(
pretrained=False,
in_channels=3,
out_channels=1,
pad=True,
bilinear=True,
normalization=None,
depth=5,
**kwargs,
):
"""
U-Net model for semantic segmentation
This implementation follows the original U-Net architecture with options for
different normalization techniques, upsampling methods, and padding strategies.
Args:
pretrained (bool): If True, returns a model pre-trained on a dataset (not yet available)
in_channels (int): Number of input channels (default: 3 for RGB images)
out_channels (int): Number of output channels/classes (default: 1 for binary segmentation)
pad (bool): If True, the input size is preserved by zero-padding convolutions and, if necessary, the results of the upsampling operations.
If False, output size will be reduced compared to input size (default: True)
bilinear (bool): If True, use bilinear upsampling. If False, use transposed convolution (default: True)
normalization (None | str): Normalization type. Options:
- None: No normalization
- 'bn': Batch normalization
- 'ln': Layer normalization
(default: None)
depth (int): The depth of the U-Net. This is the number of steps in the encoder and decoder
paths. This is one less than the number of downsampling and upsampling blocks.
The number of intermediate channels is 64*2**depth, i.e.
[64, 128, 256, 512, 1024] for depth = 5.
**kwargs: Additional arguments (currently unused but available for future extensions)
Returns:
torch.nn.Module: U-Net model with intermediate channels [64, 128, 256, 512, 1024]
Example:
>>> import torch
>>>
>>> # Basic U-Net for binary segmentation (e.g., medical imaging)
>>> model = torch.hub.load('sm00thix/unet', 'unet', pretrained=False)
>>>
>>> # Multi-class segmentation (e.g., 21 classes for PASCAL VOC)
>>> model = torch.hub.load('sm00thix/unet', 'unet', pretrained=False, out_channels=21)
>>>
>>> # U-Net with batch normalization
>>> model = torch.hub.load('sm00thix/unet', 'unet', pretrained=False, normalization='bn')
>>>
>>> # U-Net with transposed convolution upsampling instead of bilinear interpolation
>>> model = torch.hub.load('sm00thix/unet', 'unet', pretrained=False, bilinear=False)
>>>
>>> # Grayscale input (e.g., medical images, satellite imagery)
>>> model = torch.hub.load('sm00thix/unet', 'unet', pretrained=False, in_channels=1)
>>>
>>> # Forward pass
>>> x = torch.randn(1, 3, 256, 256) # (batch, channels, height, width)
>>> with torch.no_grad():
... output = model(x)
>>> print(f"Input shape: {x.shape}")
>>> print(f"Output shape: {output.shape}") # (1, out_channels, 256, 256) if pad=True
Note:
- The model uses intermediate channels [64, 128, 256, 512, 1024] following the original paper
- When pad=True, output spatial dimensions are identical to input spatial dimensions
- When pad=False, output will be smaller than input due to valid convolutions and potential dropping of rows/columns in the strided pooling layers
- Bilinear upsampling uses fewer parameters than transposed convolution and avoids checkerboard artifacts
- Normalization can be set to 'bn' for batch normalization or 'ln' for layer normalization
"""
# Create model with specified parameters
model = _UNet(
in_channels=in_channels,
out_channels=out_channels,
pad=pad,
bilinear=bilinear,
normalization=normalization,
depth=depth,
)
if pretrained:
raise NotImplementedError(
"Pretrained weights are not yet available. "
"The model will be initialized with random weights using Kaiming normal initialization. "
"Please train the model on your specific dataset for optimal performance."
)
return model
def unet_bn(pretrained=False, in_channels=3, out_channels=1, **kwargs):
"""
U-Net model with Batch Normalization
Batch Normalization can be beneficial for training stability when using larger batch sizes.
Args:
pretrained (bool): If True, returns a model pre-trained on a dataset (not yet available)
in_channels (int): Number of input channels (default: 3)
out_channels (int): Number of output channels (default: 1)
**kwargs: Additional arguments passed to the base unet function
Returns:
torch.nn.Module: U-Net model with batch normalization
Example:
>>> model = torch.hub.load('sm00thix/unet', 'unet_bn', pretrained=False)
>>> # Equivalent to: unet(normalization='bn')
"""
return unet(
pretrained=pretrained,
in_channels=in_channels,
out_channels=out_channels,
normalization="bn",
**kwargs,
)
def unet_ln(pretrained=False, in_channels=3, out_channels=1, **kwargs):
"""
U-Net model with Layer Normalization
Layer normalization can be beneficial when batch sizes are small.
Args:
pretrained (bool): If True, returns a model pre-trained on a dataset (not yet available)
in_channels (int): Number of input channels (default: 3)
out_channels (int): Number of output channels (default: 1)
**kwargs: Additional arguments passed to the base unet function
Returns:
torch.nn.Module: U-Net model with layer normalization
Example:
>>> model = torch.hub.load('sm00thix/unet', 'unet_ln', pretrained=False)
>>> # Equivalent to: unet(normalization='ln')
"""
return unet(
pretrained=pretrained,
in_channels=in_channels,
out_channels=out_channels,
normalization="ln",
**kwargs,
)
def unet_medical(pretrained=False, **kwargs):
"""
U-Net model configured for medical image segmentation
Configured with grayscale input (typical for medical images) and binary output (e.g., organ/background segmentation).
Args:
pretrained (bool): If True, returns a model pre-trained on a dataset (not yet available)
**kwargs: Additional arguments passed to the base unet function
Returns:
torch.nn.Module: U-Net model optimized for medical imaging
Example:
>>> model = torch.hub.load('sm00thix/unet', 'unet_medical', pretrained=False)
>>> # Single channel input, batch normalization
>>> x = torch.randn(1, 1, 512, 512) # Typical medical image size
>>> output = model(x)
"""
return unet(pretrained=pretrained, in_channels=1, out_channels=1, **kwargs)
def unet_transconv(pretrained=False, in_channels=3, out_channels=1, **kwargs):
"""
U-Net model using transposed convolution for upsampling
Uses transposed convolution instead of bilinear upsampling.
Args:
pretrained (bool): If True, returns a model pre-trained on a dataset (not yet available)
in_channels (int): Number of input channels (default: 3)
out_channels (int): Number of output channels (default: 1)
**kwargs: Additional arguments passed to the base unet function
Returns:
torch.nn.Module: U-Net model with transposed convolution upsampling
Example:
>>> model = torch.hub.load('sm00thix/unet', 'unet_transconv', pretrained=False)
>>> # Equivalent to: unet(bilinear=False)
"""
return unet(
pretrained=pretrained,
in_channels=in_channels,
out_channels=out_channels,
bilinear=False, # Use transposed convolution
**kwargs,
)
# Example usage for documentation
_EXAMPLE_USAGE = """
# Load and use U-Net models
import torch
# Basic usage
model = torch.hub.load('sm00thix/unet', 'unet', pretrained=False)
print(f"Model loaded: {model.__class__.__name__}")
# Multi-class segmentation example
model = torch.hub.load('sm00thix/unet', 'unet', pretrained=False, out_channels=21) # PASCAL VOC classes
# Medical imaging example
model = torch.hub.load('sm00thix/unet', 'unet_medical', pretrained=False)
# Original U-Net with transposed convolution upsampling and no padding
model = torch.hub.load('sm00thix/unet', 'unet', pretrained=False, in_channels=1, out_channels=1, pad=False, bilinear=False, normalization=None)
# U-Net with depth 3
model = torch.hub.load('sm00thix/unet', 'unet', pretrained=False, depth=3)
# Example forward pass
model = torch.hub.load('sm00thix/unet', 'unet', pretrained=False)
x = torch.randn(1, 3, 256, 256) # RGB image
with torch.no_grad():
output = model(x)
print(f"Input: {x.shape} -> Output: {output.shape}")
# List all available models
available_models = torch.hub.list('sm00thix/unet')
print(f"Available models: {available_models}")
"""