Skip to content

Commit b37c390

Browse files
committed
fix: output layer
1 parent 6a4df47 commit b37c390

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

examples/vision/adamatch.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Title: Semi-supervision and domain adaptation with AdaMatch
33
Author: [Sayak Paul](https://twitter.com/RisingSayak)
44
Date created: 2021/06/19
5-
Last modified: 2026/03/10
5+
Last modified: 2026/03/11
66
Description: Unifying semi-supervised learning and unsupervised domain adaptation with AdaMatch.
77
Accelerator: GPU
88
Converted to Keras 3 by: [Maitry Sinha](https://github.com/maitry63)
@@ -18,9 +18,6 @@
1818
unifies semi-supervised learning (SSL) and unsupervised domain adaptation
1919
(UDA) under one framework. It thereby provides a way to perform semi-supervised domain
2020
adaptation (SSDA).
21-
22-
This example requires TensorFlow 2.5 or higher, as well as TensorFlow Models, which can
23-
be installed using the following command:
2421
"""
2522

2623
"""
@@ -110,7 +107,7 @@ def load_svhn_data():
110107

111108
SOURCE_BATCH_SIZE = 64
112109
TARGET_BATCH_SIZE = 3 * SOURCE_BATCH_SIZE # Reference: Section 3.2
113-
EPOCHS = 5
110+
EPOCHS = 2
114111
STEPS_PER_EPOCH = len(mnist_x_train) // SOURCE_BATCH_SIZE
115112
TOTAL_STEPS = EPOCHS * STEPS_PER_EPOCH
116113

@@ -125,7 +122,6 @@ def load_svhn_data():
125122
## Data loading utilities
126123
"""
127124

128-
129125
class AdaMatchDataset(keras.utils.PyDataset):
130126
def __init__(self, source_x, source_y, target_x, target_size=32, **kwargs):
131127
"""
@@ -235,6 +231,7 @@ def __init__(self, model, total_steps, tau=0.9):
235231
self.tau = tau
236232
self.total_steps = total_steps
237233
self.current_step = keras.Variable(0.0, dtype="float32")
234+
self.loss_tracker = keras.metrics.Mean(name="loss")
238235

239236
self.weak_augment = keras.Sequential(
240237
[
@@ -245,6 +242,10 @@ def __init__(self, model, total_steps, tau=0.9):
245242

246243
rand_aug = layers.RandAugment(value_range=(0, 255), num_ops=2, factor=0.5)
247244
self.strong_aug = rand_aug
245+
246+
@property
247+
def metrics(self):
248+
return [self.loss_tracker]
248249

249250
# This is a warmup schedule to update the weight of the
250251
# loss contributed by the target unlabeled samples. More
@@ -340,6 +341,7 @@ def compute_loss(self, x=None, y_true=None, y_pred=None, sample_weight=None):
340341

341342
self.current_step.assign_add(1.0)
342343

344+
self.loss_tracker.update_state(total_loss)
343345
return total_loss
344346

345347

@@ -457,7 +459,7 @@ def get_network():
457459
x = layers.Activation("relu")(x)
458460
x = layers.GlobalAveragePooling2D()(x)
459461

460-
outputs = layers.Dense(10)(x)
462+
outputs = layers.Dense(10, kernel_regularizer=keras.regularizers.l2(WEIGHT_DECAY))(x)
461463
return keras.Model(inputs, outputs)
462464

463465

@@ -494,7 +496,7 @@ def get_network():
494496
"""
495497

496498
adamatch_trained_model = adamatch_trainer.model
497-
adamatch_trained_model.compile(metrics=keras.metrics.SparseCategoricalAccuracy())
499+
adamatch_trained_model.compile(metrics=[keras.metrics.SparseCategoricalAccuracy()])
498500

499501
test_path = keras.utils.get_file(
500502
"test_32x32.mat",

0 commit comments

Comments
 (0)