|
35 | 35 | "\"Mixture Density Networks\".\n", |
36 | 36 | "\n", |
37 | 37 | "I'm going to use the new\n", |
38 | | - "[multibackend Keras Core project](https://github.com/keras-team/keras-core) to\n", |
| 38 | + "[multibackend Keras V3](https://github.com/keras-team/keras) to\n", |
39 | 39 | "build my Mixture Density networks.\n", |
40 | 40 | "Great job to the Keras team on the project - it's awesome to be able to swap\n", |
41 | 41 | "frameworks in one line of code.\n", |
42 | 42 | "\n", |
43 | | - "Some bad news: I use TensorFlow probability in this guide... so it doesn't\n", |
44 | | - "actually work with other backends.\n", |
| 43 | + "Some bad news: I use TensorFlow probability in this guide... so it\n", |
| 44 | + "actually works only with TensorFlow and JAX backends.\n", |
45 | 45 | "\n", |
46 | 46 | "Anyways, let's start by installing dependencies and sorting out imports:" |
47 | 47 | ] |
48 | 48 | }, |
| 49 | + { |
| 50 | + "cell_type": "code", |
| 51 | + "execution_count": null, |
| 52 | + "metadata": {}, |
| 53 | + "outputs": [], |
| 54 | + "source": [ |
| 55 | + "%env KERAS_BACKEND=jax" |
| 56 | + ] |
| 57 | + }, |
49 | 58 | { |
50 | 59 | "cell_type": "code", |
51 | 60 | "execution_count": null, |
|
54 | 63 | }, |
55 | 64 | "outputs": [], |
56 | 65 | "source": [ |
57 | | - "!pip install -q --upgrade tensorflow-probability keras-core" |
| 66 | + "%pip install -q --upgrade jax tensorflow-probability[jax] keras" |
58 | 67 | ] |
59 | 68 | }, |
60 | 69 | { |
|
67 | 76 | "source": [ |
68 | 77 | "import numpy as np\n", |
69 | 78 | "import matplotlib.pyplot as plt\n", |
70 | | - "import math\n", |
71 | | - "import random\n", |
72 | | - "from keras_core import callbacks\n", |
73 | | - "import keras_core\n", |
74 | | - "import tensorflow as tf\n", |
75 | | - "from keras_core import layers\n", |
76 | | - "from keras_core import optimizers\n", |
77 | | - "from tensorflow_probability import distributions as tfd" |
| 79 | + "import keras\n", |
| 80 | + "from keras import callbacks, layers, ops\n", |
| 81 | + "from tensorflow_probability.substrates.jax import distributions as tfd" |
78 | 82 | ] |
79 | 83 | }, |
80 | 84 | { |
|
161 | 165 | "source": [ |
162 | 166 | "N_HIDDEN = 128\n", |
163 | 167 | "\n", |
164 | | - "model = keras_core.Sequential(\n", |
| 168 | + "model = keras.Sequential(\n", |
165 | 169 | " [\n", |
166 | 170 | " layers.Dense(N_HIDDEN, activation=\"relu\"),\n", |
167 | 171 | " layers.Dense(N_HIDDEN, activation=\"relu\"),\n", |
|
308 | 312 | "source": [ |
309 | 313 | "\n", |
310 | 314 | "def elu_plus_one_plus_epsilon(x):\n", |
311 | | - " return keras_core.activations.elu(x) + 1 + keras_core.backend.epsilon()\n" |
| 315 | + " return keras.activations.elu(x) + 1 + keras.backend.epsilon()\n" |
312 | 316 | ] |
313 | 317 | }, |
314 | 318 | { |
|
393 | 397 | "OUTPUT_DIMS = 1\n", |
394 | 398 | "N_MIXES = 20\n", |
395 | 399 | "\n", |
396 | | - "mdn_network = keras_core.Sequential(\n", |
| 400 | + "mdn_network = keras.Sequential(\n", |
397 | 401 | " [\n", |
398 | 402 | " layers.Dense(N_HIDDEN, activation=\"relu\"),\n", |
399 | 403 | " layers.Dense(N_HIDDEN, activation=\"relu\"),\n", |
|
420 | 424 | }, |
421 | 425 | "outputs": [], |
422 | 426 | "source": [ |
423 | | - "\n", |
424 | 427 | "def get_mixture_loss_func(output_dim, num_mixes):\n", |
425 | 428 | " def mdn_loss_func(y_true, y_pred):\n", |
426 | 429 | " # Reshape inputs in case this is used in a TimeDistributed layer\n", |
427 | | - " y_pred = tf.reshape(\n", |
428 | | - " y_pred,\n", |
429 | | - " [-1, (2 * num_mixes * output_dim) + num_mixes],\n", |
430 | | - " name=\"reshape_ypreds\",\n", |
431 | | - " )\n", |
432 | | - " y_true = tf.reshape(y_true, [-1, output_dim], name=\"reshape_ytrue\")\n", |
| 430 | + " y_pred = ops.reshape(y_pred, [-1, (2 * num_mixes * output_dim) + num_mixes])\n", |
| 431 | + " y_true = ops.reshape(y_true, [-1, output_dim])\n", |
433 | 432 | " # Split the inputs into parameters\n", |
434 | | - " out_mu, out_sigma, out_pi = tf.split(\n", |
435 | | - " y_pred,\n", |
436 | | - " num_or_size_splits=[\n", |
437 | | - " num_mixes * output_dim,\n", |
438 | | - " num_mixes * output_dim,\n", |
439 | | - " num_mixes,\n", |
440 | | - " ],\n", |
441 | | - " axis=-1,\n", |
442 | | - " name=\"mdn_coef_split\",\n", |
443 | | - " )\n", |
| 433 | + " out_mu, out_sigma, out_pi = ops.split(y_pred, 3, axis=-1)\n", |
444 | 434 | " # Construct the mixture models\n", |
445 | 435 | " cat = tfd.Categorical(logits=out_pi)\n", |
446 | | - " component_splits = [output_dim] * num_mixes\n", |
447 | | - " mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1)\n", |
448 | | - " sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1)\n", |
| 436 | + " mus = ops.split(out_mu, num_mixes, axis=1)\n", |
| 437 | + " sigs = ops.split(out_sigma, num_mixes, axis=1)\n", |
449 | 438 | " coll = [\n", |
450 | 439 | " tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)\n", |
451 | 440 | " for loc, scale in zip(mus, sigs)\n", |
452 | 441 | " ]\n", |
453 | 442 | " mixture = tfd.Mixture(cat=cat, components=coll)\n", |
454 | 443 | " loss = mixture.log_prob(y_true)\n", |
455 | | - " loss = tf.negative(loss)\n", |
456 | | - " loss = tf.reduce_mean(loss)\n", |
| 444 | + " loss = ops.negative(loss)\n", |
| 445 | + " loss = ops.mean(loss)\n", |
457 | 446 | " return loss\n", |
458 | 447 | "\n", |
459 | 448 | " return mdn_loss_func\n", |
|
560 | 549 | " accumulate += dist[i]\n", |
561 | 550 | " if accumulate >= r:\n", |
562 | 551 | " return i\n", |
563 | | - " tf.logging.info(\"Error sampling categorical model.\")\n", |
| 552 | + " print(\"Error sampling categorical model.\")\n", |
564 | 553 | " return -1\n", |
565 | 554 | "\n", |
566 | 555 | "\n", |
|
0 commit comments