Skip to content

Commit fcf47ee

Browse files
Adapting the script classification_with_grn_and_vsn to be Backend-Agnostic (#2023)
* adapting the script classification_with_grn_and_vsn to be backend-agnostic * adapting the script classification_with_grn_and_vsn to be Backend-Agnostic * script variable name changes refactoring * addressing the PR comments for the script: classification_with_grn_and_vsn.py * addressing comments for classification_with_grn_and _vsn.py * addressing comments for classification_with_grn_and _vsn.py
1 parent 504b5a6 commit fcf47ee

File tree

3 files changed

+3531
-267
lines changed

3 files changed

+3531
-267
lines changed

examples/structured_data/classification_with_grn_and_vsn.py

Lines changed: 127 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
Title: Classification with Gated Residual and Variable Selection Networks
33
Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
44
Date created: 2021/02/10
5-
Last modified: 2021/02/10
5+
Last modified: 2025/01/08
66
Description: Using Gated Residual and Variable Selection Networks for income level prediction.
77
Accelerator: GPU
8+
Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT) and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)
89
"""
910

1011
"""
@@ -46,13 +47,13 @@
4647
"""
4748

4849
import os
50+
import subprocess
51+
import tarfile
4952

50-
# Only the TensorFlow backend supports string inputs.
51-
os.environ["KERAS_BACKEND"] = "tensorflow"
53+
os.environ["KERAS_BACKEND"] = "torch" # or jax, or tensorflow
5254

5355
import numpy as np
5456
import pandas as pd
55-
import tensorflow as tf
5657
import keras
5758
from keras import layers
5859

@@ -108,13 +109,37 @@
108109
"income_level",
109110
]
110111

111-
data_url = "https://archive.ics.uci.edu/static/public/20/census+income.zip"
112+
data_url = "https://archive.ics.uci.edu/static/public/117/census+income+kdd.zip"
112113
keras.utils.get_file(origin=data_url, extract=True)
114+
115+
"""
116+
Determine the downloaded .tar.gz file path and
117+
extract the files from the downloaded .tar.gz file
118+
"""
119+
120+
extracted_path = os.path.join(
121+
os.path.expanduser("~"), ".keras", "datasets", "census+income+kdd.zip"
122+
)
123+
for root, dirs, files in os.walk(extracted_path):
124+
for file in files:
125+
if file.endswith(".tar.gz"):
126+
tar_gz_path = os.path.join(root, file)
127+
with tarfile.open(tar_gz_path, "r:gz") as tar:
128+
tar.extractall(path=root)
129+
113130
train_data_path = os.path.join(
114-
os.path.expanduser("~"), ".keras", "datasets", "adult.data"
131+
os.path.expanduser("~"),
132+
".keras",
133+
"datasets",
134+
"census+income+kdd.zip",
135+
"census-income.data",
115136
)
116137
test_data_path = os.path.join(
117-
os.path.expanduser("~"), ".keras", "datasets", "adult.test"
138+
os.path.expanduser("~"),
139+
".keras",
140+
"datasets",
141+
"census+income+kdd.zip",
142+
"census-income.test",
118143
)
119144

120145
data = pd.read_csv(train_data_path, header=None, names=CSV_HEADER)
@@ -211,15 +236,38 @@
211236
training and evaluation.
212237
"""
213238

239+
# Tensorflow required for tf.data.Datasets
240+
import tensorflow as tf
241+
214242

243+
# We process our datasets elements here (categorical) and convert them to indices to avoid this step
244+
# during model training since only tensorflow support strings.
215245
def process(features, target):
216246
for feature_name in features:
217247
if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
218248
# Cast categorical feature values to string.
219-
features[feature_name] = keras.ops.cast(features[feature_name], "string")
249+
features[feature_name] = tf.cast(features[feature_name], "string")
250+
vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
251+
# Create a lookup to convert a string values to an integer indices.
252+
# Since we are not using a mask token nor expecting any out of vocabulary
253+
# (oov) token, we set mask_token to None and num_oov_indices to 0.
254+
index = layers.StringLookup(
255+
vocabulary=vocabulary,
256+
mask_token=None,
257+
num_oov_indices=0,
258+
output_mode="int",
259+
)
260+
# Convert the string input values into integer indices.
261+
value_index = index(features[feature_name])
262+
features[feature_name] = value_index
263+
else:
264+
# Do nothing for numerical features
265+
pass
266+
220267
# Get the instance weight.
221268
weight = features.pop(WEIGHT_COLUMN_NAME)
222-
return features, target, weight
269+
# Change features from OrderedDict to Dict to match Inputs as they are Dict.
270+
return dict(features), target, weight
223271

