Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 176 additions & 64 deletions examples/structured_data/classification_with_grn_and_vsn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
Title: Classification with Gated Residual and Variable Selection Networks
Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
Date created: 2021/02/10
Last modified: 2021/02/10
Last modified: 2025/01/08
Description: Using Gated Residual and Variable Selection Networks for income level prediction.
Accelerator: GPU
Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT) and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)
"""

"""
Expand Down Expand Up @@ -46,13 +47,13 @@
"""

import os
import subprocess
import tarfile

# Only the TensorFlow backend supports string inputs.
os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ["KERAS_BACKEND"] = "torch" # or jax, or tensorflow

import numpy as np
import pandas as pd
import tensorflow as tf
import keras
from keras import layers

Expand Down Expand Up @@ -108,13 +109,37 @@
"income_level",
]

data_url = "https://archive.ics.uci.edu/static/public/20/census+income.zip"
data_url = "https://archive.ics.uci.edu/static/public/117/census+income+kdd.zip"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the change in dataset?

It seems like the original dataset was easier to handle.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dataset description in the script says it must have 41 input features: 7 numerical features and 34 categorical features.. The original dataset only had 14 features and its target variable was in <= or >=50k, whereas in the script it is in -5000 or +5000

keras.utils.get_file(origin=data_url, extract=True)

"""
Determine the downloaded .tar.gz file path and
extract the files from the downloaded .tar.gz file
"""

extracted_path = os.path.join(
os.path.expanduser("~"), ".keras", "datasets", "census+income+kdd.zip"
)
for root, dirs, files in os.walk(extracted_path):
for file in files:
if file.endswith(".tar.gz"):
tar_gz_path = os.path.join(root, file)
with tarfile.open(tar_gz_path, "r:gz") as tar:
tar.extractall(path=root)

train_data_path = os.path.join(
os.path.expanduser("~"), ".keras", "datasets", "adult.data"
os.path.expanduser("~"),
".keras",
"datasets",
"census+income+kdd.zip",
"census-income.data",
)
test_data_path = os.path.join(
os.path.expanduser("~"), ".keras", "datasets", "adult.test"
os.path.expanduser("~"),
".keras",
"datasets",
"census+income+kdd.zip",
"census-income.test",
)

data = pd.read_csv(train_data_path, header=None, names=CSV_HEADER)
Expand Down Expand Up @@ -157,6 +182,20 @@
valid_data.to_csv(valid_data_file, index=False, header=False)
test_data.to_csv(test_data_file, index=False, header=False)

"""
Clean the directory for the downloaded files except the .tar.gz file and
also remove the empty directories
"""

subprocess.run(
f'find {extracted_path} -type f ! -name "*.tar.gz" -exec rm -f {{}} +',
shell=True,
check=True,
)
subprocess.run(
f"find {extracted_path} -type d -empty -exec rmdir {{}} +", shell=True, check=True
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could apply to any dataset, but I find this an unnecessary distraction. Can you remove?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

"""
## Define dataset metadata

Expand Down Expand Up @@ -211,15 +250,38 @@
training and evaluation.
"""

# Tensorflow required for tf.data.Datasets
import tensorflow as tf


# We process our datasets elements here (categorical) and convert them to indices to avoid this step
# during model training since only tensorflow support strings.
def process(features, target):
for feature_name in features:
if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
# Cast categorical feature values to string.
features[feature_name] = keras.ops.cast(features[feature_name], "string")
features[feature_name] = tf.cast(features[feature_name], "string")
vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
# Create a lookup to convert a string values to an integer indices.
# Since we are not using a mask token nor expecting any out of vocabulary
# (oov) token, we set mask_token to None and num_oov_indices to 0.
index = layers.StringLookup(
vocabulary=vocabulary,
mask_token=None,
num_oov_indices=0,
output_mode="int",
)
# Convert the string input values into integer indices.
value_index = index(features[feature_name])
features[feature_name] = value_index
else:
# Do nothing for numerical features
pass

# Get the instance weight.
weight = features.pop(WEIGHT_COLUMN_NAME)
return features, target, weight
# Change features from OrderedDict to Dict to match Inputs as they are Dict.
return dict(features), target, weight


def get_dataset_from_csv(csv_file_path, shuffle=False, batch_size=128):
Expand All @@ -245,56 +307,19 @@ def get_dataset_from_csv(csv_file_path, shuffle=False, batch_size=128):
def create_model_inputs():
inputs = {}
for feature_name in FEATURE_NAMES:
if feature_name in NUMERIC_FEATURE_NAMES:
if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
# Make them int64, they are Categorical (whole units)
inputs[feature_name] = layers.Input(
name=feature_name, shape=(), dtype="float32"
name=feature_name, shape=(), dtype="int64"
)
else:
# Make them float32, they are Real numbers
inputs[feature_name] = layers.Input(
name=feature_name, shape=(), dtype="string"
name=feature_name, shape=(), dtype="float32"
)
return inputs


"""
## Encode input features

