Skip to content

Commit ba408a9

Browse files
authored
migrate metric learning example to keras 3
Refactor loss function and tensor handling in metric learning for compatibility with Keras 3.
1 parent b6949c5 commit ba408a9

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

examples/vision/metric_learning_tf_similarity.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
import numpy as np
5555

5656
import tensorflow as tf
57-
from tensorflow import keras
57+
import keras
5858

5959
import tensorflow_similarity as tfsim
6060

@@ -216,13 +216,14 @@
216216
val_steps = 50
217217

218218
# init similarity loss
219-
loss = tfsim.losses.MultiSimilarityLoss()
219+
loss = tfsim.losses.MultiSimilarityLoss(reduction='sum_over_batch_size')
220220

221221
# compiling and training
222222
model.compile(
223223
optimizer=keras.optimizers.Adam(learning_rate),
224224
loss=loss,
225225
steps_per_execution=10,
226+
run_eagerly=True,
226227
)
227228
history = model.fit(
228229
train_ds, epochs=epochs, validation_data=val_ds, validation_steps=val_steps
@@ -321,7 +322,7 @@
321322
for idx in np.argsort(y_display):
322323
tfsim.visualization.viz_neigbors_imgs(
323324
x_display[idx],
324-
y_display[idx],
325+
y_display[idx].numpy(),
325326
nns[idx],
326327
class_mapping=class_mapping,
327328
fig_size=(16, 2),
@@ -394,7 +395,7 @@
394395
"""
395396

396397
idx_no_match = np.where(np.array(matches) == 10)
397-
no_match_queries = x_confusion[idx_no_match]
398+
no_match_queries = no_match_queries = keras.ops.take(x_confusion, keras.ops.cast(idx_no_match[0], dtype="int32"), axis=0)
398399
if len(no_match_queries):
399400
plt.imshow(no_match_queries[0])
400401
else:

0 commit comments

Comments
 (0)