Skip to content

Commit 6a4df47

Browse files
committed
fix: address the gemini review
1 parent b8ba1cd commit 6a4df47

File tree

1 file changed

+56
-12
lines changed

1 file changed

+56
-12
lines changed

examples/vision/adamatch.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Title: Semi-supervision and domain adaptation with AdaMatch
33
Author: [Sayak Paul](https://twitter.com/RisingSayak)
44
Date created: 2021/06/19
5-
Last modified: 2026/03/09
5+
Last modified: 2026/03/10
66
Description: Unifying semi-supervised learning and unsupervised domain adaptation with AdaMatch.
77
Accelerator: GPU
88
Converted to Keras 3 by: [Maitry Sinha](https://github.com/maitry63)
@@ -70,6 +70,7 @@
7070
import numpy as np
7171
from keras import layers, ops
7272
import scipy.io
73+
from PIL import Image
7374

7475
"""
7576
## Prepare the data
@@ -126,21 +127,59 @@ def load_svhn_data():
126127

127128

128129
class AdaMatchDataset(keras.utils.PyDataset):
129-
def __init__(self, source_x, source_y, target_x, **kwargs):
130+
def __init__(self, source_x, source_y, target_x, target_size=32, **kwargs):
131+
"""
132+
Dataset for AdaMatch training.
133+
Performs resize-and-pad on source images to preserve aspect ratio,
134+
then tiles them to 3 channels if needed.
135+
"""
130136
super().__init__(**kwargs)
131137
self.source_x = source_x
132138
self.source_y = source_y
133139
self.target_x = target_x
134-
self.resizer = layers.Resizing(RESIZE_TO, RESIZE_TO)
140+
self.target_size = target_size
135141

136142
def __len__(self):
137143
return STEPS_PER_EPOCH
138144

145+
def resize_and_pad(self, images):
146+
"""
147+
Resize images to target_size x target_size while preserving aspect ratio.
148+
Pads with zeros if necessary.
149+
"""
150+
resized_images = []
151+
for img in images:
152+
img = np.squeeze(img)
153+
if img.ndim == 2:
154+
img = np.expand_dims(img, -1) # grayscale to (H,W,1)
155+
h, w = img.shape[:2]
156+
scale = self.target_size / max(h, w)
157+
new_h = int(h * scale)
158+
new_w = int(w * scale)
159+
if img.shape[2] == 1:
160+
pil_img = Image.fromarray(img[:, :, 0])
161+
else:
162+
pil_img = Image.fromarray(img.astype(np.uint8))
163+
pil_resized = pil_img.resize((new_w, new_h), Image.BILINEAR)
164+
resized = (
165+
np.expand_dims(np.array(pil_resized), -1)
166+
if img.shape[2] == 1
167+
else np.array(pil_resized)
168+
)
169+
# Pad
170+
pad_h = (self.target_size - new_h) // 2
171+
pad_w = (self.target_size - new_w) // 2
172+
padded = np.zeros(
173+
(self.target_size, self.target_size, img.shape[2]), dtype=img.dtype
174+
)
175+
padded[pad_h : pad_h + new_h, pad_w : pad_w + new_w, :] = resized
176+
resized_images.append(padded)
177+
return np.array(resized_images, dtype="float32")
178+
139179
def __getitem__(self, idx):
140180
s_idx = np.random.choice(len(self.source_x), SOURCE_BATCH_SIZE)
141181
t_idx = np.random.choice(len(self.target_x), TARGET_BATCH_SIZE)
142-
143-
s_imgs = self.resizer(self.source_x[s_idx].astype("float32"))
182+
s_imgs = self.resize_and_pad(self.source_x[s_idx])
144183
s_imgs = ops.tile(s_imgs, (1, 1, 1, 3))
145184

146185
t_imgs = self.target_x[t_idx].astype("float32")
@@ -279,12 +318,11 @@ def compute_loss(self, x=None, y_true=None, y_pred=None, sample_weight=None):
279318
loss_func = keras.losses.CategoricalCrossentropy(from_logits=True)
280319

281320
## Compute losses (pay attention to the indexing) ##
282-
source_loss = (
283-
loss_func(source_labels, final_source_logits[:SOURCE_BATCH_SIZE])
284-
+ loss_func(
285-
source_labels, final_source_logits[SOURCE_BATCH_SIZE:total_source]
286-
)
287-
) / 2
321+
source_loss = loss_func(
322+
source_labels, final_source_logits[:SOURCE_BATCH_SIZE]
323+
) + loss_func(
324+
source_labels, final_source_logits[SOURCE_BATCH_SIZE:total_source]
325+
)
288326

289327
target_loss = ops.mean(
290328
keras.losses.categorical_crossentropy(
@@ -368,7 +406,13 @@ def wide_basic(x, n_input_plane, n_output_plane, stride):
368406
x = layers.Activation("relu")(x)
369407

370408
shortcut = layers.Conv2D(
371-
n_output_plane, (1, 1), strides=stride, padding="same", use_bias=False
409+
n_output_plane,
410+
(1, 1),
411+
strides=stride,
412+
padding="same",
413+
use_bias=False,
414+
kernel_initializer=INIT,
415+
kernel_regularizer=keras.regularizers.l2(WEIGHT_DECAY),
372416
)(x)
373417

374418
convs = layers.Conv2D(

0 commit comments

Comments
 (0)