Skip to content

Commit 98359d8

Browse files
authored
Using the Forward-Forward Algorithm for Image Classification to keras 3.0 (Tensorflow backend only) (#1932)
* migration to keras3 * add md and ipynb files
1 parent 67f981b commit 98359d8

File tree

5 files changed

+678
-381
lines changed

5 files changed

+678
-381
lines changed

examples/vision/forwardforward.py

Lines changed: 51 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Title: Using the Forward-Forward Algorithm for Image Classification
33
Author: [Suvaditya Mukherjee](https://twitter.com/halcyonrayes)
44
Date created: 2023/01/08
5-
Last modified: 2023/01/08
5+
Last modified: 2024/09/17
66
Description: Training a Dense-layer model using the Forward-Forward algorithm.
77
Accelerator: GPU
88
"""
@@ -59,9 +59,13 @@
5959
"""
6060
## Setup imports
6161
"""
62+
import os
63+
64+
os.environ["KERAS_BACKEND"] = "tensorflow"
6265

6366
import tensorflow as tf
64-
from tensorflow import keras
67+
import keras
68+
from keras import ops
6569
import numpy as np
6670
import matplotlib.pyplot as plt
6771
from sklearn.metrics import accuracy_score
@@ -143,7 +147,7 @@ class FFDense(keras.layers.Layer):
143147
def __init__(
144148
self,
145149
units,
146-
optimizer,
150+
init_optimizer,
147151
loss_metric,
148152
num_epochs=50,
149153
use_bias=True,
@@ -163,7 +167,7 @@ def __init__(
163167
bias_regularizer=bias_regularizer,
164168
)
165169
self.relu = keras.layers.ReLU()
166-
self.optimizer = optimizer
170+
self.optimizer = init_optimizer()
167171
self.loss_metric = loss_metric
168172
self.threshold = 1.5
169173
self.num_epochs = num_epochs
@@ -172,7 +176,7 @@ def __init__(
172176
# layer.
173177

174178
def call(self, x):
175-
x_norm = tf.norm(x, ord=2, axis=1, keepdims=True)
179+
x_norm = ops.norm(x, ord=2, axis=1, keepdims=True)
176180
x_norm = x_norm + 1e-4
177181
x_dir = x / x_norm
178182
res = self.dense(x_dir)
@@ -192,22 +196,24 @@ def call(self, x):
192196
def forward_forward(self, x_pos, x_neg):
193197
for i in range(self.num_epochs):
194198
with tf.GradientTape() as tape:
195-
g_pos = tf.math.reduce_mean(tf.math.pow(self.call(x_pos), 2), 1)
196-
g_neg = tf.math.reduce_mean(tf.math.pow(self.call(x_neg), 2), 1)
199+
g_pos = ops.mean(ops.power(self.call(x_pos), 2), 1)
200+
g_neg = ops.mean(ops.power(self.call(x_neg), 2), 1)
197201

198-
loss = tf.math.log(
202+
loss = ops.log(
199203
1
200-
+ tf.math.exp(
201-
tf.concat([-g_pos + self.threshold, g_neg - self.threshold], 0)
204+
+ ops.exp(
205+
ops.concatenate(
206+
[-g_pos + self.threshold, g_neg - self.threshold], 0
207+
)
202208
)
203209
)
204-
mean_loss = tf.cast(tf.math.reduce_mean(loss), tf.float32)
210+
mean_loss = ops.cast(ops.mean(loss), dtype="float32")
205211
self.loss_metric.update_state([mean_loss])
206212
gradients = tape.gradient(mean_loss, self.dense.trainable_weights)
207213
self.optimizer.apply_gradients(zip(gradients, self.dense.trainable_weights))
208214
return (
209-
tf.stop_gradient(self.call(x_pos)),
210-
tf.stop_gradient(self.call(x_neg)),
215+
ops.stop_gradient(self.call(x_pos)),
216+
ops.stop_gradient(self.call(x_neg)),
211217
self.loss_metric.result(),
212218
)
213219

@@ -248,25 +254,24 @@ class FFNetwork(keras.Model):
248254
# the `Adam` optimizer with a default learning rate of 0.03 as that was
249255
# found to be the best rate after experimentation.
250256
# Loss is tracked using `loss_var` and `loss_count` variables.
251-
# Use legacy optimizer for Layer Optimizer to fix issue
252-
# https://github.com/keras-team/keras-io/issues/1241
253257

254258
def __init__(
255259
self,
256260
dims,
257-
layer_optimizer=keras.optimizers.legacy.Adam(learning_rate=0.03),
261+
init_layer_optimizer=lambda: keras.optimizers.Adam(learning_rate=0.03),
258262
**kwargs,
259263
):
260264
super().__init__(**kwargs)
261-
self.layer_optimizer = layer_optimizer
262-
self.loss_var = tf.Variable(0.0, trainable=False, dtype=tf.float32)
263-
self.loss_count = tf.Variable(0.0, trainable=False, dtype=tf.float32)
265+
self.init_layer_optimizer = init_layer_optimizer
266+
self.loss_var = keras.Variable(0.0, trainable=False, dtype="float32")
267+
self.loss_count = keras.Variable(0.0, trainable=False, dtype="float32")
264268
self.layer_list = [keras.Input(shape=(dims[0],))]
269+
self.metrics_built = False
265270
for d in range(len(dims) - 1):
266271
self.layer_list += [
267272
FFDense(
268273
dims[d + 1],
269-
optimizer=self.layer_optimizer,
274+
init_optimizer=self.init_layer_optimizer,
270275
loss_metric=keras.metrics.Mean(),
271276
)
272277
]
@@ -280,9 +285,9 @@ def __init__(
280285
@tf.function(reduce_retracing=True)
281286
def overlay_y_on_x(self, data):
282287
X_sample, y_sample = data
283-
max_sample = tf.reduce_max(X_sample, axis=0, keepdims=True)
284-
max_sample = tf.cast(max_sample, dtype=tf.float64)
285-
X_zeros = tf.zeros([10], dtype=tf.float64)
288+
max_sample = ops.amax(X_sample, axis=0, keepdims=True)
289+
max_sample = ops.cast(max_sample, dtype="float64")
290+
X_zeros = ops.zeros([10], dtype="float64")
286291
X_update = xla.dynamic_update_slice(X_zeros, max_sample, [y_sample])
287292
X_sample = xla.dynamic_update_slice(X_sample, X_update, [0])
288293
return X_sample, y_sample
@@ -297,25 +302,23 @@ def overlay_y_on_x(self, data):
297302
@tf.function(reduce_retracing=True)
298303
def predict_one_sample(self, x):
299304
goodness_per_label = []
300-
x = tf.reshape(x, [tf.shape(x)[0] * tf.shape(x)[1]])
305+
x = ops.reshape(x, [ops.shape(x)[0] * ops.shape(x)[1]])
301306
for label in range(10):
302307
h, label = self.overlay_y_on_x(data=(x, label))
303-
h = tf.reshape(h, [-1, tf.shape(h)[0]])
308+
h = ops.reshape(h, [-1, ops.shape(h)[0]])
304309
goodness = []
305310
for layer_idx in range(1, len(self.layer_list)):
306311
layer = self.layer_list[layer_idx]
307312
h = layer(h)
308-
goodness += [tf.math.reduce_mean(tf.math.pow(h, 2), 1)]
309-
goodness_per_label += [
310-
tf.expand_dims(tf.reduce_sum(goodness, keepdims=True), 1)
311-
]
313+
goodness += [ops.mean(ops.power(h, 2), 1)]
314+
goodness_per_label += [ops.expand_dims(ops.sum(goodness, keepdims=True), 1)]
312315
goodness_per_label = tf.concat(goodness_per_label, 1)
313-
return tf.cast(tf.argmax(goodness_per_label, 1), tf.float64)
316+
return ops.cast(ops.argmax(goodness_per_label, 1), dtype="float64")
314317

315318
def predict(self, data):
316319
x = data
317320
preds = list()
318-
preds = tf.map_fn(fn=self.predict_one_sample, elems=x)
321+
preds = ops.vectorized_map(self.predict_one_sample, x)
319322
return np.asarray(preds, dtype=int)
320323

321324
# This custom `train_step` function overrides the internal `train_step`
@@ -328,17 +331,26 @@ def predict(self, data):
328331
# the Forward-Forward computation on it. The returned loss is the final
329332
# loss value over all the layers.
330333

331-
@tf.function(jit_compile=True)
334+
@tf.function(jit_compile=False)
332335
def train_step(self, data):
333336
x, y = data
334337

338+
if not self.metrics_built:
339+
# build metrics to ensure they can be queried without erroring out.
340+
# We can't update the metrics' state, as we would usually do, since
341+
# we do not perform predictions within the train step
342+
for metric in self.metrics:
343+
if hasattr(metric, "build"):
344+
metric.build(y, y)
345+
self.metrics_built = True
346+
335347
# Flatten op
336-
x = tf.reshape(x, [-1, tf.shape(x)[1] * tf.shape(x)[2]])
348+
x = ops.reshape(x, [-1, ops.shape(x)[1] * ops.shape(x)[2]])
337349

338-
x_pos, y = tf.map_fn(fn=self.overlay_y_on_x, elems=(x, y))
350+
x_pos, y = ops.vectorized_map(self.overlay_y_on_x, (x, y))
339351

340352
random_y = tf.random.shuffle(y)
341-
x_neg, y = tf.map_fn(fn=self.overlay_y_on_x, elems=(x, random_y))
353+
x_neg, y = tf.map_fn(self.overlay_y_on_x, (x, random_y))
342354

343355
h_pos, h_neg = x_pos, x_neg
344356

@@ -351,7 +363,7 @@ def train_step(self, data):
351363
else:
352364
print(f"Passing layer {idx+1} now : ")
353365
x = layer(x)
354-
mean_res = tf.math.divide(self.loss_var, self.loss_count)
366+
mean_res = ops.divide(self.loss_var, self.loss_count)
355367
return {"FinalLoss": mean_res}
356368

357369

@@ -386,8 +398,8 @@ def train_step(self, data):
386398
model.compile(
387399
optimizer=keras.optimizers.Adam(learning_rate=0.03),
388400
loss="mse",
389-
jit_compile=True,
390-
metrics=[keras.metrics.Mean()],
401+
jit_compile=False,
402+
metrics=[],
391403
)
392404

393405
epochs = 250
@@ -400,7 +412,7 @@ def train_step(self, data):
400412
test set. We calculate the Accuracy Score to understand the results closely.
401413
"""
402414

403-
preds = model.predict(tf.convert_to_tensor(x_test))
415+
preds = model.predict(ops.convert_to_tensor(x_test))
404416

405417
preds = preds.reshape((preds.shape[0], preds.shape[1]))
406418

8.56 KB
Loading
10.3 KB
Loading

0 commit comments

Comments
 (0)