diff --git a/disentangled_rnns/notebooks/train_single_gru.ipynb b/disentangled_rnns/notebooks/train_single_gru.ipynb new file mode 100644 index 0000000..bab94ad --- /dev/null +++ b/disentangled_rnns/notebooks/train_single_gru.ipynb @@ -0,0 +1,192 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZST5GqoRtfZz" + }, + "outputs": [], + "source": [ + "# Install disentangled_rnns repo from github\n", + "!git clone https://github.com/google-deepmind/disentangled_rnns\n", + "%cd disentangled_rnns\n", + "!pip install .\n", + "!pip install -r requirements.txt\n", + "%cd ..\n", + "\n", + "\n", + "import optax\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib as mpl\n", + "import haiku as hk\n", + "\n", + "from disentangled_rnns.library import rnn_utils\n", + "from disentangled_rnns.library import get_datasets" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ApA1YfVGz9Uq" + }, + "source": [ + "# Define a dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MIhLKbgHPYmQ" + }, + "outputs": [], + "source": [ + "dataset = get_datasets.get_q_learning_dataset(n_sessions=500, n_trials=200)\n", + "dataset_train, dataset_eval = rnn_utils.split_dataset(dataset, eval_every_n=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ONzEfURn0DU4" + }, + "source": [ + "# Define and train RNN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "h5lNm21PRJti" + }, + "outputs": [], + "source": [ + "# Define the architecture of the network we'd like to train\n", + "n_hidden = 16\n", + "output_size = 2\n", + "\n", + "def make_network():\n", + " model = hk.DeepRNN(\n", + " [hk.GRU(n_hidden), hk.Linear(output_size=output_size)]\n", + " )\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OULn6VOf0l-R" + }, + "outputs": [], + "source": [ + "# INITIALIZE THE NETWORK\n", + "# Running rnn_utils.train_network with n_steps=0 does no training but sets up the\n", + "# parameters and optimizer state.\n", + "optimizer = optax.adam(learning_rate=1e-3)\n", + "\n", + "params, opt_state, losses = rnn_utils.train_network(\n", + " make_network = make_network,\n", + " training_dataset=dataset_train,\n", + " validation_dataset=dataset_eval,\n", + " opt = optimizer,\n", + " loss=\"categorical\",\n", + " n_steps=0)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JwFLIG_U1Eli" + }, + "outputs": [], + "source": [ + "# TRAIN THE NETWORK\n", + "# Running this cell repeatedly continues to train the same network.\n", + "# The cell below gives insight into what's going on in your network.\n", + "# If you'd like to reinitialize the network and start over, re-run the above cell\n", + "\n", + "n_steps = 1000\n", + "optimizer = optax.adam(learning_rate=1e-3)\n", + "\n", + "params, opt_state, losses = rnn_utils.train_network(\n", + " make_network = make_network,\n", + " training_dataset=dataset_train,\n", + " validation_dataset=dataset_eval,\n", + " loss=\"categorical\",\n", + " params=params,\n", + " opt_state=opt_state,\n", + " opt = optimizer,\n", + " loss_param = 1,\n", + " n_steps=n_steps,\n", + " do_plot = True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oiPRjxjQSFLH" + }, + "outputs": [], + "source": [ + "# Run forward pass on the unseen data\n", + "xs_eval, ys_eval = dataset_eval.get_all()\n", + "network_output, network_states = rnn_utils.eval_network(make_network, params, xs_eval)\n", + "\n", + "# Compute normalized likelihood\n", + "score = rnn_utils.normalized_likelihood(ys_eval, network_output)\n", + "print(f'Normalized Likelihood: {100*score:.1f}%')\n", + "\n", + "# Plot network activations on an example session\n", + "example_session = 0\n", + "plt.plot(network_states[:,example_session,:])\n", + "plt.xlabel('Trial Number')\n", + "plt.ylabel('Network Activations')" + ] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "//learning/deepmind/dm_python:dm_notebook3", + "kind": "private" + }, + "private_outputs": true, + "provenance": [ + { + "file_id": "1tbH1PMKB0rz4ajkRQlzBV7iYCll5RlhS", + "timestamp": 1746631254097 + }, + { + "file_id": "/piper/depot/google3/learning/deepmind/research/neuroexp/disrnn/notebooks/train_single_disrnn.ipynb?workspaceId=kevinjmiller:disentangled_rnns::citc", + "timestamp": 1746630089612 + }, + { + "file_id": "1b5VOqHaVDOJ3fAW2E853NBQbSu2Yi-CP", + "timestamp": 1727798409618 + }, + { + "file_id": "1xgFbsQ34Of-WBTEQM_Hf7Di7N9YpRmdR", + "timestamp": 1726760254895 + }, + { + "file_id": "1IuwwEfCic7w3NsyVoVPtZSQCzrvTgh_X", + "timestamp": 1696507812638 + } + ] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}