|
2 | 2 | "cells": [ |
3 | 3 | { |
4 | 4 | "cell_type": "code", |
5 | | - "execution_count": null, |
| 5 | + "execution_count": 1, |
6 | 6 | "metadata": {}, |
7 | 7 | "outputs": [], |
8 | 8 | "source": [ |
|
16 | 16 | }, |
17 | 17 | { |
18 | 18 | "cell_type": "code", |
19 | | - "execution_count": null, |
| 19 | + "execution_count": 2, |
20 | 20 | "metadata": {}, |
21 | 21 | "outputs": [], |
22 | 22 | "source": [ |
23 | | - "mouse_id = 4\n", |
| 23 | + "mouse_id = 5\n", |
24 | 24 | "\n", |
25 | 25 | "data_path = './data'\n", |
26 | | - "weight_path = './checkpoints/fullmodel'\n", |
| 26 | + "weight_path = './checkpoints'\n", |
27 | 27 | "np.random.seed(1)" |
28 | 28 | ] |
29 | 29 | }, |
30 | 30 | { |
31 | 31 | "cell_type": "code", |
32 | | - "execution_count": null, |
| 32 | + "execution_count": 3, |
33 | 33 | "metadata": {}, |
34 | | - "outputs": [], |
| 34 | + "outputs": [ |
| 35 | + { |
| 36 | + "name": "stdout", |
| 37 | + "output_type": "stream", |
| 38 | + "text": [ |
| 39 | + "raw image shape: (68000, 66, 264)\n", |
| 40 | + "cropped image shape: (68000, 66, 130)\n", |
| 41 | + "img: (68000, 66, 130) -2.0829253 2.1060908 float32\n" |
| 42 | + ] |
| 43 | + } |
| 44 | + ], |
35 | 45 | "source": [ |
36 | 46 | "# load images\n", |
37 | 47 | "img = data.load_images(data_path, mouse_id, file=data.img_file_name[mouse_id])" |
38 | 48 | ] |
39 | 49 | }, |
40 | 50 | { |
41 | 51 | "cell_type": "code", |
42 | | - "execution_count": null, |
| 52 | + "execution_count": 4, |
43 | 53 | "metadata": {}, |
44 | | - "outputs": [], |
| 54 | + "outputs": [ |
| 55 | + { |
| 56 | + "name": "stdout", |
| 57 | + "output_type": "stream", |
| 58 | + "text": [ |
| 59 | + "\n", |
| 60 | + "loading activities from ./data/FX20_nat60k_2023_09_29.npz\n" |
| 61 | + ] |
| 62 | + } |
| 63 | + ], |
45 | 64 | "source": [ |
46 | 65 | "# load neurons\n", |
47 | 66 | "fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])\n", |
|
51 | 70 | }, |
52 | 71 | { |
53 | 72 | "cell_type": "code", |
54 | | - "execution_count": null, |
| 73 | + "execution_count": 5, |
55 | 74 | "metadata": {}, |
56 | | - "outputs": [], |
| 75 | + "outputs": [ |
| 76 | + { |
| 77 | + "name": "stdout", |
| 78 | + "output_type": "stream", |
| 79 | + "text": [ |
| 80 | + "\n", |
| 81 | + "splitting training and validation set...\n", |
| 82 | + "itrain: (43081,)\n", |
| 83 | + "ival: (4787,)\n" |
| 84 | + ] |
| 85 | + } |
| 86 | + ], |
57 | 87 | "source": [ |
58 | 88 | "# split train and validation set\n", |
59 | 89 | "itrain, ival = data.split_train_val(istim_train, train_frac=0.9)" |
60 | 90 | ] |
61 | 91 | }, |
62 | 92 | { |
63 | 93 | "cell_type": "code", |
64 | | - "execution_count": null, |
| 94 | + "execution_count": 6, |
65 | 95 | "metadata": {}, |
66 | | - "outputs": [], |
| 96 | + "outputs": [ |
| 97 | + { |
| 98 | + "name": "stdout", |
| 99 | + "output_type": "stream", |
| 100 | + "text": [ |
| 101 | + "\n", |
| 102 | + "normalizing neural data...\n", |
| 103 | + "finished\n" |
| 104 | + ] |
| 105 | + } |
| 106 | + ], |
67 | 107 | "source": [ |
68 | 108 | "# normalize data\n", |
69 | 109 | "spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)" |
70 | 110 | ] |
71 | 111 | }, |
72 | 112 | { |
73 | 113 | "cell_type": "code", |
74 | | - "execution_count": null, |
| 114 | + "execution_count": 7, |
75 | 115 | "metadata": {}, |
76 | | - "outputs": [], |
| 116 | + "outputs": [ |
| 117 | + { |
| 118 | + "name": "stdout", |
| 119 | + "output_type": "stream", |
| 120 | + "text": [ |
| 121 | + "spks_train: torch.Size([43081, 2746]) tensor(-1.4092e-15) tensor(48.7427)\n", |
| 122 | + "spks_val: torch.Size([4787, 2746]) tensor(-6.8745e-16) tensor(44.7361)\n", |
| 123 | + "img_train: torch.Size([43081, 1, 66, 130]) tensor(-2.0829, device='cuda:0') tensor(2.1061, device='cuda:0')\n", |
| 124 | + "img_val: torch.Size([4787, 1, 66, 130]) tensor(-2.0829, device='cuda:0') tensor(2.1061, device='cuda:0')\n", |
| 125 | + "img_test: torch.Size([500, 1, 66, 130]) tensor(-2.0829, device='cuda:0') tensor(2.1061, device='cuda:0')\n" |
| 126 | + ] |
| 127 | + } |
| 128 | + ], |
77 | 129 | "source": [ |
78 | 130 | "ineur = np.arange(0, n_neurons) #np.arange(0, n_neurons, 5)\n", |
79 | 131 | "spks_train = torch.from_numpy(spks[itrain][:,ineur])\n", |
|
95 | 147 | }, |
96 | 148 | { |
97 | 149 | "cell_type": "code", |
98 | | - "execution_count": null, |
| 150 | + "execution_count": 8, |
99 | 151 | "metadata": {}, |
100 | | - "outputs": [], |
| 152 | + "outputs": [ |
| 153 | + { |
| 154 | + "name": "stdout", |
| 155 | + "output_type": "stream", |
| 156 | + "text": [ |
| 157 | + "core shape: torch.Size([1, 320, 33, 65])\n", |
| 158 | + "input shape of readout: (320, 33, 65)\n", |
| 159 | + "model name: FX20_092923_2layer_16_320_clamp_norm_depthsep_pool_xrange_176.pt\n" |
| 160 | + ] |
| 161 | + } |
| 162 | + ], |
101 | 163 | "source": [ |
102 | 164 | "# build model\n", |
103 | 165 | "from minimodel import model_builder\n", |
|
113 | 175 | }, |
114 | 176 | { |
115 | 177 | "cell_type": "code", |
116 | | - "execution_count": null, |
| 178 | + "execution_count": 9, |
117 | 179 | "metadata": {}, |
118 | | - "outputs": [], |
| 180 | + "outputs": [ |
| 181 | + { |
| 182 | + "name": "stdout", |
| 183 | + "output_type": "stream", |
| 184 | + "text": [ |
| 185 | + "loaded model ./checkpoints/fullmodel/FX20_092923_2layer_16_320_clamp_norm_depthsep_pool_xrange_176.pt\n" |
| 186 | + ] |
| 187 | + } |
| 188 | + ], |
119 | 189 | "source": [ |
120 | 190 | "# train model\n", |
121 | 191 | "from minimodel import model_trainer\n", |
|
129 | 199 | }, |
130 | 200 | { |
131 | 201 | "cell_type": "code", |
132 | | - "execution_count": null, |
| 202 | + "execution_count": 10, |
133 | 203 | "metadata": {}, |
134 | | - "outputs": [], |
| 204 | + "outputs": [ |
| 205 | + { |
| 206 | + "name": "stdout", |
| 207 | + "output_type": "stream", |
| 208 | + "text": [ |
| 209 | + "test_pred: (500, 2746) 0.0017536283 8.742443\n" |
| 210 | + ] |
| 211 | + } |
| 212 | + ], |
135 | 213 | "source": [ |
136 | 214 | "# test model\n", |
137 | 215 | "test_pred = model_trainer.test_epoch(model, img_test)\n", |
|
140 | 218 | }, |
141 | 219 | { |
142 | 220 | "cell_type": "code", |
143 | | - "execution_count": null, |
| 221 | + "execution_count": 11, |
144 | 222 | "metadata": {}, |
145 | | - "outputs": [], |
| 223 | + "outputs": [ |
| 224 | + { |
| 225 | + "name": "stdout", |
| 226 | + "output_type": "stream", |
| 227 | + "text": [ |
| 228 | + "filtering neurons with FEV > 0.15\n", |
| 229 | + "valid neurons: 1239 / 2746\n", |
| 230 | + "FEVE (test): 0.7267250418663025\n" |
| 231 | + ] |
| 232 | + } |
| 233 | + ], |
146 | 234 | "source": [ |
147 | 235 | "from minimodel import metrics\n", |
148 | 236 | "test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)\n", |
|
0 commit comments