Skip to content

Commit a5d689e

Browse files
authored
Tweaks to DistributedEmbedding for JAX example. (#2171)
- Add `pip install`s - Cast counts to int to not use `np.array`s in configuration
1 parent 1fa9749 commit a5d689e

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

examples/keras_rs/distributed_embedding_jax.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323
libraries.
2424
"""
2525

26+
"""shell
27+
pip install -q jax-tpu-embedding
28+
pip install -q tensorflow-cpu
29+
pip install -q keras-rs
30+
"""
31+
2632
import os
2733

2834
os.environ["KERAS_BACKEND"] = "jax"
@@ -64,7 +70,7 @@
6470
index in the user embedding table.
6571
"""
6672

67-
users_count = (
73+
users_count = int(
6874
ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
6975
.reduce(tf.constant(0, tf.int32), tf.maximum)
7076
.numpy()
@@ -75,7 +81,7 @@
7581
as an index in the movie embedding table.
7682
"""
7783

78-
movies_count = movies.cardinality().numpy()
84+
movies_count = int(movies.cardinality().numpy())
7985

8086
"""
8187
The inputs to the model are the user IDs and movie IDs and the labels are the

examples/keras_rs/ipynb/distributed_embedding_jax.ipynb

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@
3535
"libraries."
3636
]
3737
},
38+
{
39+
"cell_type": "code",
40+
"execution_count": 0,
41+
"metadata": {
42+
"colab_type": "code"
43+
},
44+
"outputs": [],
45+
"source": [
46+
"!pip install -q jax-tpu-embedding\n",
47+
"!pip install -q tensorflow-cpu\n",
48+
"!pip install -q keras-rs"
49+
]
50+
},
3851
{
3952
"cell_type": "code",
4053
"execution_count": 0,
@@ -126,7 +139,7 @@
126139
},
127140
"outputs": [],
128141
"source": [
129-
"users_count = (\n",
142+
"users_count = int(\n",
130143
" ratings.map(lambda x: tf.strings.to_number(x[\"user_id\"], out_type=tf.int32))\n",
131144
" .reduce(tf.constant(0, tf.int32), tf.maximum)\n",
132145
" .numpy()\n",
@@ -151,7 +164,7 @@
151164
},
152165
"outputs": [],
153166
"source": [
154-
"movies_count = movies.cardinality().numpy()"
167+
"movies_count = int(movies.cardinality().numpy())"
155168
]
156169
},
157170
{

examples/keras_rs/md/distributed_embedding_jax.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ Let's begin by choosing JAX as the backend and importing all the necessary
2626
libraries.
2727

2828

29+
```python
30+
!pip install -q jax-tpu-embedding
31+
!pip install -q tensorflow-cpu
32+
!pip install -q keras-rs
33+
```
34+
2935
```python
3036
import os
3137

@@ -73,7 +79,7 @@ index in the user embedding table.
7379

7480

7581
```python
76-
users_count = (
82+
users_count = int(
7783
ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
7884
.reduce(tf.constant(0, tf.int32), tf.maximum)
7985
.numpy()
@@ -85,7 +91,7 @@ as an index in the movie embedding table.
8591

8692

8793
```python
88-
movies_count = movies.cardinality().numpy()
94+
movies_count = int(movies.cardinality().numpy())
8995
```
9096

9197
The inputs to the model are the user IDs and movie IDs and the labels are the

0 commit comments

Comments
 (0)