For categorical features, we encode them using `layers.Embedding` using the
`encoding_size` as the embedding dimensions. For the numerical features,
we apply linear transformation using `layers.Dense` to project each feature into
`encoding_size`-dimensional vector. Thus, all the encoded features will have the
same dimensionality.

"""


def encode_inputs(inputs, encoding_size):
encoded_features = []
for feature_name in inputs:
if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
# Create a lookup to convert a string values to an integer indices.
# Since we are not using a mask token nor expecting any out of vocabulary
# (oov) token, we set mask_token to None and num_oov_indices to 0.
index = layers.StringLookup(
vocabulary=vocabulary, mask_token=None, num_oov_indices=0
)
# Convert the string input values into integer indices.
value_index = index(inputs[feature_name])
# Create an embedding layer with the specified dimensions
embedding_ecoder = layers.Embedding(
input_dim=len(vocabulary), output_dim=encoding_size
)
# Convert the index values to embedding representations.
encoded_feature = embedding_ecoder(value_index)
else:
# Project the numeric feature to encoding_size using linear transformation.
encoded_feature = keras.ops.expand_dims(inputs[feature_name], -1)
encoded_feature = layers.Dense(units=encoding_size)(encoded_feature)
encoded_features.append(encoded_feature)
return encoded_features


"""
## Implement the Gated Linear Unit

Expand Down Expand Up @@ -360,12 +385,35 @@ def call(self, inputs):

Note that the output of the VSN is [batch_size, encoding_size], regardless of the
number of the input features.

For categorical features, we encode them using `layers.Embedding` using the
`encoding_size` as the embedding dimensions. For the numerical features,
we apply linear transformation using `layers.Dense` to project each feature into
`encoding_size`-dimensional vector. Thus, all the encoded features will have the
same dimensionality.

