22Title: Using the Forward-Forward Algorithm for Image Classification
33Author: [Suvaditya Mukherjee](https://twitter.com/halcyonrayes)
44Date created: 2023/01/08
5- Last modified: 2023/01/08
5+ Last modified: 2024/09/17
66Description: Training a Dense-layer model using the Forward-Forward algorithm.
77Accelerator: GPU
88"""
5959"""
6060## Setup imports
6161"""
62+ import os
63+
64+ os .environ ["KERAS_BACKEND" ] = "tensorflow"
6265
6366import tensorflow as tf
64- from tensorflow import keras
67+ import keras
68+ from keras import ops
6569import numpy as np
6670import matplotlib .pyplot as plt
6771from sklearn .metrics import accuracy_score
@@ -143,7 +147,7 @@ class FFDense(keras.layers.Layer):
143147 def __init__ (
144148 self ,
145149 units ,
146- optimizer ,
150+ init_optimizer ,
147151 loss_metric ,
148152 num_epochs = 50 ,
149153 use_bias = True ,
@@ -163,7 +167,7 @@ def __init__(
163167 bias_regularizer = bias_regularizer ,
164168 )
165169 self .relu = keras .layers .ReLU ()
166- self .optimizer = optimizer
170+ self .optimizer = init_optimizer ()
167171 self .loss_metric = loss_metric
168172 self .threshold = 1.5
169173 self .num_epochs = num_epochs
@@ -172,7 +176,7 @@ def __init__(
172176 # layer.
173177
174178 def call (self , x ):
175- x_norm = tf .norm (x , ord = 2 , axis = 1 , keepdims = True )
179+ x_norm = ops .norm (x , ord = 2 , axis = 1 , keepdims = True )
176180 x_norm = x_norm + 1e-4
177181 x_dir = x / x_norm
178182 res = self .dense (x_dir )
@@ -192,22 +196,24 @@ def call(self, x):
192196 def forward_forward (self , x_pos , x_neg ):
193197 for i in range (self .num_epochs ):
194198 with tf .GradientTape () as tape :
195- g_pos = tf . math . reduce_mean ( tf . math . pow (self .call (x_pos ), 2 ), 1 )
196- g_neg = tf . math . reduce_mean ( tf . math . pow (self .call (x_neg ), 2 ), 1 )
199+ g_pos = ops . mean ( ops . power (self .call (x_pos ), 2 ), 1 )
200+ g_neg = ops . mean ( ops . power (self .call (x_neg ), 2 ), 1 )
197201
198- loss = tf . math .log (
202+ loss = ops .log (
199203 1
200- + tf .math .exp (
201- tf .concat ([- g_pos + self .threshold , g_neg - self .threshold ], 0 )
204+ + ops .exp (
205+ ops .concatenate (
206+ [- g_pos + self .threshold , g_neg - self .threshold ], 0
207+ )
202208 )
203209 )
204- mean_loss = tf .cast (tf . math . reduce_mean (loss ), tf . float32 )
210+ mean_loss = ops .cast (ops . mean (loss ), dtype = " float32" )
205211 self .loss_metric .update_state ([mean_loss ])
206212 gradients = tape .gradient (mean_loss , self .dense .trainable_weights )
207213 self .optimizer .apply_gradients (zip (gradients , self .dense .trainable_weights ))
208214 return (
209- tf .stop_gradient (self .call (x_pos )),
210- tf .stop_gradient (self .call (x_neg )),
215+ ops .stop_gradient (self .call (x_pos )),
216+ ops .stop_gradient (self .call (x_neg )),
211217 self .loss_metric .result (),
212218 )
213219
@@ -248,25 +254,24 @@ class FFNetwork(keras.Model):
248254 # the `Adam` optimizer with a default learning rate of 0.03 as that was
249255 # found to be the best rate after experimentation.
250256 # Loss is tracked using `loss_var` and `loss_count` variables.
251- # Use legacy optimizer for Layer Optimizer to fix issue
252- # https://github.com/keras-team/keras-io/issues/1241
253257
254258 def __init__ (
255259 self ,
256260 dims ,
257- layer_optimizer = keras .optimizers . legacy .Adam (learning_rate = 0.03 ),
261+ init_layer_optimizer = lambda : keras .optimizers .Adam (learning_rate = 0.03 ),
258262 ** kwargs ,
259263 ):
260264 super ().__init__ (** kwargs )
261- self .layer_optimizer = layer_optimizer
262- self .loss_var = tf .Variable (0.0 , trainable = False , dtype = tf . float32 )
263- self .loss_count = tf .Variable (0.0 , trainable = False , dtype = tf . float32 )
265+ self .init_layer_optimizer = init_layer_optimizer
266+ self .loss_var = keras .Variable (0.0 , trainable = False , dtype = " float32" )
267+ self .loss_count = keras .Variable (0.0 , trainable = False , dtype = " float32" )
264268 self .layer_list = [keras .Input (shape = (dims [0 ],))]
269+ self .metrics_built = False
265270 for d in range (len (dims ) - 1 ):
266271 self .layer_list += [
267272 FFDense (
268273 dims [d + 1 ],
269- optimizer = self .layer_optimizer ,
274+ init_optimizer = self .init_layer_optimizer ,
270275 loss_metric = keras .metrics .Mean (),
271276 )
272277 ]
@@ -280,9 +285,9 @@ def __init__(
280285 @tf .function (reduce_retracing = True )
281286 def overlay_y_on_x (self , data ):
282287 X_sample , y_sample = data
283- max_sample = tf . reduce_max (X_sample , axis = 0 , keepdims = True )
284- max_sample = tf .cast (max_sample , dtype = tf . float64 )
285- X_zeros = tf .zeros ([10 ], dtype = tf . float64 )
288+ max_sample = ops . amax (X_sample , axis = 0 , keepdims = True )
289+ max_sample = ops .cast (max_sample , dtype = " float64" )
290+ X_zeros = ops .zeros ([10 ], dtype = " float64" )
286291 X_update = xla .dynamic_update_slice (X_zeros , max_sample , [y_sample ])
287292 X_sample = xla .dynamic_update_slice (X_sample , X_update , [0 ])
288293 return X_sample , y_sample
@@ -297,25 +302,23 @@ def overlay_y_on_x(self, data):
297302 @tf .function (reduce_retracing = True )
298303 def predict_one_sample (self , x ):
299304 goodness_per_label = []
300- x = tf .reshape (x , [tf .shape (x )[0 ] * tf .shape (x )[1 ]])
305+ x = ops .reshape (x , [ops .shape (x )[0 ] * ops .shape (x )[1 ]])
301306 for label in range (10 ):
302307 h , label = self .overlay_y_on_x (data = (x , label ))
303- h = tf .reshape (h , [- 1 , tf .shape (h )[0 ]])
308+ h = ops .reshape (h , [- 1 , ops .shape (h )[0 ]])
304309 goodness = []
305310 for layer_idx in range (1 , len (self .layer_list )):
306311 layer = self .layer_list [layer_idx ]
307312 h = layer (h )
308- goodness += [tf .math .reduce_mean (tf .math .pow (h , 2 ), 1 )]
309- goodness_per_label += [
310- tf .expand_dims (tf .reduce_sum (goodness , keepdims = True ), 1 )
311- ]
313+ goodness += [ops .mean (ops .power (h , 2 ), 1 )]
314+ goodness_per_label += [ops .expand_dims (ops .sum (goodness , keepdims = True ), 1 )]
312315 goodness_per_label = tf .concat (goodness_per_label , 1 )
313- return tf .cast (tf .argmax (goodness_per_label , 1 ), tf . float64 )
316+ return ops .cast (ops .argmax (goodness_per_label , 1 ), dtype = " float64" )
314317
315318 def predict (self , data ):
316319 x = data
317320 preds = list ()
318- preds = tf . map_fn ( fn = self .predict_one_sample , elems = x )
321+ preds = ops . vectorized_map ( self .predict_one_sample , x )
319322 return np .asarray (preds , dtype = int )
320323
321324 # This custom `train_step` function overrides the internal `train_step`
@@ -328,17 +331,26 @@ def predict(self, data):
328331 # the Forward-Forward computation on it. The returned loss is the final
329332 # loss value over all the layers.
330333
331- @tf .function (jit_compile = True )
334+ @tf .function (jit_compile = False )
332335 def train_step (self , data ):
333336 x , y = data
334337
338+ if not self .metrics_built :
339+ # build metrics to ensure they can be queried without erroring out.
340+ # We can't update the metrics' state, as we would usually do, since
341+ # we do not perform predictions within the train step
342+ for metric in self .metrics :
343+ if hasattr (metric , "build" ):
344+ metric .build (y , y )
345+ self .metrics_built = True
346+
335347 # Flatten op
336- x = tf .reshape (x , [- 1 , tf .shape (x )[1 ] * tf .shape (x )[2 ]])
348+ x = ops .reshape (x , [- 1 , ops .shape (x )[1 ] * ops .shape (x )[2 ]])
337349
338- x_pos , y = tf . map_fn ( fn = self .overlay_y_on_x , elems = (x , y ))
350+ x_pos , y = ops . vectorized_map ( self .overlay_y_on_x , (x , y ))
339351
340352 random_y = tf .random .shuffle (y )
341- x_neg , y = tf .map_fn (fn = self .overlay_y_on_x , elems = (x , random_y ))
353+ x_neg , y = tf .map_fn (self .overlay_y_on_x , (x , random_y ))
342354
343355 h_pos , h_neg = x_pos , x_neg
344356
@@ -351,7 +363,7 @@ def train_step(self, data):
351363 else :
352364 print (f"Passing layer { idx + 1 } now : " )
353365 x = layer (x )
354- mean_res = tf . math .divide (self .loss_var , self .loss_count )
366+ mean_res = ops .divide (self .loss_var , self .loss_count )
355367 return {"FinalLoss" : mean_res }
356368
357369
@@ -386,8 +398,8 @@ def train_step(self, data):
386398model .compile (
387399 optimizer = keras .optimizers .Adam (learning_rate = 0.03 ),
388400 loss = "mse" ,
389- jit_compile = True ,
390- metrics = [keras . metrics . Mean () ],
401+ jit_compile = False ,
402+ metrics = [],
391403)
392404
393405epochs = 250
@@ -400,7 +412,7 @@ def train_step(self, data):
400412test set. We calculate the Accuracy Score to understand the results closely.
401413"""
402414
403- preds = model .predict (tf .convert_to_tensor (x_test ))
415+ preds = model .predict (ops .convert_to_tensor (x_test ))
404416
405417preds = preds .reshape ((preds .shape [0 ], preds .shape [1 ]))
406418
0 commit comments