|
2 | 2 | Title: Semi-supervision and domain adaptation with AdaMatch |
3 | 3 | Author: [Sayak Paul](https://twitter.com/RisingSayak) |
4 | 4 | Date created: 2021/06/19 |
5 | | -Last modified: 2026/03/11 |
| 5 | +Last modified: 2026/03/12 |
6 | 6 | Description: Unifying semi-supervised learning and unsupervised domain adaptation with AdaMatch. |
7 | 7 | Accelerator: GPU |
8 | 8 | Converted to Keras 3 by: [Maitry Sinha](https://github.com/maitry63) |
@@ -122,6 +122,7 @@ def load_svhn_data(): |
122 | 122 | ## Data loading utilities |
123 | 123 | """ |
124 | 124 |
|
| 125 | + |
125 | 126 | class AdaMatchDataset(keras.utils.PyDataset): |
126 | 127 | def __init__(self, source_x, source_y, target_x, target_size=32, **kwargs): |
127 | 128 | """ |
@@ -242,7 +243,7 @@ def __init__(self, model, total_steps, tau=0.9): |
242 | 243 |
|
243 | 244 | rand_aug = layers.RandAugment(value_range=(0, 255), num_ops=2, factor=0.5) |
244 | 245 | self.strong_aug = rand_aug |
245 | | - |
| 246 | + |
246 | 247 | @property |
247 | 248 | def metrics(self): |
248 | 249 | return [self.loss_tracker] |
@@ -459,7 +460,9 @@ def get_network(): |
459 | 460 | x = layers.Activation("relu")(x) |
460 | 461 | x = layers.GlobalAveragePooling2D()(x) |
461 | 462 |
|
462 | | - outputs = layers.Dense(10, kernel_regularizer=keras.regularizers.l2(WEIGHT_DECAY))(x) |
| 463 | + outputs = layers.Dense(10, kernel_regularizer=keras.regularizers.l2(WEIGHT_DECAY))( |
| 464 | + x |
| 465 | + ) |
463 | 466 | return keras.Model(inputs, outputs) |
464 | 467 |
|
465 | 468 |
|
|
0 commit comments