|
31 | 31 | "`[Release Note (March 2025)]:` We are excited to announce the new MAISI Version `'maisi-rflow'`. Compared with the previous version `'maisi-ddpm'`, it accelerated latent diffusion model inference by 33x. Please see the detailed difference in the following section."
|
32 | 32 | ]
|
33 | 33 | },
|
| 34 | + { |
| 35 | + "cell_type": "markdown", |
| 36 | + "id": "aa51792d", |
| 37 | + "metadata": {}, |
| 38 | + "source": [ |
| 39 | + "## Set up the MAISI version\n", |
| 40 | + "\n", |
| 41 | + "Choose between `'maisi-ddpm'` and `'maisi-rflow'`. The differences are:\n", |
| 42 | + "- The maisi version `'maisi-ddpm'` uses basic noise scheduler DDPM. `'maisi-rflow'` uses Rectified Flow scheduler, can be 33 times faster during inference.\n", |
| 43 | + "- The maisi version `'maisi-ddpm'` requires training images to be labeled with body region (`\"top_region_index\"` and `\"bottom_region_index\"`), while `'maisi-rflow'` does not have such requirement. In other words, it is easier to prepare training data for `'maisi-rflow'`.\n", |
| 44 | + "- For the released model weights, `'maisi-rflow'` can generate images with better quality for head region and small output volumes, and comparable quality for other cases compared with `'maisi-ddpm'`." |
| 45 | + ] |
| 46 | + }, |
| 47 | + { |
| 48 | + "cell_type": "code", |
| 49 | + "execution_count": 1, |
| 50 | + "id": "828b9ece-7759-40e8-ac4c-6467c3399701", |
| 51 | + "metadata": {}, |
| 52 | + "outputs": [], |
| 53 | + "source": [ |
| 54 | + "maisi_version = \"maisi-ddpm\"\n", |
| 55 | + "assert maisi_version in [\"maisi-ddpm\", \"maisi-rflow\"]" |
| 56 | + ] |
| 57 | + }, |
34 | 58 | {
|
35 | 59 | "cell_type": "markdown",
|
36 | 60 | "id": "c9ecfb90",
|
|
41 | 65 | },
|
42 | 66 | {
|
43 | 67 | "cell_type": "code",
|
44 |
| - "execution_count": 1, |
| 68 | + "execution_count": 2, |
45 | 69 | "id": "58cbde9b",
|
46 | 70 | "metadata": {},
|
47 | 71 | "outputs": [],
|
|
59 | 83 | },
|
60 | 84 | {
|
61 | 85 | "cell_type": "code",
|
62 |
| - "execution_count": 13, |
| 86 | + "execution_count": 3, |
63 | 87 | "id": "e3bf0346",
|
64 | 88 | "metadata": {},
|
65 | 89 | "outputs": [
|
|
136 | 160 | },
|
137 | 161 | {
|
138 | 162 | "cell_type": "code",
|
139 |
| - "execution_count": 3, |
140 |
| - "id": "828b9ece-7759-40e8-ac4c-6467c3399701", |
| 163 | + "execution_count": 4, |
| 164 | + "id": "31684f74", |
141 | 165 | "metadata": {},
|
142 | 166 | "outputs": [],
|
143 | 167 | "source": [
|
|
159 | 183 | },
|
160 | 184 | {
|
161 | 185 | "cell_type": "code",
|
162 |
| - "execution_count": 4, |
| 186 | + "execution_count": 5, |
163 | 187 | "id": "fc32a7fe",
|
164 | 188 | "metadata": {},
|
165 | 189 | "outputs": [],
|
|
181 | 205 | },
|
182 | 206 | {
|
183 | 207 | "cell_type": "code",
|
184 |
| - "execution_count": 5, |
| 208 | + "execution_count": 6, |
185 | 209 | "id": "1b199078",
|
186 | 210 | "metadata": {},
|
187 | 211 | "outputs": [
|
188 | 212 | {
|
189 | 213 | "name": "stderr",
|
190 | 214 | "output_type": "stream",
|
191 | 215 | "text": [
|
192 |
| - "[2025-03-11 21:46:47.184][ INFO](notebook) - Generated simulated images.\n" |
| 216 | + "[2025-03-11 22:05:02.952][ INFO](notebook) - Generated simulated images.\n" |
193 | 217 | ]
|
194 | 218 | }
|
195 | 219 | ],
|
|
228 | 252 | },
|
229 | 253 | {
|
230 | 254 | "cell_type": "code",
|
231 |
| - "execution_count": 6, |
| 255 | + "execution_count": 7, |
232 | 256 | "id": "6c7b434c",
|
233 | 257 | "metadata": {},
|
234 | 258 | "outputs": [
|
235 | 259 | {
|
236 | 260 | "name": "stderr",
|
237 | 261 | "output_type": "stream",
|
238 | 262 | "text": [
|
239 |
| - "[2025-03-11 21:46:47.199][ INFO](notebook) - files and folders under work_dir: ['predictions', 'config_maisi.json', 'models', 'sim_dataroot', 'config_maisi_diff_model.json', 'embeddings', 'environment_maisi_diff_model.json', 'sim_datalist.json'].\n", |
240 |
| - "[2025-03-11 21:46:47.199][ INFO](notebook) - number of GPUs: 1.\n" |
| 263 | + "[2025-03-11 22:05:02.966][ INFO](notebook) - files and folders under work_dir: ['predictions', 'config_maisi.json', 'models', 'sim_dataroot', 'config_maisi_diff_model.json', 'embeddings', 'environment_maisi_diff_model.json', 'sim_datalist.json'].\n", |
| 264 | + "[2025-03-11 22:05:02.966][ INFO](notebook) - number of GPUs: 1.\n" |
241 | 265 | ]
|
242 | 266 | }
|
243 | 267 | ],
|
244 | 268 | "source": [
|
245 | 269 | "env_config_path = \"./configs/environment_maisi_diff_model.json\"\n",
|
246 | 270 | "model_config_path = \"./configs/config_maisi_diff_model.json\"\n",
|
247 |
| - "if maisi_version == 'maisi-ddpm':\n", |
| 271 | + "if maisi_version == \"maisi-ddpm\":\n", |
248 | 272 | " model_def_path = \"./configs/config_maisi-ddpm.json\"\n",
|
249 | 273 | " include_body_region = True\n",
|
250 |
| - "elif maisi_version == 'maisi-rflow':\n", |
| 274 | + "elif maisi_version == \"maisi-rflow\":\n", |
251 | 275 | " model_def_path = \"./configs/config_maisi-rflow.json\"\n",
|
252 | 276 | " include_body_region = False\n",
|
253 | 277 | "else:\n",
|
|
315 | 339 | },
|
316 | 340 | {
|
317 | 341 | "cell_type": "code",
|
318 |
| - "execution_count": 7, |
| 342 | + "execution_count": 8, |
319 | 343 | "id": "95ea6972",
|
320 | 344 | "metadata": {},
|
321 | 345 | "outputs": [],
|
|
375 | 399 | },
|
376 | 400 | {
|
377 | 401 | "cell_type": "code",
|
378 |
| - "execution_count": 8, |
| 402 | + "execution_count": 9, |
379 | 403 | "id": "f45ea863",
|
380 | 404 | "metadata": {},
|
381 | 405 | "outputs": [
|
382 | 406 | {
|
383 | 407 | "name": "stderr",
|
384 | 408 | "output_type": "stream",
|
385 | 409 | "text": [
|
386 |
| - "[2025-03-11 21:46:47.210][ INFO](notebook) - Creating training data...\n" |
| 410 | + "[2025-03-11 22:05:02.977][ INFO](notebook) - Creating training data...\n" |
387 | 411 | ]
|
388 | 412 | },
|
389 | 413 | {
|
390 | 414 | "name": "stdout",
|
391 | 415 | "output_type": "stream",
|
392 | 416 | "text": [
|
393 | 417 | "\n",
|
394 |
| - "[2025-03-11 21:46:57.369][ INFO](creating training data) - Using device cuda:0\n", |
395 |
| - "[2025-03-11 21:46:58.170][ INFO](creating training data) - filenames_raw: ['tr_image_001.nii.gz', 'tr_image_002.nii.gz']\n", |
| 418 | + "[2025-03-11 22:05:10.881][ INFO](creating training data) - Using device cuda:0\n", |
| 419 | + "[2025-03-11 22:05:11.686][ INFO](creating training data) - filenames_raw: ['tr_image_001.nii.gz', 'tr_image_002.nii.gz']\n", |
396 | 420 | "\n"
|
397 | 421 | ]
|
398 | 422 | }
|
|
428 | 452 | },
|
429 | 453 | {
|
430 | 454 | "cell_type": "code",
|
431 |
| - "execution_count": 9, |
| 455 | + "execution_count": 10, |
432 | 456 | "id": "0221a658",
|
433 | 457 | "metadata": {},
|
434 | 458 | "outputs": [
|
435 | 459 | {
|
436 | 460 | "name": "stderr",
|
437 | 461 | "output_type": "stream",
|
438 | 462 | "text": [
|
439 |
| - "[2025-03-11 21:47:00.412][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n", |
440 |
| - "[2025-03-11 21:47:00.414][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n", |
441 |
| - "[2025-03-11 21:47:00.415][ INFO](notebook) - Completed creating .json files for all embedding files.\n" |
| 463 | + "[2025-03-11 22:05:13.881][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n", |
| 464 | + "[2025-03-11 22:05:13.884][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n", |
| 465 | + "[2025-03-11 22:05:13.885][ INFO](notebook) - Completed creating .json files for all embedding files.\n" |
442 | 466 | ]
|
443 | 467 | }
|
444 | 468 | ],
|
|
467 | 491 | " spacing = [float(_item) for _item in spacing]\n",
|
468 | 492 | "\n",
|
469 | 493 | " # Create the dictionary with the specified keys and values\n",
|
470 |
| - " data = {\n", |
471 |
| - " \"dim\": dimensions,\n", |
472 |
| - " \"spacing\": spacing\n", |
473 |
| - " }\n", |
| 494 | + " data = {\"dim\": dimensions, \"spacing\": spacing}\n", |
474 | 495 | " if include_body_region:\n",
|
475 | 496 | " # The region can be selected from one of four regions from top to bottom.\n",
|
476 | 497 | " # [1,0,0,0] is the head and neck, [0,1,0,0] is the chest region, [0,0,1,0]\n",
|
|
510 | 531 | },
|
511 | 532 | {
|
512 | 533 | "cell_type": "code",
|
513 |
| - "execution_count": 10, |
| 534 | + "execution_count": 11, |
514 | 535 | "id": "ade6389d",
|
515 | 536 | "metadata": {},
|
516 | 537 | "outputs": [
|
517 | 538 | {
|
518 | 539 | "name": "stderr",
|
519 | 540 | "output_type": "stream",
|
520 | 541 | "text": [
|
521 |
| - "[2025-03-11 21:47:00.420][ INFO](notebook) - Training the model...\n" |
| 542 | + "[2025-03-11 22:05:13.892][ INFO](notebook) - Training the model...\n" |
522 | 543 | ]
|
523 | 544 | },
|
524 | 545 | {
|
525 | 546 | "name": "stdout",
|
526 | 547 | "output_type": "stream",
|
527 | 548 | "text": [
|
528 | 549 | "\n",
|
529 |
| - "[2025-03-11 21:47:09.081][ INFO](training) - Using cuda:0 of 1\n", |
530 |
| - "[2025-03-11 21:47:09.081][ INFO](training) - [config] ckpt_folder -> ./temp_work_dir/./models.\n", |
531 |
| - "[2025-03-11 21:47:09.081][ INFO](training) - [config] data_root -> ./temp_work_dir/./embeddings.\n", |
532 |
| - "[2025-03-11 21:47:09.081][ INFO](training) - [config] data_list -> ./temp_work_dir/sim_datalist.json.\n", |
533 |
| - "[2025-03-11 21:47:09.081][ INFO](training) - [config] lr -> 0.0001.\n", |
534 |
| - "[2025-03-11 21:47:09.081][ INFO](training) - [config] num_epochs -> 2.\n", |
535 |
| - "[2025-03-11 21:47:09.081][ INFO](training) - [config] num_train_timesteps -> 1000.\n", |
536 |
| - "[2025-03-11 21:47:09.081][ INFO](training) - num_files_train: 2\n", |
537 |
| - "[2025-03-11 21:47:10.815][ INFO](training) - Training from scratch.\n", |
538 |
| - "[2025-03-11 21:47:11.273][ INFO](training) - Scaling factor set to 1.159977912902832.\n", |
539 |
| - "[2025-03-11 21:47:11.273][ INFO](training) - scale_factor -> 1.159977912902832.\n", |
540 |
| - "[2025-03-11 21:47:11.276][ INFO](training) - torch.set_float32_matmul_precision -> highest.\n", |
541 |
| - "[2025-03-11 21:47:11.276][ INFO](training) - Epoch 1, lr 0.0001.\n", |
542 |
| - "[2025-03-11 21:47:12.253][ INFO](training) - [2025-03-11 21:47:12] epoch 1, iter 1/2, loss: 0.7979, lr: 0.000100000000.\n", |
543 |
| - "[2025-03-11 21:47:12.535][ INFO](training) - [2025-03-11 21:47:12] epoch 1, iter 2/2, loss: 0.7931, lr: 0.000056250000.\n", |
544 |
| - "[2025-03-11 21:47:12.572][ INFO](training) - epoch 1 average loss: 0.7955.\n", |
545 |
| - "[2025-03-11 21:47:14.031][ INFO](training) - Epoch 2, lr 2.5e-05.\n", |
546 |
| - "[2025-03-11 21:47:14.420][ INFO](training) - [2025-03-11 21:47:14] epoch 2, iter 1/2, loss: 0.7883, lr: 0.000025000000.\n", |
547 |
| - "[2025-03-11 21:47:14.517][ INFO](training) - [2025-03-11 21:47:14] epoch 2, iter 2/2, loss: 0.7893, lr: 0.000006250000.\n", |
548 |
| - "[2025-03-11 21:47:14.594][ INFO](training) - epoch 2 average loss: 0.7888.\n", |
| 550 | + "[2025-03-11 22:05:24.419][ INFO](training) - Using cuda:0 of 1\n", |
| 551 | + "[2025-03-11 22:05:24.419][ INFO](training) - [config] ckpt_folder -> ./temp_work_dir/./models.\n", |
| 552 | + "[2025-03-11 22:05:24.419][ INFO](training) - [config] data_root -> ./temp_work_dir/./embeddings.\n", |
| 553 | + "[2025-03-11 22:05:24.419][ INFO](training) - [config] data_list -> ./temp_work_dir/sim_datalist.json.\n", |
| 554 | + "[2025-03-11 22:05:24.419][ INFO](training) - [config] lr -> 0.0001.\n", |
| 555 | + "[2025-03-11 22:05:24.419][ INFO](training) - [config] num_epochs -> 2.\n", |
| 556 | + "[2025-03-11 22:05:24.419][ INFO](training) - [config] num_train_timesteps -> 1000.\n", |
| 557 | + "[2025-03-11 22:05:24.420][ INFO](training) - num_files_train: 2\n", |
| 558 | + "[2025-03-11 22:05:26.152][ INFO](training) - Training from scratch.\n", |
| 559 | + "[2025-03-11 22:05:26.539][ INFO](training) - Scaling factor set to 1.159977912902832.\n", |
| 560 | + "[2025-03-11 22:05:26.539][ INFO](training) - scale_factor -> 1.159977912902832.\n", |
| 561 | + "[2025-03-11 22:05:26.542][ INFO](training) - torch.set_float32_matmul_precision -> highest.\n", |
| 562 | + "[2025-03-11 22:05:26.542][ INFO](training) - Epoch 1, lr 0.0001.\n", |
| 563 | + "[2025-03-11 22:05:28.578][ INFO](training) - [2025-03-11 22:05:28] epoch 1, iter 1/2, loss: 0.7974, lr: 0.000100000000.\n", |
| 564 | + "[2025-03-11 22:05:28.719][ INFO](training) - [2025-03-11 22:05:28] epoch 1, iter 2/2, loss: 0.7943, lr: 0.000056250000.\n", |
| 565 | + "[2025-03-11 22:05:28.762][ INFO](training) - epoch 1 average loss: 0.7958.\n", |
| 566 | + "[2025-03-11 22:05:30.615][ INFO](training) - Epoch 2, lr 2.5e-05.\n", |
| 567 | + "[2025-03-11 22:05:31.002][ INFO](training) - [2025-03-11 22:05:31] epoch 2, iter 1/2, loss: 0.7898, lr: 0.000025000000.\n", |
| 568 | + "[2025-03-11 22:05:31.105][ INFO](training) - [2025-03-11 22:05:31] epoch 2, iter 2/2, loss: 0.7886, lr: 0.000006250000.\n", |
| 569 | + "[2025-03-11 22:05:31.168][ INFO](training) - epoch 2 average loss: 0.7892.\n", |
549 | 570 | "\n"
|
550 | 571 | ]
|
551 | 572 | }
|
|
583 | 604 | },
|
584 | 605 | {
|
585 | 606 | "cell_type": "code",
|
586 |
| - "execution_count": 11, |
| 607 | + "execution_count": 12, |
587 | 608 | "id": "1626526d",
|
588 | 609 | "metadata": {},
|
589 | 610 | "outputs": [
|
590 | 611 | {
|
591 | 612 | "name": "stderr",
|
592 | 613 | "output_type": "stream",
|
593 | 614 | "text": [
|
594 |
| - "[2025-03-11 21:47:18.262][ INFO](notebook) - Running inference...\n", |
595 |
| - "[2025-03-11 21:47:35.148][ INFO](notebook) - Completed all steps.\n" |
| 615 | + "[2025-03-11 22:05:35.033][ INFO](notebook) - Running inference...\n", |
| 616 | + "[2025-03-11 22:05:50.259][ INFO](notebook) - Completed all steps.\n" |
596 | 617 | ]
|
597 | 618 | },
|
598 | 619 | {
|
599 | 620 | "name": "stdout",
|
600 | 621 | "output_type": "stream",
|
601 | 622 | "text": [
|
602 | 623 | "\n",
|
603 |
| - "[2025-03-11 21:47:27.859][ INFO](inference) - Using cuda:0 of 1 with random seed: 99760\n", |
604 |
| - "[2025-03-11 21:47:27.859][ INFO](inference) - [config] ckpt_filepath -> ./temp_work_dir/./models/diff_unet_ckpt.pt.\n", |
605 |
| - "[2025-03-11 21:47:27.860][ INFO](inference) - [config] random_seed -> 99760.\n", |
606 |
| - "[2025-03-11 21:47:27.860][ INFO](inference) - [config] output_prefix -> unet_3d.\n", |
607 |
| - "[2025-03-11 21:47:27.860][ INFO](inference) - [config] output_size -> (256, 256, 128).\n", |
608 |
| - "[2025-03-11 21:47:27.860][ INFO](inference) - [config] out_spacing -> (1.0, 1.0, 0.75).\n", |
609 |
| - "[2025-03-11 21:47:27.860][ INFO](root) - `controllable_anatomy_size` is not provided.\n", |
610 |
| - "[2025-03-11 21:47:30.510][ INFO](inference) - checkpoints ./temp_work_dir/./models/diff_unet_ckpt.pt loaded.\n", |
611 |
| - "[2025-03-11 21:47:30.512][ INFO](inference) - scale_factor -> 1.159977912902832.\n", |
612 |
| - "[2025-03-11 21:47:30.512][ INFO](inference) - num_downsample_level -> 4, divisor -> 4.\n", |
613 |
| - "[2025-03-11 21:47:30.514][ INFO](inference) - noise: cuda:0, torch.float32, <class 'torch.Tensor'>\n", |
| 624 | + "[2025-03-11 22:05:43.502][ INFO](inference) - Using cuda:0 of 1 with random seed: 7854\n", |
| 625 | + "[2025-03-11 22:05:43.502][ INFO](inference) - [config] ckpt_filepath -> ./temp_work_dir/./models/diff_unet_ckpt.pt.\n", |
| 626 | + "[2025-03-11 22:05:43.502][ INFO](inference) - [config] random_seed -> 7854.\n", |
| 627 | + "[2025-03-11 22:05:43.502][ INFO](inference) - [config] output_prefix -> unet_3d.\n", |
| 628 | + "[2025-03-11 22:05:43.502][ INFO](inference) - [config] output_size -> (256, 256, 128).\n", |
| 629 | + "[2025-03-11 22:05:43.502][ INFO](inference) - [config] out_spacing -> (1.0, 1.0, 0.75).\n", |
| 630 | + "[2025-03-11 22:05:43.502][ INFO](root) - `controllable_anatomy_size` is not provided.\n", |
| 631 | + "[2025-03-11 22:05:45.793][ INFO](inference) - checkpoints ./temp_work_dir/./models/diff_unet_ckpt.pt loaded.\n", |
| 632 | + "[2025-03-11 22:05:45.795][ INFO](inference) - scale_factor -> 1.159977912902832.\n", |
| 633 | + "[2025-03-11 22:05:45.796][ INFO](inference) - num_downsample_level -> 4, divisor -> 4.\n", |
| 634 | + "[2025-03-11 22:05:45.798][ INFO](inference) - noise: cuda:0, torch.float32, <class 'torch.Tensor'>\n", |
614 | 635 | "\n",
|
615 | 636 | " 0%| | 0/10 [00:00<?, ?it/s]\n",
|
616 |
| - " 10%|█ | 1/10 [00:00<00:07, 1.24it/s]\n", |
617 |
| - " 60%|██████ | 6/10 [00:00<00:00, 8.39it/s]\n", |
618 |
| - "100%|██████████| 10/10 [00:01<00:00, 9.80it/s]\n", |
619 |
| - "[2025-03-11 21:47:33.116][ INFO](inference) - Saved ./temp_work_dir/./predictions/unet_3d_seed99760_size256x256x128_spacing1.00x1.00x0.75_20250311214732_rank0.nii.gz.\n", |
| 637 | + " 10%|█ | 1/10 [00:00<00:05, 1.78it/s]\n", |
| 638 | + " 60%|██████ | 6/10 [00:00<00:00, 11.19it/s]\n", |
| 639 | + "100%|██████████| 10/10 [00:00<00:00, 12.88it/s]\n", |
| 640 | + "[2025-03-11 22:05:48.356][ INFO](inference) - Saved ./temp_work_dir/./predictions/unet_3d_seed7854_size256x256x128_spacing1.00x1.00x0.75_20250311220547_rank0.nii.gz.\n", |
620 | 641 | "\n"
|
621 | 642 | ]
|
622 | 643 | }
|
|
654 | 675 | },
|
655 | 676 | {
|
656 | 677 | "cell_type": "code",
|
657 |
| - "execution_count": 12, |
| 678 | + "execution_count": 13, |
658 | 679 | "id": "0d8a344d",
|
659 | 680 | "metadata": {},
|
660 | 681 | "outputs": [
|
|
0 commit comments