224272

225273
def get_dataset_from_csv(csv_file_path, shuffle=False, batch_size=128):
@@ -245,56 +293,19 @@ def get_dataset_from_csv(csv_file_path, shuffle=False, batch_size=128):
245293
def create_model_inputs():
246294
inputs = {}
247295
for feature_name in FEATURE_NAMES:
248-
if feature_name in NUMERIC_FEATURE_NAMES:
296+
if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
297+
# Make them int64, they are Categorical (whole units)
249298
inputs[feature_name] = layers.Input(
250-
name=feature_name, shape=(), dtype="float32"
299+
name=feature_name, shape=(), dtype="int64"
251300
)
252301
else:
302+
# Make them float32, they are Real numbers
253303
inputs[feature_name] = layers.Input(
254-
name=feature_name, shape=(), dtype="string"
304+
name=feature_name, shape=(), dtype="float32"
255305
)
256306
return inputs
257307

258308

259-
"""
260-
## Encode input features
261-
262-
For categorical features, we encode them using `layers.Embedding` using the
263-
`encoding_size` as the embedding dimensions. For the numerical features,
264-
we apply linear transformation using `layers.Dense` to project each feature into
265-
`encoding_size`-dimensional vector. Thus, all the encoded features will have the
266-
same dimensionality.
267-
268-
"""
269-
270-
271-
def encode_inputs(inputs, encoding_size):
272-
encoded_features = []
273-
for feature_name in inputs:
274-
if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
275-
vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
276-
# Create a lookup to convert a string values to an integer indices.
277-
# Since we are not using a mask token nor expecting any out of vocabulary
278-
# (oov) token, we set mask_token to None and num_oov_indices to 0.
279-
index = layers.StringLookup(
280-
vocabulary=vocabulary, mask_token=None, num_oov_indices=0
281-
)
282-
# Convert the string input values into integer indices.
283-
value_index = index(inputs[feature_name])
284-
# Create an embedding layer with the specified dimensions
285-
embedding_ecoder = layers.Embedding(
286-
input_dim=len(vocabulary), output_dim=encoding_size
287-
)
288-
# Convert the index values to embedding representations.
289-
encoded_feature = embedding_ecoder(value_index)
290-
else:
291-
# Project the numeric feature to encoding_size using linear transformation.
292-
encoded_feature = keras.ops.expand_dims(inputs[feature_name], -1)
293-
encoded_feature = layers.Dense(units=encoding_size)(encoded_feature)
294-
encoded_features.append(encoded_feature)
295-
return encoded_features
296-
297-
298309
"""
299310
## Implement the Gated Linear Unit
300311
@@ -312,6 +323,10 @@ def __init__(self, units):
312323
def call(self, inputs):
313324
return self.linear(inputs) * self.sigmoid(inputs)
314325

326+
# Remove build warnings
327+
def build(self):
328+
self.built = True
329+
315330

316331
"""
317332
## Implement the Gated Residual Network
@@ -347,6 +362,10 @@ def call(self, inputs):
347362
x = self.layer_norm(x)
348363
return x
349364

365+
# Remove build warnings
366+
def build(self):
367+
self.built = True
368+
350369

351370
"""
352371
## Implement the Variable Selection Network
@@ -360,12 +379,35 @@ def call(self, inputs):
360379
361380
Note that the output of the VSN is [batch_size, encoding_size], regardless of the
362381
number of the input features.
382+
383+
For categorical features, we encode them using `layers.Embedding` using the
384+
`encoding_size` as the embedding dimensions. For the numerical features,
385+
we apply linear transformation using `layers.Dense` to project each feature into
386+
`encoding_size`-dimensional vector. Thus, all the encoded features will have the
387+
same dimensionality.
388+
363389
"""
364390

365391

