Skip to content

Commit df49789

Browse files
committed
Fix for inference notebook
1 parent 84f55be commit df49789

File tree

1 file changed

+58
-39
lines changed

1 file changed

+58
-39
lines changed

lsd/tutorial/notebooks/inference.ipynb

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
"metadata": {
55
"colab": {
66
"name": "inference.ipynb",
7-
"provenance": [],
8-
"collapsed_sections": []
7+
"provenance": []
98
},
109
"kernelspec": {
1110
"name": "python3",
@@ -37,8 +36,7 @@
3736
{
3837
"cell_type": "code",
3938
"metadata": {
40-
"id": "LFzbgV1YZ2n0",
41-
"cellView": "form"
39+
"id": "LFzbgV1YZ2n0"
4240
},
4341
"source": [
4442
"#@title install packages + repos\n",
@@ -58,8 +56,7 @@
5856
{
5957
"cell_type": "code",
6058
"metadata": {
61-
"id": "3UzsrNvAZ5LF",
62-
"cellView": "form"
59+
"id": "3UzsrNvAZ5LF"
6360
},
6461
"source": [
6562
"#@title import packages\n",
@@ -89,8 +86,7 @@
8986
{
9087
"cell_type": "code",
9188
"metadata": {
92-
"id": "6exTbzR9Z_uV",
93-
"cellView": "form"
89+
"id": "6exTbzR9Z_uV"
9490
},
9591
"source": [
9692
"#@title utility function to view labels\n",
@@ -126,16 +122,15 @@
126122
{
127123
"cell_type": "code",
128124
"metadata": {
129-
"id": "al-0cyZlaC57",
130-
"cellView": "form"
125+
"id": "al-0cyZlaC57"
131126
},
132127
"source": [
133128
"#@title utility function to download / save data as zarr\n",
134129
"\n",
135130
"def create_data(\n",
136-
" url, \n",
137-
" name, \n",
138-
" offset, \n",
131+
" url,\n",
132+
" name,\n",
133+
" offset,\n",
139134
" resolution,\n",
140135
" sections=None,\n",
141136
" squeeze=True):\n",
@@ -144,7 +139,7 @@
144139
"\n",
145140
" raw = in_f['volumes/raw']\n",
146141
" labels = in_f['volumes/labels/neuron_ids']\n",
147-
" \n",
142+
"\n",
148143
" container = zarr.open(name, 'a')\n",
149144
"\n",
150145
" if sections is None:\n",
@@ -164,7 +159,7 @@
164159
" for ds_name, data in [\n",
165160
" ('raw', raw_slice),\n",
166161
" ('labels', labels_slice)]:\n",
167-
" \n",
162+
"\n",
168163
" container[f'{ds_name}/{index}'] = data\n",
169164
" container[f'{ds_name}/{index}'].attrs['offset'] = offset\n",
170165
" container[f'{ds_name}/{index}'].attrs['resolution'] = resolution"
@@ -216,38 +211,45 @@
216211
{
217212
"cell_type": "code",
218213
"metadata": {
219-
"id": "bkxBq8BYchlf",
220-
"cellView": "form"
214+
"id": "bkxBq8BYchlf"
221215
},
222216
"source": [
223217
"#@title create mtlsd model\n",
224218
"\n",
225219
"class MtlsdModel(torch.nn.Module):\n",
226220
"\n",
227-
" def __init__(self):\n",
221+
" def __init__(\n",
222+
" self,\n",
223+
" in_channels,\n",
224+
" num_fmaps,\n",
225+
" fmap_inc_factor,\n",
226+
" downsample_factors,\n",
227+
" kernel_size_down,\n",
228+
" kernel_size_up,\n",
229+
" constant_upsample):\n",
230+
"\n",
228231
" super().__init__()\n",
229232
"\n",
233+
" # create unet\n",
230234
" self.unet = UNet(\n",
231-
" in_channels=1,\n",
232-
" num_fmaps=num_fmaps,\n",
233-
" fmap_inc_factor=5,\n",
234-
" downsample_factors=[\n",
235-
" [2, 2],\n",
236-
" [2, 2]],\n",
237-
" kernel_size_down=[\n",
238-
" [[3, 3], [3, 3]],\n",
239-
" [[3, 3], [3, 3]],\n",
240-
" [[3, 3], [3, 3]]],\n",
241-
" kernel_size_up=[\n",
242-
" [[3, 3], [3, 3]],\n",
243-
" [[3, 3], [3, 3]]])\n",
244-
"\n",
235+
" in_channels=in_channels,\n",
236+
" num_fmaps=num_fmaps,\n",
237+
" fmap_inc_factor=fmap_inc_factor,\n",
238+
" downsample_factors=downsample_factors,\n",
239+
" kernel_size_down=kernel_size_down,\n",
240+
" kernel_size_up=kernel_size_up,\n",
241+
" constant_upsample=constant_upsample)\n",
242+
"\n",
243+
" # create lsd and affs heads\n",
245244
" self.lsd_head = ConvPass(num_fmaps, 6, [[1, 1]], activation='Sigmoid')\n",
246245
" self.aff_head = ConvPass(num_fmaps, 2, [[1, 1]], activation='Sigmoid')\n",
247246
"\n",
248247
" def forward(self, input):\n",
249248
"\n",
249+
" # pass raw through unet\n",
250250
" z = self.unet(input)\n",
251+
"\n",
252+
" # pass output through heads\n",
251253
" lsds = self.lsd_head(z)\n",
252254
" affs = self.aff_head(z)\n",
253255
"\n",
@@ -266,7 +268,7 @@
266268
" checkpoint,\n",
267269
" raw_file,\n",
268270
" raw_dataset):\n",
269-
" \n",
271+
"\n",
270272
" raw = gp.ArrayKey('RAW')\n",
271273
" pred_lsds = gp.ArrayKey('PRED_LSDS')\n",
272274
" pred_affs = gp.ArrayKey('PRED_AFFS')\n",
@@ -287,12 +289,29 @@
287289
" {\n",
288290
" raw: gp.ArraySpec(interpolatable=True)\n",
289291
" })\n",
290-
" \n",
292+
"\n",
291293
" with gp.build(source):\n",
292294
" total_input_roi = source.spec[raw].roi\n",
293295
" total_output_roi = source.spec[raw].roi.grow(-context,-context)\n",
294296
"\n",
295-
" model = MtlsdModel()\n",
297+
" in_channels=1\n",
298+
" num_fmaps=12\n",
299+
" fmap_inc_factor=5\n",
300+
" ds_fact = [(2,2),(2,2)]\n",
301+
" num_levels = len(ds_fact) + 1\n",
302+
" ksd = [[(3,3), (3,3)]]*num_levels\n",
303+
" ksu = [[(3,3), (3,3)]]*(num_levels - 1)\n",
304+
" constant_upsample = True\n",
305+
"\n",
306+
" model = MtlsdModel(\n",
307+
" in_channels,\n",
308+
" num_fmaps,\n",
309+
" fmap_inc_factor,\n",
310+
" ds_fact,\n",
311+
" ksd,\n",
312+
" ksu,\n",
313+
" constant_upsample\n",
314+
" )\n",
296315
"\n",
297316
" # set model to eval mode\n",
298317
" model.eval()\n",
@@ -307,10 +326,10 @@
307326
" outputs = {\n",
308327
" 0: pred_lsds,\n",
309328
" 1: pred_affs})\n",
310-
" \n",
329+
"\n",
311330
" # this will scan in chunks equal to the input/output sizes of the respective arrays\n",
312331
" scan = gp.Scan(scan_request)\n",
313-
" \n",
332+
"\n",
314333
" pipeline = source\n",
315334
" pipeline += gp.Normalize(raw)\n",
316335
"\n",
@@ -357,7 +376,7 @@
357376
"cell_type": "code",
358377
"source": [
359378
"# fetch checkpoint\n",
360-
"!wget https://www.dropbox.com/s/r1u8pvji5lbanyq/model_checkpoint_50000"
379+
"!gdown 1dx5P08Cmml4N2RmQ0AU1RE7b-Tmn-LaC"
361380
],
362381
"metadata": {
363382
"id": "YCbEfvEXcBzn"
@@ -371,7 +390,7 @@
371390
"id": "AYbHmh4dTL3V"
372391
},
373392
"source": [
374-
"checkpoint = 'model_checkpoint_50000' \n",
393+
"checkpoint = 'model_checkpoint_10000'\n",
375394
"raw_file = 'testing_data.zarr'\n",
376395
"raw_dataset = 'raw/0'\n",
377396
"\n",

0 commit comments

Comments
 (0)