|
2 | 2 | Title: Semi-supervision and domain adaptation with AdaMatch |
3 | 3 | Author: [Sayak Paul](https://twitter.com/RisingSayak) |
4 | 4 | Date created: 2021/06/19 |
5 | | -Last modified: 2026/03/09 |
| 5 | +Last modified: 2026/03/10 |
6 | 6 | Description: Unifying semi-supervised learning and unsupervised domain adaptation with AdaMatch. |
7 | 7 | Accelerator: GPU |
8 | 8 | Converted to Keras 3 by: [Maitry Sinha](https://github.com/maitry63) |
|
70 | 70 | import numpy as np |
71 | 71 | from keras import layers, ops |
72 | 72 | import scipy.io |
| 73 | +from PIL import Image |
73 | 74 |
|
74 | 75 | """ |
75 | 76 | ## Prepare the data |
@@ -126,21 +127,59 @@ def load_svhn_data(): |
126 | 127 |
|
127 | 128 |
|
128 | 129 | class AdaMatchDataset(keras.utils.PyDataset): |
129 | | - def __init__(self, source_x, source_y, target_x, **kwargs): |
| 130 | + def __init__(self, source_x, source_y, target_x, target_size=32, **kwargs): |
| 131 | + """ |
| 132 | + Dataset for AdaMatch training. |
| 133 | + Performs resize-and-pad on source images to preserve aspect ratio, |
| 134 | + then tiles them to 3 channels if needed. |
| 135 | + """ |
130 | 136 | super().__init__(**kwargs) |
131 | 137 | self.source_x = source_x |
132 | 138 | self.source_y = source_y |
133 | 139 | self.target_x = target_x |
134 | | - self.resizer = layers.Resizing(RESIZE_TO, RESIZE_TO) |
| 140 | + self.target_size = target_size |
135 | 141 |
|
136 | 142 | def __len__(self): |
137 | 143 | return STEPS_PER_EPOCH |
138 | 144 |
|
| 145 | + def resize_and_pad(self, images): |
| 146 | + """ |
| 147 | + Resize images to target_size x target_size while preserving aspect ratio. |
| 148 | + Pads with zeros if necessary. |
| 149 | + """ |
| 150 | + resized_images = [] |
| 151 | + for img in images: |
| 152 | + img = np.squeeze(img) |
| 153 | + if img.ndim == 2: |
| 154 | + img = np.expand_dims(img, -1) # grayscale to (H,W,1) |
| 155 | + h, w = img.shape[:2] |
| 156 | + scale = self.target_size / max(h, w) |
| 157 | + new_h = int(h * scale) |
| 158 | + new_w = int(w * scale) |
| 159 | + if img.shape[2] == 1: |
| 160 | + pil_img = Image.fromarray(img[:, :, 0]) |
| 161 | + else: |
| 162 | + pil_img = Image.fromarray(img.astype(np.uint8)) |
| 163 | + pil_resized = pil_img.resize((new_w, new_h), Image.BILINEAR) |
| 164 | + resized = ( |
| 165 | + np.expand_dims(np.array(pil_resized), -1) |
| 166 | + if img.shape[2] == 1 |
| 167 | + else np.array(pil_resized) |
| 168 | + ) |
| 169 | + # Pad |
| 170 | + pad_h = (self.target_size - new_h) // 2 |
| 171 | + pad_w = (self.target_size - new_w) // 2 |
| 172 | + padded = np.zeros( |
| 173 | + (self.target_size, self.target_size, img.shape[2]), dtype=img.dtype |
| 174 | + ) |
| 175 | + padded[pad_h : pad_h + new_h, pad_w : pad_w + new_w, :] = resized |
| 176 | + resized_images.append(padded) |
| 177 | + return np.array(resized_images, dtype="float32") |
| 178 | + |
139 | 179 | def __getitem__(self, idx): |
140 | 180 | s_idx = np.random.choice(len(self.source_x), SOURCE_BATCH_SIZE) |
141 | 181 | t_idx = np.random.choice(len(self.target_x), TARGET_BATCH_SIZE) |
142 | | - |
143 | | - s_imgs = self.resizer(self.source_x[s_idx].astype("float32")) |
| 182 | + s_imgs = self.resize_and_pad(self.source_x[s_idx]) |
144 | 183 | s_imgs = ops.tile(s_imgs, (1, 1, 1, 3)) |
145 | 184 |
|
146 | 185 | t_imgs = self.target_x[t_idx].astype("float32") |
@@ -279,12 +318,11 @@ def compute_loss(self, x=None, y_true=None, y_pred=None, sample_weight=None): |
279 | 318 | loss_func = keras.losses.CategoricalCrossentropy(from_logits=True) |
280 | 319 |
|
281 | 320 | ## Compute losses (pay attention to the indexing) ## |
282 | | - source_loss = ( |
283 | | - loss_func(source_labels, final_source_logits[:SOURCE_BATCH_SIZE]) |
284 | | - + loss_func( |
285 | | - source_labels, final_source_logits[SOURCE_BATCH_SIZE:total_source] |
286 | | - ) |
287 | | - ) / 2 |
| 321 | + source_loss = loss_func( |
| 322 | + source_labels, final_source_logits[:SOURCE_BATCH_SIZE] |
| 323 | + ) + loss_func( |
| 324 | + source_labels, final_source_logits[SOURCE_BATCH_SIZE:total_source] |
| 325 | + ) |
288 | 326 |
|
289 | 327 | target_loss = ops.mean( |
290 | 328 | keras.losses.categorical_crossentropy( |
@@ -368,7 +406,13 @@ def wide_basic(x, n_input_plane, n_output_plane, stride): |
368 | 406 | x = layers.Activation("relu")(x) |
369 | 407 |
|
370 | 408 | shortcut = layers.Conv2D( |
371 | | - n_output_plane, (1, 1), strides=stride, padding="same", use_bias=False |
| 409 | + n_output_plane, |
| 410 | + (1, 1), |
| 411 | + strides=stride, |
| 412 | + padding="same", |
| 413 | + use_bias=False, |
| 414 | + kernel_initializer=INIT, |
| 415 | + kernel_regularizer=keras.regularizers.l2(WEIGHT_DECAY), |
372 | 416 | )(x) |
373 | 417 |
|
374 | 418 | convs = layers.Conv2D( |
|
0 commit comments