Skip to content

Commit 78290c8

Browse files
authored
Refactor the example to use keras 3 and create custom layer
1 parent c15b2c7 commit 78290c8

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

examples/nlp/tweet-classification-using-tfdf.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,8 @@
3737
import pandas as pd
3838
import numpy as np
3939
import tensorflow as tf
40-
from tensorflow import keras
40+
import keras
4141
import tensorflow_hub as hub
42-
from tensorflow.keras import layers
4342
import tensorflow_decision_forests as tfdf
4443
import matplotlib.pyplot as plt
4544

@@ -158,9 +157,16 @@ def create_dataset(dataframe):
158157
159158
"""
160159

161-
sentence_encoder_layer = hub.KerasLayer(
162-
"https://tfhub.dev/google/universal-sentence-encoder/4"
163-
)
160+
sentence_encoder_url = "https://tfhub.dev/google/universal-sentence-encoder/4"
161+
162+
class SentenceEncoderLayer(keras.layers.Layer):
163+
def __init__(self, **kwargs):
164+
super(SentenceEncoderLayer, self).__init__(**kwargs)
165+
self.encoder = hub.KerasLayer(sentence_encoder_url)
166+
167+
def call(self, inputs):
168+
return self.encoder(inputs)
169+
164170

165171
"""
166172
## Creating our models
@@ -175,8 +181,8 @@ def create_dataset(dataframe):
175181
Building model_1
176182
"""
177183

178-
inputs = layers.Input(shape=(), dtype=tf.string)
179-
outputs = sentence_encoder_layer(inputs)
184+
inputs = keras.layers.Input(shape=(), dtype="string")
185+
outputs = SentenceEncoderLayer()(inputs)
180186
preprocessor = keras.Model(inputs=inputs, outputs=outputs)
181187
model_1 = tfdf.keras.GradientBoostedTreesModel(preprocessing=preprocessor)
182188

@@ -278,9 +284,9 @@ def plot_curve(logs):
278284

279285
test_df.reset_index(inplace=True, drop=True)
280286
for index, row in test_df.iterrows():
281-
text = tf.expand_dims(row["text"], axis=0)
287+
text = keras.ops.expand_dims(row["text"], axis=0)
282288
preds = model_1.predict_step(text)
283-
preds = tf.squeeze(tf.round(preds))
289+
preds = keras.ops.squeeze(keras.ops.round(preds))
284290
print(f"Text: {row['text']}")
285291
print(f"Prediction: {int(preds)}")
286292
print(f"Ground Truth : {row['target']}")

0 commit comments

Comments
 (0)