22Title: Semi-supervision and domain adaptation with AdaMatch
33Author: [Sayak Paul](https://twitter.com/RisingSayak)
44Date created: 2021/06/19
5- Last modified: 2026/03/10
5+ Last modified: 2026/03/11
66Description: Unifying semi-supervised learning and unsupervised domain adaptation with AdaMatch.
77Accelerator: GPU
88Converted to Keras 3 by: [Maitry Sinha](https://github.com/maitry63)
1818unifies semi-supervised learning (SSL) and unsupervised domain adaptation
1919(UDA) under one framework. It thereby provides a way to perform semi-supervised domain
2020adaptation (SSDA).
21-
22- This example requires TensorFlow 2.5 or higher, as well as TensorFlow Models, which can
23- be installed using the following command:
2421"""
2522
2623"""
@@ -110,7 +107,7 @@ def load_svhn_data():
110107
111108SOURCE_BATCH_SIZE = 64
112109TARGET_BATCH_SIZE = 3 * SOURCE_BATCH_SIZE # Reference: Section 3.2
113- EPOCHS = 5
110+ EPOCHS = 2
114111STEPS_PER_EPOCH = len (mnist_x_train ) // SOURCE_BATCH_SIZE
115112TOTAL_STEPS = EPOCHS * STEPS_PER_EPOCH
116113
@@ -125,7 +122,6 @@ def load_svhn_data():
125122## Data loading utilities
126123"""
127124
128-
129125class AdaMatchDataset (keras .utils .PyDataset ):
130126 def __init__ (self , source_x , source_y , target_x , target_size = 32 , ** kwargs ):
131127 """
@@ -235,6 +231,7 @@ def __init__(self, model, total_steps, tau=0.9):
235231 self .tau = tau
236232 self .total_steps = total_steps
237233 self .current_step = keras .Variable (0.0 , dtype = "float32" )
234+ self .loss_tracker = keras .metrics .Mean (name = "loss" )
238235
239236 self .weak_augment = keras .Sequential (
240237 [
@@ -245,6 +242,10 @@ def __init__(self, model, total_steps, tau=0.9):
245242
246243 rand_aug = layers .RandAugment (value_range = (0 , 255 ), num_ops = 2 , factor = 0.5 )
247244 self .strong_aug = rand_aug
245+
246+ @property
247+ def metrics (self ):
248+ return [self .loss_tracker ]
248249
249250 # This is a warmup schedule to update the weight of the
250251 # loss contributed by the target unlabeled samples. More
@@ -340,6 +341,7 @@ def compute_loss(self, x=None, y_true=None, y_pred=None, sample_weight=None):
340341
341342 self .current_step .assign_add (1.0 )
342343
344+ self .loss_tracker .update_state (total_loss )
343345 return total_loss
344346
345347
@@ -457,7 +459,7 @@ def get_network():
457459 x = layers .Activation ("relu" )(x )
458460 x = layers .GlobalAveragePooling2D ()(x )
459461
460- outputs = layers .Dense (10 )(x )
462+ outputs = layers .Dense (10 , kernel_regularizer = keras . regularizers . l2 ( WEIGHT_DECAY ) )(x )
461463 return keras .Model (inputs , outputs )
462464
463465
@@ -494,7 +496,7 @@ def get_network():
494496"""
495497
496498adamatch_trained_model = adamatch_trainer .model
497- adamatch_trained_model .compile (metrics = keras .metrics .SparseCategoricalAccuracy ())
499+ adamatch_trained_model .compile (metrics = [ keras .metrics .SparseCategoricalAccuracy ()] )
498500
499501test_path = keras .utils .get_file (
500502 "test_32x32.mat" ,
0 commit comments