-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Adapting the script classification_with_grn_and_vsn to be Backend-Agnostic #2023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
ad1ac90
0da607c
6399fe2
2aaadce
bec3344
290dd60
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,9 +2,10 @@ | |
| Title: Classification with Gated Residual and Variable Selection Networks | ||
| Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/) | ||
| Date created: 2021/02/10 | ||
| Last modified: 2021/02/10 | ||
| Last modified: 2025/01/08 | ||
| Description: Using Gated Residual and Variable Selection Networks for income level prediction. | ||
| Accelerator: GPU | ||
| Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT) and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234) | ||
| """ | ||
|
|
||
| """ | ||
|
|
@@ -46,13 +47,13 @@ | |
| """ | ||
|
|
||
| import os | ||
| import subprocess | ||
| import tarfile | ||
|
|
||
| # Only the TensorFlow backend supports string inputs. | ||
| os.environ["KERAS_BACKEND"] = "tensorflow" | ||
| os.environ["KERAS_BACKEND"] = "torch" # or jax, or tensorflow | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
| import tensorflow as tf | ||
| import keras | ||
| from keras import layers | ||
|
|
||
|
|
@@ -108,13 +109,37 @@ | |
| "income_level", | ||
| ] | ||
|
|
||
| data_url = "https://archive.ics.uci.edu/static/public/20/census+income.zip" | ||
| data_url = "https://archive.ics.uci.edu/static/public/117/census+income+kdd.zip" | ||
| keras.utils.get_file(origin=data_url, extract=True) | ||
|
|
||
| """ | ||
| Determine the downloaded .tar.gz file path and | ||
| extract the files from the downloaded .tar.gz file | ||
| """ | ||
|
|
||
| extracted_path = os.path.join( | ||
| os.path.expanduser("~"), ".keras", "datasets", "census+income+kdd.zip" | ||
| ) | ||
| for root, dirs, files in os.walk(extracted_path): | ||
| for file in files: | ||
| if file.endswith(".tar.gz"): | ||
| tar_gz_path = os.path.join(root, file) | ||
| with tarfile.open(tar_gz_path, "r:gz") as tar: | ||
| tar.extractall(path=root) | ||
|
|
||
| train_data_path = os.path.join( | ||
| os.path.expanduser("~"), ".keras", "datasets", "adult.data" | ||
| os.path.expanduser("~"), | ||
| ".keras", | ||
| "datasets", | ||
| "census+income+kdd.zip", | ||
| "census-income.data", | ||
| ) | ||
| test_data_path = os.path.join( | ||
| os.path.expanduser("~"), ".keras", "datasets", "adult.test" | ||
| os.path.expanduser("~"), | ||
| ".keras", | ||
| "datasets", | ||
| "census+income+kdd.zip", | ||
| "census-income.test", | ||
| ) | ||
|
|
||
| data = pd.read_csv(train_data_path, header=None, names=CSV_HEADER) | ||
|
|
@@ -157,6 +182,20 @@ | |
| valid_data.to_csv(valid_data_file, index=False, header=False) | ||
| test_data.to_csv(test_data_file, index=False, header=False) | ||
|
|
||
| """ | ||
| Clean the directory for the downloaded files except the .tar.gz file and | ||
| also remove the empty directories | ||
| """ | ||
|
|
||
| subprocess.run( | ||
| f'find {extracted_path} -type f ! -name "*.tar.gz" -exec rm -f {{}} +', | ||
| shell=True, | ||
| check=True, | ||
| ) | ||
| subprocess.run( | ||
| f"find {extracted_path} -type d -empty -exec rmdir {{}} +", shell=True, check=True | ||
| ) | ||
|
|
||
|
||
| """ | ||
| ## Define dataset metadata | ||
|
|
||
|
|
@@ -211,15 +250,38 @@ | |
| training and evaluation. | ||
| """ | ||
|
|
||
| # Tensorflow required for tf.data.Datasets | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| # We process our datasets elements here (categorical) and convert them to indices to avoid this step | ||
| # during model training since only tensorflow support strings. | ||
| def process(features, target): | ||
| for feature_name in features: | ||
| if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY: | ||
| # Cast categorical feature values to string. | ||
| features[feature_name] = keras.ops.cast(features[feature_name], "string") | ||
| features[feature_name] = tf.cast(features[feature_name], "string") | ||
| vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name] | ||
| # Create a lookup to convert a string values to an integer indices. | ||
| # Since we are not using a mask token nor expecting any out of vocabulary | ||
| # (oov) token, we set mask_token to None and num_oov_indices to 0. | ||
| index = layers.StringLookup( | ||
| vocabulary=vocabulary, | ||
| mask_token=None, | ||
| num_oov_indices=0, | ||
| output_mode="int", | ||
| ) | ||
| # Convert the string input values into integer indices. | ||
| value_index = index(features[feature_name]) | ||
| features[feature_name] = value_index | ||
| else: | ||
| # Do nothing for numerical features | ||
| pass | ||
|
|
||
| # Get the instance weight. | ||
| weight = features.pop(WEIGHT_COLUMN_NAME) | ||
| return features, target, weight | ||
| # Change features from OrderedDict to Dict to match Inputs as they are Dict. | ||
| return dict(features), target, weight | ||
|
|
||
|
|
||
| def get_dataset_from_csv(csv_file_path, shuffle=False, batch_size=128): | ||
|
|
@@ -245,56 +307,19 @@ def get_dataset_from_csv(csv_file_path, shuffle=False, batch_size=128): | |
| def create_model_inputs(): | ||
| inputs = {} | ||
| for feature_name in FEATURE_NAMES: | ||
| if feature_name in NUMERIC_FEATURE_NAMES: | ||
| if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY: | ||
| # Make them int64, they are Categorical (whole units) | ||
| inputs[feature_name] = layers.Input( | ||
| name=feature_name, shape=(), dtype="float32" | ||
| name=feature_name, shape=(), dtype="int64" | ||
| ) | ||
| else: | ||
| # Make them float32, they are Real numbers | ||
| inputs[feature_name] = layers.Input( | ||
| name=feature_name, shape=(), dtype="string" | ||
| name=feature_name, shape=(), dtype="float32" | ||
| ) | ||
| return inputs | ||
|
|
||
|
|
||
| """ | ||
| ## Encode input features | ||
|
|
||
| For categorical features, we encode them using `layers.Embedding` using the | ||
| `encoding_size` as the embedding dimensions. For the numerical features, | ||
| we apply linear transformation using `layers.Dense` to project each feature into | ||
| `encoding_size`-dimensional vector. Thus, all the encoded features will have the | ||
| same dimensionality. | ||
|
|
||
| """ | ||
|
|
||
|
|
||
| def encode_inputs(inputs, encoding_size): | ||
| encoded_features = [] | ||
| for feature_name in inputs: | ||
| if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY: | ||
| vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name] | ||
| # Create a lookup to convert a string values to an integer indices. | ||
| # Since we are not using a mask token nor expecting any out of vocabulary | ||
| # (oov) token, we set mask_token to None and num_oov_indices to 0. | ||
| index = layers.StringLookup( | ||
| vocabulary=vocabulary, mask_token=None, num_oov_indices=0 | ||
| ) | ||
| # Convert the string input values into integer indices. | ||
| value_index = index(inputs[feature_name]) | ||
| # Create an embedding layer with the specified dimensions | ||
| embedding_ecoder = layers.Embedding( | ||
| input_dim=len(vocabulary), output_dim=encoding_size | ||
| ) | ||
| # Convert the index values to embedding representations. | ||
| encoded_feature = embedding_ecoder(value_index) | ||
| else: | ||
| # Project the numeric feature to encoding_size using linear transformation. | ||
| encoded_feature = keras.ops.expand_dims(inputs[feature_name], -1) | ||
| encoded_feature = layers.Dense(units=encoding_size)(encoded_feature) | ||
| encoded_features.append(encoded_feature) | ||
| return encoded_features | ||
|
|
||
|
|
||
| """ | ||
| ## Implement the Gated Linear Unit | ||
|
|
||
|
|
@@ -360,12 +385,35 @@ def call(self, inputs): | |
|
|
||
| Note that the output of the VSN is [batch_size, encoding_size], regardless of the | ||
| number of the input features. | ||
|
|
||
| For categorical features, we encode them using `layers.Embedding` using the | ||
| `encoding_size` as the embedding dimensions. For the numerical features, | ||
| we apply linear transformation using `layers.Dense` to project each feature into | ||
| `encoding_size`-dimensional vector. Thus, all the encoded features will have the | ||
| same dimensionality. | ||
|
|
||
| """ | ||
|
|
||
|
|
||
| class VariableSelection(layers.Layer): | ||
| def __init__(self, num_features, units, dropout_rate): | ||
| super().__init__() | ||
| self.units = units | ||
| # Create an embedding layers with the specified dimensions | ||
| self.embeddings = dict() | ||
| for input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY: | ||
| vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[input_] | ||
| embedding_encoder = layers.Embedding( | ||
| input_dim=len(vocabulary), output_dim=self.units, name=input_ | ||
| ) | ||
| self.embeddings[input_] = embedding_encoder | ||
|
|
||
| # Projection layers for numeric features | ||
| self.proj_layer = dict() | ||
| for input_ in NUMERIC_FEATURE_NAMES: | ||
| proj_layer = layers.Dense(units=self.units) | ||
| self.proj_layer[input_] = proj_layer | ||
|
|
||
| self.grns = list() | ||
| # Create a GRN for each feature independently | ||
| for idx in range(num_features): | ||
|
|
@@ -376,17 +424,78 @@ def __init__(self, num_features, units, dropout_rate): | |
| self.softmax = layers.Dense(units=num_features, activation="softmax") | ||
|
|
||
| def call(self, inputs): | ||
| v = layers.concatenate(inputs) | ||
| concat_inputs = [] | ||
| for input_ in inputs: | ||
| if input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY: | ||
| max_index = self.embeddings[input_].input_dim - 1 # Clamp the indices | ||
| # torch had some index errors during embedding hence the clip function | ||
| embedded_feature = self.embeddings[input_]( | ||
| keras.ops.clip(inputs[input_], 0, max_index) | ||
| ) | ||
| concat_inputs.append(embedded_feature) | ||
| else: | ||
| # Project the numeric feature to encoding_size using linear transformation. | ||
| proj_feature = keras.ops.expand_dims(inputs[input_], -1) | ||
| proj_feature = self.proj_layer[input_](proj_feature) | ||
| concat_inputs.append(proj_feature) | ||
|
|
||
| v = layers.concatenate(concat_inputs) | ||
| v = self.grn_concat(v) | ||
| v = keras.ops.expand_dims(self.softmax(v), axis=-1) | ||
|
|
||
| x = [] | ||
| for idx, input in enumerate(inputs): | ||
| for idx, input in enumerate(concat_inputs): | ||
| x.append(self.grns[idx](input)) | ||
| x = keras.ops.stack(x, axis=1) | ||
|
|
||
| outputs = keras.ops.squeeze(tf.matmul(v, x, transpose_a=True), axis=1) | ||
| return outputs | ||
| # The reason for each individual backend calculation is that I couldn't find | ||
| # the equivalent keras operation that is backend-agnostic. In the following case there,s | ||
| # a keras.ops.matmul but it was returning errors. I could have used the tensorflow matmul | ||
| # for all backends, but due to jax jit tracing it results in an error. | ||
| def matmul_dependent_on_backend(tensor_1, tensor_2): | ||
| """ | ||
| Function for executing matmul for each backend. | ||
| """ | ||
| # jax backend | ||
| if keras.backend.backend() == "jax": | ||
| import jax.numpy as jnp | ||
|
|
||
| result = jnp.sum(tensor_1 * tensor_2, axis=1) | ||
| elif keras.backend.backend() == "torch": | ||
| result = torch.sum(tensor_1 * tensor_2, dim=1) | ||
| # tensorflow backend | ||
| elif keras.backend.backend() == "tensorflow": | ||
| result = keras.ops.squeeze(tf.matmul(tensor_1, tensor_2, transpose_a=True), axis=1) | ||
| # unsupported backend exception | ||
| else: | ||
| raise ValueError( | ||
| "Unsupported backend: {}".format(keras.backend.backend()) | ||
| ) | ||
| return result | ||
|
|
||
| # jax backend | ||
| if keras.backend.backend() == "jax": | ||
| # This repetative imports are intentional to force the idea of backend | ||
| # separation | ||
| import jax.numpy as jnp | ||
|
|
||
| result_jax = matmul_dependent_on_backend(v, x) | ||
| return result_jax | ||
| # torch backend | ||
| if keras.backend.backend() == "torch": | ||
| import torch | ||
|
|
||
| result_torch = matmul_dependent_on_backend(v, x) | ||
| return result_torch | ||
| # tensorflow backend | ||
| if keras.backend.backend() == "tensorflow": | ||
| import tensorflow as tf | ||
|
|
||
| result_tf = keras.ops.squeeze(tf.matmul(v, x, transpose_a=True), axis=1) | ||
| return result_tf | ||
|
||
|
|
||
| # to remove the build warnings | ||
| def build(self): | ||
| self.built = True | ||
|
Comment on lines
+447
to
+449
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks!
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added more of this. |
||
|
|
||
|
|
||
| """ | ||
|
|
@@ -396,14 +505,10 @@ def call(self, inputs): | |
|
|
||
| def create_model(encoding_size): | ||
| inputs = create_model_inputs() | ||
| feature_list = encode_inputs(inputs, encoding_size) | ||
| num_features = len(feature_list) | ||
|
|
||
| features = VariableSelection(num_features, encoding_size, dropout_rate)( | ||
| feature_list | ||
| ) | ||
|
|
||
| num_features = len(inputs) | ||
| features = VariableSelection(num_features, encoding_size, dropout_rate)(inputs) | ||
| outputs = layers.Dense(units=1, activation="sigmoid")(features) | ||
| # Functional model | ||
| model = keras.Model(inputs=inputs, outputs=outputs) | ||
| return model | ||
|
|
||
|
|
@@ -415,7 +520,7 @@ def create_model(encoding_size): | |
| learning_rate = 0.001 | ||
| dropout_rate = 0.15 | ||
| batch_size = 265 | ||
| num_epochs = 20 | ||
| num_epochs = 1 # maybe adjusted to a desired value | ||
|
||
| encoding_size = 16 | ||
|
|
||
| model = create_model(encoding_size) | ||
|
|
@@ -425,6 +530,13 @@ def create_model(encoding_size): | |
| metrics=[keras.metrics.BinaryAccuracy(name="accuracy")], | ||
| ) | ||
|
|
||
| """ | ||
| Let's visualize our connectivity graph: | ||
| """ | ||
|
|
||
| # `rankdir='LR'` is to make the graph horizontal. | ||
| keras.utils.plot_model(model, show_shapes=True, show_layer_names=True, rankdir="LR") | ||
|
|
||
|
|
||
| # Create an early stopping callback. | ||
| early_stopping = keras.callbacks.EarlyStopping( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the change in dataset?
It seems like the original dataset was easier to handle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dataset description in the script says it must have
41 input features: 7 numerical features and 34 categorical features.. The original dataset only had14 featuresand its target variable was in<= or >=50k, whereas in the script it is in-5000 or +5000