Skip to content

Commit 5b13561

Browse files
authored
convert MDN to KerasV3 (#2050)
1 parent 98429b5 commit 5b13561

File tree

4 files changed

+75
-115
lines changed

4 files changed

+75
-115
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ templates/**/guides/**/*.md
1515
templates/keras_hub/getting_started.md
1616
templates/keras_tuner/getting_started.md
1717
datasets/*
18-
.vscode/*
18+
.vscode/*
19+
.history

examples/keras_recipes/approximating_non_function_mappings.py

Lines changed: 22 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -27,30 +27,29 @@
2727
"Mixture Density Networks".
2828
2929
I'm going to use the new
30-
[multibackend Keras Core project](https://github.com/keras-team/keras-core) to
30+
[multibackend Keras V3](https://github.com/keras-team/keras) to
3131
build my Mixture Density networks.
3232
Great job to the Keras team on the project - it's awesome to be able to swap
3333
frameworks in one line of code.
3434
35-
Some bad news: I use TensorFlow probability in this guide... so it doesn't
36-
actually work with other backends.
35+
Some bad news: I use TensorFlow probability in this guide... so it
36+
actually works only with TensorFlow and JAX backends.
3737
3838
Anyways, let's start by installing dependencies and sorting out imports:
3939
"""
4040
"""shell
41-
pip install -q --upgrade tensorflow-probability keras-core
41+
pip install -q --upgrade jax tensorflow-probability[jax] keras
4242
"""
4343

44+
import os
45+
46+
os.environ["KERAS_BACKEND"] = "jax"
47+
4448
import numpy as np
4549
import matplotlib.pyplot as plt
46-
import math
47-
import random
48-
from keras_core import callbacks
49-
import keras_core
50-
import tensorflow as tf
51-
from keras_core import layers
52-
from keras_core import optimizers
53-
from tensorflow_probability import distributions as tfd
50+
import keras
51+
from keras import callbacks, layers, ops
52+
from tensorflow_probability.substrates.jax import distributions as tfd
5453

5554
"""
5655
Next, lets generate a noisy spiral that we're going to attempt to approximate.
@@ -99,7 +98,7 @@ def create_noisy_spiral(n, jitter_std=0.2, revolutions=2):
9998

10099
N_HIDDEN = 128
101100

102-
model = keras_core.Sequential(
101+
model = keras.Sequential(
103102
[
104103
layers.Dense(N_HIDDEN, activation="relu"),
105104
layers.Dense(N_HIDDEN, activation="relu"),
@@ -179,7 +178,7 @@ def create_noisy_spiral(n, jitter_std=0.2, revolutions=2):
179178

180179

181180
def elu_plus_one_plus_epsilon(x):
182-
return keras_core.activations.elu(x) + 1 + keras_core.backend.epsilon()
181+
return keras.activations.elu(x) + 1 + keras.backend.epsilon()
183182

184183

185184
"""
@@ -238,7 +237,7 @@ def call(self, x, mask=None):
238237
OUTPUT_DIMS = 1
239238
N_MIXES = 20
240239

241-
mdn_network = keras_core.Sequential(
240+
mdn_network = keras.Sequential(
242241
[
243242
layers.Dense(N_HIDDEN, activation="relu"),
244243
layers.Dense(N_HIDDEN, activation="relu"),
@@ -255,36 +254,22 @@ def call(self, x, mask=None):
255254
def get_mixture_loss_func(output_dim, num_mixes):
256255
def mdn_loss_func(y_true, y_pred):
257256
# Reshape inputs in case this is used in a TimeDistributed layer
258-
y_pred = tf.reshape(
259-
y_pred,
260-
[-1, (2 * num_mixes * output_dim) + num_mixes],
261-
name="reshape_ypreds",
262-
)
263-
y_true = tf.reshape(y_true, [-1, output_dim], name="reshape_ytrue")
257+
y_pred = ops.reshape(y_pred, [-1, (2 * num_mixes * output_dim) + num_mixes])
258+
y_true = ops.reshape(y_true, [-1, output_dim])
264259
# Split the inputs into parameters
265-
out_mu, out_sigma, out_pi = tf.split(
266-
y_pred,
267-
num_or_size_splits=[
268-
num_mixes * output_dim,
269-
num_mixes * output_dim,
270-
num_mixes,
271-
],
272-
axis=-1,
273-
name="mdn_coef_split",
274-
)
260+
out_mu, out_sigma, out_pi = ops.split(y_pred, 3, axis=-1)
275261
# Construct the mixture models
276262
cat = tfd.Categorical(logits=out_pi)
277-
component_splits = [output_dim] * num_mixes
278-
mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1)
279-
sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1)
263+
mus = ops.split(out_mu, num_mixes, axis=1)
264+
sigs = ops.split(out_sigma, num_mixes, axis=1)
280265
coll = [
281266
tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
282267
for loc, scale in zip(mus, sigs)
283268
]
284269
mixture = tfd.Mixture(cat=cat, components=coll)
285270
loss = mixture.log_prob(y_true)
286-
loss = tf.negative(loss)
287-
loss = tf.reduce_mean(loss)
271+
loss = ops.negative(loss)
272+
loss = ops.mean(loss)
288273
return loss
289274

290275
return mdn_loss_func
@@ -349,7 +334,7 @@ def sample_from_categorical(dist):
349334
accumulate += dist[i]
350335
if accumulate >= r:
351336
return i
352-
tf.logging.info("Error sampling categorical model.")
337+
print("Error sampling categorical model.")
353338
return -1
354339

355340

examples/keras_recipes/ipynb/approximating_non_function_mappings.ipynb

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,26 @@
3535
"\"Mixture Density Networks\".\n",
3636
"\n",
3737
"I'm going to use the new\n",
38-
"[multibackend Keras Core project](https://github.com/keras-team/keras-core) to\n",
38+
"[multibackend Keras V3](https://github.com/keras-team/keras) to\n",
3939
"build my Mixture Density networks.\n",
4040
"Great job to the Keras team on the project - it's awesome to be able to swap\n",
4141
"frameworks in one line of code.\n",
4242
"\n",
43-
"Some bad news: I use TensorFlow probability in this guide... so it doesn't\n",
44-
"actually work with other backends.\n",
43+
"Some bad news: I use TensorFlow probability in this guide... so it\n",
44+
"actually works only with TensorFlow and JAX backends.\n",
4545
"\n",
4646
"Anyways, let's start by installing dependencies and sorting out imports:"
4747
]
4848
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": null,
52+
"metadata": {},
53+
"outputs": [],
54+
"source": [
55+
"%env KERAS_BACKEND=jax"
56+
]
57+
},
4958
{
5059
"cell_type": "code",
5160
"execution_count": null,
@@ -54,7 +63,7 @@
5463
},
5564
"outputs": [],
5665
"source": [
57-
"!pip install -q --upgrade tensorflow-probability keras-core"
66+
"%pip install -q --upgrade jax tensorflow-probability[jax] keras"
5867
]
5968
},
6069
{
@@ -67,14 +76,9 @@
6776
"source": [
6877
"import numpy as np\n",
6978
"import matplotlib.pyplot as plt\n",
70-
"import math\n",
71-
"import random\n",
72-
"from keras_core import callbacks\n",
73-
"import keras_core\n",
74-
"import tensorflow as tf\n",
75-
"from keras_core import layers\n",
76-
"from keras_core import optimizers\n",
77-
"from tensorflow_probability import distributions as tfd"
79+
"import keras\n",
80+
"from keras import callbacks, layers, ops\n",
81+
"from tensorflow_probability.substrates.jax import distributions as tfd"
7882
]
7983
},
8084
{
@@ -161,7 +165,7 @@
161165
"source": [
162166
"N_HIDDEN = 128\n",
163167
"\n",
164-
"model = keras_core.Sequential(\n",
168+
"model = keras.Sequential(\n",
165169
" [\n",
166170
" layers.Dense(N_HIDDEN, activation=\"relu\"),\n",
167171
" layers.Dense(N_HIDDEN, activation=\"relu\"),\n",
@@ -308,7 +312,7 @@
308312
"source": [
309313
"\n",
310314
"def elu_plus_one_plus_epsilon(x):\n",
311-
" return keras_core.activations.elu(x) + 1 + keras_core.backend.epsilon()\n"
315+
" return keras.activations.elu(x) + 1 + keras.backend.epsilon()\n"
312316
]
313317
},
314318
{
@@ -393,7 +397,7 @@
393397
"OUTPUT_DIMS = 1\n",
394398
"N_MIXES = 20\n",
395399
"\n",
396-
"mdn_network = keras_core.Sequential(\n",
400+
"mdn_network = keras.Sequential(\n",
397401
" [\n",
398402
" layers.Dense(N_HIDDEN, activation=\"relu\"),\n",
399403
" layers.Dense(N_HIDDEN, activation=\"relu\"),\n",
@@ -420,40 +424,25 @@
420424
},
421425
"outputs": [],
422426
"source": [
423-
"\n",
424427
"def get_mixture_loss_func(output_dim, num_mixes):\n",
425428
" def mdn_loss_func(y_true, y_pred):\n",
426429
" # Reshape inputs in case this is used in a TimeDistributed layer\n",
427-
" y_pred = tf.reshape(\n",
428-
" y_pred,\n",
429-
" [-1, (2 * num_mixes * output_dim) + num_mixes],\n",
430-
" name=\"reshape_ypreds\",\n",
431-
" )\n",
432-
" y_true = tf.reshape(y_true, [-1, output_dim], name=\"reshape_ytrue\")\n",
430+
" y_pred = ops.reshape(y_pred, [-1, (2 * num_mixes * output_dim) + num_mixes])\n",
431+
" y_true = ops.reshape(y_true, [-1, output_dim])\n",
433432
" # Split the inputs into parameters\n",
434-
" out_mu, out_sigma, out_pi = tf.split(\n",
435-
" y_pred,\n",
436-
" num_or_size_splits=[\n",
437-
" num_mixes * output_dim,\n",
438-
" num_mixes * output_dim,\n",
439-
" num_mixes,\n",
440-
" ],\n",
441-
" axis=-1,\n",
442-
" name=\"mdn_coef_split\",\n",
443-
" )\n",
433+
" out_mu, out_sigma, out_pi = ops.split(y_pred, 3, axis=-1)\n",
444434
" # Construct the mixture models\n",
445435
" cat = tfd.Categorical(logits=out_pi)\n",
446-
" component_splits = [output_dim] * num_mixes\n",
447-
" mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1)\n",
448-
" sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1)\n",
436+
" mus = ops.split(out_mu, num_mixes, axis=1)\n",
437+
" sigs = ops.split(out_sigma, num_mixes, axis=1)\n",
449438
" coll = [\n",
450439
" tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)\n",
451440
" for loc, scale in zip(mus, sigs)\n",
452441
" ]\n",
453442
" mixture = tfd.Mixture(cat=cat, components=coll)\n",
454443
" loss = mixture.log_prob(y_true)\n",
455-
" loss = tf.negative(loss)\n",
456-
" loss = tf.reduce_mean(loss)\n",
444+
" loss = ops.negative(loss)\n",
445+
" loss = ops.mean(loss)\n",
457446
" return loss\n",
458447
"\n",
459448
" return mdn_loss_func\n",
@@ -560,7 +549,7 @@
560549
" accumulate += dist[i]\n",
561550
" if accumulate >= r:\n",
562551
" return i\n",
563-
" tf.logging.info(\"Error sampling categorical model.\")\n",
552+
" print(\"Error sampling categorical model.\")\n",
564553
" return -1\n",
565554
"\n",
566555
"\n",

0 commit comments

Comments
 (0)