22Title: Classification with Gated Residual and Variable Selection Networks
33Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
44Date created: 2021/02/10
5- Last modified: 2021/02/10
5+ Last modified: 2025/01/08
66Description: Using Gated Residual and Variable Selection Networks for income level prediction.
77Accelerator: GPU
8+ Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT) and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)
89"""
910
1011"""
4647"""
4748
4849import os
50+ import subprocess
51+ import tarfile
4952
50- # Only the TensorFlow backend supports string inputs.
51- os .environ ["KERAS_BACKEND" ] = "tensorflow"
53+ os .environ ["KERAS_BACKEND" ] = "torch" # or jax, or tensorflow
5254
5355import numpy as np
5456import pandas as pd
55- import tensorflow as tf
5657import keras
5758from keras import layers
5859
108109 "income_level" ,
109110]
110111
111- data_url = "https://archive.ics.uci.edu/static/public/20 /census+income.zip"
112+ data_url = "https://archive.ics.uci.edu/static/public/117 /census+income+kdd .zip"
112113keras .utils .get_file (origin = data_url , extract = True )
114+
115+ """
116+ Determine the downloaded .tar.gz file path and
117+ extract the files from the downloaded .tar.gz file
118+ """
119+
120+ extracted_path = os .path .join (
121+ os .path .expanduser ("~" ), ".keras" , "datasets" , "census+income+kdd.zip"
122+ )
123+ for root , dirs , files in os .walk (extracted_path ):
124+ for file in files :
125+ if file .endswith (".tar.gz" ):
126+ tar_gz_path = os .path .join (root , file )
127+ with tarfile .open (tar_gz_path , "r:gz" ) as tar :
128+ tar .extractall (path = root )
129+
113130train_data_path = os .path .join (
114- os .path .expanduser ("~" ), ".keras" , "datasets" , "adult.data"
131+ os .path .expanduser ("~" ),
132+ ".keras" ,
133+ "datasets" ,
134+ "census+income+kdd.zip" ,
135+ "census-income.data" ,
115136)
116137test_data_path = os .path .join (
117- os .path .expanduser ("~" ), ".keras" , "datasets" , "adult.test"
138+ os .path .expanduser ("~" ),
139+ ".keras" ,
140+ "datasets" ,
141+ "census+income+kdd.zip" ,
142+ "census-income.test" ,
118143)
119144
120145data = pd .read_csv (train_data_path , header = None , names = CSV_HEADER )
211236training and evaluation.
212237"""
213238
239+ # Tensorflow required for tf.data.Datasets
240+ import tensorflow as tf
241+
214242
243+ # We process our datasets elements here (categorical) and convert them to indices to avoid this step
244+ # during model training since only tensorflow support strings.
215245def process (features , target ):
216246 for feature_name in features :
217247 if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY :
218248 # Cast categorical feature values to string.
219- features [feature_name ] = keras .ops .cast (features [feature_name ], "string" )
249+ features [feature_name ] = tf .cast (features [feature_name ], "string" )
250+ vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY [feature_name ]
251+ # Create a lookup to convert a string values to an integer indices.
252+ # Since we are not using a mask token nor expecting any out of vocabulary
253+ # (oov) token, we set mask_token to None and num_oov_indices to 0.
254+ index = layers .StringLookup (
255+ vocabulary = vocabulary ,
256+ mask_token = None ,
257+ num_oov_indices = 0 ,
258+ output_mode = "int" ,
259+ )
260+ # Convert the string input values into integer indices.
261+ value_index = index (features [feature_name ])
262+ features [feature_name ] = value_index
263+ else :
264+ # Do nothing for numerical features
265+ pass
266+
220267 # Get the instance weight.
221268 weight = features .pop (WEIGHT_COLUMN_NAME )
222- return features , target , weight
269+ # Change features from OrderedDict to Dict to match Inputs as they are Dict.
270+ return dict (features ), target , weight
223271
224272
225273def get_dataset_from_csv (csv_file_path , shuffle = False , batch_size = 128 ):
@@ -245,56 +293,19 @@ def get_dataset_from_csv(csv_file_path, shuffle=False, batch_size=128):
245293def create_model_inputs ():
246294 inputs = {}
247295 for feature_name in FEATURE_NAMES :
248- if feature_name in NUMERIC_FEATURE_NAMES :
296+ if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY :
297+ # Make them int64, they are Categorical (whole units)
249298 inputs [feature_name ] = layers .Input (
250- name = feature_name , shape = (), dtype = "float32 "
299+ name = feature_name , shape = (), dtype = "int64 "
251300 )
252301 else :
302+ # Make them float32, they are Real numbers
253303 inputs [feature_name ] = layers .Input (
254- name = feature_name , shape = (), dtype = "string "
304+ name = feature_name , shape = (), dtype = "float32 "
255305 )
256306 return inputs
257307
258308
259- """
260- ## Encode input features
261-
262- For categorical features, we encode them using `layers.Embedding` using the
263- `encoding_size` as the embedding dimensions. For the numerical features,
264- we apply linear transformation using `layers.Dense` to project each feature into
265- `encoding_size`-dimensional vector. Thus, all the encoded features will have the
266- same dimensionality.
267-
268- """
269-
270-
271- def encode_inputs (inputs , encoding_size ):
272- encoded_features = []
273- for feature_name in inputs :
274- if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY :
275- vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY [feature_name ]
276- # Create a lookup to convert a string values to an integer indices.
277- # Since we are not using a mask token nor expecting any out of vocabulary
278- # (oov) token, we set mask_token to None and num_oov_indices to 0.
279- index = layers .StringLookup (
280- vocabulary = vocabulary , mask_token = None , num_oov_indices = 0
281- )
282- # Convert the string input values into integer indices.
283- value_index = index (inputs [feature_name ])
284- # Create an embedding layer with the specified dimensions
285- embedding_ecoder = layers .Embedding (
286- input_dim = len (vocabulary ), output_dim = encoding_size
287- )
288- # Convert the index values to embedding representations.
289- encoded_feature = embedding_ecoder (value_index )
290- else :
291- # Project the numeric feature to encoding_size using linear transformation.
292- encoded_feature = keras .ops .expand_dims (inputs [feature_name ], - 1 )
293- encoded_feature = layers .Dense (units = encoding_size )(encoded_feature )
294- encoded_features .append (encoded_feature )
295- return encoded_features
296-
297-
298309"""
299310## Implement the Gated Linear Unit
300311
@@ -312,6 +323,10 @@ def __init__(self, units):
312323 def call (self , inputs ):
313324 return self .linear (inputs ) * self .sigmoid (inputs )
314325
326+ # Remove build warnings
327+ def build (self ):
328+ self .built = True
329+
315330
316331"""
317332## Implement the Gated Residual Network
@@ -347,6 +362,10 @@ def call(self, inputs):
347362 x = self .layer_norm (x )
348363 return x
349364
365+ # Remove build warnings
366+ def build (self ):
367+ self .built = True
368+
350369
351370"""
352371## Implement the Variable Selection Network
@@ -360,12 +379,35 @@ def call(self, inputs):
360379
361380Note that the output of the VSN is [batch_size, encoding_size], regardless of the
362381number of the input features.
382+
383+ For categorical features, we encode them using `layers.Embedding` using the
384+ `encoding_size` as the embedding dimensions. For the numerical features,
385+ we apply linear transformation using `layers.Dense` to project each feature into
386+ `encoding_size`-dimensional vector. Thus, all the encoded features will have the
387+ same dimensionality.
388+
363389"""
364390
365391
366392class VariableSelection (layers .Layer ):
367393 def __init__ (self , num_features , units , dropout_rate ):
368394 super ().__init__ ()
395+ self .units = units
396+ # Create an embedding layers with the specified dimensions
397+ self .embeddings = dict ()
398+ for input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY :
399+ vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY [input_ ]
400+ embedding_encoder = layers .Embedding (
401+ input_dim = len (vocabulary ), output_dim = self .units , name = input_
402+ )
403+ self .embeddings [input_ ] = embedding_encoder
404+
405+ # Projection layers for numeric features
406+ self .proj_layer = dict ()
407+ for input_ in NUMERIC_FEATURE_NAMES :
408+ proj_layer = layers .Dense (units = self .units )
409+ self .proj_layer [input_ ] = proj_layer
410+
369411 self .grns = list ()
370412 # Create a GRN for each feature independently
371413 for idx in range (num_features ):
@@ -376,17 +418,35 @@ def __init__(self, num_features, units, dropout_rate):
376418 self .softmax = layers .Dense (units = num_features , activation = "softmax" )
377419
378420 def call (self , inputs ):
379- v = layers .concatenate (inputs )
421+ concat_inputs = []
422+ for input_ in inputs :
423+ if input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY :
424+ max_index = self .embeddings [input_ ].input_dim - 1 # Clamp the indices
425+ # torch had some index errors during embedding hence the clip function
426+ embedded_feature = self .embeddings [input_ ](
427+ keras .ops .clip (inputs [input_ ], 0 , max_index )
428+ )
429+ concat_inputs .append (embedded_feature )
430+ else :
431+ # Project the numeric feature to encoding_size using linear transformation.
432+ proj_feature = keras .ops .expand_dims (inputs [input_ ], - 1 )
433+ proj_feature = self .proj_layer [input_ ](proj_feature )
434+ concat_inputs .append (proj_feature )
435+
436+ v = layers .concatenate (concat_inputs )
380437 v = self .grn_concat (v )
381438 v = keras .ops .expand_dims (self .softmax (v ), axis = - 1 )
382-
383439 x = []
384- for idx , input in enumerate (inputs ):
440+ for idx , input in enumerate (concat_inputs ):
385441 x .append (self .grns [idx ](input ))
386442 x = keras .ops .stack (x , axis = 1 )
443+ return keras .ops .squeeze (
444+ keras .ops .matmul (keras .ops .transpose (v , axes = [0 , 2 , 1 ]), x ), axis = 1
445+ )
387446
388- outputs = keras .ops .squeeze (tf .matmul (v , x , transpose_a = True ), axis = 1 )
389- return outputs
447+ # to remove the build warnings
448+ def build (self ):
449+ self .built = True
390450
391451
392452"""
@@ -396,14 +456,10 @@ def call(self, inputs):
396456
397457def create_model (encoding_size ):
398458 inputs = create_model_inputs ()
399- feature_list = encode_inputs (inputs , encoding_size )
400- num_features = len (feature_list )
401-
402- features = VariableSelection (num_features , encoding_size , dropout_rate )(
403- feature_list
404- )
405-
459+ num_features = len (inputs )
460+ features = VariableSelection (num_features , encoding_size , dropout_rate )(inputs )
406461 outputs = layers .Dense (units = 1 , activation = "sigmoid" )(features )
462+ # Functional model
407463 model = keras .Model (inputs = inputs , outputs = outputs )
408464 return model
409465
@@ -415,7 +471,7 @@ def create_model(encoding_size):
415471learning_rate = 0.001
416472dropout_rate = 0.15
417473batch_size = 265
418- num_epochs = 20
474+ num_epochs = 20 # may be adjusted to a desired value
419475encoding_size = 16
420476
421477model = create_model (encoding_size )
@@ -425,6 +481,13 @@ def create_model(encoding_size):
425481 metrics = [keras .metrics .BinaryAccuracy (name = "accuracy" )],
426482)
427483
484+ """
485+ Let's visualize our connectivity graph:
486+ """
487+
488+ # `rankdir='LR'` is to make the graph horizontal.
489+ keras .utils .plot_model (model , show_shapes = True , show_layer_names = True , rankdir = "LR" )
490+
428491
429492# Create an early stopping callback.
430493early_stopping = keras .callbacks .EarlyStopping (
0 commit comments