|
102 | 102 | "import os\n", |
103 | 103 | "\n", |
104 | 104 | "# Set backend before importing keras\n", |
105 | | - "os.environ[\"KERAS_BACKEND\"] = \"jax\" # Or \"torch\" or \"tensorflow\"\n", |
| 105 | + "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\" # Or \"torch\" or \"tensorflow\"\n", |
106 | 106 | "\n", |
107 | 107 | "import numpy as np\n", |
108 | 108 | "import keras\n", |
109 | 109 | "from keras import layers\n", |
110 | 110 | "from keras import ops\n", |
111 | | - "from typing import Optional\n", |
112 | 111 | "from matplotlib import pyplot as plt\n", |
113 | 112 | "from random import randint\n", |
114 | 113 | "\n", |
|
155 | 154 | }, |
156 | 155 | "source": [ |
157 | 156 | "## Data Loading with PyDataset\n", |
| 157 | + "\n", |
158 | 158 | "Keras 3 introduces PyDataset as a standardized way to handle data.\n", |
159 | 159 | "It works identically across all backends and avoids the \"Symbolic Tensor\" issues often found\n", |
160 | 160 | "when using tf.data with JAX or PyTorch." |
|
363 | 363 | "outputs": [], |
364 | 364 | "source": [ |
365 | 365 | "\n", |
366 | | - "def MLP(\n", |
367 | | - " in_features: int,\n", |
368 | | - " hidden_features: Optional[int] = None,\n", |
369 | | - " out_features: Optional[int] = None,\n", |
370 | | - " mlp_drop_rate: float = 0.0,\n", |
371 | | - "):\n", |
| 366 | + "def MLP(in_features, hidden_features=None, out_features=None, mlp_drop_rate=0.0):\n", |
372 | 367 | " hidden_features = hidden_features or in_features\n", |
373 | 368 | " out_features = out_features or in_features\n", |
374 | 369 | " return keras.Sequential(\n", |
|
468 | 463 | "| Equation 5: Linear projection of `Z^0` (Source: Aritra and Ritwik) |\n", |
469 | 464 | "\n", |
470 | 465 | "`Z^0` is then passed on to a series of Depth-Wise (DWConv) Conv and\n", |
471 | | - "[GeLU](hhttps://keras.io/api/layers/activations/#gelu-function) layers. The\n", |
| 466 | + "[GeLU](https://keras.io/api/layers/activations/#gelu-function) layers. The\n", |
472 | 467 | "authors term each block of DWConv and GeLU as levels denoted by `l`. In **Figure 6** we\n", |
473 | 468 | "have two levels. Mathematically this is represented as:\n", |
474 | 469 | "\n", |
|
0 commit comments