|
4 | 4 | "metadata": { |
5 | 5 | "colab": { |
6 | 6 | "name": "inference.ipynb", |
7 | | - "provenance": [], |
8 | | - "collapsed_sections": [] |
| 7 | + "provenance": [] |
9 | 8 | }, |
10 | 9 | "kernelspec": { |
11 | 10 | "name": "python3", |
|
37 | 36 | { |
38 | 37 | "cell_type": "code", |
39 | 38 | "metadata": { |
40 | | - "id": "LFzbgV1YZ2n0", |
41 | | - "cellView": "form" |
| 39 | + "id": "LFzbgV1YZ2n0" |
42 | 40 | }, |
43 | 41 | "source": [ |
44 | 42 | "#@title install packages + repos\n", |
|
58 | 56 | { |
59 | 57 | "cell_type": "code", |
60 | 58 | "metadata": { |
61 | | - "id": "3UzsrNvAZ5LF", |
62 | | - "cellView": "form" |
| 59 | + "id": "3UzsrNvAZ5LF" |
63 | 60 | }, |
64 | 61 | "source": [ |
65 | 62 | "#@title import packages\n", |
|
89 | 86 | { |
90 | 87 | "cell_type": "code", |
91 | 88 | "metadata": { |
92 | | - "id": "6exTbzR9Z_uV", |
93 | | - "cellView": "form" |
| 89 | + "id": "6exTbzR9Z_uV" |
94 | 90 | }, |
95 | 91 | "source": [ |
96 | 92 | "#@title utility function to view labels\n", |
|
126 | 122 | { |
127 | 123 | "cell_type": "code", |
128 | 124 | "metadata": { |
129 | | - "id": "al-0cyZlaC57", |
130 | | - "cellView": "form" |
| 125 | + "id": "al-0cyZlaC57" |
131 | 126 | }, |
132 | 127 | "source": [ |
133 | 128 | "#@title utility function to download / save data as zarr\n", |
134 | 129 | "\n", |
135 | 130 | "def create_data(\n", |
136 | | - " url, \n", |
137 | | - " name, \n", |
138 | | - " offset, \n", |
| 131 | + " url,\n", |
| 132 | + " name,\n", |
| 133 | + " offset,\n", |
139 | 134 | " resolution,\n", |
140 | 135 | " sections=None,\n", |
141 | 136 | " squeeze=True):\n", |
|
144 | 139 | "\n", |
145 | 140 | " raw = in_f['volumes/raw']\n", |
146 | 141 | " labels = in_f['volumes/labels/neuron_ids']\n", |
147 | | - " \n", |
| 142 | + "\n", |
148 | 143 | " container = zarr.open(name, 'a')\n", |
149 | 144 | "\n", |
150 | 145 | " if sections is None:\n", |
|
164 | 159 | " for ds_name, data in [\n", |
165 | 160 | " ('raw', raw_slice),\n", |
166 | 161 | " ('labels', labels_slice)]:\n", |
167 | | - " \n", |
| 162 | + "\n", |
168 | 163 | " container[f'{ds_name}/{index}'] = data\n", |
169 | 164 | " container[f'{ds_name}/{index}'].attrs['offset'] = offset\n", |
170 | 165 | " container[f'{ds_name}/{index}'].attrs['resolution'] = resolution" |
|
216 | 211 | { |
217 | 212 | "cell_type": "code", |
218 | 213 | "metadata": { |
219 | | - "id": "bkxBq8BYchlf", |
220 | | - "cellView": "form" |
| 214 | + "id": "bkxBq8BYchlf" |
221 | 215 | }, |
222 | 216 | "source": [ |
223 | 217 | "#@title create mtlsd model\n", |
224 | 218 | "\n", |
225 | 219 | "class MtlsdModel(torch.nn.Module):\n", |
226 | 220 | "\n", |
227 | | - " def __init__(self):\n", |
| 221 | + " def __init__(\n", |
| 222 | + " self,\n", |
| 223 | + " in_channels,\n", |
| 224 | + " num_fmaps,\n", |
| 225 | + " fmap_inc_factor,\n", |
| 226 | + " downsample_factors,\n", |
| 227 | + " kernel_size_down,\n", |
| 228 | + " kernel_size_up,\n", |
| 229 | + " constant_upsample):\n", |
| 230 | + "\n", |
228 | 231 | " super().__init__()\n", |
229 | 232 | "\n", |
| 233 | + " # create unet\n", |
230 | 234 | " self.unet = UNet(\n", |
231 | | - " in_channels=1,\n", |
232 | | - " num_fmaps=num_fmaps,\n", |
233 | | - " fmap_inc_factor=5,\n", |
234 | | - " downsample_factors=[\n", |
235 | | - " [2, 2],\n", |
236 | | - " [2, 2]],\n", |
237 | | - " kernel_size_down=[\n", |
238 | | - " [[3, 3], [3, 3]],\n", |
239 | | - " [[3, 3], [3, 3]],\n", |
240 | | - " [[3, 3], [3, 3]]],\n", |
241 | | - " kernel_size_up=[\n", |
242 | | - " [[3, 3], [3, 3]],\n", |
243 | | - " [[3, 3], [3, 3]]])\n", |
244 | | - "\n", |
| 235 | + " in_channels=in_channels,\n", |
| 236 | + " num_fmaps=num_fmaps,\n", |
| 237 | + " fmap_inc_factor=fmap_inc_factor,\n", |
| 238 | + " downsample_factors=downsample_factors,\n", |
| 239 | + " kernel_size_down=kernel_size_down,\n", |
| 240 | + " kernel_size_up=kernel_size_up,\n", |
| 241 | + " constant_upsample=constant_upsample)\n", |
| 242 | + "\n", |
| 243 | + " # create lsd and affs heads\n", |
245 | 244 | " self.lsd_head = ConvPass(num_fmaps, 6, [[1, 1]], activation='Sigmoid')\n", |
246 | 245 | " self.aff_head = ConvPass(num_fmaps, 2, [[1, 1]], activation='Sigmoid')\n", |
247 | 246 | "\n", |
248 | 247 | " def forward(self, input):\n", |
249 | 248 | "\n", |
| 249 | + " # pass raw through unet\n", |
250 | 250 | " z = self.unet(input)\n", |
| 251 | + "\n", |
| 252 | + " # pass output through heads\n", |
251 | 253 | " lsds = self.lsd_head(z)\n", |
252 | 254 | " affs = self.aff_head(z)\n", |
253 | 255 | "\n", |
|
266 | 268 | " checkpoint,\n", |
267 | 269 | " raw_file,\n", |
268 | 270 | " raw_dataset):\n", |
269 | | - " \n", |
| 271 | + "\n", |
270 | 272 | " raw = gp.ArrayKey('RAW')\n", |
271 | 273 | " pred_lsds = gp.ArrayKey('PRED_LSDS')\n", |
272 | 274 | " pred_affs = gp.ArrayKey('PRED_AFFS')\n", |
|
287 | 289 | " {\n", |
288 | 290 | " raw: gp.ArraySpec(interpolatable=True)\n", |
289 | 291 | " })\n", |
290 | | - " \n", |
| 292 | + "\n", |
291 | 293 | " with gp.build(source):\n", |
292 | 294 | " total_input_roi = source.spec[raw].roi\n", |
293 | 295 | " total_output_roi = source.spec[raw].roi.grow(-context,-context)\n", |
294 | 296 | "\n", |
295 | | - " model = MtlsdModel()\n", |
| 297 | + " in_channels=1\n", |
| 298 | + " num_fmaps=12\n", |
| 299 | + " fmap_inc_factor=5\n", |
| 300 | + " ds_fact = [(2,2),(2,2)]\n", |
| 301 | + " num_levels = len(ds_fact) + 1\n", |
| 302 | + " ksd = [[(3,3), (3,3)]]*num_levels\n", |
| 303 | + " ksu = [[(3,3), (3,3)]]*(num_levels - 1)\n", |
| 304 | + " constant_upsample = True\n", |
| 305 | + "\n", |
| 306 | + " model = MtlsdModel(\n", |
| 307 | + " in_channels,\n", |
| 308 | + " num_fmaps,\n", |
| 309 | + " fmap_inc_factor,\n", |
| 310 | + " ds_fact,\n", |
| 311 | + " ksd,\n", |
| 312 | + " ksu,\n", |
| 313 | + " constant_upsample\n", |
| 314 | + " )\n", |
296 | 315 | "\n", |
297 | 316 | " # set model to eval mode\n", |
298 | 317 | " model.eval()\n", |
|
307 | 326 | " outputs = {\n", |
308 | 327 | " 0: pred_lsds,\n", |
309 | 328 | " 1: pred_affs})\n", |
310 | | - " \n", |
| 329 | + "\n", |
311 | 330 | " # this will scan in chunks equal to the input/output sizes of the respective arrays\n", |
312 | 331 | " scan = gp.Scan(scan_request)\n", |
313 | | - " \n", |
| 332 | + "\n", |
314 | 333 | " pipeline = source\n", |
315 | 334 | " pipeline += gp.Normalize(raw)\n", |
316 | 335 | "\n", |
|
357 | 376 | "cell_type": "code", |
358 | 377 | "source": [ |
359 | 378 | "# fetch checkpoint\n", |
360 | | - "!wget https://www.dropbox.com/s/r1u8pvji5lbanyq/model_checkpoint_50000" |
| 379 | + "!gdown 1dx5P08Cmml4N2RmQ0AU1RE7b-Tmn-LaC" |
361 | 380 | ], |
362 | 381 | "metadata": { |
363 | 382 | "id": "YCbEfvEXcBzn" |
|
371 | 390 | "id": "AYbHmh4dTL3V" |
372 | 391 | }, |
373 | 392 | "source": [ |
374 | | - "checkpoint = 'model_checkpoint_50000' \n", |
| 393 | + "checkpoint = 'model_checkpoint_10000'\n", |
375 | 394 | "raw_file = 'testing_data.zarr'\n", |
376 | 395 | "raw_dataset = 'raw/0'\n", |
377 | 396 | "\n", |
|
0 commit comments