"""


class VariableSelection(layers.Layer):
def __init__(self, num_features, units, dropout_rate):
super().__init__()
self.units = units
# Create an embedding layers with the specified dimensions
self.embeddings = dict()
for input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY:
vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[input_]
embedding_encoder = layers.Embedding(
input_dim=len(vocabulary), output_dim=self.units, name=input_
)
self.embeddings[input_] = embedding_encoder

# Projection layers for numeric features
self.proj_layer = dict()
for input_ in NUMERIC_FEATURE_NAMES:
proj_layer = layers.Dense(units=self.units)
self.proj_layer[input_] = proj_layer

self.grns = list()
# Create a GRN for each feature independently
for idx in range(num_features):
Expand All @@ -376,17 +424,78 @@ def __init__(self, num_features, units, dropout_rate):
self.softmax = layers.Dense(units=num_features, activation="softmax")

def call(self, inputs):
v = layers.concatenate(inputs)
concat_inputs = []
for input_ in inputs:
if input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY:
max_index = self.embeddings[input_].input_dim - 1 # Clamp the indices
# torch had some index errors during embedding hence the clip function
embedded_feature = self.embeddings[input_](
keras.ops.clip(inputs[input_], 0, max_index)
)
concat_inputs.append(embedded_feature)
else:
# Project the numeric feature to encoding_size using linear transformation.
proj_feature = keras.ops.expand_dims(inputs[input_], -1)
proj_feature = self.proj_layer[input_](proj_feature)
concat_inputs.append(proj_feature)

v = layers.concatenate(concat_inputs)
v = self.grn_concat(v)
v = keras.ops.expand_dims(self.softmax(v), axis=-1)

x = []
for idx, input in enumerate(inputs):
for idx, input in enumerate(concat_inputs):
x.append(self.grns[idx](input))
x = keras.ops.stack(x, axis=1)

outputs = keras.ops.squeeze(tf.matmul(v, x, transpose_a=True), axis=1)
return outputs
# The reason for each individual backend calculation is that I couldn't find
# the equivalent keras operation that is backend-agnostic. In the following case there,s
# a keras.ops.matmul but it was returning errors. I could have used the tensorflow matmul
# for all backends, but due to jax jit tracing it results in an error.
def matmul_dependent_on_backend(tensor_1, tensor_2):
"""
Function for executing matmul for each backend.
"""
# jax backend
if keras.backend.backend() == "jax":
import jax.numpy as jnp

result = jnp.sum(tensor_1 * tensor_2, axis=1)
elif keras.backend.backend() == "torch":
result = torch.sum(tensor_1 * tensor_2, dim=1)
# tensorflow backend
elif keras.backend.backend() == "tensorflow":
result = keras.ops.squeeze(tf.matmul(tensor_1, tensor_2, transpose_a=True), axis=1)
# unsupported backend exception
else:
raise ValueError(
"Unsupported backend: {}".format(keras.backend.backend())
)
return result

# jax backend
if keras.backend.backend() == "jax":
# This repetative imports are intentional to force the idea of backend
# separation
import jax.numpy as jnp

result_jax = matmul_dependent_on_backend(v, x)
return result_jax
# torch backend
if keras.backend.backend() == "torch":
import torch

result_torch = matmul_dependent_on_backend(v, x)
return result_torch
# tensorflow backend
if keras.backend.backend() == "tensorflow":
import tensorflow as tf

result_tf = keras.ops.squeeze(tf.matmul(v, x, transpose_a=True), axis=1)
return result_tf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This definitely should not be needed.

What is the issue with keras.ops.squeeze(keras.ops.matmul(keras.ops.transpose(v), x), axis=1)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After careful thought, I've made it to work. I also struggled with the Keras doc string of the op keras.transpose, I don't think axes is explicit about the permutations. I had to read tensorflow doc to have a clear picture. But, nonetheless, it is resolved.


# to remove the build warnings
def build(self):
self.built = True
Comment on lines +447 to +449
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more of this.



"""
Expand All @@ -396,14 +505,10 @@ def call(self, inputs):

def create_model(encoding_size):
inputs = create_model_inputs()
feature_list = encode_inputs(inputs, encoding_size)
num_features = len(feature_list)

features = VariableSelection(num_features, encoding_size, dropout_rate)(
feature_list
)

num_features = len(inputs)
features = VariableSelection(num_features, encoding_size, dropout_rate)(inputs)
outputs = layers.Dense(units=1, activation="sigmoid")(features)
# Functional model
model = keras.Model(inputs=inputs, outputs=outputs)
return model

Expand All @@ -415,7 +520,7 @@ def create_model(encoding_size):
learning_rate = 0.001
dropout_rate = 0.15
batch_size = 265
num_epochs = 20
num_epochs = 1 # maybe adjusted to a desired value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert after you're done testing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted to 20, but left the comment.

encoding_size = 16

model = create_model(encoding_size)
Expand All @@ -425,6 +530,13 @@ def create_model(encoding_size):
metrics=[keras.metrics.BinaryAccuracy(name="accuracy")],
)

"""
Let's visualize our connectivity graph:
"""

# `rankdir='LR'` is to make the graph horizontal.
keras.utils.plot_model(model, show_shapes=True, show_layer_names=True, rankdir="LR")


# Create an early stopping callback.
early_stopping = keras.callbacks.EarlyStopping(
Expand Down
Loading