diff --git a/docs_nnx/conf.py b/docs_nnx/conf.py index 4b3a8c760..e02721782 100644 --- a/docs_nnx/conf.py +++ b/docs_nnx/conf.py @@ -169,6 +169,7 @@ doctest_default_flags = doctest.NORMALIZE_WHITESPACE doctest_global_setup = """ import jax +jax.config.update('jax_num_cpu_devices', 8) import jax.numpy as jnp from flax import nnx diff --git a/docs_nnx/guides/optimization_cookbook.ipynb b/docs_nnx/guides/optimization_cookbook.ipynb new file mode 100644 index 000000000..0c6bbef51 --- /dev/null +++ b/docs_nnx/guides/optimization_cookbook.ipynb @@ -0,0 +1,550 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f7fae97b", + "metadata": {}, + "source": [ + "# A Flax Optimization Cookbook" + ] + }, + { + "cell_type": "markdown", + "id": "73641666", + "metadata": {}, + "source": [ + "This notebook goes through some common problems in nontrivial training loops for Flax models. For clarity, all sections below will be training the following toy model. We allow extra keyword arguments so that the sharding and dtype can be determined on an instance by instance basis. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "22a1f44d", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "from flax import nnx\n", + "jax.config.update('jax_num_cpu_devices', 8)\n", + "import jax.numpy as jnp\n", + "import functools as ft\n", + "import matplotlib.pyplot as plt\n", + "\n", + "param_init = jax.nn.initializers.lecun_normal()\n", + "\n", + "rngs = nnx.Rngs(0)\n", + "\n", + "def make_model(rngs, **kwargs):\n", + " return nnx.Sequential(\n", + " nnx.Linear(2,8, rngs=rngs, kernel_init=ft.partial(param_init, **kwargs)),\n", + " nnx.Linear(8,8, rngs=rngs, kernel_init=ft.partial(param_init, **kwargs)))\n", + "\n", + "def loss_fn(model, x, y):\n", + " return jnp.sum((model(x) - y) ** 2)" + ] + }, + { + "cell_type": "markdown", + "id": "9e625b1c", + "metadata": {}, + "source": [ + "We'll operate on the following fake data:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "6d5e1c78", + "metadata": {}, + "outputs": [], + "source": [ + "x = rngs.normal((32, 2))\n", + "y = rngs.normal((32, 8))" + ] + }, + { + "cell_type": "markdown", + "id": "e6f0987b", + "metadata": {}, + "source": [ + "# Exponential Moving Average\n", + "\n", + "Neural network see increased robustness when, rather than using only the weights available at the end of training, we use an exponential moving average of the weights produced throughout training. It is easy to modify the standard Flax training loop to accomodate calculating exponential moving averages. " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a816ef6a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjcsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvTLEjVAAAAAlwSFlzAAAPYQAAD2EBqD+naQAARUdJREFUeJzt3Qt8znXj//H3zrPZwbAx5nzKOYSh+GVEkqK7g2NyTkV1u8vv1+mu+x+3DqKElChKKJK7SIRkhomcj2UTMxo72nn/x/fDdlutMjtc13a9no/H97bv93vtuj77XNe9vfscnXJycnIEAABgR5xtXQAAAIDfIqAAAAC7Q0ABAAB2h4ACAADsDgEFAADYHQIKAACwOwQUAABgdwgoAADA7riqDMrOztbp06fl4+MjJycnWxcHAABcA2tt2MTERAUHB8vZ2bn8BRQrnISEhNi6GAAA4DpER0erZs2a5S+gWC0nuT+gr6+vrYsDAACuQUJCgmlgyP07Xu4CSm63jhVOCCgAAJQt1zI8g0GyAADA7hBQAACA3SGgAAAAu0NAAQAAdoeAAgAA7A4BBQAA2B0CCgAAsDsEFAAAYHcIKAAAwO4QUAAAgN0hoAAAgPIVUKZOnWrW0584cWLetZiYGA0ZMkTVqlWTt7e32rRpo08//TTf98XFxWnQoEFmHx1/f3+NGDFCSUlJRSkKAAAoR647oOzYsUNz585Vy5Yt810fOnSoDh8+rFWrVmnv3r3q37+/7r33Xv3www95j7HCyf79+7Vu3TqtXr1amzdv1ujRo2Vr8SkZGvJehPZEX7R1UQAAcGjXFVCs1g4rZMybN0+VKlXKd2/r1q169NFH1b59e9WrV0/PPPOMaSWJjIw09w8ePKg1a9bo3XffVYcOHdSlSxe9+eabWrJkiU6fPi1beuXrQ/ru6HkTUvb9Em/TsgAA4MiuK6CMHz9effr0UVhY2O/uderUSZ988onpxsnOzjbBIzU1Vd26dTP3w8PDTWBp165d3vdYz+Ps7KyIiAjZ0tO9b1Db2pWUkJqpwe9F6OCZBJuWBwAAR1XogGIFjl27dmnKlCkF3l+6dKkyMjJUuXJleXh4aMyYMVqxYoUaNGiQN0YlMDAw3/e4uroqICDA3CtIWlqaEhIS8h0loaKHq94ffpNahfjrYkqGBr8boaNnE0vktQAAQDEFlOjoaE2YMEGLFy+Wp6dngY959tlndfHiRX3zzTfauXOnnnjiCTMGxRqPcr2sMOTn55d3hISEqKT4errpg+Ht1byGr35NTtcD8yJ0/BwDeAEAKE1OOTk5Odf64JUrV+ruu++Wi4tL3rWsrCwzk8fqorEGx1otJfv27VOzZs3ydeFY1+fMmaP58+frySef1IULF/LuZ2ZmmsCzbNky8/wFtaBYRy6rBcUKKfHx8WYmUEm4kJyuge9e7uYJ8vXQJ6NDVaeKd4m8FgAAjiAhIcE0NFzL3+9CtaB0797dtITs3r0777DGklgDZq2vU1JSLj+pc/6ntQKNNR7FEhoaalpYcgfNWjZs2GDuW4NmC2J1FVk/yNVHSavk7a5FI9qrUVBFnU1I08B52xQdd/nnAwAAJcu1MA/28fFR8+bN812z1jqxxptY162xJ1ZLiTXu5NVXXzXXrVaX3OnElhtuuEG9evXSqFGjTIuK9T2PPPKI7r//fgUHB8ueVK7oocUjO+r+d8J1/FyyHpi3TZ+MCVUN/wq2LhoAAOVasa4k6+bmpi+//FJVq1ZV3759zRopH3zwgRYuXKjbb78973HWGJYmTZqYFhnrujXV+J133pE9qurjoY9GdVSdyl46deGSaUmJiU+1dbEAACjXCjUGpSz2YRWX0xcv6b53whUdd0n1qnhryZiOCvQpeKAwAAAoxTEojizYv4I+HtXRdO+cOJ+sgfMidC7xvwN3AQBA8SGgFELNSl4mpFT389Sx2CQNenebzicRUgAAKG4ElEKqVflySLGmHh85m2QWc4tLTi/2NwYAAEdGQLkO1nooS0aHKtDHQ4diEs3AWWvdFAAAUDwIKNepbhVvfTy6o5nlY4WUQe9G6GIKIQUAgOJAQCmC+lUr6uNRHVSloocOnEkwGwzGp2QUyxsDAIAjI6AUUYNAHxNSKnu7a98vV0LKJUIKAABFQUApBg2DfMxibgHe7tr7S7yGvhehhFRCCgAA14uAUkwaV/PR4pEdVMnLTXtOxWvY/O1KJKQAAHBdCCjF6Ibqvlo0soP8vdz0Q9RFQgoAANeJgFLMmgX7adGIDvKr4KZdURf14Ps7aEkBAKCQCCgloHkNP9PdY4WUyJMXCCkAABQSAaWEEFIAALh+BJQSREgBAOD6EFBKGCEFAIDCI6CUAkIKAACFQ0ApJYQUAACuHQGlFBFSAAC4NgSUUkZIAQDgrxFQ7CCkDJ2/nb17AAC4CgHFDkKKtSz+kPe2swsyAABXEFDsIKRYe/fsibZCSoTiU9gFGQAAAoodhJSPRnY0uyD/eCpeg97bposp6bYuFgAANkVAsQNNg3318eiOquztrn2/JGjgvAhdSCakAAAcFwHFTjSpdjmkVKnorgNnEjTw3QjFEVIAAA6KgGJHGgX5aMnojqrq46GDVkiZt02/JqXZulgAAJQ6AoqdaRB4OaQE+njoUEyiHpi3TecSCSkAAMdCQLFD9atWNCElyNdDR84mmZASm5hq62IBAFBqCCh2ql7VivpkdKiq+3nqWGyS7p+7TTHxhBQAgGMgoNixOlW8TUip4V9BJ84n6753wvXLxUu2LhYAACWOgGLnalX2Mt09IQEVdPLXFN03N1zRcSm2LhYAACWKgFIGhAR4mZaUulW8derCJRNSfj6fbOtiAQBQYggoZUSwfwXTklK/qrdOx6ea7p7j55JsXSwAAEoEAaUMCfL11JLRoWoUVFFnE9J039xtOno20dbFAgCg2BFQyhhrEbePR3XUDdV9dT4pTfe/s80s6gYAQHlCQCmDKle0QkoHtajhp1+T0806Kft+ibd1sQAAKDYElDLK38tdi0Z2UOsQf11MyTDL4u+OvmjrYgEAUCwIKGWYXwU3fTiivdrVrqSE1EwNfjdCO36Os3WxAAAoMgJKGefj6aaFD7VXaL3KSkrL1ND3tuv7Y+dtXSwAAGwXUKZOnSonJydNnDgx3/Xw8HDdeuut8vb2lq+vr2655RZduvTfFVDj4uI0aNAgc8/f318jRoxQUhJTZq+Xt4er3h9+k25pVFWXMrI0fMEOfXs4tihvLQAAZTOg7NixQ3PnzlXLli1/F0569eqlnj17avv27eZxjzzyiJyd//tSVjjZv3+/1q1bp9WrV2vz5s0aPXp00X4SB+fp5qJ5Q9uqR9MgpWdma/QHO7VmX4ytiwUAwHVxysnJySnsN1mtHW3atNHbb7+tf/3rX2rdurXeeOMNc69jx47q0aOHXnrppQK/9+DBg2ratKkJLu3atTPX1qxZo9tvv12nTp1ScHDwX75+QkKC/Pz8FB8fb1ph8F8ZWdma+Mlu/efHM3JxdtL0+1rrzlZ/XacAAJS0wvz9vq4WlPHjx6tPnz4KCwvLdz02NlYREREKDAxUp06dFBQUpK5du2rLli35Wlisbp3ccGKxnsdqYbG+F0Xj5uKsGfe1Vv82NZSVnaOJS37Qsp3RVCsAoExxLew3LFmyRLt27TItIL914sQJ8+8LL7ygV1991bSsfPDBB+revbv27dunhg0bKiYmxgSYfIVwdVVAQIC5V5C0tDRzXJ3A8CdvqouzXr2nlTxcXfTx9ihNWv6j0jKzNbhjbaoNAFAmFKoFJTo6WhMmTNDixYvl6en5u/vZ2dnm3zFjxmj48OG68cYbNX36dDVu3Fjz58+/7kJOmTLFNAnlHiEhIdf9XI7C2dlJL9/dXA92qmPOn1m5T+9t+cnWxQIAoPgDSmRkpOnGscafWK0e1rFp0ybNnDnTfG116VisMSZXu+GGGxQVFWW+rlatmnmOq2VmZpqZPda9gkyePNn0V+UeVlDCX7NmWD3ft6nGdatvzl9afUCzvj1G1QEAylcXj9VVs3fv3nzXrJaSJk2a6KmnnlK9evXMINfDhw/ne8yRI0fUu3dv83VoaKguXrxowk7btm3NtQ0bNpjWlw4dOhT4uh4eHubA9YWUf9zWWJ6uLpr+zRG9svawktMyNem2xuYeAABlPqD4+PioefPm+a5Za51Urlw57/qkSZP0/PPPq1WrVmYMysKFC3Xo0CEtX748rzXFmoY8atQozZkzRxkZGWYa8v33339NM3hQeFYQmRDWUBXcnfXyl4f09sbjSknP0nN3NDVdQQAAlPlBsn/FWrQtNTVVjz/+uOm2sYKKtd5J/fqXuxks1hgWK5RYLTLW7J0BAwaYbiKUrNG31FcFd1c9u3KfFmz9WSnpmZrSv6WZjgwAQJlfB8XWWAelaD6NPKVJy/coO0e6o2V1s1aKNT0ZAIAyvQ4KyrYBbWvqrYFt5ObipNU/ntG4RZFKzciydbEAAMhDQHFQt7eorneGtJOHq7O+ORirkQt3mi4fAADsAQHFgf1Pk0CzyaCXu4u2HDtvdkJOSM2wdbEAACCgOLpO9ato0cgO8vV01c6TFzRoXoTiktNtXSwAgIOjBQVqU6uSPh7dUQHe7tr7S7zufydcsQmp1AwAwGYIKDCaBftp6ZiOCvL10JGzSbpnTrii41KoHQCATRBQkKdBoI+Wj+2kWgFeiopL0d/mhOtYbCI1BAAodQQU5BMS4KVlY0PVMLCiYhJSTUjZeyqeWgIAlCoCCn4nyNdTn4wJVcuafrqQkqGB87Zp+09x1BQAoNQQUFAga8Ds4pEd1KFugBLTMjV0foQ2Hs6/CzUAACWFgII/5OPppoUPtdetTQKVmpGtUR/s1H9+PEONAQBKHAEFf8rTzUVzBrc1e/ZkZOXo0Y93aemOaGoNAFCiCCj4S+6uzppx/416oH2I2WDwH5/+qPe2/ETNAQBKDAEF18TF2Ukv391Co2+pZ85fWn1Ar687ojK4GTYAoAwgoOCaOTk5aXLvJnqyRyNzPnP9Ub2war+yrWYVAACKEQEFhQ4pj3ZvqJf6NZOTk7Qw/KSeWLpbGVnZ1CQAoNgQUHBdhoTW0Rv3tZars5NW7j6tMR9GKjUji9oEABQLAgquW7/WNTRvaDt5uDprw6FYDX1vuxJSM6hRAECREVBQJP/TJFAfjuggHw9Xbf85Tg+8s03nk9KoVQBAkRBQUGTt6wbo49EdVaWiu/afTtC9c8J16gI7IQMArh8BBcWieQ0/LRvbSTX8K+jE+WR2QgYAFAkBBcWmbhVvLR8XqgaBFXUm/vJOyHuiL1LDAIBCI6CgWFX3q6ClY0LV6qqdkLccPU8tAwAKhYCCktkJeVRHdWlQRcnpWRq+YLtW/3iamgYAXDMCCkpERQ9XvfdgO/VpkbvJ4A/6cNtJahsAcE0IKCgxHq4umvnAjRrUoZasLXueXblPM745yv49AIC/REBBiW8y+K+7mmtC94bmfPo3R9i/BwDwlwgoKJX9ex7v0Uj/vPO/+/dM+GS30jPZvwcAUDACCkrNsE51NOP+G+Xm4qQv9pzWiIU7lJyWyTsAAPgdAgpK1Z2tgvXesJtUwc1F3x09r4HvRiguOZ13AQCQDwEFpe6WRlX10agO8vdyMwu53TNnK0vjAwDyIaDAJm6sVUnLx4aqup+nTpxL1oDZW3UoJoF3AwBgEFBgMw0CffTZw53UMLCiziakmaXxt/8UxzsCACCgwPZL4y8bG6p2tSspMTVTg9+L0Nr9MbwtAODgaEGBzfl7uWvRyA4KuyHITD0etyhSiyNYdRYAHBkBBXbB081Fcwa30f03hSg7R/q/Faw6CwCOjIACu+Hq4qwp/VvosVsb5K06+8zKfcqyEgsAwKEQUGB3q84+0bOxXup3edXZxRFRGr94l1IzsmxdNABAKSKgwC4NCa2jWQPbyN3FWWv2x2jo/O2KT8mwdbEAAKWEgAK7dXuL6lr4UHv5eLia6cd/m7tVpy9esnWxAAD2HlCmTp1qmuQnTpz4u3s5OTnq3bu3ub9y5cp896KiotSnTx95eXkpMDBQkyZNUmYme7Lg90LrV9bSsaEK8vXQkbNJ6v82C7oBgCO47oCyY8cOzZ07Vy1btizw/htvvGHCyW9lZWWZcJKenq6tW7dq4cKFWrBggZ577rnrLQrKuRuq++qzhzurQWBFxSSkmgXdwo//autiAQDsLaAkJSVp0KBBmjdvnipVqvS7+7t379Zrr72m+fPn/+7e119/rQMHDmjRokVq3bq1aWV56aWXNGvWLBNagILU8K9glsa/qc7lBd2Gzd+u1T+eprIAoJy6roAyfvx40woSFhb2u3spKSkaOHCgCRzVqlX73f3w8HC1aNFCQUFBedduu+02JSQkaP/+/QW+Xlpamrl/9QHHXNDtwxEd1KtZNaVnZevRj3/Q/C0/2bpYAAB7CChLlizRrl27NGXKlALvP/744+rUqZP69etX4P2YmJh84cSSe27dK4j1Wn5+fnlHSEhIYYuNcrSg26xBbTQstLZycqQXVx/Qy18eVDZrpQCA4waU6OhoTZgwQYsXL5anp+fv7q9atUobNmww40+K0+TJkxUfH593WOWA43JxdtILdzbT072bmPN3Np/QxE92Ky2TtVIAwCEDSmRkpGJjY9WmTRu5urqaY9OmTZo5c6b5et26dTp+/Lj8/f3z7lsGDBigbt26ma+tbp+zZ8/me97c84K6hCweHh7y9fXNd8CxWQOwx3atr9fvbSVXZyet2nNaw9/foYRU1koBgPLAKceaD3yNEhMTdfJk/k3chg8friZNmuipp55SlSpVdP78+Xz3rfEmM2bMUN++fVW3bl199dVXuuOOO3TmzBkzxdjyzjvvmKnGVvixwshfscagWF09VmsKYQXfHT2nsR9GKjk9S42DfPT+8JsU7F+BigEAO1OYv9+XmziukY+Pj5o3b57vmre3typXrpx3vaBWkFq1aplwYunZs6eaNm2qIUOGaNq0aWbcyTPPPGMG3l5LOAF+6+aGVc1aKVYLyuGziWatFCukWNOTAQBlU6mvJOvi4qLVq1ebf0NDQzV48GANHTpUL774YmkXBeVIs2A/rRjfWQ2vWivFalkBADhAF4+9oIsHf8Tar2fMop3adiLOjE2ZOqCl7mlbkwoDgDL295u9eFCu+Hm5mf17+rUOVmZ2jv6+bI9mrj9qtl4AAJQdBBSUOx6uLpp+b2uN61bfnL++7oie/nSvMrKybV00AMA1IqCgXHJ2dtJTvZropbuay9lJ+mRntEYs3KmkNDalBICygICCcm1Ix9p6Z0g7VXBz0eYj53TvnHCdTUi1dbEAAH+BgIJyL6xpkJaM7qjK3u46cCZBd8/6Xodi2M8JAOwZAQUOoVWIv1Y83Fn1qnrrdHyq7pnNNGQAsGcEFDiMWpW99Nm4TmpfN8CMRbEWdvtkR5StiwUAKAABBQ7F38tdH45or7uuTEN+6tO9enXtYaYhA4CdIaDAMach39daj93awJy/9e0xdkMGADtDQIHD7ob8RM/GmnZPS7Pi7Oe7T2vIu9t1ITnd1kUDABBQ4OjubReiBcPby8fDVdt/jtOA2Vt18tdkWxcLABweLShweF0aVtHycZ1Uw7+CTpxP1t1vb1XkyQsOXy8AYEsEFEBS42o+WvFwJzWv4au45HQ9MG+bvthzmroBABshoABXBPp6aumYUIXdEKT0zGw9+vEPepONBgHAJggowFW83F01d0hbjexS15y/tu6Inly2R2mZWdQTAJQiAgrwGy7OTnrmjqZmo0Hr6892/aIh7zHDBwBKEwEF+JONBuc/eJMqWjN8fopT/9lb9dN5ZvgAQGkgoAB/omujqvr0ygwfK5zc/fb3JqwAAEoWAQW4lhk+4zuZDQcvpmRo0Lvb9NmuU9QbAJQgAgpwDQJ9PLVkVEfd3qKaMrJy9MTSPXrt68PKzs6h/gCgBBBQgGtUwd1Fbz3QRuO61Tfnb244ZqYiX0pnhg8AFDcCClCY/8M4O+mpXk30yj0t5ebipP/sPaP73wlXbEIq9QgAxYiAAlyHv7UL0aIRHeTv5aY9p+LVb9b32vdLPHUJAMWEgAJcpw71Kuvz8Z1Vv6q3zsSn6m9zwrV2fwz1CQDFgIACFEHtyt767OHOurlhFV3KyNLYRZGas+m4cnIYPAsARUFAAYrIr4Kb3n/wJrOwm5VLpn51SP9Y/qPZzwcAcH0IKEAxcHVxNkvj//POZnJ2kpZFntLg9yLMzsgAgMIjoADFaFinOmZ5fJ8ry+PfNet7HT2bSB0DQCERUIBi1q1xoD57uJNCAiooKi5Fd7+9Vd8eiqWeAaAQCChACWgY5KPPx3dR+7oBSkrL1IiFO/TudycYPAsA14iAApSQAG93s1bK/TeFyFoR/1//OWgGz6ZlsvIsAPwVAgpQgtxdnTWlfws9d0fTvMGzg+ZF6HxSGvUOAH+CgAKUMCcnJz3Upa7eH95ePp6u2nnygvq99b0Onkmg7gHgDxBQgFLStVFVrXi4s+pU9tIvFy9pwOyt+pqVZwGgQAQUoBQ1CKyoleM7q3ODykpJz9KYRZGa9e0xBs8CwG8QUIBS5u/lrgXD22to6OWVZ19Ze1gTluzWpXQGzwJALgIKYANuLs56sV9z/euu5nJ1dtKqPaf1t7lbdfriJd4PACCgALY1uGNtLRrZwUxJ3vdLgu5863tFnozjbQHg8GhBAWysY73K+nx8ZzWp5mOmH9//zjYt3RFt62IBQNkNKFOnTjVTKCdOnGjO4+Li9Oijj6px48aqUKGCatWqpccee0zx8fH5vi8qKkp9+vSRl5eXAgMDNWnSJGVmZhbtJwHKsJAAL306rpN6N6+mjKwc/ePTH/XCqv3KzGJHZACO6boDyo4dOzR37ly1bNky79rp06fN8eqrr2rfvn1asGCB1qxZoxEjRuQ9Jisry4ST9PR0bd26VQsXLjSPe+6554r+0wBlmLeHq2YNbKOJYQ3N+YKtP2vY+9t1gR2RATggp5wcax5B4SQlJalNmzZ6++239a9//UutW7fWG2+8UeBjly1bpsGDBys5OVmurq766quvdMcdd5ggExQUZB4zZ84cPfXUUzp37pzc3d3/8vUTEhLk5+dnWmZ8fX0LW3zA7q3Zd0ZPLN1jpiLXCvDSu8PaqVGQj62LBQBFUpi/39fVgjJ+/HjTChIWFvaXj80thBVOLOHh4WrRokVeOLHcdtttptD79+8v8DnS0tLM/asPoDzr1by62RG5ZqUrOyLP+l5rWdQNgAMpdEBZsmSJdu3apSlTpvzlY8+fP6+XXnpJo0ePzrsWExOTL5xYcs+tewWxXstKXLlHSEhIYYsNlDlNqvlq1SNd1LFegJKtRd0+jNTr644o29p5EADKuUIFlOjoaE2YMEGLFy+Wp6fnnz7WauWwWlmaNm2qF154oUiFnDx5smmJyT2scgCOwJp+/OGIDnqwUx1zPnP9UY3+MFKJqRm2LhoA2E9AiYyMVGxsrBl/YnXZWMemTZs0c+ZM87U1ANaSmJioXr16ycfHRytWrJCbm1vec1SrVk1nz57N97y559a9gnh4eJhuoqsPwJEWdXvhzmZ65Z6WZnfkbw6e1V2zvteJc0m2LhoA2EdA6d69u/bu3avdu3fnHe3atdOgQYPM1y4uLqblpGfPnmaw66pVq37X0hIaGmqewwo6udatW2dCh9XaAqBgf2sXoqVjQlXN11PHzyWbHZE3HMof9gHAoWfxXK1bt255s3hyw0lKSoppOfH29s57XNWqVU2AsVpZrMcHBwdr2rRpZtzJkCFDNHLkSL388svX9JrM4oEji01M1cOLdmnnyQtycpL+3rOxHu5W36xJBAAOPYvnj1iDZyMiIkwLSYMGDVS9evW8I3fciBVSVq9ebf61WlOsKchDhw7Viy++WJxFAcqtQB9PfTSqowZ1qJW32eD4j3YpOY3FDgGUH0VuQbEFWlCAyz6KiNLzq/aZ1WcbB/nonaFtVbvyf1suAcCe2KwFBUDpGtihlpaM7qiqPh46fDZRfd/com8P/3d8FwCUVQQUoIxrWztAXzzSRTfW8ldCaqYeWrBDb64/ynopAMo0AgpQDlTz8zQtKVaLitVp+9q6IxqziPVSAJRdBBSgnPBwddHLd7fQvwe0kLuLs9YdOKt+s77XsdhEWxcNAAqNgAKUM/fdVEtLx4aqup+nTlxZL2XNvoK3kQAAe0VAAcqh1iH++uLRLupQ9/I+PmMXReqVtYeUxT4+AMoIAgpQTlWp6KFFIzvooc51zfmsb49r+IIdupiSbuuiAcBfIqAA5Xwfn+f6NtWM+1vL081Zm4+cU9+3tmjfL/G2LhoA/CkCCuAA+rWuoc/GdVZIQAVFx13SgNlbtWwnu4IDsF8EFMBBNA321epHbtatTQKVlpmtSct/1P+u2Ku0zMu7kAOAPSGgAA7Ez8tN7w5tpyd6NDIbDVpL5d87J1y/XLxk66IBQD4EFMDBODs76bHuDfX+gzfJr4Kb9pyKN0vkbzl63tZFA4A8BBTAQXVrHKjVj3ZRs2BfxSWna+j8CL298ZjK4P6hAMohAgrgwEICvPTpuE76W9uaspZImbbmsMZ8GKmE1AxbFw2AgyOgAA7O081F0+5pqSn9Ly+R//WBs7rzzS06eCbB1kUD4MAIKADk5OSkB9rX0rKxoQr289TPv6bo7re/1/LIU9QOAJsgoADI0yrEX/957GZ1bVRVqRnZ+vuyPXr60x+VmsFUZACli4ACIJ9K3u5mhs/jYZenIi/ZEW0Wdov6NYWaAlBqCCgAfv+LwdlJE8IaauHw9qrk5ab9pxN0x5vf6ZsDZ6ktAKWCgALgD93SqKrp8rmxlr8SUjM18oOd+veaQ8rMyqbWAJQoAgqAPxXsX0GfjA7V8M51zPnsjcc15L3tOpeYRs0BKDEEFAB/yd3VWc/3baa3Bt4ob3cXhZ/4VX1mfqdtJ36l9gCUCAIKgGt2R8tgff5IFzUKqqjYxDQNnLdNs749pmxrlTcAKEYEFACF0iCwolaO76z+bWqY1WdfWXtYDy3coQvJ6dQkgGJDQAFQaF7urnrtb600bUBLebg6a+Phc6bLJ/LkBWoTQLEgoAC47tVn770pxLSm1K3irdPxqbpvbrje/e4EGw4CKDICCoAiuaG6r754tIv6tgpWZnaO/vWfg2bDwfhLbDgI4PoRUAAUWUUPV828v7Veuqt53oaD1sJuP566SO0CuC4EFADF1uUzpGNtfTquk0ICKig67pLumR2uhVt/pssHQKERUAAUqxY1/bT60Zt1W7MgpWdl6/lV+zVu0S66fAAUCgEFQLHzq+CmOYPb6vm+TeXm4qQ1+2PMLJ/d0XT5ALg2BBQAJdblM7xzXdPlUyvAS6cuXNLf5mxllg+Aa0JAAVCiWtb01+rHuuj2FtWUkXV5ls+oD3bqYgoLuwH4YwQUACXO19NNswa2uTzLx9VZ3xyM1e0zrIXd4qh9AAUioAAo1Vk+Kx7ulLew271zt2nOpuPs5QPgdwgoAEpVs2A/s7Dbna2ClZWdo6lfHdLwBTt0PimNdwJAHgIKAJss7Dbj/taa0r+F2ctn05Fz6j3jO31/7DzvBgCDgALAZl0+D7SvpVWPdFGjoIo6l5imwe9F6JW1h5SZlc27Ajg4AgoAm2pczUefj+9iwkpOjjTr2+O6751tOnUhhXcGcGBFCihTp041/xU0ceLEvGupqakaP368KleurIoVK2rAgAE6e/Zsvu+LiopSnz595OXlpcDAQE2aNEmZmZlFKQqAMqyCu4vp7rFm+vh4uCry5AUzy2fNvjO2LhqAshZQduzYoblz56ply5b5rj/++OP64osvtGzZMm3atEmnT59W//798+5nZWWZcJKenq6tW7dq4cKFWrBggZ577rmi/SQAyrw+Lavrywk3q3WIvxJSMzV20S49s3KvUjOybF00AGUhoCQlJWnQoEGaN2+eKlWqlHc9Pj5e7733nl5//XXdeuutatu2rd5//30TRLZt22Ye8/XXX+vAgQNatGiRWrdurd69e+ull17SrFmzTGgB4NhCAry0bGyoxnatb84XbYvSXbO+17HYRFsXDYC9BxSrC8dqBQkLC8t3PTIyUhkZGfmuN2nSRLVq1VJ4eLg5t/5t0aKFgoKC8h5z2223KSEhQfv37y/w9dLS0sz9qw8A5Zebi7Oe7t1EHzzUXlUquutQTKLueHOLPoqIYmdkwEEUOqAsWbJEu3bt0pQpU353LyYmRu7u7vL398933Qoj1r3cx1wdTnLv594riPVafn5+eUdISEhhiw2gDLqlUVXT5XNzwypKzcjW/67Ya3ZGZpl8oPwrVECJjo7WhAkTtHjxYnl6eqq0TJ482XQf5R5WOQA4hkAfTy0c3l7/e3uTvJ2RrTVTIk78auuiAbCXgGJ14cTGxqpNmzZydXU1hzUQdubMmeZrqyXEGkdy8WL+LdWtWTzVqlUzX1v//nZWT+557mN+y8PDQ76+vvkOAI7D2dlJo2+pr8/GdTbL5J+JT9UD87bp9a8Ps2YKUE4VKqB0795de/fu1e7du/OOdu3amQGzuV+7ublp/fr1ed9z+PBhM604NDTUnFv/Ws9hBZ1c69atM6GjadOmxfmzAShnWtT00+pHu+hvbWsqO0eaueGYWTMlOo41U4Dyxiknx1oa6fp169bNzMZ54403zPm4ceP05ZdfmqnDVuh49NFHzXVrJk/uNGPr8cHBwZo2bZoZdzJkyBCNHDlSL7/88jW9pjVI1hqLYnX30JoCOKZVe07r/z7bq8S0TPl4uurlu1uob6tgWxcLQDH9/S72lWSnT5+uO+64wyzQdsstt5hum88++yzvvouLi1avXm3+tVpTBg8erKFDh+rFF18s7qIAKMeszQatAbRtavkrMTVTj378gyYt26OkNBZ9BMqDIreg2AItKAByWfv2zFh/VG99e8wslV+7spfeuK+1bqz13zWaANgHm7agAEBpcnVx1pM9G2vJqI6q4V9BJ39N0T1zwvXm+qPKsgaqACiTCCgAyoUO9SqbLp87WlY3weS1dUd0/zvhbDoIlFEEFADlhl8FN735wI16/d5Wqujhqh0/X1DvN77T57t/sXXRABQSAQVAuWLtsN6/TU19+diVAbRpmZqwZLce/2S3ElIzbF08ANeIgAKgXKpV2UtLx4RqYlhDOTtJK374RbfP+E47f46zddEAXAMCCoByPYB2YlgjsztySEAFnbpwSffODderaw8rIyvb1sUD8CcIKADKvba1A0yXT/82NcwKtNaU5AGzt+r4uSRbFw3AHyCgAHAIPp5uev3e1po1sI0ZTPvjqXj1mfmdPgz/WWVwOSig3COgAHAofVpW19qJt+jmhlWUmpGtZz/fr+ELdig2MdXWRQNwFQIKAIdTzc9TC4e313N3NJW7q7M2Hj6nXm98p7X7Y2xdNABXEFAAOCRnZyc91KWu2R25aXVfxSWna8yHkXpq+Y/s5wPYAQIKAIfWKMhHK8Z30tiu9eXkJH2yM5rpyIAdIKAAcHgeri56uncTfXxlP5+ouBQzHfnfaw4pLTPL4esHsAUCCgBc0bFeZX018WYNaFPTTEeevfG4+r31vQ7FJFBHQCkjoADAVXw93fTava00Z3BbBXi761BMou5883vN2XSc3ZGBUkRAAYAC9GpezUxHDrshUOlZ2Zr61SGzO3LUrynUF1AKCCgA8Aeq+nho3tB2mjagpbzdXS7vjjxjs5Zsj2JxN6CEEVAA4C92R773phCtmXiL2tcNUHJ6lp7+bK9GLNzJ4m5ACSKgAMA1CAnw0pJRHfV/t98gdxdnbTgUq57TN+uLPaepP6AEEFAA4Fp/YTo7adQt9fTFo13ULNhXF1My9OjHP2j84l1moTcAxYeAAgCF1Liaj1aO76wJ3RvKxdlJ/9l7Rj2nb9LXLJUPFBsCCgBcBzcXZz3eo5FWPtxZjYIq6nxSukZ/GKknlu5W/KUM6hQoIgIKABRBi5p+WvVIF43pWs8slf/Zrl902/TN2nTkHPUKFAEBBQCKyNPNRZN736DlY0NVp7KXYhJSNWz+dv3vir1sPAhcJwIKABSTtrUD9OWEm/Vgpzrm/KOIKPV6Y7O2Hj9PHQOFREABgGLk5e6qF+5spo9GdjAbD566cEkD50Xo2ZX7lJyWSV0D14iAAgAloFODKlr7+C0a2KGWOf9w20n1mkFrCnCtCCgAUEIqerjq5btbaNGIy60p0XG0pgDXioACACWsS8MqWjPxZj3QPn9rSvjxX6l74A8QUACgFPh4umlK/xb6cER7Bft5mtaUB+Zt0/OfMzYFKAgBBQBK0c0Nq5qxKbmtKQvDGZsCFISAAgB20JpizfT5vxV7lZjKKrQAAQUA7KA1ZXDHy60piyOizCq0Gw/H8r7A4dGCAgA2bk35110t9NGoDqoV4KXT8al68P0d+vuyPYpPoTUFjouAAgB2oFP9yzN9Hupc1+zpszzylHpM36R1B87aumiATRBQAMCOVqF9rm9TLRsTqnpVvBWbmKZRH+zUYx//oLjkdFsXDyhVBBQAsDPt6lze02ds1/pydpJW7TmtHq9v0hd7TisnJ8fWxQNKBQEFAOx0h+SnezfRioc7q3GQj35NTtejH/+gUR9EKiY+1dbFA0ocAQUA7FirEH998WgXPR7WSG4uTvrm4FnTmmLtlJydTWsKyq9CBZTZs2erZcuW8vX1NUdoaKi++uqrvPsxMTEaMmSIqlWrJm9vb7Vp00affvppvueIi4vToEGDzPf7+/trxIgRSkpKKr6fCADKGXdXZ00Ia6j/PHazWof4KzEtU/+7Yq9Zifan88m2Lh5g+4BSs2ZNTZ06VZGRkdq5c6duvfVW9evXT/v37zf3hw4dqsOHD2vVqlXau3ev+vfvr3vvvVc//PBD3nNY4cR6/Lp167R69Wpt3rxZo0ePLv6fDADKmUZBPvp0XCc9e0dTVXBzUcRPcer1xmbN3XRcmVnZti4eUKyccoo44iogIECvvPKKaQmpWLGiaWWxWlFyVa5cWf/+9781cuRIHTx4UE2bNtWOHTvUrl07c3/NmjW6/fbbderUKQUHB1/TayYkJMjPz0/x8fGmJQYAHE10XIomf7ZXW46dN+ctavjp3wNaqmkwvxNhvwrz9/u6x6BkZWVpyZIlSk5ONl09lk6dOumTTz4x3TjZ2dnmfmpqqrp162buh4eHm26d3HBiCQsLk7OzsyIiIv7wtdLS0swPdfUBAI4sJMDLLJU/7Z6W8vV01d5f4nXnW1v07zWHlJqRZeviAUVW6IBidd1YLSUeHh4aO3asVqxYYVpFLEuXLlVGRoZpNbHujxkzxtxv0KBB3hiVwMDAfM/n6upqWmGse39kypQpJnHlHiEhIYX/SQGgnHFyctK97UL0zZNd1bt5NWVm52j2xuOm22frlZYVwGECSuPGjbV7927T4jFu3DgNGzZMBw4cMPeeffZZXbx4Ud98840Zo/LEE0+YMShWqCmKyZMnm+ag3CM6OrpIzwcA5Umgj6dmD26ruUPaKsjXQz//mqKB70Zo0rI9upjCAm9w0DEoVhdN/fr19Y9//MO0lOzbt0/NmjXLd9+6PmfOHM2fP19PPvmkLly4kHc/MzNTnp6eWrZsme6+++5rek3GoADAH/x+TM3QtDWHtGhblDmvUtFdz/Vtpr4tq5sWF6Dcj0HJZY01scaIpKSkXH5C5/xP6eLiYh5jscaqWC0s1iygXBs2bDD3O3ToUNSiAIDD872y+eDysaFqEFhR55PSzVL5Ixbu1C8XLzl8/aCctqBYXS29e/dWrVq1lJiYqI8++sjM0Fm7dq0ZCGuNRalevbpeffVVMw5l5cqVmjRpkplObM3UsVjff/bsWdOiYo1XGT58uBk0az3XtaIFBQD+WlpmlhmTMuvbY8rIypGXu4v+3rOxhnWqIxdrDX2glBXm73ehAoo1lXj9+vU6c+aMeQFr0bannnpKPXr0MPePHj2qp59+Wlu2bDGLr1ldO3//+9/zTTu2Zvg88sgj+uKLL0xry4ABAzRz5kwz8LYkfkAAcHTHYhP19Kd7tfPkhbzVaV++u7maBfvZumhwMAklFVDsBQEFAArHWhb/o+1R+vdXh8xKtFYLyogudTUxrKHZRRkod2NQAAD2z9nZSYM71jZTkvu0qK6s7By9s/mEery+Wd8eirV18YDfIaAAgAMJ8vXUrEFtNP/BdqrhX8EMnB2+YIfGL96l2AR2SYb9IKAAgAO6tUmQ1j1xi0bfUs909/xn7xl1f32TFm07yS7JsAsEFABwUNbYk/+9/QateqSzWtX0U2Jqpp5ZuU/3zNmqQzFsKQLbIqAAgIOzZvN89nBnvdC3qbzdXbQr6qL6zNyiKV8eVEp6pq2LBwdFQAEAmG6eBzvXNYNoezWrZgbRzr0yiHbdgbPUEEodAQUAkKe6XwXNGdJW7w377yDaUR/sNAcr0aI0EVAAAL/T/YYgffNEV43rVl+uzk6mFSXstU16Z/NxZWRd3r4EKEkEFABAgSq4u+ipXk305YSbdVOdSrqUkaWXvzykvm9uUeTJOGoNJYqAAgD4U42CfPTJ6FBNu6elKnm56VBMogbMDtfkz37UheR0ag8lgoACAPjrPxbOTrq3XYjWP9lN97araa59vD1at762UZ/siGLtFBQ79uIBABTajp/j9MyKfTp8NtGct6nlr5fuYgNC/Dk2CwQAlDhrsOzCrT9r+rojSk7PkrOTNDS0jp7o2Ui+nm68A/gdNgsEAJQ4Nxdnjby5nun26dOyurJzpAVbf1b31zbp892/KCcnh3cB140xKACAIqnm56lZA9vowxHtVbeKt84lpmnCkt0aOC9Cx2IvdwEBhUVAAQAUi5sbVtWaiTfryR6N5OHqrPATv6rXG9+ZJfOT0lgyH4VDQAEAFBsPVxc92r2hWeSte5NAZV5ZMr/7axu1as9pun1wzZjFAwAoMesPntU/vzigqLgUcx5ar7L+2a+ZWVsFjichIUF+fn6Kj4+Xr6/vnz6WgAIAKFGpGVmau+mE3t54TGmZ2Wbp/OGd62hCWCNV9HCl9h1IQiECCl08AIAS5enmoglhl7t9ejQNMt0+8777Sbe+upHZPvhDtKAAAErVt4dj9c9V+/Xzr5e7fTrUDTDdPk2q/fl/UaPso4sHAGD33T7vfndCb317TKkZ2XJxdtKQjrX1eFgj+XmxyFt5RUABAJQJpy6k6OUvD+rLvTHmPMDbXZNua2z2/bFCC8oXAgoAoEz5/th5vbBqv47GJpnzFjX89MKdzdS2diVbFw3FiIACACiTe/t8GH7S7O2TeGVht/5taujpXk0U6Otp6+KhGBBQAABl1vmkNE1bc0hLd54y59ZU5Me6N9CDnerK3ZXJp2UZAQUAUObtjr6o51ft157oi+a8XhVvPXtHU/1Pk0BbFw3XiYACACgXsrNztHzXKdOicj4p3Vzr1riqCSr1q1a0dfFQSAQUAEC5kpiaoTc3HNP73/+kjKwcsxrtg53q6LGwhvL1ZFpyWUFAAQCUSyfOJen//eeg1h+KNeeVr0xL/hvTkssEAgoAoFzbeDhWL60+oOPnks15s2BfPd+3mdrXDbB10fAnCCgAAIeYlvxB+Em98c0RJaZenpZ8R8vqerp3E9Ws5GXr4qEABBQAgMP4NSlNr359REt2RCknR/Jwddaom+tpXLf68ma3ZLtCQAEAOJwDpxP04ur92nYizpwH+njoH72aqP+NNeTMsvl2gYACAHBIOTk5+vrAWbO/z8kruyVby+Y/17epbqrD+BRbI6AAABxaWmaWFm79WW+uP5a3bH4fa3xKryYKCWB8iq0QUAAAuLJs/uvrjmjJ9ihl58gslT+yS10zPsWH9VNKHQEFAICrHDyTYKYlbz3+qzmvUtFdT/RorHvb1ZSrC/v7lBYCCgAABYxP+eZgrBmf8tP5y+unNAqqqP/r01RdG1WlvuwsoBQqNs6ePVstW7Y0T2odoaGh+uqrr/I9Jjw8XLfeequ8vb3NY2655RZdunQp735cXJwGDRpk7vn7+2vEiBFKSkoq7M8IAEChODk5qUfTIK2deIue79tU/l5uOnI2ScPmb9fQ+dt1OCaRGrUjhQooNWvW1NSpUxUZGamdO3eaINKvXz/t378/L5z06tVLPXv21Pbt27Vjxw498sgjcnb+78tY4cR6/Lp167R69Wpt3rxZo0ePLv6fDACAAljjUIZ3rqtNf/8fMx7FzcVJm4+cU+8Zm/W/K/bqXGIa9WYHnHKsNq8iCAgI0CuvvGJaQjp27KgePXropZdeKvCxBw8eVNOmTU1wadeunbm2Zs0a3X777Tp16pSCg4OLvYkIAIA/8/P5ZE396pDW7I8x5xU9XM0g2hFd6srTzYXKKwtdPFfLysrSkiVLlJycbLp6YmNjFRERocDAQHXq1ElBQUHq2rWrtmzZkvc9VguL1a2TG04sYWFhpoXF+t4/kpaWZn6oqw8AAIpDnSremjOkrZaOCVXLmn5KSsvUK2sP69ZXN+rTyFPKtqb/oNQVOqDs3btXFStWlIeHh8aOHasVK1aYVpETJ06Y+y+88IJGjRplWkbatGmj7t276+jRo+ZeTEyMCTBXc3V1Na0w1r0/MmXKFJO4co+QkJDC/6QAAPwJa6PBlQ931vT7WqmGfwWdjk/Vk8v26I43t+j7Y+epO3sPKI0bN9bu3btNi8e4ceM0bNgwHThwQNnZ2eb+mDFjNHz4cN14442aPn26efz8+fOLVMjJkyeb5qDcIzo6ukjPBwBAQawl8e++sabWP9lVT/VqIh8PVx04k6BB70bowfcZSFuaXAv7De7u7mrQoIH5um3btmY8yYwZM/T000+ba1ZrytVuuOEGRUVFma+rVatmuoKulpmZaWb2WPf+iNVaYx0AAJQGa+yJNQ7lvptCNHP9US3adlIbD58zg2mta4+HNVKgrydvRgkq8uo0VsuJNUakTp06ZpDr4cOH890/cuSIateubb62xqpcvHjRzALKtWHDBvMcHTp0KGpRAAAoVgHe7nrhzmZa90RX9W5ezaxG+/H2aHV7daOmrzui5CvL6MPGLShWV0vv3r1Vq1YtJSYm6qOPPtLGjRu1du1aM7980qRJev7559WqVSu1bt1aCxcu1KFDh7R8+fK81hRrGrI1RmXOnDnKyMgw05Dvv//+a57BAwBAaatbxVuzB7fVzp/j9P++PKgfoi5qxvqjWhwRpQlhDXX/TSFyY0Va200ztqYSr1+/XmfOnDGDVa1F25566ikztTiXtU7KrFmzTLeNFVSmTZumLl265N23rluh5IsvvjCzdwYMGKCZM2eagbfXimnGAABbsf5sfrk3Rq+sPaSfr+yYXK+Kt/7Rq7Fua1bN/Ac7CsZS9wAAlLD0zGx9vD3KjFH5NTndXGtTy1+Tb79BN9UJoP4LQEABAKCUJKZmaN7mE5r33U+6lJFlrllL6j/Vq7EaBPrwPlyFgAIAQCmLTUjV9G+OaunOaGVl58jZSWbGz8SwRgpixo9BQAEAwEaOxSZp2ppD+vrAWXPu6XZ575+xXevLr4KbQ78vCYVY6r7Ie/HYAoNkAQD2zprxM+WrQ4o8ecGcW+Hk4W71NaxTHYfd4yeBgAIAgO1ZbQDfHIw1M36OnE0y16r7eWpiWEMNaFNTrg42NTmBgAIAgP2wxqR8tuuUWdzN2uPH0iCwov7e05qaHOQwU5MTCCgAANif1Iwss2z+W98e08WUDHPtxlr+Zt+fjvUqq7xLIKAAAGC/ElIz9M6mE3pvy3+nJt/SqKr+cVtjNa/hp/KKgAIAQBmZmvzmhmNmwbdMa6MfSX1aVNcTPRupftVrX2G9rCCgAABQhkT9mqLp3xzRyt2/yJpb6+LspHva1DT7/AT7V1B5QUABAKAMOhSToFfXHtE3By+voeLu4qzBHWtr/P/UV+WKHirrCCgAAJRhkScvmKnJ207EmXNvdxeNuLmeRt5cV76eZXexNwIKAADlYA2V746e1ytrD2vvL/F5i71ZK9IO61RbXu6uKmsIKAAAlKOgsmZfjF5bd8Qso2+pUtHDdPsM7FBLHq5lZ1VaAgoAAOVwsbfPd/+iN745qqi4FHMt2M9Tj3VvqAFta8qtDKxKS0ABAKCcysjK1rKdpzRz/VHFJFxelbZOZS+za3LfVsFmBpC9IqAAAOAAq9IujojS298e06/J6eZao6CKejyskW5rVk3OdhhUCCgAADiI5LRMLdj6s+ZuOq6E1ExzrWl1Xz3eo5HCbgi0q31+CCgAADiY+EsZZun8+Vt+UlLa5aDSqqafCSpdG1W1i6BCQAEAwEFdSE7XvO9OmFaVlPTL+/y0rV1JT/RopE71K9s0qBBQAABwcOeT0ky3zwfhJ5WWmW2udagbYIJKBxvtnExAAQAAeRsSvr3xuD6KiFJ61uWg0rlBZTOYtl2dAJUmAgoAAMjnTPwlvbXhmJbujFZG1uWdk7s0qKLHezRU29qlE1QIKAAAoECnLqRo1rfHtWxntDKzLweVmxtWMeuoWGNVShIBBQAA/KnouBS9vfGYWfQtN6jc0qiqJoY1VJtaJRNUCCgAAOCag8qsb49peeR/g0rXK0HlxmIOKoUJKPa/cD8AACgxIQFemjqgpTY82U33tQsxS+VvOnLOLKVvS0451jaJZUxhEhgAALh2Ub+m6K1vj2pQh9pqFeIvW/39di3WVwYAAGVarcpemnZPK1sXgy4eAABgfxiDAgAA7A4BBQAA2B0CCgAAsDsEFAAAYHcIKAAAwO4QUAAAgN0hoAAAALtDQAEAAHaHgAIAAMp2QJk9e7Zatmxp1s+3jtDQUH311Ve/e5y1vU/v3r3l5OSklStX5rsXFRWlPn36yMvLS4GBgZo0aZIyMzOL/pMAAIByo1B78dSsWVNTp05Vw4YNTQhZuHCh+vXrpx9++EHNmjXLe9wbb7xhwslvZWVlmXBSrVo1bd26VWfOnNHQoUPl5uaml19+uXh+IgAAUOYVeTfjgIAAvfLKKxoxYoQ53717t+644w7t3LlT1atX14oVK3TXXXeZe1Zri3Xv9OnTCgoKMtfmzJmjp556SufOnZO7u/s1vSa7GQMAUPaUym7GVmvIsmXLlJycbLp6LCkpKRo4cKBmzZplWkl+Kzw8XC1atMgLJ5bbbrtN48aN0/79+3XjjTcW+FppaWnmyGX9YLk/KAAAKBty/25fS9tIoQPK3r17TSBJTU1VxYoVTQtJ06ZNzb3HH39cnTp1Mt0+BYmJickXTiy559a9PzJlyhT985///N31kJCQwhYfAADYWGJiomlJKdaA0rhxY9ONY7ViLF++XMOGDdOmTZt07NgxbdiwwYxHKW6TJ0/WE088kXeenZ2tuLg4Va5cucCxLkVNd1bwiY6O/svmJ1DfZQ2fb+q7POPzbf/1bbWcWOEkODj4Lx9b6IBijRNp0KCB+bpt27basWOHZsyYoQoVKuj48ePy9/fP9/gBAwbo5ptv1saNG023z/bt2/PdP3v2rPm3oC6hXB4eHua42m9fp7jlzlRC6aC+Sxf1TX2XZ3y+7bu+/6rlpNjWQbFaM6zxIU8//bR+/PFH07qSe1imT5+u999/33xtdQ1ZXUSxsbF5379u3Trzg+V2EwEAALgWtqvFWt+kVq1aponmo48+Mi0ja9euNS0gBbWCWI+tW7eu+bpnz54miAwZMkTTpk0z406eeeYZjR8//nctJAAAwHEVKqBYLR/WuiXW+iVWE421aJsVTnr06HFN3+/i4qLVq1ebWTtWa4q3t7cZw/Liiy/KXlhB6fnnnycwUd/lEp9v6rs84/Ndvuq7yOugAAAAFDf24gEAAHaHgAIAAOwOAQUAANgdAgoAALA7BJSrWHsI1alTR56enurQocPvFpXD9dm8ebP69u1rVg60Vv5duXJlvvvWOO3nnnvObC5pLfgXFhamo0ePUt3Xydoa4qabbpKPj48CAwPNZp2HDx/O9xhrqwprer+1GrO1ZYW1oGLuookonNmzZ5sZjbmLVVkzFK2NUanr0jF16lTze2XixInUeQl44YUXTP1efTRp0qRUfpcQUK745JNPzHL61pSpXbt2qVWrVmYjw6sXlcP1sTaUtOrTCoAFsdbEmTlzptnZOiIiwkw/t+re+uCj8KytJ6xfGNu2bTMLIWZkZJg1iKz3IZe1b9YXX3xhNvy0Hm/tMN6/f3+q+zrUrFnT/JGMjIw0u7jfeuutZj8yawNU6rpkWSuZz5071wTEq/H5Ll7NmjUzy4vkHlu2bCmduramGSMnp3379jnjx4/Pq4qsrKyc4ODgnClTplA9xcj6yK1YsSLvPDs7O6datWo5r7zySt61ixcv5nh4eOR8/PHH1H0xiI2NNfW+adOmvPp1c3PLWbZsWd5jDh48aB4THh5OnReDSpUq5bz77rvUdQlKTEzMadiwYc66detyunbtmjNhwgRznc938Xr++edzWrVqVeC9kq5rWlAkpaenm//6sboWcjk7O5vz8PDw4kmCKNBPP/1kVhS+uu6tRQCtLjbqvnhYG3taAgICzL/WZ91qVbm6zq0mW2vVZ+q8aLKysrRkyRLTWmV19VDXJcdqJezTp0++z7GFOi9+Vpe71UVfr149DRo0SFFRUaVS14XeLLA8On/+vPnFEhQUlO+6dX7o0CGblcsRWOHEUlDd595D0fbKsvrmO3furObNm+fVubXp52833KTOr5+1x5gVSKxuSasffsWKFWZbD2tPMuq6+Fkh0OqKt7p4fovPd/Gy/mNxwYIFaty4sene+ec//2k2AN63b1+J1zUBBSjn/5Vp/SK5us8Yxc/65W2FEau1avny5WYLD6s/HsUvOjpaEyZMMOOrrAkNKFnW/nu5rLE+VmCpXbu2li5daiY1lCS6eCRVqVLF7BP025HH1nlBGyCi+OTWL3Vf/B555BGz99W3335rBnJeXedWt+bFixfzPZ7P+/Wz/iuyQYMGatu2rZlFZQ0KnzFjBnVdAqxuBWvyQps2beTq6moOKwxaA+2tr63/eufzXXKs1pJGjRrp2LFjJf75JqBc+eVi/WJZv359vqZx69xqtkXJsXa6tj7IV9d9QkKCmc1D3V8fayyyFU6sboYNGzbk7Saey/qsu7m55atzaxqy1a9MnRcP6/dHWloadV0CunfvbrrUrBar3KNdu3ZmbETu13y+S05SUpKOHz9uloUo8d8lRR5mW04sWbLEzBxZsGBBzoEDB3JGjx6d4+/vnxMTE2PropWL0fY//PCDOayP3Ouvv26+PnnypLk/depUU9eff/55zo8//pjTr1+/nLp16+ZcunTJ1kUvk8aNG5fj5+eXs3HjxpwzZ87kHSkpKXmPGTt2bE6tWrVyNmzYkLNz586c0NBQc6Dwnn76aTND6qeffjKfX+vcyckp5+uvv6auS8nVs3gsfL6Lz5NPPml+l1if7++//z4nLCwsp0qVKmZ2YEnXNQHlKm+++aapaHd3dzPteNu2bcVSyY7u22+/NcHkt8ewYcPypho/++yzOUFBQSYkdu/ePefw4cO2LnaZVVBdW8f777+f9xgr/D388MNmOqyXl1fO3XffbUIMCu+hhx7KqV27tvm9UbVqVfP5zQ0n1LVtAgqf7+Jz33335VSvXt18vmvUqGHOjx07Vip17WT9T9HbYQAAAIoPY1AAAIDdIaAAAAC7Q0ABAAB2h4ACAADsDgEFAADYHQIKAACwOwQUAABgdwgoAADA7hBQAACA3SGgAAAAu0NAAQAAdoeAAgAAZG/+P4kyvPeV1PsMAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import optax\n", + "from jax import tree\n", + "\n", + "class Ema(nnx.Module):\n", + " def __init__(self, model, decay=0.9):\n", + " self.decay = decay\n", + " self.ema = nnx.clone(model)\n", + " def update(self, model):\n", + " def ema_update(ema, new_val):\n", + " return self.decay * ema + (1 - self.decay) * new_val\n", + " self.ema = tree.map(ema_update, self.ema, model)\n", + "\n", + "model = make_model(rngs)\n", + "ema = Ema(model)\n", + "\n", + "optimizer = nnx.Optimizer(\n", + " model,\n", + " tx=optax.adam(1e-3),\n", + " wrt=nnx.Param)\n", + "\n", + "@nnx.jit\n", + "def train_step(model, optimizer, ema, x, y):\n", + " loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)\n", + " optimizer.update(model, grads)\n", + " ema.update(model)\n", + " return loss\n", + "\n", + "losses = []\n", + "for _ in range(50):\n", + " loss = train_step(model, optimizer, ema, x, y)\n", + " losses.append(loss)\n", + "plt.plot(losses)" + ] + }, + { + "cell_type": "markdown", + "id": "78a1f698", + "metadata": {}, + "source": [ + "# Low Rank Adaptation" + ] + }, + { + "cell_type": "markdown", + "id": "f3d54f47", + "metadata": {}, + "source": [ + "The pattern for adding low rank adaptation to an optimization loop is very similar to adding an exponential moving average. As before, we create a new pytree with the same structure as our model parameters, but here we store low rank additions to these parameters rather than weighted average values. " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ca8d9603", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def add_rank2_lora(path, node):\n", + " if isinstance(node, nnx.Linear):\n", + " return nnx.LoRA(node.in_features, 2, node.out_features, base_module=node, rngs=rngs)\n", + " return node\n", + "\n", + "base_model = make_model(rngs)\n", + "lora_model = nnx.recursive_map(add_rank2_lora, base_model)\n", + "nnx.display(lora_model)" + ] + }, + { + "cell_type": "markdown", + "id": "dd6b1d62-2009-40c3-af3b-6544b476ecbf", + "metadata": {}, + "source": [ + "To indicate that we only want to to update the low rank corrections, we add the `wrt=nnx.LoRAParam` argument to `nnx.Optimizer`. This will filter out all the variables in the gradient that are not `nnx.LoRAParam`s. The other components of the gradient will go unused, so Jax's dead code elimination passes should prevent us from computing them in the first place once the code gets compiled. " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3e98cfba", + "metadata": {}, + "outputs": [], + "source": [ + "@nnx.jit\n", + "def train_step(model, optimizer, x, y):\n", + " loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)\n", + " optimizer.update(model, grads)\n", + " return loss\n", + "\n", + "optimizer = nnx.Optimizer(\n", + " lora_model,\n", + " tx=optax.adam(1e-3),\n", + " wrt=nnx.LoRAParam,\n", + ")\n", + "\n", + "losses = []\n", + "for _ in range(50):\n", + " loss = train_step(lora_model, optimizer, x, y)\n", + " losses.append(loss)" + ] + }, + { + "cell_type": "markdown", + "id": "752983e3", + "metadata": {}, + "source": [ + "# LBFGS" + ] + }, + { + "cell_type": "markdown", + "id": "b753ec03", + "metadata": {}, + "source": [ + "So far, we've been using optax optimizers with the interface ``optimizer.update(grads, opt_state)``. This works for simple optimization algorithms like ADAM, but for algorithms that use a line search like LBFGS, we need to pass more parameters. Below, we can see how the call to ``optimizer.update`` is given additional parameters when using LBFGS." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a177b31a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAicAAAGdCAYAAADJ6dNTAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjcsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvTLEjVAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAL2RJREFUeJzt3QmYVOWd7/F/VXVX7zv0RrPJarMOqNiSGEYQouigknmcR0SMRgbEuQheHsNziQl4I1zyjEZz3W6SESbqkOCVJJAhQmRxFFBkpxGu3SDd0Btb73RXd1Xd5327q+yGBnqpqnOq6vuZOTmn6pyqOp6up/vH+/7f91jcbrdbAAAATMJq9AkAAAC0RTgBAACmQjgBAACmQjgBAACmQjgBAACmQjgBAACmQjgBAACmQjgBAACmEiFByOVySUlJiSQkJIjFYjH6dAAAQCeoeV9ramokOztbrFZraIUTFUz69u1r9GkAAIBuKC4ulpycnNAKJ6rFxPMfl5iYaPTpAACATqiurtaNC56/4yEVTjxdOSqYEE4AAAguNyrJoCAWAACYCuEEAACYCuEEAACYCuEEAACYCuEEAACYCuEEAACYCuEEAACYCuEEAACYCuEEAACYCuEEAACYCuEEAACYCuEEAACYSlDe+M9fth0vl50nzsnEwb1k6ohMo08HAICwRMtJG1+cuiRrd5+WXYUXjPuJAAAQ5ggnbfRJidHrM5cuG/XzAAAg7BFO2shpDSdnKwknAAAYhXDSRl9vy0m9UT8PAADCHuGkjezklnBS09AsVZebwv7LAQCAEQgnbcTaIyQtzq63z1J3AgCAIQgn16g7oWsHAABjEE6uwIgdAACMRTi5Qk5KrF4zYgcAAGMQTq5Atw4AAMYinFyhT+uIHSZiAwAgCMPJqlWrxGKxyLPPPut9btKkSfq5tsu8efPava6oqEimT58usbGxkp6eLkuWLJHm5mYxA7p1AAAI0hv/7d27V95++20ZPXr0VfueeuopWbFihfexCiEeTqdTB5PMzEzZtWuXlJaWymOPPSaRkZHy0ksviVkKYivrm6SmoUkSoiONPiUAAMJKt1pOamtrZdasWfLrX/9aUlJSrtqvwogKH54lMTHRu2/Lli1y7Ngxeffdd2Xs2LFyzz33yIsvviivv/66OBwOMVp8VIQkx7YEEopiAQAIknCyYMEC3foxZcqUDve/99570qtXLxk5cqQsXbpU6uu/nQ5+9+7dMmrUKMnIyPA+N23aNKmurpb8/PwO36+xsVHvb7sE5B47TMQGAID5u3XWrVsn+/fv1906HXnkkUekf//+kp2dLYcPH5bnn39eTpw4IR9++KHeX1ZW1i6YKJ7Hal9HVq5cKcuXL5dAyUmOlaNnqymKBQDA7OGkuLhYFi5cKFu3bpXo6OgOj5k7d653W7WQZGVlyeTJk6WwsFAGDRrUrZNUrS+LFy/2PlYtJ3379hX/T8TGDQABADB1t86+ffukoqJCxo0bJxEREXrZuXOnvPbaa3pbFbteacKECXpdUFCg16oGpby8vN0xnsdqX0eioqJ03UrbJSDdOpWX/fo5AACgh+FEtYAcOXJEDh486F1uueUWXRyrtm0221WvUc8rqgVFycvL0++hQo6HaolRgSM3N1fMNJyYuU4AADB5t05CQoIucm0rLi5O0tLS9POq6+b999+Xe++9Vz+nak4WLVokd955p3fI8dSpU3UImT17tqxevVrXmSxbtkwX2aoWEjNgIjYAAEJkhli73S5/+9vfdAAZPny4PPfcczJz5kzZuHGj9xjVurJp0ya9Vq0ojz76qJ7npO28KEbz1JxcrHNIvcMck8MBABAuLG632y1BRhXEJiUlSVVVld/qT0b/7COpbmiWrYvulCEZCX75DAAAwkl1J/9+c2+da+hD3QkAAIYgnNzo7sSM2AEAIKAIJzcKJ8x1AgBAQBFOroEROwAAGINwcoO5Tri/DgAAgUU4uWG3DrPEAgAQSISTG4ST87WN0tB09bT8AADAPwgn15AUEynxUS0T6HKPHQAAAodwcg0Wi4WuHQAADEA46dSInfpA/TwAAAh7hJNO1J0wYgcAgMAhnHRiODEjdgAACBzCSSfuTky3DgAAgUM46Uy3DvfXAQAgYAgnnejWKa9ulMZm5joBACAQCCfXkRIbKTGRNr1dUtkQkB8IAADhjnDSyblOGLEDAEBgEE46fY8d5joBACAQCCedHrHDDQABAAgEwkkni2IZsQMAQGAQTm6Abh0AAAKLcNLp++vQrQMAQCAQTjo910mDOJpdgfiZAAAQ1ggnN9Ar3i5REVZxuUXKqpjrBAAAfyOcdGKuE+6xAwBA4BBOunJ3Yu6xAwCA3xFOOoGiWAAAAodw0gkMJwYAIHAIJ53A/XUAAAgcwkmXWk6Y6wQAAH8jnHShILasukGancx1AgCAPxFOOqF3fJTYbVZxutw6oAAAAP8hnHTmIlktkp0crbfp2gEAwL8IJ12d64S6EwAA/Ipw0kmM2AEAIDAIJ12eiK3enz8PAADCHuGkk3JSGU4MAEAgEE66WHNylvvrAADgV4STLnbrlFRe1kOKAQCAfxBOOikjMVoirBZpdrmlnLlOAADwG8JJJ9n0XCctrSd07QAA4D+Eky5gxA4AAP5HOOnODQAvcgNAAAD8hXDSBYzYAQDA/wgnXdDH03LCFPYAAPgN4aQ73TrMEgsAgN8QTroRTkoqG8TFXCcAAPgF4aQLMhOj9ZBih9Ml52ob/fMTAQAgzBFOuiDCZtUBRaFrBwAA/yCcdLvuhOHEAAD4A+GkixixAwCAfxFOujnXCS0nAAD4B+GkixhODACAfxFOumhQ7zi9PlFW44+fBwAAYY9w0kU3ZyWKxSJSUdMoFTUNYf8FAgDA1wgnXRRrj5BBveP1dn5Jtc9/IAAAhDvCSTeMyE7U6/yzVb7+eQAAEPYIJ90wMjtJr2k5AQDA9wgn3TCiT0vLydESWk4AAPA1wkk3jMhqaTkpvnhZquqbfP0zAQAgrBFOuiEpNlL6prZMY59fSusJAAC+RDjpYetJ/llG7AAA4EuEk24aSd0JAAB+QTjpphF9GLEDAIDpwsmqVavEYrHIs88+632uoaFBFixYIGlpaRIfHy8zZ86U8vLydq8rKiqS6dOnS2xsrKSnp8uSJUukublZgnGuk8JztVLvCK5zBwAgJMPJ3r175e2335bRo0e3e37RokWyceNGWb9+vezcuVNKSkrkoYce8u53Op06mDgcDtm1a5esXbtW1qxZIy+88IIEk/SEaElPiBK3W+SrUupOAAAwNJzU1tbKrFmz5Ne//rWkpKR4n6+qqpLf/va38vLLL8tdd90l48ePl3feeUeHkD179uhjtmzZIseOHZN3331Xxo4dK/fcc4+8+OKL8vrrr+vAEkxG0rUDAIA5wonqtlGtH1OmTGn3/L59+6Spqand88OHD5d+/frJ7t279WO1HjVqlGRkZHiPmTZtmlRXV0t+fn6Hn9fY2Kj3t13M1LVzlGnsAQDwmYiuvmDdunWyf/9+3a1zpbKyMrHb7ZKcnNzueRVE1D7PMW2DiWe/Z19HVq5cKcuXLxezGdE6jf1RhhMDAGBMy0lxcbEsXLhQ3nvvPYmOjpZAWbp0qe4y8izqPMw0nPjrihppbHYafToAAIRfOFHdNhUVFTJu3DiJiIjQiyp6fe211/S2agFRdSOVlZXtXqdG62RmZupttb5y9I7nseeYK0VFRUliYmK7xQz6JMdIUkykNDnd8nV5rdGnAwBA+IWTyZMny5EjR+TgwYPe5ZZbbtHFsZ7tyMhI+fjjj72vOXHihB46nJeXpx+rtXoPFXI8tm7dqgNHbm6uBBM1jNo7GRt1JwAABL7mJCEhQUaOHNnuubi4OD2nief5J598UhYvXiypqak6cPzLv/yLDiS333673j916lQdQmbPni2rV6/WdSbLli3TRbaqhSTYjMxOks8KLkh+iTmKdAEACLuC2Bt55ZVXxGq16snX1CgbNRLnjTfe8O632WyyadMmmT9/vg4tKtzMmTNHVqxYIcEo1zNip4QbAAIA4AsWt1tNIxZc1FDipKQkXRxrdP2JmiF28r/ulOhIq+Qv/77YrBZDzwcAgGD/+829dXpoYFqcxNlt0tDkkpPnKIoFAKCnCCc9vYBWi9ycRdcOAAC+Qjjx4TT2TMYGAEDPEU58OI19PkWxAAD0GOHEh9PYq+HEQVhfDACAqRBOfGBIRrzYbVapaWiW4ouXffGWAACELcKJD0TarDI8K0FvM98JAAA9Qzjxcd0J09gDANAzhBM/1J0AAIDuI5z4fDhxFUWxAAD0AOHER4ZnJuip6y/UOaS8utFXbwsAQNghnPhIdKRNBveO19vMdwIAQPcRTvxSFEvdCQAA3UU48aERnroTZooFAKDbCCc+NLK15eQYI3YAAOg2wokP5baGk7OVl+VincOXbw0AQNggnPhQQnSkDEiL1dsUxQIA0D2EEz/VnTAZGwAA3UM48TGmsQcAoGcIJz42kmnsAQDoEcKJn1pOTp2vk5qGJl+/PQAAIY9w4mNp8VGSlRStt78qrfH12wMAEPIIJ368Q7G6CSAAAOgawokfjOzT0rXDiB0AALqOcOLHlhPmOgEAoOsIJ35sOfm6olYampz++AgAAEIW4cQPMhOjpVd8lDhdbjlWyh2KAQDoCsKJH1gsFhmd09K1c+QMRbEAAHQF4cRPRrVOY3/oTKW/PgIAgJBEOPETWk4AAOgewomfW04KztVKXWOzvz4GAICQQzjxk/TEaF0Y63Yz3wkAAF1BOAlA185h6k4AAOg0wkkg6k6Yxh4AgE4jnPjRqJxkvWY4MQAAnUc4CUBR7MnzdVLd0OTPjwIAIGQQTvwoNc4uOSkxeps7FAMA0DmEEz9jvhMAALqGcOJno/q01J0cpigWAIBOIZz4GS0nAAB0DeHEz0ZmtxTFFl2sl0t1Dn9/HAAAQY9w4mdJsZEyIC1WbzPfCQAAN0Y4CeR8J9SdAABwQ4STABjdOt8J09gDAHBjhJMAGOWZxv5MVSA+DgCAoEY4CYCRfZLEYhEpqWqQczWNgfhIAACCFuEkAOKjImRQ73i9zUyxAABcH+Ek4HUndO0AAHA9hJNA152crQzURwIAEJQIJwGeKZaWEwAAro9wEiC5WUlitYhU1DRKeXVDoD4WAICgQzgJkBi7TYZmJOhtWk8AALg2wkkAjWIyNgAAbohwEkDUnQAAcGOEE4PuseN2uwP50QAABA3CSQANz0yQCKtFLtY55Gzl5UB+NAAAQYNwEkDRkTYZltlSFMt9dgAA6BjhxKi6k7PMFAsAQEcIJwE22lN3wjT2AAB0iHBi4HBiimIBALga4STA1ERs9girVDc0S9HF+kB/PAAApkc4CTAVTG7OStTbzBQLAMDVCCcGGN3ataPmOwEAAO0RTgwwynuH4kojPh4AgNAJJ2+++aaMHj1aEhMT9ZKXlyebN2/27p80aZJYLJZ2y7x589q9R1FRkUyfPl1iY2MlPT1dlixZIs3NzRKOw4mPnq0Wl4uZYgEAaCtCuiAnJ0dWrVolQ4YM0SNN1q5dKzNmzJADBw7IiBEj9DFPPfWUrFixwvsaFUI8nE6nDiaZmZmya9cuKS0tlccee0wiIyPlpZdeknAxuHe8REdapbaxWU6er5PB6fFGnxIAAMHZcnL//ffLvffeq8PJ0KFD5ec//7nEx8fLnj172oURFT48i2ph8diyZYscO3ZM3n33XRk7dqzcc8898uKLL8rrr78uDodDwkWEzSojsj11J3TtAADgk5oT1Qqybt06qaur0907Hu+995706tVLRo4cKUuXLpX6+m+Hy+7evVtGjRolGRkZ3uemTZsm1dXVkp+ff83Pamxs1Me0XUJnvhOKYgEA6Ha3jnLkyBEdRhoaGnSryYYNGyQ3N1fve+SRR6R///6SnZ0thw8flueff15OnDghH374od5fVlbWLpgonsdq37WsXLlSli9fLqFYd8JMsQAA9DCcDBs2TA4ePChVVVXywQcfyJw5c2Tnzp06oMydO9d7nGohycrKksmTJ0thYaEMGjRIuku1wCxevNj7WLWc9O3bV0IhnOSXVEuz06W7egAAQDe6dex2uwwePFjGjx+vWzTGjBkjr776aofHTpgwQa8LCgr0WtWglJeXtzvG81jtu5aoqCjvCCHPEuwG9oqXOLtNLjc55euKWqNPBwAA0+jxP9ddLpeuCemIamFRVAuKorqDVLdQRUWF95itW7fqsOHpGgoXNqtFxvRtuQnggSKKYgEA6FY4Ud0rn3zyiXzzzTc6ZKjHO3bskFmzZumuGzXyZt++fXr/n//8Zz1M+M4779RzoyhTp07VIWT27Nly6NAh+eijj2TZsmWyYMEC3ToSbsb1S9HrfacvGX0qAAAEZ82JavFQgUPNT5KUlKRDhwoYd999txQXF8vf/vY3+eUvf6lH8KiakJkzZ+rw4WGz2WTTpk0yf/583YoSFxena1bazosSTsb3bwknB4oIJwAAeFjcaja1IKMKYlU4UkW5wVx/UlnvkLErturt/T+5W1Lj7EafEgAAhv/9ZoiIgZJj7XJT7zi9TesJAAAtCCcGG99ad7Kfrh0AADTCicHGtdad7D/NiB0AABTCiUlG7BwsrtSTsQEAEO4IJwYbkh4vCVERejK242U1Rp8OAACGI5wY/QOwWmRsP89kbAwpBgCAcGICTMYGAMC3CCdmKoplGnsAAAgnZjC2b7JYLCJFF+vlXE3H9ykCACBc0HJiAkkxkbowVmG+EwBAuCOcmKzuhHACAAh3hBOT1Z0cYDI2AECYI5yYrOXk0JlKcTQzGRsAIHwRTkzipl5xuvaksdklX5VWG306AAAYhnBiosnYxrVOxkbdCQAgnBFOTITJ2AAAIJyYsyiWydgAAGGMlhMTGdM3WawWkbOVl6WsqsHo0wEAwBCEExOJj4qQYZmJepu6EwBAuCKcmIy3KPY0dygGAIQnwonJjPfeBJBwAgAIT4QTk47YOXq2WhqbnUafDgAAAUc4MZn+abGSGmcXh9OlAwoAAOGGcGIyFouajM0zpJiuHQBA+CGcmNC4/i1FsfsoigUAhCHCiQl5Wk5UUazb7Tb6dAAACCjCiQmNyUkWm9Ui5dWNUsJkbACAMEM4MaEYu01ys1omY6NrBwAQbggnJsVkbACAcEU4Mf1NABmxAwAIL4QTkxfF5pdUS0MTk7EBAMIH4cSkclJipHdClDS73HL4TJXRpwMAQMAQTkw8Gdv4NkOKAQAIF4QTE2MyNgBAOCKcmFjbaeyZjA0AEC4IJyY2sk+SRNoscr7WIcUXLxt9OgAABAThxMSiI20yIjtJb+/95qLRpwMAQEAQTkxuwk2per2r8ILRpwIAQEAQTkxu4qBeer278Dx1JwCAsEA4MblbB6TquhN1A8BvLtQbfToAAPgd4SQIbgL4d62jdj4rOG/06QAA4HeEkyDq2tlVSDgBAIQ+wkkQmDg4Ta93F14Ql8tt9OkAAOBXhJMgMKZvssTZbXKpvkmOlVYbfToAAPgV4SQIRNqscttAz5BiunYAAKGNcBIkJg721J0w3wkAILQRToLEHa1FsV+cuiiOZpfRpwMAgN8QToLE8MwESY2zS73DKYfOVBp9OgAA+A3hJEhYrRbJG9Qyaof5TgAAoYxwEozznRRQdwIACF2EkyByR2vLyYHiS1LvaDb6dAAA8AvCSRDpnxYrfZJjpMnp1oWxAACEIsJJELFYLN7WE4YUAwBCFeEkaOc7YTI2AEBoIpwEGU/LSX5JtVTWO4w+HQAAfI5wEmTSE6NlSHq8uN0tNwIEACDUEE6CuGvnM7p2AAAhiHAShLxFscx3AgAIQYSTIDThpjSxWkROnq+T0qrLRp8OAAA+RTgJQkkxkTIqJ1lvf0brCQAgxBBOgtS3850wpBgAEFoIJyFwnx23GroDAECIIJwEqVsGpIg9wipl1Q269gQAgFBBOAlS0ZE2Gd8vRW/vKqBrBwAQpuHkzTfflNGjR0tiYqJe8vLyZPPmzd79DQ0NsmDBAklLS5P4+HiZOXOmlJeXt3uPoqIimT59usTGxkp6erosWbJEmpu5w253TBzcUndCUSwAIGzDSU5OjqxatUr27dsnX375pdx1110yY8YMyc/P1/sXLVokGzdulPXr18vOnTulpKREHnroIe/rnU6nDiYOh0N27dola9eulTVr1sgLL7zg+/+yMHBH62Rsu09eEKeLuhMAQGiwuHtYTZmamiq/+MUv5Ac/+IH07t1b3n//fb2tHD9+XG6++WbZvXu33H777bqV5b777tOhJSMjQx/z1ltvyfPPPy/nzp0Tu93eqc+srq6WpKQkqaqq0i044arZ6ZK/W7FVahqbZeMz35FROUlGnxIAAD3++93tmhPVCrJu3Tqpq6vT3TuqNaWpqUmmTJniPWb48OHSr18/HU4UtR41apQ3mCjTpk3TJ+tpfelIY2OjPqbtApEIm1Um3JSqLwVDigEAoaLL4eTIkSO6niQqKkrmzZsnGzZskNzcXCkrK9MtH8nJLZODeaggovYpat02mHj2e/Zdy8qVK3XS8ix9+/bt6mmHrDtahxR/xk0AAQDhGk6GDRsmBw8elM8//1zmz58vc+bMkWPHjok/LV26VDcBeZbi4mK/fl4w3gRw76mL4mh2GX06AAD0WERXX6BaRwYPHqy3x48fL3v37pVXX31VHn74YV3oWllZ2a71RI3WyczM1Ntq/cUXX7R7P89oHs8xHVGtNGrB1YZmxEuveLucr3XIgaJL+r47AACE9TwnLpdL14SooBIZGSkff/yxd9+JEyf00GFVk6KoteoWqqio8B6zdetWXRSjuobQdRaLRfJau3Y+Zb4TAEC4hRPVvfLJJ5/IN998o0OGerxjxw6ZNWuWrgV58sknZfHixbJ9+3ZdIPvDH/5QBxI1UkeZOnWqDiGzZ8+WQ4cOyUcffSTLli3Tc6PQMtJ9k4b21uu/Hr123Q4AACHZraNaPB577DEpLS3VYURNyKYCxt133633v/LKK2K1WvXka6o1RY3EeeONN7yvt9lssmnTJl2rokJLXFycrllZsWKF7//LwsiU3Ayx26zydUWt/L/yGhmakWD0KQEAYNw8J0ZgnpOrPblmr3x8vEIWTh4ii+4easBPBQAAg+c5gblMH52l1/95pNToUwEAoEcIJyHatQMAQLAinISIxOhIuXNoy6idvxym9QQAELwIJyHYtfOXI6UShKVEAABohJMQMvnmlq6dAt21U2v06QAA0C2Ek5Dr2untbT0BACAYEU5CzPTRLbcB+MvhErp2AABBiXASYqaorp0IqxSeq6NrBwAQlAgnISZBde0Mae3aOVxi9OkAANBlhJMQdF/rqJ1NjNoBAAQhwkkImnxzuu7aOXmuTk4wIRsAIMgQTkK0a+d7nlE7TMgGAAgyhJMQ79phQjYAQLAhnITyhGytXTvHy7jXDgAgeBBOQlR8VIRMau3a4U7FAIBgQjgJh3vtHOZeOwCA4EE4CYeunfN18lUpXTsAgOBAOAlhdO0AAIIR4SRcunaYkA0AECQIJ2HStXOKrh0AQJAgnIRB187fD2udkO0I99oBAJgf4SQMTB+drdf/eaRM3G630acDAMB1EU7CwOTh6RLV2rVzrLTa6NMBAOC6CCdhIE537aTrbe61AwAwO8JJmI3a+WDfGWlsdhp9OgAAXBPhJExMHZEhmYnRUlHTKH88cNbo0wEA4JoIJ2EiKsImP/ruQL399s6T4nRRGAsAMCfCSRj5p9v6SVJMpJ7Ofkt+mdGnAwBAhwgnYTbnyZy8/nr7zZ2FDCsGAJgS4STMzLljgERHWuXwmSrZVXjB6NMBAOAqhJMwkxYfJf90az+9/eaOQqNPBwCAqxBOwpAqjLVZLfJpwXk5fKbS6NMBAKAdwkkYykmJlRljWqa0f2snrScAAHMhnISpf/7eIL3efLRMTp6rNfp0AADwIpyEqWGZCTLl5nRR9wH8P5+cNPp0AADwIpyEsfmTWlpP/u/+M1JW1WD06QAAoBFOwtj4/qly24BUaXK65d8+O2X06QAAoBFOwty8STfp9Xt7TktVfZPRpwMAAOEk3P39sHQZlpEgdQ6n/G7PN0afDgAAhJNwZ7FYvLUn73z2jTQ0OY0+JQBAmKNbB3Lf6CzJSYmRC3UOWf9lMVcEAGAowgkkwmaVuXe21J68/clJaXa6uCoAAMMQTqD94/i+khZnlzOXLstfjpRyVQAAhiGcQIux2+SHEwfo7de3F0gTrScAAIMQTuA1+/YBkhQTKf+vvFZ+ta2AKwMAMAThBF5JsZHyPx8Y6W09OVjMHYsBAIFHOEE794/Jln8Yky1Ol1sW//6gXHYwtBgAEFiEE1xlxYwRkpEYJSfP18n/+utxrhAAIKAIJ7hKcqxdVv9gjN5es+sb+fTr81wlAEDAEE7Qoe8N7S2zb++vt5d8cEiqLnPfHQBAYBBOcE1L7x0uA9JipbSqQX7253yuFAAgIAgnuKZYe4S8/PBYsVpENhw4K//J5GwAgAAgnOC6xvVLkacnDdbb/2PDEamobuCKAQD8inCCG/pvk4dIblaiXKpvkh9/eETcbjdXDQDgN4QT3JA9wiqvPDxW7DarbDteIev2cudiAID/EE7QKcMyE2TJtGF6+8VNx6ToQj1XDgDgF4QTdNoT3xkotw1MlXqHUxb/4aA4ml1cPQCAzxFO0Gk2q0X+9R/HSJzdJl+eviQL3t9PQAEA+BzhBF3SNzVW3nh0vK5D2XqsXOa/u08am7n/DgDAdwgn6Nbssb+dc4tERVjl4+MV8s+/2ycNTQQUAIBvEE7QLd8d0lv+7fFbJTrSKjtOnJOn/v1LAgoAwCcIJ+i2iYN7yTuP3yYxkTb5r6/Py4/WfimXHbSgAAB6hnCCHskblCZrfnirxNpt8mnBeXlizV6pdzRzVQEA3UY4QY9NuClN/v2J2/Qont0nL8jj7+yVukYCCgAgAOFk5cqVcuutt0pCQoKkp6fLAw88ICdOnGh3zKRJk8RisbRb5s2b1+6YoqIimT59usTGxur3WbJkiTQ388csmN0yIFX+/ckJEh8VIV+cuiiPv/OF1BJQAAD+Dic7d+6UBQsWyJ49e2Tr1q3S1NQkU6dOlbq6unbHPfXUU1JaWupdVq9e7d3ndDp1MHE4HLJr1y5Zu3atrFmzRl544YXunD9MZHz/FPndk7dJQnSE7P3mkjz228/lYp3D6NMCAAQZi7sHd3E7d+6cbvlQoeXOO+/0tpyMHTtWfvnLX3b4ms2bN8t9990nJSUlkpGRoZ9766235Pnnn9fvZ7fbb/i51dXVkpSUJFVVVZKYmNjd04efHD5TKY/+5nOpbmiWXvF2WfnQaLk7t+VnDQAIX9Wd/Pvdo5oT9eZKampqu+ffe+896dWrl4wcOVKWLl0q9fXf3odl9+7dMmrUKG8wUaZNm6ZPOD8/v8PPaWxs1PvbLjCv0TnJ8od5eTI0I17O1zr0MOP/vv6QVDc0GX1qAIAg0O1w4nK55Nlnn5WJEyfqEOLxyCOPyLvvvivbt2/XweR3v/udPProo979ZWVl7YKJ4nms9l2r1kUlLc/St2/f7p42AmR4ZqL8+ZnvyD/feZNYLCIf7Dsj33/lE9lVcJ6fAQDguiKkm1TtydGjR+XTTz9t9/zcuXO926qFJCsrSyZPniyFhYUyaNCgbn2WCjmLFy/2PlYtJwQU84uOtMnSe2+WKbkZ8twfDknRxXp55Defy+N3DJDnvz9cYuw2o08RABAqLSfPPPOMbNq0SbeO5OTkXPfYCRMm6HVBQYFeZ2ZmSnl5ebtjPI/Vvo5ERUXpvqm2C4LHrQNSZfPC78qsCf304zW7vpF7X/sv2V90yehTAwAEezhRtbMqmGzYsEG2bdsmAwcOvOFrDh48qNeqBUXJy8uTI0eOSEVFhfcYNfJHBY7c3Nyu/xcgKMRFRcjPHxwla5+4TTITo+XU+Tr5wZu7ZPVfjzPtPQCg+6N1nn76aXn//fflT3/6kwwbNsz7vKoDiYmJ0V03av+9994raWlpcvjwYVm0aJFuXVEjejxDidVonuzsbD3EWNWZzJ49W370ox/JSy+91KnzYLROcKuqb5KfbcyXDQfO6sfpCVEy986b5JEJ/STW3u2eRgCAyXX273eXwomaUK0j77zzjjz++ONSXFysi19VLYqa+0TVhTz44IOybNmydidx+vRpmT9/vuzYsUPi4uJkzpw5smrVKomI6NwfJsJJaPjr0VJZvvGYlFY16MepcXZ58jsD5bG8/pIQHWn06QEAgiGcmAXhJHQ4ml3y4f4z8saOQl0wqyRGR8jjEwfKD+8YIClxN573BgAQHAgnCCrNTpdsOlwq/3t7gRRU1Orn1L16Hr29vzz53YGSnhBt9CkCAHqIcIKg5HK55aP8MvnVtgI5Vtoy2V5UhFXuH5Mt/zAmW+4YlCYRNu5XCQDBiHCCoKZ6G7efqNAh5UBRpfd5NR3+9FFZ8g9js2Vcv5Rr1kEBAMyHcIKQCSn7Tl+SPx48K385XCqX6r+dAr9Pcoy3ReXmrASCCgCYHOEEIafJ6ZJPC87LxoMluuunzuH07huSHi/fH5kpeYPSdIuKmp0WAGAuhBOEtMsOp2w7XiF/PnRWth8/Jw6ny7vPHmGVcf2S5Y5BvXRYGZOTrJ8DABiLcIKwUXW5Sbbkl8l/fX1edp+8IOdqGtvtj4m0yS0DUuT2m9JkwsBUuTkrUc9YCwAILMIJwrZGpfBcnQ4pewovyJ6TF+RCnaPdMaqGdmBanNycnSi5WYmSm50oI7ISpXdCFHUrAOBHhBOgdWjy1xW1srvwvOwqvCCHzlRKeXX7lpW2I4FUq4paBqTFSf+0WL1kJcWIzcqoIADoKcIJcA3naxvlq9JqOVZSredSUevCc7XiusZcyXabVXJSY6R/qgor34aWzMQYyUiM0tPuM6QZAG6McAJ0QUOTU06U1eiwotanL9TJ6Yv1UnyxXpqc17/DgwovqktIBZWMxOg2S5SkxUdJSmykpMTadYiJtdsIMgDCVnUn761DVSAgoocej+mbrJe2nC63lFZdltMX6luWi3Vy+ny9vg9QeXWDrmdRI4XOVl7Wy42oIJMS1xJW9BIXKUkxkfpGh/FREZIQHdG6jmyz3bKOsdv0XZvpYgIQ6ggnwHWoIJCTEquXiYM7vnHhudpGHVTKqxpa1jWtj6sb5GJdk1yqc8jFeoc+VgUZVfNyrbqXzlDDotUIJNUK0xJYbPpxjD1CT/XfstgkKtIq0a1rz3PqtXqxWSTSZvUu9oj2jyOsFoloPUZdg0irVWzqsdWiH0e0HmPzLBaLWKnLAeAjhBOgB9QfejVTrVpuNIrocpNTz3CrwsqleodcVOs6h1Q3NEttY7PUNKilybtd2/q4prFZ6hqbvTUxOuQ0u/QQarNRgUWFlIg2gUWFF5VbrOqxpeWxGjHV8nz7fer5tmu1z9LmGPW8Rf2ffr6j59RWy76W9beP1TMt79dyrm3fx/O49f9bHre+tmW74+c9G55n2t5Noe1r2x7T7rkr36iNK+/M0FFJdkd3b2j7OdfSmbs+dKYEnForc7L4qH5/8vAM+c6QXmIEwgkQAOqXuOqSUcuNgsy1wk1js0tPPlff5JTLjma57HBJvaO59bFT6h1OHVoam5362MamNttq3eSShmaXvgO0mm1XPa/WqqZGrdVr9drpEqd6zuXW3Vrq+Ga1bt2+VuGwoo5RB7QfvA0gGPVOiCKcALh+uFF1MWpJMcHw7ObW4OJ0t67bLm53u2NUsGp5TuUWt17U8y05xnNMSwDzPKcW/Zy0vE693nOMykZqWx/TGty8j/Vrvj1O/Y/a8jyvjtFPtz6WNu+nXtP6Ev342+32aezbfd++19Wva7/vSm3fs+0xbQ9v//z1i7I7es2V73fNgzo6pBufFY46+3MJVuP6GffbhpYTAF2iumrs1JcA8CNuOAIAAEyFcAIAAEyFcAIAAEyFcAIAAEyFcAIAAEyFcAIAAEyFcAIAAEyFcAIAAEyFcAIAAEyFcAIAAEyFcAIAAEyFcAIAAEyFcAIAAEwlKO9K7LndeHV1tdGnAgAAOsnzd9vzdzykwklNTY1e9+3b1+hTAQAA3fg7npSUdM39FveN4osJuVwuKSkpkYSEBLFYLD5PdSr0FBcXS2Jiok/fG1xvo/H95nqHMr7f5r/eKnKoYJKdnS1WqzW0Wk7Uf1BOTo5fP0NdaMJJ4HC9A4vrzfUOZXy/zX29r9di4kFBLAAAMBXCCQAAMBXCyRWioqLkpz/9qV7D/7jegcX15nqHMr7foXO9g7IgFgAAhC5aTgAAgKkQTgAAgKkQTgAAgKkQTgAAgKkQTtp4/fXXZcCAARIdHS0TJkyQL774wrifTAj55JNP5P7779czAqoZff/4xz+2269qsl944QXJysqSmJgYmTJlinz99deGnW+wW7lypdx66616BuX09HR54IEH5MSJE+2OaWhokAULFkhaWprEx8fLzJkzpby83LBzDmZvvvmmjB492jsRVV5enmzevNm7n2vtX6tWrdK/V5599lmuuR/87Gc/09e37TJ8+HC/f78JJ61+//vfy+LFi/WwqP3798uYMWNk2rRpUlFR0eOLHO7q6ur09VThryOrV6+W1157Td566y35/PPPJS4uTl979aVH1+3cuVP/stizZ49s3bpVmpqaZOrUqfrn4LFo0SLZuHGjrF+/Xh+vbgfx0EMPcbm7Qc1Wrf5A7tu3T7788ku56667ZMaMGZKfn8+19rO9e/fK22+/rcNhW3y/fWvEiBFSWlrqXT799FP/X2s1lBhu92233eZesGCB91I4nU53dna2e+XKlVweH1JfuQ0bNngfu1wud2ZmpvsXv/iF97nKykp3VFSU+z/+4z+49j5QUVGhr/vOnTu91zcyMtK9fv167zFfffWVPmb37t1ccx9ISUlx/+Y3v+Fa+1FNTY17yJAh7q1bt7q/973vuRcuXKif5/vtWz/96U/dY8aM6XCfP681LSci4nA49L96VHdC2/v3qMe7d+/ueQLENZ06dUrKysraXXt13wXVrca1942qqiq9Tk1N1Wv1XVetKW2vuWqm7devH9e8h5xOp6xbt063UqnuHa61/6jWwenTp7f7Hitcc99T3eyqW/6mm26SWbNmSVFRkd+vdVDe+M/Xzp8/r3+pZGRktHtePT5+/Lhh5xUOVDBROrr2nn3o2R28VV/8xIkTZeTIkd5rbrfbJTk5mWvuI0eOHNFhRHVFqn73DRs2SG5urhw8eJBr7QcqAKrud9WtcyW+376l/qG4Zs0aGTZsmO7SWb58uXz3u9+Vo0eP+vVaE06AEP/Xpfol0raPGL6nfnGrIKJaqT744AOZM2eO7n+H7xUXF8vChQt1PZUavAD/uueee7zbqrZHhZX+/fvLH/7wBz2AwV/o1hGRXr16ic1mu6rCWD3OzMz028WHeK8v1973nnnmGdm0aZNs375dF222veaqK7OysrLd8Xzfu0/963Hw4MEyfvx4PVpKFYC/+uqrXGs/UF0JaqDCuHHjJCIiQi8qCKqierWt/tXO99t/VCvJ0KFDpaCgwK/fb8JJ6y8W9Uvl448/btccrh6rplr4z8CBA/WXuO21r66u1qN2uPbdo+qOVTBRXQvbtm3T17gt9V2PjIxsd83VUGPVj8w19w31+6OxsZFr7QeTJ0/W3Wiqpcqz3HLLLboWwrPN99t/amtrpbCwUE/94NffJT0qpw0h69at0yNE1qxZ4z527Jh77ty57uTkZHdZWZnRpxYSVfUHDhzQi/rKvfzyy3r79OnTev+qVav0tf7Tn/7kPnz4sHvGjBnugQMHui9fvmz0qQel+fPnu5OSktw7duxwl5aWepf6+nrvMfPmzXP369fPvW3bNveXX37pzsvL0wu67sc//rEeCXXq1Cn9/VWPLRaLe8uWLVzrAGk7Wkfh++07zz33nP5dor7fn332mXvKlCnuXr166VGA/rzWhJM2fvWrX+mLbLfb9dDiPXv29PgCw+3evn27DiVXLnPmzPEOJ/7JT37izsjI0AFx8uTJ7hMnTnDpuqmja62Wd955x3uMCn5PP/20HvIaGxvrfvDBB3WAQdc98cQT7v79++vfG71799bfX08w4VobE074fvvOww8/7M7KytLf7z59+ujHBQUFfr/WFvU/PW/oAQAA8A1qTgAAgKkQTgAAgKkQTgAAgKkQTgAAgKkQTgAAgKkQTgAAgKkQTgAAgKkQTgAAgKkQTgAAgKkQTgAAgKkQTgAAgKkQTgAAgJjJ/wd3F0PrWfcv4wAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def train_step(model, optimizer, x, y):\n", + " # Create state-based loss function for LBFGS\n", + " graphdef = nnx.graphdef(model)\n", + " loss_fn_state = lambda state: loss_fn(nnx.merge(graphdef, state), x, y)\n", + "\n", + " loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)\n", + " optimizer.update(\n", + " model,\n", + " grads,\n", + " grad=grads,\n", + " value=loss,\n", + " value_fn=loss_fn_state)\n", + " return loss\n", + "\n", + "model = make_model(rngs)\n", + "optimizer = nnx.Optimizer(\n", + " model,\n", + " tx=optax.lbfgs(1e-3),\n", + " wrt=nnx.Param)\n", + "\n", + "losses = []\n", + "for _ in range(50):\n", + " loss = train_step(model, optimizer, x, y)\n", + " losses.append(loss)\n", + "plt.plot(losses)" + ] + }, + { + "cell_type": "markdown", + "id": "924d3641", + "metadata": {}, + "source": [ + "# Per-Parameter Learning Rates\n", + "\n", + "In some training regimes, you will want to optimize different parameters with different learning rates.\n", + "\n", + "In Jax, we map from each leaf to the type of parameter it is (weight or bias). We then create a dictionary giving the learning rates to use for each parameter type. Finally, we can make a compound optimizers that uses each rate appropriately.\n", + "\n", + "To do this in Flax, we can map from each leaf to the type of parameter it is (weight or bias). With this pytree of parameter types, we can make a compound optimizer that uses each rate appropriately. " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "6a18b321", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "model = make_model(rngs)\n", + "state = nnx.state(model, nnx.Param)\n", + "rates = {'kernel': optax.adam(1e-3), 'bias': optax.adam(1e-2)}\n", + "param_tys = nnx.map_state(lambda p, v: list(p)[-1], state)\n", + "optimizer = nnx.Optimizer(model, tx=optax.partition(rates, param_tys), wrt=nnx.Param)\n", + "\n", + "@nnx.jit\n", + "def train_step(model, optimizer, x, y):\n", + " loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)\n", + " optimizer.update(model, grads)\n", + " return loss\n", + "\n", + "losses = []\n", + "for _ in range(50):\n", + " loss = train_step(model, optimizer, x, y)\n", + " losses.append(loss)" + ] + }, + { + "cell_type": "markdown", + "id": "8c12883b", + "metadata": {}, + "source": [ + "# Gradient Accumulation" + ] + }, + { + "cell_type": "markdown", + "id": "f8e25577", + "metadata": {}, + "source": [ + "Gradient accumulation in Flax is easy: just use the `optax.MultiSteps` optimizer." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f590a02b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjcsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvTLEjVAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAPMVJREFUeJzt3Qd8ldUB9/F/dsgmJCFAElbYAjIEggtZioo4Whwg1Lq3aK3yVq3YV6G1ta+2Fq2I4kAULLZaEVEEBMIUFFnKTICEMLMg+76fc0JSoqyQe3NHft/P5zE393ny5NzD/eT+PdPP4XA4BAAA4EH83V0AAACAnyKgAAAAj0NAAQAAHoeAAgAAPA4BBQAAeBwCCgAA8DgEFAAA4HEIKAAAwOMEygtVVFRoz549ioyMlJ+fn7uLAwAAzoBZGzY/P1/NmzeXv7+/7wUUE06Sk5PdXQwAAHAWMjMzlZSU5HsBxbScVL3AqKgodxcHAACcgby8PNvAUPU57nMBpapbx4QTAgoAAN7lTIZnMEgWAAB4HAIKAADwOAQUAADgcQgoAADA4xBQAACAxyGgAAAAj0NAAQAAvhVQJk2aZOcyP/TQQ9XPFRUV6d5771WTJk0UERGh6667Tnv37q3xcxkZGbriiisUFhamhIQEPfrooyorK6tLUQAAgA8564CycuVKvfrqq+rWrVuN58eNG6ePP/5YM2fO1MKFC+2y9Ndee231+fLychtOSkpKtHTpUk2bNk1vvvmmnnrqqbq9EgAA0LADSkFBgUaNGqXXXntNjRs3rn4+NzdXr7/+ul544QUNHDhQvXr10htvvGGDyLJly+w1n3/+uTZs2KB33nlH5557roYNG6Y//OEPevnll21oAQAAOKuAYrpwTCvI4MGDazy/evVqlZaW1ni+Y8eOSklJUXp6uv3efO3atauaNm1afc2ll15q1+dfv379CX9fcXGxPX/8AQAAfFet9+KZMWOGvvnmG9vF81PZ2dkKDg5WTExMjedNGDHnqq45PpxUna86dyITJ07UhAkTaltUAADQEAKK2T34wQcf1Lx58xQaGqr6Mn78eD388MM/2w3R2VbvPKhPvsty2v26tojWtT1PvZ00AACoY0AxXTg5OTnq2bNnjUGvixYt0t///nfNnTvXjiM5fPhwjVYUM4snMTHRPjZfV6xYUeO+VbN8qq75qZCQEHu42ubsAr2xZIdT79mpWZQ9AACAiwLKoEGDtG7duhrP3XLLLXacyWOPPWZbNYKCgvTll1/a6cXG5s2b7bTitLQ0+735+uyzz9qgY6YYG6ZFJioqSp07d5Y7dWkepXsvaeuUe3394359tytXM1ZkaMKIc5xyTwAAGopaBZTIyEidc07ND9vw8HC75knV87feeqvtjomNjbWh4/7777ehpF+/fvb80KFDbRC5+eab9ac//cmOO3niiSfswNv6aCU5le7JMfZwhr6t92nM1BWavWa3Hh/WSY2CA5xyXwAAGgKnryT717/+VVdeeaVtQbnoootst82//vWv6vMBAQH65JNP7FcTXEaPHq0xY8bomWeekS+5IDVOSY0bKa+oTJ+uc964FgAAGgI/h8PhkJcxg2Sjo6PtuiumlcZT/X3+j/rz5z/ovFaNNfOu/u4uDgAAXvP5zV48LvTL3skK8PfTyh2H9OPefFf+KgAAfAoBxYWaRoVqYMfKgcDvrch05a8CAMCnEFBc7KY+Kfbrv9bsUlFpuat/HQAAPoGA4mIXtY9X8+hQHT5SqrnrT7xSLgAAqImA4mJmDMrI8ypXvZ2+PMPVvw4AAJ9AQKkHI3sny99PWr79oLbtK6iPXwkAgFcjoNSD5jGNdEmHysGyM1YyWBYAgNMhoNSTG44Nlp21epeKyxgsCwDAqRBQ6sklHeLVNCpEBwtLNG9D5eaIAADgxAgo9SQwwN+ORTHeW8FgWQAAToWAUo9MQPHzk5ZsOaCdBwrr81cDAOBVCCj1KDk2TBe1i7ePGSwLAMDJEVDq2Y19Krt5Zq7apdLyivr+9QAAeAUCSj0b1Kmp4iJCtL+gWF9uZLAsAAAnQkCpZ0F2sGySfTydDQQBADghAoobXH9s6fuvf9ynzINH3FEEAAA8GgHFDVo2CdcFqXFyOKQPVrGyLAAAPxX4s2dQL27sk6LFW/br/ZWZSm4c5pRNCQd2TFDj8GCnlA8AAHcioLjJkM5N1SQ8WDn5xfrth9855Z6DOibo9V+d55R7AQDgTgQUNwkO9Ncfr+umGSszVOGo270cDoe+2rxP8zfn2DEtZr0VAAC8GQHFjQZ3bmoPZxg1ZZldodZ0Gf3m0g5OuScAAO7CIFkfcVOflvbr+6syWQAOAOD1CCg+NKYlLiJY+/JZAA4A4P0IKD40puWXx3ZLfnc5uyUDALwbAcWH3Hheiv369Y/7lXGABeAAAN6LgOJDUpqE6cJ2cfbxeytpRQEAeC8Cio8Z1beyFWXmqkyVlLFbMgDAOxFQfHC35PhIs1tyieZtYLdkAIB3IqD44G7J1x8bLDt9xU53FwcAgLNCQPFBN/RJlp+f7MJtO/YXurs4AADUGgHFByU1DtPF7ePtYwbLAgC8EQHFh3dLNmat2sVgWQCA1yGg+Cizs3HTqBAdKCzR3PXZ7i4OAAC1QkDxUYHHD5ZlZVkAgJchoPiw6/ukyN9PSt92QNv2Fbi7OAAAnDECig9rEdNIAzok2MfvrWBlWQCA9yCg+LibqgbLrt6lotJydxcHAADnB5TJkyerW7duioqKskdaWprmzJlTfX7r1q265pprFB8fb8+PHDlSe/fWXM20VatW8vPzq3FMmjSpNsVALQzoEK9m0aE6dKSUwbIAAN8MKElJSTZMrF69WqtWrdLAgQM1YsQIrV+/XoWFhRo6dKgNHPPnz9eSJUtUUlKi4cOHq6Ki5p4wzzzzjLKysqqP+++/39mvC8cPlj2PwbIAAO8SWJuLTdg43rPPPmtbVZYtW6bdu3drx44dWrNmjW09MaZNm6bGjRvbwDJ48ODqn4uMjFRiYqKzXgNOwwSUl778Ucu3H9SWnAKlJkRQZwAA3wkoxysvL9fMmTNty4np6jHdO6b1JCQkpPqa0NBQ+fv7a/HixTUCimmF+cMf/qCUlBTddNNNGjdunAIDT16U4uJie1TJy8s722I3SM2iG2lgx6b6YuNeTfh4vXq3jK3zPc2GhDeclyx/M00IAAB3B5R169bZQFJUVKSIiAjNnj1bnTt3tuNOwsPD9dhjj+m5556Tw+HQ448/boOM6cap8sADD6hnz56KjY3V0qVLNX78eHv+hRdeOOnvnDhxoiZMmHD2rxIa1TfFBpSvf9xvD2cIDfLXtT2TqF0AgNP5OUySqAUzriQjI0O5ubmaNWuWpkyZooULF9qQ8vnnn+vuu+/W9u3bbcvJjTfeqA0bNqhPnz62K+hEpk6dqjvvvFMFBQU1Wl9O14KSnJxsy1DVnYRTM//Mry/eru1O2DzQ3GPp1gPqmRKjf91zPlUPADgj5vM7Ojr6jD6/a92CEhwcrNTUVPu4V69eWrlypV588UW9+uqrdpCs6erZv3+/7bKJiYmxY03atGlz0vv17dtXZWVldvxKhw4dTniNCS4nCy84M6b77bYLT/7vUBs5+UXqP3G+vsk4rA178tS5OSERAOBh66CYGTrHt24YcXFxNpyYwbE5OTm66qqrTvrza9euta0tCQmVC4rB8yVEhurSLpWDnKev2Onu4gAAfFCtWlDMeJFhw4bZwa35+fmaPn26FixYoLlz59rzb7zxhjp16mTHo6Snp+vBBx+0A2CrWkbMc8uXL9cll1xiZ/KY78350aNH29k+8K4xLf9dl6XZ3+zW48M6KSLkrMdbAwDwM7X6VDGtIWPGjLGDWk0fklm0zYSTIUOG2PObN2+2IebgwYN2Qbbf/e53NoBUMd00M2bM0NNPP21bXVq3bm3PP/zww7UpBjxAWtsmahMXrm37C/Xvtbs1qm9LdxcJANCQB8l62yAbuM6Ur7fp//53ozo1i9KnD1xgx7kAAOCMz2/24sFZ+0WvJAUH+mtjVp7WZh6mJgEATkNAwVmLCQvWld2a2cfvLGO3ZACA8xBQUCdVY08++W6PDh8poTYBAE5BQEGdmMXaOiZGqrisQh9+s5vaBAA4BQEFdWIGxo7uV9mK8u7ynXbFWgAA6oqAgjq7ukcLhQcHaNu+QqVvO0CNAgDqjICCOjOLtI3o0cI+fnc5g2UBAHVHQIHTVpY15n6frX35Nbc+AACgtggocIouzaPVIyVGZRUOfbAqk1oFANQJAQVOn3I8fXmGyisYLAsAOHsEFDiNWbQtKjRQuw8f1aIf9lGzAICzRkCB04QGBegXvZKrpxwDAHC2CChwqlH9KgfLzt+UY1tSAAA4GwQUOFXb+AiltWkiMwRlxgqmHAMAzg4BBS5rRZmxMlOl5RXUMACg1gJr/yPAqQ3tnKi4iBC7Hkq/575UgL9fnatsePfmevLKzlQ9ADQQBBQ4XXCgv245v5Wen7tZBwqds8Px64u32z1/WseFO+V+AADPRkCBS9wzoK0u7ZKo4rLyOt9r0pxN+vrH/Xp32U49QSsKADQIBBS4bJfj1IQIp9zr1+e3tgFl5upd+s2lHex0ZgCAb2OQLDzeRe3jlRzbSLlHS/Xxt3vcXRwAQD0goMDjmUG2N/WpXEb/nWUsAAcADQEBBV5hZO8kBQf469tdufpu12F3FwcA4GIEFHiFJhEhurxron1MKwoA+D4CCrzGzWmV3Tz/XrtHuUdK3V0cAIALEVDgNXqmNFanZlEqLqvQzNWZ7i4OAMCFCCjwqqnLo48to//u8gxVmA1/AAA+iYACr3L1uS0UERKo7fsLtXTrAXcXBwDgIgQUeJXwkEBd17OFfcxgWQDwXQQUeJ1R/SoHy87buFdZuUfdXRwAgAsQUOB12jeNVN/WsSqvcOi9FQyWBQBfRECBVzI7GxszVmSotLzC3cUBADgZAQVeyeyUHBcRopz8Ys3bsNfdxQEAOBkBBV4pONBfN/ZJto/fTmd/HgDwNQQUeK0b+6TI309K33ZAW3Ly3V0cAIATEVDgtZrHNNKgTk3t43eWZbi7OAAAJyKgwKvdfGyw7Ierd+lISZm7iwMAcEdAmTx5srp166aoqCh7pKWlac6cOdXnt27dqmuuuUbx8fH2/MiRI7V3b80BjAcPHtSoUaPs+ZiYGN16660qKChw1utBA3NBapxaNglTfnGZ/rN2j7uLAwBwksDaXJyUlKRJkyapXbt2cjgcmjZtmkaMGKE1a9aoVatWGjp0qLp376758+fb65988kkNHz5cy5Ytk79/ZRYy4SQrK0vz5s1TaWmpbrnlFt1xxx2aPn26s14TGhB/fz+N7ttSz366UX+bv0Xf7sqt8z0bBQXorovbKCEq1CllBADUnp/DJI06iI2N1fPPP6/k5GQNGzZMhw4dsq0jRm5urho3bqzPP/9cgwcP1saNG9W5c2etXLlSvXv3ttd89tlnuvzyy7Vr1y41b978jH5nXl6eoqOj7f2rfhcarkOFJUqb9KWKSp23Hsp1PZP0l5HdnXY/AIBq9fldqxaU45WXl2vmzJkqLCy0XT2me8fsNhsSElJ9TWhoqG05Wbx4sQ0o6enptlunKpwY5nlzzfLly233EFBbjcOD9dav+2r5trpvHmi6iv65aJs+/m6PfndFJ8WGB/MPAgBuUOuAsm7dOhtIioqKFBERodmzZ9tWETPuJDw8XI899piee+452wX0+OOP2yBjunSM7OxsJSQk1CxAYKBthTHnTqa4uNgexycw4Hh9Wsfao67M+zZ96wGt252rD1Zl6q6L21LRAOANs3g6dOigtWvX2haPu+++W2PHjtWGDRtsQDEtKh9//LENLqYJ5/Dhw+rZs2f1+JOzNXHiRHu/qsN0JwGuYFoBb05rWb1bstnvBwBQ/2qdHIKDg5WamqpevXrZ4GAGxb744ov2nBkka7p6cnJytH//fr399tvavXu32rRpY88nJibac8crKyuzM3vMuZMZP3687a+qOjIz2SAOrnNV9+aKCQvSrkNHtWBzzfcrAMBL1kGpqKio0f1ixMXF2bEmZjaPCSRXXXWVfd50DZlWldWrV1dfa64x9+jbt+9Jf4cZ11I1tbnqAFwlNChAI3tXttK9xTL6AOD5Y1BMS4aZqZOSkqL8/Hw7NXjBggWaO3euPf/GG2+oU6dOtrvHDIh98MEHNW7cONstZJhzl112mW6//Xa98sordprxfffdpxtuuOGMZ/AA9cFMXX7t621a+MM+7dhfqFZx4VQ8AHhqQDGtIWPGjLGDXs1YELNomwknQ4YMsec3b95sQ4zpsjHrovzud7+zAeV47777rg0lgwYNsmNTrrvuOr300kvOfVVAHaU0CdOA9vH6avM+OxbliSs7U6cA4E3roLgD66CgPny1KUe3vLlSUaGBWv5/BqtRcAAVDwD19PnNXjzASVzcPl4psWHKKyrTv9fupp4AoB4RUIBTLaPfL6V6sKwXNjYCgNcioACnYGbzhAT6a0NWnr7JOERdAUA9IaAApxATFmzXRTGYcgwA9YeAApzGmLRW9uun67K0L7/mmj8AANcgoACn0TUpWucmx6i03KH3V2ZQXwBQDwgowBkYc2x/nneXZ6isvII6AwAXI6AAZ+Dyrs0UGx6srNwifbGR/XkAwNUIKMAZ7s9z/XmV+/O8vWwHdQYALkZAAc7QqL4p8veTlmw5oC05BdQbALgQAQU4Q0mNwzSwY1P72OzPAwBwHQIKcBaDZWet3qWC4jLqDgA8YTdjoKG7IDVOrePCtX1/oe6f/o2axzSq8z07NovSzf0qgw8AoBIBBajl/jymFWXCxxv01eZ9Tqu7jomROq9VLP8WAHAMAQWopdH9Wsrfz0+Hj5TWue6Wbt2v5dsPatrSHQQUADgOAQWopaAAf43tX7n8fV0N7pygK15arM++z9bevCI1jQrl3wMAGCQLuFeX5tE6r1VjlVU47Cq1AIBKzOIBPGQzwunLM1RSxjL6AGAQUAA3u+ycRCVEhmh/QbHmfJ/l7uIAgEcgoAAeMKZlVN/KacZmsCwAgIACeIQb+yYrKMBP32Qc1rpdue4uDgC4HS0ogAdIiAy1OyYb09JpRQEAAgrgIaqmLv/n2z06WFji7uIAgFsRUAAP0SM5Rl1bRNuZPDNWMuUYQMNGQAE8hJ+fX3UryrvLMlRWzpRjAA0XAQXwIFd2a6bY8GDtPnxUX2zMcXdxAMBtCCiABwkNCtAN5yXbx28xWBZAA0ZAATzMKLsZodlI8IB+2Jvv7uIAgFsQUAAP0yKmkYZ2TrSPaUUB0FARUAAPNKZ/5cqy//pmt/KKSt1dHACodwQUwAOltWmi9k0jdKSkXLNW7XJ3cQCg3hFQAA+dcly1y7Hp5qmocLi7SABQrwgogIe6pkcLRYYGaseBI1r04z53FwcA6hUBBfBQ4SGB+mWvyinH7HIMoKHxczgcXtd2nJeXp+joaOXm5ioqKsrdxQFcZvv+Ql3y5wX2cVLjRnW+X4C/n+4Z0FbXn5fihNIBgOs+vwNreW8A9ah1XLgu65Koz9Zna9eho06558Q5m3RV9xZqFBzglPsBgCsQUAAP9+KN5+qH7AKVO6Gx877p39ig8++1u3VDH1pRAPjIGJTJkyerW7dutlnGHGlpaZozZ071+ezsbN18881KTExUeHi4evbsqQ8//LDGPVq1amVnKBx/TJo0yXmvCPAxIYEB6poUrXOTY+p8jD02M+jNpTvkhb27ABqQWgWUpKQkGyZWr16tVatWaeDAgRoxYoTWr19vz48ZM0abN2/Wf/7zH61bt07XXnutRo4cqTVr1tS4zzPPPKOsrKzq4/7773fuqwJwQiN7J6tRUIA2Zedr2baD1BIA3wgow4cP1+WXX6527dqpffv2evbZZxUREaFly5bZ80uXLrVho0+fPmrTpo2eeOIJxcTE2EBzvMjISNvKUnWY1hYArhcdFqRre7awj99cup0qB+B704zLy8s1Y8YMFRYW2q4eo3///nr//fd18OBBVVRU2PNFRUUaMGBAjZ81rTBNmjRRjx499Pzzz6usrKzurwTAGflV/8punnkb9irz4BFqDYBvDJI1XTcmkJjgYVpPZs+erc6dO9tzH3zwga6//nobPgIDAxUWFmbPp6amVv/8Aw88YMemxMbG2haX8ePH226eF1544aS/s7i42B7HT1MCcHbaNY3UBalxWrxlv95ZtlPjL+9EVQLw/nVQSkpKlJGRYecwz5o1S1OmTNHChQttSDHdOytWrNBzzz2nuLg4ffTRR/rrX/+qr7/+Wl27dj3h/aZOnao777xTBQUFCgkJOeE1Tz/9tCZMmPCz51kHBTg7X2zYq9veWqXoRkFaNn4QU44BeNw6KHVeqG3w4MFq27atfvvb39qWku+//15dunSpcd48/8orr5zw580A23POOUebNm1Shw4dzrgFJTk5mYACnKXyCoddAC7j4BFNvLarbmTKMQAPCyh1XurejDUx4eHIkcq+bH//mrcMCAiw15zM2rVr7c8kJCSc9BrTslI1tbnqAFC3FWXHpLW0j99cwpRjAF4+BsWMFxk2bJhSUlKUn5+v6dOna8GCBZo7d646duxoW0pMd82f//xnOw7FdPHMmzdPn3zyif359PR0LV++XJdccomdyWO+HzdunEaPHq3GjRu76jUCOIFf9k7WC/N+0Oa9+UrfdkD928ZRTwC8M6Dk5OTYtU7MoFbTRGMWbTPhZMiQIfb8p59+qscff9xORzZjSkxgmTZtmp2aXNUSYmb2mDElptWldevWNqA8/PDDrnl1AE7KjD+5rmeS3l6207aiEFAAeBI2CwQasC05+Rr8wiL5+0kLH71EybFh7i4SAB+WV59jUAB4r9SESF3YLk4VDtmWFADwFAQUoIG75fzKhdtmrMjQkRIWTQTgGQgoQAM3oH2CWjYJU15RmWav2e3u4gCARUABGjh/O+W4shVlGrscA/AQBBQA+mXvJIUFB+iHvQVK33qAGgHgdgQUAIoKDdIveiXZmnhj6Q5qBIDbEVAAWFXdPF9sZJdjAF64mzEA35SaEKGL2sdr0Q/7NOHj9erXpkmd79kkIlgjurew41wAoDYIKACq3dK/lQ0oX2zMsYczlJY7NLJ3MrUMoFYIKACqDegQr98Mba8fcwrqXCtZh4u0YsdBTV28Xb/slSQ/P1pRAJw5AgqAaiZE3DewnVNqJPdIqfpN/FKbstmMEEDtMUgWgEtEhwXpul4t7OOpi5kZBKB2CCgAXOZX/Vvbr19u2qudBwqpaQBnjIACwKUzgy5uHy+HQ3qT9VUA1AIBBUC9bEY4c9Uu5ReVUtsAzggBBYBLXdQuXm3jw1VQXGZDCgCcCQIKAJcyi7T96vzKsSjT0neovMJBjQM4LQIKAJe7rmcLRYUGaueBI/pqk3MWgAPg2wgoAFwuLDhQN/ZJsY+nLtlOjQM4LQIKgHpxc1pLmS15lm49oE3ZedQ6gFMioACoF0mNw3TZOYn28Rss3AbgNAgoAOrNLccGy85eu1sHCoqpeQAnRUABUG96t2ysri2iVVJWofdWZFDzAE6KgAKgXjcjrFq47e1lO1VaXkHtAzghAgqAenVFt2aKjwzR3rxifboui9oHcEIEFAD1KiQwQKP7trSPpy5hl2MAJ0ZAAVDvbuqbouAAf32beVjfZBziXwDAzxBQANQ708Vz1bnN7eOpi1m4DcDPBZ7gOQBwOTNYdtbqXZrzfbY++W6PGgUF1Pme57SIVtOoUKeUD4B7EVAAuEWX5tHq2zpWy7cf1H3T1zjlns2iQ7Xg0QF2nAsA70ZAAeA2v7uik57970YVldV9uvHWnAJl5RbpP2v36Je9k51SPgDu4+dwOLxu7/O8vDxFR0crNzdXUVFR7i4OAA8wecFW/fGzTeqYGKk5D15o11wB4L2f3wySBeATbuqTYsexbMrOV/rWA+4uDoA6IqAA8AnRYUH6Ra8k+/h1ZgYBXo+AAsBnVC2j/+WmHG3fX+ju4gCoAwIKAJ/RJj5CAzsm2MdvLGF9FcCbEVAA+JRbL2htv85ctUu5R0rdXRwA9RFQJk+erG7dutmRt+ZIS0vTnDlzqs9nZ2fr5ptvVmJiosLDw9WzZ099+OGHNe5x8OBBjRo1yv58TEyMbr31VhUUFJxt+QGghv5tm9iZPEdLy/XeygxqB2gIASUpKUmTJk3S6tWrtWrVKg0cOFAjRozQ+vXr7fkxY8Zo8+bN+s9//qN169bp2muv1ciRI7Vmzf8WYTLhxFw/b948ffLJJ1q0aJHuuOMO578yAA2SmV7862OtKNOW7lBped3XWAHgheugxMbG6vnnn7ctIREREbaVxbSiVGnSpIn++Mc/6rbbbtPGjRvVuXNnrVy5Ur1797bnP/vsM11++eXatWuXmjev3JvjdFgHBcCpFJWW64I/ztf+ghK9dGMPXdX9zP62APCBdVDKy8s1Y8YMFRYW2q4eo3///nr//fdtN05FRYU9X1RUpAEDBtjz6enptlunKpwYgwcPlr+/v5YvX362RQGAGkKDAjS6X8vqKcdeuB4l0ODVOqCYrhvTUhISEqK77rpLs2fPtq0ixgcffKDS0lLbamLO33nnnfZ8ampq9RiVhITKEfZVAgMDbSuMOXcyxcXFNnUdfwDAqZiAEhzor28zD+ubjENUFuDrAaVDhw5au3atbfG4++67NXbsWG3YsMGee/LJJ3X48GF98cUXdozKww8/bMegmFBTFxMnTrRNQlVHcjL7bAA4tbiIEF19bmXXDgu3AQ1wDIrpomnbtq1++9vf2paS77//Xl26dKlx3jz/yiuvaOrUqXrkkUd06ND//m+mrKxMoaGhmjlzpq655pqTtqCYo4ppQTEhhb14AJzKpuw8Xfb/vpa/n7Tw0UuUHBtGhQENZS8eM9bEhIcjR45U3tC/5i0DAgLsNYYZq2JaWMwsoCrz58+35/v27XvS32G6i6qmNlcdAHA6HROjdEFqnCoclTN6AHiPWgWU8ePH22nBO3bssN025vsFCxbYqcMdO3a0LSVm3MmKFSu0detW/eUvf7HTia+++mr78506ddJll12m22+/3V6zZMkS3XfffbrhhhvOeAYPAJzNwm3vr8xUQXEZlQf4YkDJycmxa52YcSiDBg2y04Xnzp2rIUOGKCgoSJ9++qni4+M1fPhwu6DbW2+9pWnTptlpxFXeffddG2bMz5vnL7jgAv3zn/90xWsDAF3cPl5t4sOVX1ymD1ZmUiNAQxmD4g6sgwKgNt5ZtlNPfPS9UmLD9NVvBijADEoB4NtjUADA013XM0nRjYKUcfCIvti4193FAXAGAs/kIgDwZo2CA3RT3xRNXrBVL37xozIPVg7qr+ticCPOba7I0CCnlBFATQQUAA3C2LRWem3RNm3IytOG/zpnscf1e/I08dquTrkXgJoIKAAahMToUP1lZHd9tSmnzvcqKq3QZ+uz9eE3u/TI0PZ2UTgAzkVAAdBgjDi3hT3qyswtuPofS+0y+mYA7kOD2zulfAD+h0GyAFBLfn5+uu3Y+ipvp++0uycDcC4CCgCchWHnJKpFTCMdKCzR7DW7qUPAyQgoAHAWAgP8dcv5rao3I6ww6+kDcBoCCgCcpevPS1ZESKC25BRo4Q/7qEfAiQgoAHCWzBooN5yXbB9PWbyNegSciIACAHXwq/Nb2aXzl2w5oPV7cqlLwEkIKABQB0mNw+yA2aqxKACcg4ACAHV0+4Vt7NePv92jvXlF1CfgBAQUAKij7skx6tMqVqXlDk1buoP6BJyAgAIATnDrhZULt727PENHSsqoU6COCCgA4ASDOzVVqyZhyj1aqlmrd1GnQB0RUADACcxMnl8fW/7eDJYtZ+E2oE4IKADgJL/olaToRkHaeeCIvti4l3oF6oCAAgBOEhYcqFF9U+zjKV+zcBtQFwQUAHCisf1bKSjATyt3HNLazMPULXCWCCgA4ERNo0I1vHtz+5hWFODsEVAAwMluu6By4bY532dr16Ej1C9wFgLP5ocAACfXuXmUzk9tYvfnufJvixUeXPc/tW0TIvTPm3spNCiAqkeDQEABABe4d0CqDSiHj5Tao652Hz6qmat36eZ+LZ1SPsDT+TkcDoe8TF5enqKjo5Wbm6uoqCh3FwcATijjwBEdOlJS59r5cuNevTR/i10I7stHBtg1VwBvVJvPb1pQAMBFUpqE2aOuUhMiNC19p3YcOKJ5G/bqsmO7JwO+jEGyAODhwkMCNbpf5foqr7G+ChoIAgoAeIGxaa0UHOCv1TsPafXOg+4uDuByBBQA8AIJUaG6ukfl+ir/XMQqtfB9BBQA8BK3X1i5vsrnG/Zq+/5CdxcHcCkCCgB4iXZNIzWwY4LM3MvXF9OKAt9GQAEAL2xFmblqlw4UFLu7OIDLEFAAwIv0axOrri2iVVxWoXeWZbi7OIDLEFAAwIv4+fnp9osqW1HeSt+hotJydxcJcAkCCgB4mcvPSVSLmEY6UFiiD7/Z5e7iAC5BQAEALxMY4K9bL2htH0/5ersqKrxuxxLgtAgoAOCFRp6XrKjQQDvd+IuNe91dHMC9AWXy5Mnq1q2b3eDHHGlpaZozZ449t2PHDts3eqJj5syZ1fc40fkZM2Y4/5UBgA+LCAnUqGM7G7P8PdTQA0pSUpImTZqk1atXa9WqVRo4cKBGjBih9evXKzk5WVlZWTWOCRMmKCIiQsOGDatxnzfeeKPGdVdffbWzXxcA+Lxf9W+loAA/rdxxSN9kHHJ3cQCnqtVuxsOHD6/x/bPPPmtbVZYtW6YuXbooMbHmDpuzZ8/WyJEjbUg5XkxMzM+uBQDUTtOoUI04t4Vmrd6l1xZt0+TRvahC+IyzHoNSXl5uu2YKCwttV89PmVaWtWvX6tZbb/3ZuXvvvVdxcXHq06ePpk6dKodZFhEAUGt3HJty/Nn6bO08wPL3aKAtKMa6detsICkqKrItI6aVpHPnzj+77vXXX1enTp3Uv3//Gs8/88wztmsoLCxMn3/+ue655x4VFBTogQceOOnvLC4utkeVvLy82hYbAHxS+6aRGtAhXgs279Pri7frmRHnuLtIgFP4OWrZfFFSUqKMjAzl5uZq1qxZmjJlihYuXFgjpBw9elTNmjXTk08+qUceeeSU93vqqafsmJTMzMyTXvP000/b8Sw/ZcpgBusCQEO2dMt+3TRluUKD/HVNjxZ1vp+ZvPCLXknqmdLYKeUDjm9giI6OPqPP71oHlJ8aPHiw2rZtq1dffbX6ubffftt27ezevVvx8fGn/Pn//ve/uvLKK22LTEhIyBm3oJhBuQQUAJDtJh/+98X6frfzWpfNQnALHh2goABWo4B7Akqtu3h+qqKiokZ4qOreueqqq04bTgwzTqVx48YnDSeGOXeq8wDQkJkWj3/c1EufrNvjlEXb3ly6Q7sPH9V/v8vS1U5okQHORq0Cyvjx4+2U4ZSUFOXn52v69OlasGCB5s6dW33Nli1btGjRIn366ac/+/mPP/5Ye/fuVb9+/RQaGqp58+bpueee029+8xv+9QCgDlKahOmeAalOCzzPz92sVxZu1Yhzm9vvAY8OKDk5ORozZoxdu8Q00ZhF20w4GTJkSPU1ZlaOWS9l6NChP/v5oKAgvfzyyxo3bpxtkkxNTdULL7yg22+/3TmvBgBQZ6P7ttTLX23Rpux8Lfpxvy5uf/rWcMDZ6jwGxdP7sAAAtffMxxs0dcl2nZ/aRO/e1o8qRL1/fjP6CQDwM7de2FoB/n5asuWA1u3KpYZQ7wgoAIATzuK5qntz+/jVRVupIdQ7AgoA4JSr1H66LksZB45QS6hXBBQAwAl1ahZlB8iamctTFm+jllCvCCgAgJO681grygerMnWgoOaaV4ArEVAAACeV1raJuraIVlFphd5K30lNod4QUAAAJ2UWabvz4spWlLfSd+hoSTm1hXpBQAEAnNJlXRKVEhumQ0dKNXP1yTd2BZyJgAIAOKXAAH/dfmFr+/i1r7eprLyCGoPLEVAAAKf1i17Jig0PVubBo5rzfTY1BpcjoAAATqtRcIDGpLWsXrjNC3dJgZchoAAAzsiYtFYKDfLX97vztHTrAWoNLkVAAQCcEdPFc33vZPv4lYUsfw/XIqAAAM7YbRe2kb+f9PWP+7V+D5sIwnUCXXhvAICPSY4N0xXdmuvjb/fo0ZnfqUvzqDrfMyYsSA8Nbq/wED6S8D+8GwAAtV7+3gSUDVl59nCGRkEBenhoB/4lUI2AAgColXNaROv1sb21eW9+nWsu63CR3l62U9PSd+qOi9sqglYUHENAAQDU2qBOTe1RV+UVDi3Zsl/b9hdqxooMO8YFMBgkCwBwmwB/P91xbMfkKV9vV0kZq9SiEgEFAOBW1/RsoYTIEGXnFemjtbv514BFQAEAuFVIYIBuO7bXz6sLt6qiglVqQUABAHiAG/ukKDI0UFv3FWrexr3uLg48AC0oAAC3iwwNqt7rZ/IC9voBAQUA4CF+1b+1ggP9tTbzsJZvP+ju4sDNaEEBAHiE+MgQjeydVN2KgoaNgAIA8Bh3XNjW7vWz8Id92rDHOavUwjsRUAAAHiOlSeVePwY7JjdsBBQAgEe56+LKhds++W6PMg4ccXdx4CYEFACAR+nSPFoXtY+XWQ7lta+3ubs4cBMCCgDA49x9cVv79YNVmdpfUOzu4sANCCgAAI/Tr02szk2OUXFZhd5cssPdxYEbEFAAAB7Hz89Pdx1rRXkrfYcKisvcXSTUMwIKAMAjDe3cVG3iw5VXVKb3lme4uzioZwQUAIBH8vf/XyvKlMXbVFxW7u4ioR4F1ucvAwCgNq4+t4Ve+PwHZecV6bz/+4UCA+r+/9UXtYvTX68/13YjwXMRUAAAHsvszXPfwFQ98dH3tqvHGT5au0e/6JWsC9rFOeV+cA0CCgDAo43qm6KL28erqLTuXTxTvt6u91dl6h8LthBQPFyt2somT56sbt26KSoqyh5paWmaM2eOPbdjxw7bXHaiY+bMmdX3yMjI0BVXXKGwsDAlJCTo0UcfVVkZo7MBACdmPkeSY8PUrmlknY8HBrdToL+flm49oDUZh6hyXwkoSUlJmjRpklavXq1Vq1Zp4MCBGjFihNavX6/k5GRlZWXVOCZMmKCIiAgNGzbM/nx5ebkNJyUlJVq6dKmmTZumN998U0899ZSrXh8AANVaxDTS1T1a2Mf/YMdkj+bncDgcdblBbGysnn/+ed16660/O9ejRw/17NlTr7/+uv3etLZceeWV2rNnj5o2bWqfe+WVV/TYY49p3759Cg4OPqPfmZeXp+joaOXm5tqWHAAAztSWnAIN+etCmU+/z8ddpPZNI6m8elKbz++zHg5tWkNmzJihwsJC29XzU6aVZe3atTWCS3p6urp27VodToxLL73UFti0wpxMcXGxveb4AwCAs5GaEKHLuiTax5NpRfFYtQ4o69ats902ISEhuuuuuzR79mx17tz5Z9eZVpNOnTqpf//+1c9lZ2fXCCdG1ffm3MlMnDjRJq6qw3QnAQBwtu4ZkGq//ufbPco8yI7JPhFQOnToYFtGli9frrvvvltjx47Vhg0balxz9OhRTZ8+/YTdPmdj/Pjxtjmo6sjMzHTKfQEADVPXpGhd2C5O5RUOvbpoq7uLA2cEFDNOJDU1Vb169bItG927d9eLL75Y45pZs2bpyJEjGjNmTI3nExMTtXfv3hrPVX1vzp2Maa2pmjlUdQAA4IxWlA9W7VJOfhGV6WHqvCRfRUWFHSPy0+6dq666SvHx8TWeN2NVTBdRTk5O9XPz5s2zgeNE3UQAALhyx+QeKTEqKavQ64u3U9HeHFBMV8uiRYvsmicmaJjvFyxYoFGjRlVfs2XLFnvNbbfd9rOfHzp0qA0iN998s7799lvNnTtXTzzxhO69917bSgIAQH2ur3LvsVaUd5dlKPdoKZXvrQHFtHyYbhszDmXQoEFauXKlDRlDhgypvmbq1Kl2vRQTRn4qICBAn3zyif1qWlNGjx5t7/fMM88459UAAFALAzsmqEPTSBUUl+nt9B3UnS+tg+IOrIMCAHCWf6/drQdnrFVseLCWPDZQjYIDqFxvXgcFAABfcEXXZkqJDdPBwhLNWJnh7uLgGAIKAKBBCwzw150Xt7GP/7lomx00C/cjoAAAGrzreiYpPjJEWblF+mjt7gZfH56AgAIAaPBCgwJ0+4WtbT28snCrXcAN7kVAAQBA0k19Wyq6UZC27SvU3PUn334F9SOwnn4PAAAeLSIkUGPTWuql+Vv05883a2NW3TemDQ7w1/V9kpUQGeqUMjYkBBQAAI751fmt9drX220ryt/mb3FKvazbnat/julNHdcSAQUAgGPMWiiv3txL8zf9b0uWs1VWUaF3l2fo8w17tTk7Xx0SI6nnWiCgAABwnIvax9vDGczaKp+uy9Y/FmzRizf0oJ5rgUGyAAC4eMfkj7/dox37C6nnWiCgAADgIue0iNYlHeJlZi1PXrCVeq4FAgoAAC5038DKVpR/rdmlPYePUtdniIACAIAL9WoZq7Q2TVRa7rBL6ePMEFAAAKinVpT3VmRoX34x9X0GCCgAALhY/7ZNdG5yjIrLKjRlMa0oZ4KAAgCAi/n5+em+SypbUd5J36nDR0qo89MgoAAAUA8GdUpQx8RIFZaU682lO6jz0yCgAABQX60ox8aivLFkhwqKy6j3UyCgAABQT4ad00xt4sOVe7RU7yzbSb2fAgEFAIB6EuDvp7svbmsfT/l6u4pKy6n7kyCgAABQj67u0UItYhppf0Gx3l+ZSd2fBAEFAIB6FBTgr7sGVLaivLpwq0rKKqj/EyCgAABQz37ZK0kJkSHak1ukj9bspv5PgIACAEA9Cw0K0O0XtrGP/7Fgi8rKaUX5qcCfPQMAAFzupr4pennBFu04cEST5mxS24SIOt8zJTZM56fGyRcQUAAAcIPwkED9+vzWemHeD5qyeLvT7jvzrjSd1ypW3o6AAgCAm9x2YWtl5xU5ZQPBjANHtHlvvl768ke9fWtfeTsCCgAAbhIWHKjnrunqlHtlHjyiAX9eoK9/3K+1mYft5oTejEGyAAD4gOTYMF3To4V9/Lcvf5S3I6AAAOAj7r0kVf5+0pebcvT97lx5MwIKAAA+onVcuIZ3b24f/22+d7eiEFAAAPAh912SKj8/ae76vdqUnSdvRUABAMCHtGsaqWHnJNrHf5+/Rd6KgAIAgI+575J29ut/12VpS06BvBEBBQAAH9O5eZSGdG4qh0P6x1fe2YpCQAEAwAc9MLCyFeXf3+7RzgOF8jYEFAAAfFDXpGgN6BCv8gqH/vHVVvl0QJk8ebK6deumqKgoe6SlpWnOnDk1rklPT9fAgQMVHh5ur7nooot09OjR6vOtWrWSn59fjWPSpEnOe0UAAMC6/1gryoff7NKuQ0d8N6AkJSXZMLF69WqtWrXKBpERI0Zo/fr11eHksssu09ChQ7VixQqtXLlS9913n/z9a/6aZ555RllZWdXH/fff79xXBQAA1KtlY52f2kRlFQ69snCr7+7FM3z48BrfP/vss7ZVZdmyZerSpYvGjRunBx54QI8//nj1NR06dPjZfSIjI5WYWDkFCgAAuLYVZcmWA/pg5S47uycxOtS3x6CUl5drxowZKiwstF09OTk5Wr58uRISEtS/f381bdpUF198sRYvXvyznzWtME2aNFGPHj30/PPPq6ys7JS/q7i4WHl5eTUOAABwev3aNFGfVrEqKa/Qq4u8pxWl1gFl3bp1ioiIUEhIiO666y7Nnj1bnTt31rZt2+z5p59+Wrfffrs+++wz9ezZU4MGDdKPP/5vuV3TwmKCzVdffaU777xTzz33nH7729+e8ndOnDhR0dHR1UdycvLZvFYAABqk+wel2q/Tl2coJ79I3sDP4TCzpM9cSUmJMjIylJubq1mzZmnKlClauHChDh8+rPPPP1/jx4+3oaOKGVR7xRVX2JBxIlOnTrVBpaCgwIaek7WgmKOKaUExIcWUwQzEBQAAJ2c+6q+dvFRrMg7rjova6P9c3knuYD6/TUPDmXx+12oMihEcHKzU1Mok1qtXLzsQ9sUXX6wed2JaU47XqVMnG2hOpm/fvraLZ8eOHSccr2KY4HKy8AIAAE7NzJg166Lc8uZKvbNsp0rLK85ogO2V3So3HnSHWgeUn6qoqLCtG2b6cPPmzbV58+Ya53/44QcNGzbspD+/du1aO8vHjF0BAACuYdZE6doiWut25+qNJTtOe31xWYX3BBTTfWPCRkpKivLz8zV9+nQtWLBAc+fOtens0Ucf1e9//3t1795d5557rqZNm6ZNmzbZrqCqachmIO0ll1xiZ/KY783Mn9GjR6tx48aueo0AADR4fn5++tuNPfSvNbtVXnH6FpTuSTFurbNaBRQzU2fMmDF27RLTh2TGl5hwMmTIEHv+oYceUlFRkQ0dBw8etEFl3rx5atu2rT1vumnMAFkzkNa0urRu3dpe+/DDD7vm1QEAgGqt4sL18JD28slBsp6gNoNsAACA931+sxcPAADwOAQUAADgcQgoAADA4xBQAACAxyGgAAAAj0NAAQAAHoeAAgAAPA4BBQAAeBwCCgAA8DgEFAAA4HEIKAAAwOMQUAAAgMep1W7GnqJqf0Oz6RAAAPAOVZ/bZ7JPsVcGlPz8fPs1OTnZ3UUBAABn8TludjU+FT/HmcQYD1NRUaE9e/YoMjJSfn5+Tk93JvhkZmaeditoUN/ehvc39e3LeH97fn2byGHCSfPmzeXv7+97LSjmRSUlJbn0d5jKJqDUH+q7flHf1Lcv4/3t2fV9upaTKgySBQAAHoeAAgAAPA4B5SdCQkL0+9//3n6F61Hf9Yv6pr59Ge9v36pvrxwkCwAAfBstKAAAwOMQUAAAgMchoAAAAI9DQAEAAB6HgHKcl19+Wa1atVJoaKj69u2rFStWuO9fxocsWrRIw4cPtysHmpV/P/rooxrnzTjtp556Ss2aNVOjRo00ePBg/fjjj24rr7ebOHGizjvvPLvSckJCgq6++mpt3ry5xjVFRUW699571aRJE0VEROi6667T3r173VZmbzZ58mR169aterGqtLQ0zZkzp/o8de1akyZNsn9XHnroIercBZ5++mlbv8cfHTt2rJf3NwHlmPfff18PP/ywnTL1zTffqHv37rr00kuVk5PjlIpuyAoLC219mgB4In/605/00ksv6ZVXXtHy5csVHh5u69688VF7CxcutH8wli1bpnnz5qm0tFRDhw61/w5Vxo0bp48//lgzZ86015utI6699lqq+yyYVa3Nh+Tq1au1atUqDRw4UCNGjND69eupaxdbuXKlXn31VRsQj8f727m6dOmirKys6mPx4sX1U9dmmjEcjj59+jjuvffe6qooLy93NG/e3DFx4kSqx4nMW2727NnV31dUVDgSExMdzz//fPVzhw8fdoSEhDjee+896t4JcnJybL0vXLiwun6DgoIcM2fOrL5m48aN9pr09HTq3AkaN27smDJlCnXtQvn5+Y527do55s2b57j44osdDz74oH2e97dz/f73v3d07979hOdcXde0oEgqKSmx//djuhaO3+/HfJ+enu6cJIgT2r59u7Kzs2vUvdmnwXSxUffOkZuba7/Gxsbar+a9blpVjq9z02SbkpJCnddReXm5ZsyYYVurTFcPde06ppXwiiuuqPE+Nqhz5zNd7qaLvk2bNho1apQyMjLqpa69crNAZ9u/f7/9w9K0adMaz5vvN23a5LZyNQQmnBgnqvuqc6jbzt+mb/7888/XOeecU13nwcHBiomJoc6dZN26dTaQmG5J0w8/e/Zsde7cWWvXrqWuXcCEQNMVb7p4for3t3OZ/1l888031aFDB9u9M2HCBF144YX6/vvvXV7XBBTAx/8v0/whOb7PGM5n/nibMGJaq2bNmqWxY8fa/ng4X2Zmph588EE7vspMaIBrDRs2rPqxGetjAkvLli31wQcf2EkNrkQXj6S4uDgFBAT8bOSx+T4xMdGl/wANXVX9UvfOd9999+mTTz7RV199ZQdyHl/nplvz8OHDNa7n/X72zP9FpqamqlevXnYWlRkU/uKLL1LXLmC6FczkhZ49eyowMNAeJgyagfbmsfm/d97frmNaS9q3b68tW7a4/P1NQDn2x8X8Yfnyyy9rNI2b702zLVyndevW9o18fN3n5eXZ2TzU/dkxY5FNODHdDPPnz7d1fDzzXg8KCqpR52YasulXps6dw/z9KC4upq5dYNCgQbZLzbRYVR29e/e2YyOqHvP+dp2CggJt3brVLgvh8r8ldR5m6yNmzJhhZ468+eabjg0bNjjuuOMOR0xMjCM7O9vdRfOJ0fZr1qyxh3nLvfDCC/bxzp077flJkybZuv73v//t+O677xwjRoxwtG7d2nH06FF3F90r3X333Y7o6GjHggULHFlZWdXHkSNHqq+56667HCkpKY758+c7Vq1a5UhLS7MHau/xxx+3M6S2b99u37/mez8/P8fnn39OXdeT42fxGLy/neeRRx6xf0vM+3vJkiWOwYMHO+Li4uzsQFfXNQHlOH/7299sRQcHB9tpx8uWLXNKJTd0X331lQ0mPz3Gjh1bPdX4ySefdDRt2tSGxEGDBjk2b97s7mJ7rRPVtTneeOON6mtM+LvnnnvsdNiwsDDHNddcY0MMau/Xv/61o2XLlvbvRnx8vH3/VoUT6to9AYX3t/Ncf/31jmbNmtn3d4sWLez3W7ZsqZe69jP/qXs7DAAAgPMwBgUAAHgcAgoAAPA4BBQAAOBxCCgAAMDjEFAAAIDHIaAAAACPQ0ABAAAeh4ACAAA8DgEFAAB4HAIKAADwOAQUAADgcQgoAABAnub/Awwtx7a6IEpxAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model = make_model(rngs)\n", + "optimizer = nnx.Optimizer(model, tx=optax.MultiSteps(optax.adam(1e-3), every_k_schedule=3), wrt=nnx.Param)\n", + "\n", + "@nnx.jit\n", + "def train_step(model, optimizer, x, y):\n", + " loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)\n", + " optimizer.update(model, grads)\n", + " return loss\n", + "\n", + "losses = []\n", + "for _ in range(50):\n", + " loss = train_step(model, optimizer, x, y)\n", + " losses.append(loss)\n", + "plt.plot(losses)" + ] + }, + { + "cell_type": "markdown", + "id": "1f9d184a", + "metadata": {}, + "source": [ + "# Sharding Optimization State Differently from Parameters" + ] + }, + { + "cell_type": "markdown", + "id": "1723871a", + "metadata": {}, + "source": [ + "Say we're doing data parallelism. We want to replicate our parameters across all GPUs so we can do the forward and backward passes without communication latency." + ] + }, + { + "cell_type": "markdown", + "id": "755ec67f", + "metadata": {}, + "source": [ + "But we don't need to replicate the optimizer state, as it's not invovled in SPMD computations. One copy is enough, and we can shard this copy across our mesh to reduce memory usage. This means that we need the optimizer state to be sharded differently from the parameters themselves." + ] + }, + { + "cell_type": "markdown", + "id": "bd19eaa8", + "metadata": {}, + "source": [ + "To do this, we can pass the params initializer given the the optimizer a `sharding` argument. This will shard the optimization state the same way. But when we initialize the model parameters themselves, we won't provide a sharding, allowing for data parallelism." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "e7c31cca", + "metadata": {}, + "outputs": [], + "source": [ + "from jax.sharding import PartitionSpec as P, AxisType, get_abstract_mesh, reshard\n", + "mesh = jax.make_mesh((2, 4), (\"x\", \"y\"),\n", + " axis_types=(AxisType.Explicit, AxisType.Explicit))\n", + "jax.set_mesh(mesh)\n", + "\n", + "ghost_model = jax.eval_shape(lambda: make_model(nnx.Rngs(0), out_sharding=P('x', 'y')))\n", + "optimizer = nnx.Optimizer(ghost_model, optax.adam(1e-3), wrt=nnx.Param)\n", + "model = make_model(rngs)\n", + "\n", + "@nnx.jit\n", + "def train_step(model, optimizer, x, y):\n", + " loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)\n", + " optimizer.update(model, grads)\n", + " return loss\n", + "\n", + "losses = []\n", + "for _ in range(50):\n", + " loss = train_step(model, optimizer, x, y)\n", + " model = reshard(model, P(None, None))\n", + " losses.append(loss)" + ] + }, + { + "cell_type": "markdown", + "id": "18e25a6f", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "The optimizer state is sharded:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "a087ec2b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ShapedArray(float32[2@x,8@y])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.typeof(optimizer.opt_state[0][1].layers[0]['kernel'][...])" + ] + }, + { + "cell_type": "markdown", + "id": "9862dddd", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "But the model is not:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "d539947e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ShapedArray(float32[2,8])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.typeof(model.layers[0].kernel[...])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e64d951-5eab-426c-ad54-fb2a67b12087", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs_nnx/guides/optimization_cookbook.md b/docs_nnx/guides/optimization_cookbook.md new file mode 100644 index 000000000..3422f1d28 --- /dev/null +++ b/docs_nnx/guides/optimization_cookbook.md @@ -0,0 +1,246 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.13.8 +--- + +# A Flax Optimization Cookbook + + +This notebook goes through some common problems in nontrivial training loops for Flax models. For clarity, all sections below will be training the following toy model. We allow extra keyword arguments so that the sharding and dtype can be determined on an instance by instance basis. + +```python +import jax +from flax import nnx +jax.config.update('jax_num_cpu_devices', 8) +import jax.numpy as jnp +import functools as ft +import matplotlib.pyplot as plt + +param_init = jax.nn.initializers.lecun_normal() + +rngs = nnx.Rngs(0) + +def make_model(rngs, **kwargs): + return nnx.Sequential( + nnx.Linear(2,8, rngs=rngs, kernel_init=ft.partial(param_init, **kwargs)), + nnx.Linear(8,8, rngs=rngs, kernel_init=ft.partial(param_init, **kwargs))) + +def loss_fn(model, x, y): + return jnp.sum((model(x) - y) ** 2) +``` + +We'll operate on the following fake data: + +```python +x = rngs.normal((32, 2)) +y = rngs.normal((32, 8)) +``` + +# Exponential Moving Average + +Neural network see increased robustness when, rather than using only the weights available at the end of training, we use an exponential moving average of the weights produced throughout training. It is easy to modify the standard Flax training loop to accomodate calculating exponential moving averages. + +```python +import optax +from jax import tree + +class Ema(nnx.Module): + def __init__(self, model, decay=0.9): + self.decay = decay + self.ema = jax.tree.map(jnp.copy, model) # Make a copy + def update(self, model): + def ema_update(ema, new_val): + return self.decay * ema + (1 - self.decay) * new_val + self.ema = tree.map(ema_update, self.ema, model) + +model = make_model(rngs) +ema = Ema(model) + +optimizer = nnx.Optimizer( + model, + tx=optax.adam(1e-3), + wrt=nnx.Param) + +@nnx.jit +def train_step(model, optimizer, ema, x, y): + loss, grads = nnx.value_and_grad(loss_fn)(model, x, y) + optimizer.update(model, grads) + ema.update(model) + return loss + +losses = [] +for _ in range(50): + loss = train_step(model, optimizer, ema, x, y) + losses.append(loss) +plt.plot(losses) +``` + +# Low Rank Adaptation + + +The pattern for adding low rank adaptation to an optimization loop is very similar to adding an exponential moving average. As before, we create a new pytree with the same structure as our model parameters, but here we store low rank additions to these parameters rather than weighted average values. + +```python +def add_rank2_lora(path, node): + if isinstance(node, nnx.Linear): + return nnx.LoRA(node.in_features, 2, node.out_features, base_module=node, rngs=rngs) + return node + +base_model = make_model(rngs) +lora_model = nnx.recursive_map(add_rank2_lora, base_model) +nnx.display(lora_model) +``` + + +To indicate that we only want to to update the low rank corrections, we add the `wrt=nnx.LoRAParam` argument to `nnx.Optimizer`. This will filter out all the variables in the gradient that are not `nnx.LoRAParam`s. The other components of the gradient will go unused, so Jax's dead code elimination passes should prevent us from computing them in the first place once the code gets compiled. + +```python +@nnx.jit +def train_step(model, optimizer, x, y): + loss, grads = nnx.value_and_grad(loss_fn)(model, x, y) + optimizer.update(model, grads) + return loss + +optimizer = nnx.Optimizer( + lora_model, + tx=optax.adam(1e-3), + wrt=nnx.LoRAParam, +) + +losses = [] +for _ in range(50): + loss = train_step(lora_model, optimizer, x, y) + losses.append(loss) +``` + +# LBFGS + + +So far, we've been using optax optimizers with the interface ``optimizer.update(grads, opt_state)``. This works for simple optimization algorithms like ADAM, but for algorithms that use a line search like LBFGS, we need to pass more parameters. Below, we can see how the call to ``optimizer.update`` is given additional parameters when using LBFGS. + +```python +def train_step(model, optimizer, x, y): + # Create state-based loss function for LBFGS + graphdef = nnx.graphdef(model) + loss_fn_state = lambda state: loss_fn(nnx.merge(graphdef, state), x, y) + + loss, grads = nnx.value_and_grad(loss_fn)(model, x, y) + optimizer.update( + model, + grads, + grad=grads, + value=loss, + value_fn=loss_fn_state) + return loss + +model = make_model(rngs) +optimizer = nnx.Optimizer( + model, + tx=optax.lbfgs(1e-3), + wrt=nnx.Param) + +losses = [] +for _ in range(50): + loss = train_step(model, optimizer, x, y) + losses.append(loss) +plt.plot(losses) +``` + +# Per-Parameter Learning Rates + +In some training regimes, you will want to optimize different parameters with different learning rates. + +In Jax, we map from each leaf to the type of parameter it is (weight or bias). We then create a dictionary giving the learning rates to use for each parameter type. Finally, we can make a compound optimizers that uses each rate appropriately. + +To do this in Flax, we can map from each leaf to the type of parameter it is (weight or bias). With this pytree of parameter types, we can make a compound optimizer that uses each rate appropriately. + +```python +model = make_model(rngs) +state = nnx.state(model, nnx.Param) +rates = {'kernel': optax.adam(1e-3), 'bias': optax.adam(1e-2)} +param_tys = nnx.map_state(lambda p, v: list(p)[-1], state) +optimizer = nnx.Optimizer(model, tx=optax.partition(rates, param_tys), wrt=nnx.Param) + +@nnx.jit +def train_step(model, optimizer, x, y): + loss, grads = nnx.value_and_grad(loss_fn)(model, x, y) + optimizer.update(model, grads) + return loss + +losses = [] +for _ in range(50): + loss = train_step(model, optimizer, x, y) + losses.append(loss) +``` + + +# Gradient Accumulation + + +Gradient accumulation in Flax is easy: just use the `optax.MultiSteps` optimizer. + +```python +model = make_model(rngs) +optimizer = nnx.Optimizer(model, tx=optax.MultiSteps(optax.adam(1e-3), every_k_schedule=3), wrt=nnx.Param) + +@nnx.jit +def train_step(model, optimizer, x, y): + loss, grads = nnx.value_and_grad(loss_fn)(model, x, y) + optimizer.update(model, grads) + return loss + +losses = [] +for _ in range(50): + loss = train_step(model, optimizer, x, y) + losses.append(loss) +plt.plot(losses) +``` + +# Sharding Optimization State Differently from Parameters + + +Say we're doing data parallelism. We want to replicate our parameters across all GPUs so we can do the forward and backward passes without communication latency. + + +But we don't need to replicate the optimizer state, as it's not invovled in SPMD computations. One copy is enough, and we can shard this copy across our mesh to reduce memory usage. This means that we need the optimizer state to be sharded differently from the parameters themselves. + + +To do this, we can pass the params initializer given the the optimizer a `sharding` argument. This will shard the optimization state the same way. But when we initialize the model parameters themselves, we won't provide a sharding, allowing for data parallelism. + +```python +from jax.sharding import PartitionSpec as P, AxisType, get_abstract_mesh, reshard +mesh = jax.make_mesh((2, 4), ("x", "y"), + axis_types=(AxisType.Explicit, AxisType.Explicit)) +jax.set_mesh(mesh) + +ghost_model = jax.eval_shape(lambda: make_model(nnx.Rngs(0), out_sharding=P('x', 'y'))) +optimizer = nnx.Optimizer(ghost_model, optax.adam(1e-3), wrt=nnx.Param) +model = make_model(rngs) + +@nnx.jit +def train_step(model, optimizer, x, y): + loss, grads = nnx.value_and_grad(loss_fn)(model, x, y) + optimizer.update(model, grads) + return loss + +losses = [] +for _ in range(50): + loss = train_step(model, optimizer, x, y) + model = reshard(model, P(None, None)) + losses.append(loss) +``` + +The optimizer state is sharded: +```python +jax.typeof(optimizer.opt_state[0][1].layers[0]['kernel'][...]) +``` + +But the model is not: +```python +jax.typeof(model.layers[0].kernel[...]) +``` diff --git a/docs_nnx/guides_advanced.rst b/docs_nnx/guides_advanced.rst index 3cbf647ab..985727ba8 100644 --- a/docs_nnx/guides_advanced.rst +++ b/docs_nnx/guides_advanced.rst @@ -8,4 +8,5 @@ Advanced Guides guides/flax_gspmd guides/performance guides/bridge_guide + guides/optimization_cookbook guides/surgery diff --git a/flax/nnx/pytreelib.py b/flax/nnx/pytreelib.py index 82edc5ee8..27edd982f 100644 --- a/flax/nnx/pytreelib.py +++ b/flax/nnx/pytreelib.py @@ -899,11 +899,12 @@ def _pytree__flatten(self): else: key_fn = None node_attributes = self._pytree__nodes - node_names: list[str] = [] + node_names: list[tp.Union[int,str]] = [] node_attrs: list[tp.Any] = [] - static_attrs: list[tuple[str, tp.Any]] = [] + static_attrs: list[tuple[tp.Union[int,str], tp.Any]] = [] for name, value in sorted(obj_items, key=key_fn): - if name in node_attributes and node_attributes[name]: + str_name = str(name) + if str_name in node_attributes and node_attributes[str_name]: node_names.append(name) node_attrs.append(value) else: @@ -924,6 +925,9 @@ def _pytree__unflatten( node_names = tuple( str(name) if isinstance(name, int) else name for name in node_names ) + static_attrs = tuple( + (str(name) if isinstance(name, int) else name, val) for name, val in static_attrs + ) for name, value in zip(node_names, node_attrs, strict=True): object.__setattr__(obj, name, value) for name, value in static_attrs: @@ -1000,4 +1004,4 @@ def _maybe_int(x): try: return int(x) except (ValueError, TypeError): - return x \ No newline at end of file + return x diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index d73900c49..b843d25c2 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -39,6 +39,14 @@ def __init__(self, a, b): self.assertEqual(jax.tree.leaves(foo), [1]) + def test_sequential_map(self): + model = nnx.Sequential(nnx.Linear(2,8, rngs=nnx.Rngs(0))) + jax.tree.map(lambda x: x + 1, model) # shouldn't error + + def test_sequential_has_leaves(self): + model = nnx.Sequential(nnx.Linear(2,8, rngs=nnx.Rngs(0))) + self.assertLen(jax.tree.leaves(model), 2) + def test_consistent_attrs(self): class Foo(nnx.Pytree): def __init__(self, a, b, c):