Skip to content

Commit e5a4d28

Browse files
committed
fix path
1 parent f8fcdb9 commit e5a4d28

File tree

5 files changed

+118
-30
lines changed

5 files changed

+118
-30
lines changed

notebooks/fullmodel_monkey.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
1313
"\n",
1414
"monkey_data_path = './data'\n",
15-
"weight_path = './checkpoints/fullmodel'"
15+
"weight_path = './checkpoints'"
1616
]
1717
},
1818
{

notebooks/fullmodel_mouse.ipynb

Lines changed: 110 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": null,
5+
"execution_count": 1,
66
"metadata": {},
77
"outputs": [],
88
"source": [
@@ -16,32 +16,51 @@
1616
},
1717
{
1818
"cell_type": "code",
19-
"execution_count": null,
19+
"execution_count": 2,
2020
"metadata": {},
2121
"outputs": [],
2222
"source": [
23-
"mouse_id = 4\n",
23+
"mouse_id = 5\n",
2424
"\n",
2525
"data_path = './data'\n",
26-
"weight_path = './checkpoints/fullmodel'\n",
26+
"weight_path = './checkpoints'\n",
2727
"np.random.seed(1)"
2828
]
2929
},
3030
{
3131
"cell_type": "code",
32-
"execution_count": null,
32+
"execution_count": 3,
3333
"metadata": {},
34-
"outputs": [],
34+
"outputs": [
35+
{
36+
"name": "stdout",
37+
"output_type": "stream",
38+
"text": [
39+
"raw image shape: (68000, 66, 264)\n",
40+
"cropped image shape: (68000, 66, 130)\n",
41+
"img: (68000, 66, 130) -2.0829253 2.1060908 float32\n"
42+
]
43+
}
44+
],
3545
"source": [
3646
"# load images\n",
3747
"img = data.load_images(data_path, mouse_id, file=data.img_file_name[mouse_id])"
3848
]
3949
},
4050
{
4151
"cell_type": "code",
42-
"execution_count": null,
52+
"execution_count": 4,
4353
"metadata": {},
44-
"outputs": [],
54+
"outputs": [
55+
{
56+
"name": "stdout",
57+
"output_type": "stream",
58+
"text": [
59+
"\n",
60+
"loading activities from ./data/FX20_nat60k_2023_09_29.npz\n"
61+
]
62+
}
63+
],
4564
"source": [
4665
"# load neurons\n",
4766
"fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])\n",
@@ -51,29 +70,62 @@
5170
},
5271
{
5372
"cell_type": "code",
54-
"execution_count": null,
73+
"execution_count": 5,
5574
"metadata": {},
56-
"outputs": [],
75+
"outputs": [
76+
{
77+
"name": "stdout",
78+
"output_type": "stream",
79+
"text": [
80+
"\n",
81+
"splitting training and validation set...\n",
82+
"itrain: (43081,)\n",
83+
"ival: (4787,)\n"
84+
]
85+
}
86+
],
5787
"source": [
5888
"# split train and validation set\n",
5989
"itrain, ival = data.split_train_val(istim_train, train_frac=0.9)"
6090
]
6191
},
6292
{
6393
"cell_type": "code",
64-
"execution_count": null,
94+
"execution_count": 6,
6595
"metadata": {},
66-
"outputs": [],
96+
"outputs": [
97+
{
98+
"name": "stdout",
99+
"output_type": "stream",
100+
"text": [
101+
"\n",
102+
"normalizing neural data...\n",
103+
"finished\n"
104+
]
105+
}
106+
],
67107
"source": [
68108
"# normalize data\n",
69109
"spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)"
70110
]
71111
},
72112
{
73113
"cell_type": "code",
74-
"execution_count": null,
114+
"execution_count": 7,
75115
"metadata": {},
76-
"outputs": [],
116+
"outputs": [
117+
{
118+
"name": "stdout",
119+
"output_type": "stream",
120+
"text": [
121+
"spks_train: torch.Size([43081, 2746]) tensor(-1.4092e-15) tensor(48.7427)\n",
122+
"spks_val: torch.Size([4787, 2746]) tensor(-6.8745e-16) tensor(44.7361)\n",
123+
"img_train: torch.Size([43081, 1, 66, 130]) tensor(-2.0829, device='cuda:0') tensor(2.1061, device='cuda:0')\n",
124+
"img_val: torch.Size([4787, 1, 66, 130]) tensor(-2.0829, device='cuda:0') tensor(2.1061, device='cuda:0')\n",
125+
"img_test: torch.Size([500, 1, 66, 130]) tensor(-2.0829, device='cuda:0') tensor(2.1061, device='cuda:0')\n"
126+
]
127+
}
128+
],
77129
"source": [
78130
"ineur = np.arange(0, n_neurons) #np.arange(0, n_neurons, 5)\n",
79131
"spks_train = torch.from_numpy(spks[itrain][:,ineur])\n",
@@ -95,9 +147,19 @@
95147
},
96148
{
97149
"cell_type": "code",
98-
"execution_count": null,
150+
"execution_count": 8,
99151
"metadata": {},
100-
"outputs": [],
152+
"outputs": [
153+
{
154+
"name": "stdout",
155+
"output_type": "stream",
156+
"text": [
157+
"core shape: torch.Size([1, 320, 33, 65])\n",
158+
"input shape of readout: (320, 33, 65)\n",
159+
"model name: FX20_092923_2layer_16_320_clamp_norm_depthsep_pool_xrange_176.pt\n"
160+
]
161+
}
162+
],
101163
"source": [
102164
"# build model\n",
103165
"from minimodel import model_builder\n",
@@ -113,9 +175,17 @@
113175
},
114176
{
115177
"cell_type": "code",
116-
"execution_count": null,
178+
"execution_count": 9,
117179
"metadata": {},
118-
"outputs": [],
180+
"outputs": [
181+
{
182+
"name": "stdout",
183+
"output_type": "stream",
184+
"text": [
185+
"loaded model ./checkpoints/fullmodel/FX20_092923_2layer_16_320_clamp_norm_depthsep_pool_xrange_176.pt\n"
186+
]
187+
}
188+
],
119189
"source": [
120190
"# train model\n",
121191
"from minimodel import model_trainer\n",
@@ -129,9 +199,17 @@
129199
},
130200
{
131201
"cell_type": "code",
132-
"execution_count": null,
202+
"execution_count": 10,
133203
"metadata": {},
134-
"outputs": [],
204+
"outputs": [
205+
{
206+
"name": "stdout",
207+
"output_type": "stream",
208+
"text": [
209+
"test_pred: (500, 2746) 0.0017536283 8.742443\n"
210+
]
211+
}
212+
],
135213
"source": [
136214
"# test model\n",
137215
"test_pred = model_trainer.test_epoch(model, img_test)\n",
@@ -140,9 +218,19 @@
140218
},
141219
{
142220
"cell_type": "code",
143-
"execution_count": null,
221+
"execution_count": 11,
144222
"metadata": {},
145-
"outputs": [],
223+
"outputs": [
224+
{
225+
"name": "stdout",
226+
"output_type": "stream",
227+
"text": [
228+
"filtering neurons with FEV > 0.15\n",
229+
"valid neurons: 1239 / 2746\n",
230+
"FEVE (test): 0.7267250418663025\n"
231+
]
232+
}
233+
],
146234
"source": [
147235
"from minimodel import metrics\n",
148236
"test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)\n",

notebooks/minimodel_monkey.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
"source": [
110110
"if not os.path.exists(model_path):\n",
111111
" # initialize model conv1\n",
112-
" pretrained_model_path = os.path.join(weight_path, 'fullmodel', 'monkeyV1_2019_2layer_16_320_clamp_norm_depthsep_pool.pt')\n",
112+
" pretrained_model_path = os.path.join(weight_path, 'monkeyV1_2019_2layer_16_320_clamp_norm_depthsep_pool.pt')\n",
113113
" pretrained_state_dict = torch.load(pretrained_model_path, map_location=device)\n",
114114
" model.core.features.layer0.conv.weight.data = pretrained_state_dict['core.features.layer0.conv.weight']\n",
115115
" # set the weight fix\n",

notebooks/minimodel_mouse.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
"model, in_channels = model_builder.build_model(NN=1, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, Wc_coef=wc_coef)\n",
110110
"model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], ineuron=ineur[0], n_layers=nlayers, in_channels=in_channels, seed=seed,hs_readout=hs_readout)\n",
111111
"\n",
112-
"model_path = os.path.join(weight_path, 'minimodel', model_name)\n",
112+
"model_path = os.path.join(weight_path, model_name)\n",
113113
"model = model.to(device)"
114114
]
115115
},
@@ -122,8 +122,8 @@
122122
"# train model\n",
123123
"from minimodel import model_trainer\n",
124124
"if not os.path.exists(model_path):\n",
125-
" if mouse_id == 5: pretrained_model_path = os.path.join(weight_path, 'fullmodel', f'{data.mouse_names[mouse_id]}_{data.exp_date[mouse_id]}_2layer_16_320_clamp_norm_depthsep_pool_xrange_176.pt')\n",
126-
" else: pretrained_model_path = os.path.join(weight_path, 'fullmodel', f'{data.mouse_names[mouse_id]}_{data.exp_date[mouse_id]}_2layer_16_320_clamp_norm_depthsep_pool.pt')\n",
125+
" if mouse_id == 5: pretrained_model_path = os.path.join(weight_path, f'{data.mouse_names[mouse_id]}_{data.exp_date[mouse_id]}_2layer_16_320_clamp_norm_depthsep_pool_xrange_176.pt')\n",
126+
" else: pretrained_model_path = os.path.join(weight_path, f'{data.mouse_names[mouse_id]}_{data.exp_date[mouse_id]}_2layer_16_320_clamp_norm_depthsep_pool.pt')\n",
127127
" print('pretrained_model_path: ', pretrained_model_path)\n",
128128
" pretrained_state_dict = torch.load(pretrained_model_path, map_location=device)\n",
129129
" # initialize conv1 with the fullmodel weights\n",
@@ -205,7 +205,7 @@
205205
"nconv2 = 320\n",
206206
"fullmodel, in_channels = model_builder.build_model(NN=n_neurons, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2)\n",
207207
"model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels)\n",
208-
"model_path = os.path.join(weight_path, 'fullmodel', model_name)\n",
208+
"model_path = os.path.join(weight_path, model_name)\n",
209209
"\n",
210210
"fullmodel.load_state_dict(torch.load(model_path))\n",
211211
"print('loaded model', model_path)\n",

notebooks/mouse_pipeline.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@
172172
"fullmodel, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2)\n",
173173
"model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels)\n",
174174
"\n",
175-
"model_path = os.path.join(weight_path, 'fullmodel', model_name)\n",
175+
"model_path = os.path.join(weight_path, model_name)\n",
176176
"fullmodel = fullmodel.to(device)"
177177
]
178178
},
@@ -268,7 +268,7 @@
268268
"model, in_channels = model_builder.build_model(NN=1, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, Wc_coef=wc_coef)\n",
269269
"model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], ineuron=ineur[0], n_layers=nlayers, in_channels=in_channels, seed=seed,hs_readout=hs_readout)\n",
270270
"\n",
271-
"model_path = os.path.join(weight_path, 'minimodel', model_name)\n",
271+
"model_path = os.path.join(weight_path, model_name)\n",
272272
"model = model.to(device)"
273273
]
274274
},

0 commit comments

Comments
 (0)