366392
class VariableSelection(layers.Layer):
367393
def __init__(self, num_features, units, dropout_rate):
368394
super().__init__()
395+
self.units = units
396+
# Create an embedding layers with the specified dimensions
397+
self.embeddings = dict()
398+
for input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY:
399+
vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[input_]
400+
embedding_encoder = layers.Embedding(
401+
input_dim=len(vocabulary), output_dim=self.units, name=input_
402+
)
403+
self.embeddings[input_] = embedding_encoder
404+
405+
# Projection layers for numeric features
406+
self.proj_layer = dict()
407+
for input_ in NUMERIC_FEATURE_NAMES:
408+
proj_layer = layers.Dense(units=self.units)
409+
self.proj_layer[input_] = proj_layer
410+
369411
self.grns = list()
370412
# Create a GRN for each feature independently
371413
for idx in range(num_features):
@@ -376,17 +418,35 @@ def __init__(self, num_features, units, dropout_rate):
376418
self.softmax = layers.Dense(units=num_features, activation="softmax")
377419

378420
def call(self, inputs):
379-
v = layers.concatenate(inputs)
421+
concat_inputs = []
422+
for input_ in inputs:
423+
if input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY:
424+
max_index = self.embeddings[input_].input_dim - 1 # Clamp the indices
425+
# torch had some index errors during embedding hence the clip function
426+
embedded_feature = self.embeddings[input_](
427+
keras.ops.clip(inputs[input_], 0, max_index)
428+
)
429+
concat_inputs.append(embedded_feature)
430+
else:
431+
# Project the numeric feature to encoding_size using linear transformation.
432+
proj_feature = keras.ops.expand_dims(inputs[input_], -1)
433+
proj_feature = self.proj_layer[input_](proj_feature)
434+
concat_inputs.append(proj_feature)
435+
436+
v = layers.concatenate(concat_inputs)
380437
v = self.grn_concat(v)
381438
v = keras.ops.expand_dims(self.softmax(v), axis=-1)
382-
383439
x = []
384-
for idx, input in enumerate(inputs):
440+
for idx, input in enumerate(concat_inputs):
385441
x.append(self.grns[idx](input))
386442
x = keras.ops.stack(x, axis=1)
443+
return keras.ops.squeeze(
444+
keras.ops.matmul(keras.ops.transpose(v, axes=[0, 2, 1]), x), axis=1
445+
)
387446

388-
outputs = keras.ops.squeeze(tf.matmul(v, x, transpose_a=True), axis=1)
389-
return outputs
447+
# to remove the build warnings
448+
def build(self):
449+
self.built = True
390450

391451

392452
"""
@@ -396,14 +456,10 @@ def call(self, inputs):
396456

397457
def create_model(encoding_size):
398458
inputs = create_model_inputs()
399-
feature_list = encode_inputs(inputs, encoding_size)
400-
num_features = len(feature_list)
401-
402-
features = VariableSelection(num_features, encoding_size, dropout_rate)(
403-
feature_list
404-
)
405-
459+
num_features = len(inputs)
460+
features = VariableSelection(num_features, encoding_size, dropout_rate)(inputs)
406461
outputs = layers.Dense(units=1, activation="sigmoid")(features)
462+
# Functional model
407463
model = keras.Model(inputs=inputs, outputs=outputs)
408464
return model
409465

@@ -415,7 +471,7 @@ def create_model(encoding_size):
415471
learning_rate = 0.001
416472
dropout_rate = 0.15
417473
batch_size = 265
418-
num_epochs = 20
474+
num_epochs = 20 # may be adjusted to a desired value
419475
encoding_size = 16
420476

421477
model = create_model(encoding_size)
@@ -425,6 +481,13 @@ def create_model(encoding_size):
425481
metrics=[keras.metrics.BinaryAccuracy(name="accuracy")],
426482
)
427483

484+
"""
485+
Let's visualize our connectivity graph:
486+
"""
487+
488+
# `rankdir='LR'` is to make the graph horizontal.
489+
keras.utils.plot_model(model, show_shapes=True, show_layer_names=True, rankdir="LR")
490+
428491

429492
# Create an early stopping callback.
430493
early_stopping = keras.callbacks.EarlyStopping(

0 commit comments

Comments
 (0)