|
28 | 28 | "\n",
|
29 | 29 | "In this notebook, we detail the procedure for training a 3D latent diffusion model to generate high-dimensional 3D medical images. Due to the potential for out-of-memory issues on most GPUs when generating large images (e.g., those with dimensions of 512 x 512 x 512 or greater), we have structured the training process into two primary steps: 1) generating image embeddings and 2) training 3D latent diffusion models. The subsequent sections will demonstrate the entire process using a simulated dataset.\n",
|
30 | 30 | "\n",
|
31 |
| - "`[Release Note (March 2025)]:` We are excited to announce the new MAISI Version `'maisi-rflow'`. Compared with the previous version `'maisi-ddpm'`, it accelerated latent diffusion model inference by 33x. Please see the detailed difference in the following section." |
| 31 | + "`[Release Note (March 2025)]:` We are excited to announce the new MAISI Version `'maisi3d-rflow'`. Compared with the previous version `'maisi3d-ddpm'`, it accelerated latent diffusion model inference by 33x. Please see the detailed difference in the following section." |
32 | 32 | ]
|
33 | 33 | },
|
34 | 34 | {
|
|
38 | 38 | "source": [
|
39 | 39 | "## Set up the MAISI version\n",
|
40 | 40 | "\n",
|
41 |
| - "Choose between `'maisi-ddpm'` and `'maisi-rflow'`. The differences are:\n", |
42 |
| - "- The maisi version `'maisi-ddpm'` uses basic noise scheduler DDPM. `'maisi-rflow'` uses Rectified Flow scheduler, can be 33 times faster during inference.\n", |
43 |
| - "- The maisi version `'maisi-ddpm'` requires training images to be labeled with body region (`\"top_region_index\"` and `\"bottom_region_index\"`), while `'maisi-rflow'` does not have such requirement. In other words, it is easier to prepare training data for `'maisi-rflow'`.\n", |
44 |
| - "- For the released model weights, `'maisi-rflow'` can generate images with better quality for head region and small output volumes, and comparable quality for other cases compared with `'maisi-ddpm'`." |
| 41 | + "Choose between `'maisi3d-ddpm'` and `'maisi3d-rflow'`. The differences are:\n", |
| 42 | + "- The maisi version `'maisi3d-ddpm'` uses basic noise scheduler DDPM. `'maisi3d-rflow'` uses Rectified Flow scheduler, can be 33 times faster during inference.\n", |
| 43 | + "- The maisi version `'maisi3d-ddpm'` requires training images to be labeled with body region (`\"top_region_index\"` and `\"bottom_region_index\"`), while `'maisi3d-rflow'` does not have such requirement. In other words, it is easier to prepare training data for `'maisi3d-rflow'`.\n", |
| 44 | + "- For the released model weights, `'maisi3d-rflow'` can generate images with better quality for head region and small output volumes, and comparable quality for other cases compared with `'maisi3d-ddpm'`." |
45 | 45 | ]
|
46 | 46 | },
|
47 | 47 | {
|
|
51 | 51 | "metadata": {},
|
52 | 52 | "outputs": [],
|
53 | 53 | "source": [
|
54 |
| - "maisi_version = \"maisi-ddpm\"\n", |
55 |
| - "assert maisi_version in [\"maisi-ddpm\", \"maisi-rflow\"]" |
| 54 | + "maisi_version = \"maisi3d-ddpm\"\n", |
| 55 | + "assert maisi_version in [\"maisi3d-ddpm\", \"maisi3d-rflow\"]" |
56 | 56 | ]
|
57 | 57 | },
|
58 | 58 | {
|
|
131 | 131 | "import numpy as np\n",
|
132 | 132 | "import nibabel as nib\n",
|
133 | 133 | "import subprocess\n",
|
| 134 | + "from IPython.display import Image, display\n", |
134 | 135 | "\n",
|
135 | 136 | "from monai.apps import download_url\n",
|
136 | 137 | "from monai.data import create_test_image_3d\n",
|
137 | 138 | "from monai.config import print_config\n",
|
138 | 139 | "\n",
|
139 |
| - "from IPython.display import Image, display\n", |
140 |
| - "\n", |
141 | 140 | "from scripts.diff_model_setting import setup_logging\n",
|
142 | 141 | "\n",
|
143 | 142 | "print_config()\n",
|
|
152 | 151 | "source": [
|
153 | 152 | "## Set up the MAISI version\n",
|
154 | 153 | "\n",
|
155 |
| - "Choose between `'maisi-ddpm'` and `'maisi-rflow'`. The differences are:\n", |
156 |
| - "- The maisi version `'maisi-ddpm'` uses basic noise scheduler DDPM. `'maisi-rflow'` uses Rectified Flow scheduler, can be 33 times faster during inference.\n", |
157 |
| - "- The maisi version `'maisi-ddpm'` requires training images to be labeled with body region (`\"top_region_index\"` and `\"bottom_region_index\"`), while `'maisi-rflow'` does not have such requirement. In other words, it is easier to prepare training data for `'maisi-rflow'`.\n", |
158 |
| - "- For the released model weights, `'maisi-rflow'` can generate images with better quality for head region and small output volumes, and comparable quality for other cases compared with `'maisi-ddpm'`." |
| 154 | + "Choose between `'maisi3d-ddpm'` and `'maisi3d-rflow'`. The differences are:\n", |
| 155 | + "- The maisi version `'maisi3d-ddpm'` uses basic noise scheduler DDPM. `'maisi3d-rflow'` uses Rectified Flow scheduler, can be 33 times faster during inference.\n", |
| 156 | + "- The maisi version `'maisi3d-ddpm'` requires training images to be labeled with body region (`\"top_region_index\"` and `\"bottom_region_index\"`), while `'maisi3d-rflow'` does not have such requirement. In other words, it is easier to prepare training data for `'maisi3d-rflow'`.\n", |
| 157 | + "- For the released model weights, `'maisi3d-rflow'` can generate images with better quality for head region and small output volumes, and comparable quality for other cases compared with `'maisi3d-ddpm'`." |
159 | 158 | ]
|
160 | 159 | },
|
161 | 160 | {
|
|
165 | 164 | "metadata": {},
|
166 | 165 | "outputs": [],
|
167 | 166 | "source": [
|
168 |
| - "maisi_version = \"maisi-ddpm\"\n", |
169 |
| - "assert maisi_version in [\"maisi-ddpm\", \"maisi-rflow\"]" |
| 167 | + "maisi_version = \"maisi3d-ddpm\"\n", |
| 168 | + "assert maisi_version in [\"maisi3d-ddpm\", \"maisi3d-rflow\"]" |
170 | 169 | ]
|
171 | 170 | },
|
172 | 171 | {
|
|
213 | 212 | "name": "stderr",
|
214 | 213 | "output_type": "stream",
|
215 | 214 | "text": [
|
216 |
| - "[2025-03-11 22:05:02.952][ INFO](notebook) - Generated simulated images.\n" |
| 215 | + "[2025-03-11 22:16:41.000][ INFO](notebook) - Generated simulated images.\n" |
217 | 216 | ]
|
218 | 217 | }
|
219 | 218 | ],
|
|
260 | 259 | "name": "stderr",
|
261 | 260 | "output_type": "stream",
|
262 | 261 | "text": [
|
263 |
| - "[2025-03-11 22:05:02.966][ INFO](notebook) - files and folders under work_dir: ['predictions', 'config_maisi.json', 'models', 'sim_dataroot', 'config_maisi_diff_model.json', 'embeddings', 'environment_maisi_diff_model.json', 'sim_datalist.json'].\n", |
264 |
| - "[2025-03-11 22:05:02.966][ INFO](notebook) - number of GPUs: 1.\n" |
| 262 | + "[2025-03-11 22:16:41.012][ INFO](notebook) - files and folders under work_dir: ['predictions', 'config_maisi.json', 'models', 'sim_dataroot', 'config_maisi_diff_model.json', 'embeddings', 'environment_maisi_diff_model.json', 'sim_datalist.json'].\n", |
| 263 | + "[2025-03-11 22:16:41.012][ INFO](notebook) - number of GPUs: 1.\n" |
265 | 264 | ]
|
266 | 265 | }
|
267 | 266 | ],
|
268 | 267 | "source": [
|
269 | 268 | "env_config_path = \"./configs/environment_maisi_diff_model.json\"\n",
|
270 | 269 | "model_config_path = \"./configs/config_maisi_diff_model.json\"\n",
|
271 |
| - "if maisi_version == \"maisi-ddpm\":\n", |
272 |
| - " model_def_path = \"./configs/config_maisi-ddpm.json\"\n", |
| 270 | + "if maisi_version == \"maisi3d-ddpm\":\n", |
| 271 | + " model_def_path = \"./configs/config_maisi3d-ddpm.json\"\n", |
273 | 272 | " include_body_region = True\n",
|
274 |
| - "elif maisi_version == \"maisi-rflow\":\n", |
275 |
| - " model_def_path = \"./configs/config_maisi-rflow.json\"\n", |
| 273 | + "elif maisi_version == \"maisi3d-rflow\":\n", |
| 274 | + " model_def_path = \"./configs/config_maisi3d-rflow.json\"\n", |
276 | 275 | " include_body_region = False\n",
|
277 | 276 | "else:\n",
|
278 |
| - " raise ValueError(f\"maisi_version has to be chosen from ['maisi-ddpm', 'maisi-rflow'], yet got {maisi_version}.\")\n", |
| 277 | + " raise ValueError(f\"maisi_version has to be chosen from ['maisi3d-ddpm', 'maisi3d-rflow'], yet got {maisi_version}.\")\n", |
279 | 278 | "\n",
|
280 | 279 | "# Load environment configuration, model configuration and model definition\n",
|
281 | 280 | "with open(env_config_path, \"r\") as f:\n",
|
|
407 | 406 | "name": "stderr",
|
408 | 407 | "output_type": "stream",
|
409 | 408 | "text": [
|
410 |
| - "[2025-03-11 22:05:02.977][ INFO](notebook) - Creating training data...\n" |
| 409 | + "[2025-03-11 22:16:41.021][ INFO](notebook) - Creating training data...\n" |
411 | 410 | ]
|
412 | 411 | },
|
413 | 412 | {
|
414 | 413 | "name": "stdout",
|
415 | 414 | "output_type": "stream",
|
416 | 415 | "text": [
|
417 | 416 | "\n",
|
418 |
| - "[2025-03-11 22:05:10.881][ INFO](creating training data) - Using device cuda:0\n", |
419 |
| - "[2025-03-11 22:05:11.686][ INFO](creating training data) - filenames_raw: ['tr_image_001.nii.gz', 'tr_image_002.nii.gz']\n", |
| 417 | + "[2025-03-11 22:16:50.396][ INFO](creating training data) - Using device cuda:0\n", |
| 418 | + "[2025-03-11 22:16:51.402][ INFO](creating training data) - filenames_raw: ['tr_image_001.nii.gz', 'tr_image_002.nii.gz']\n", |
420 | 419 | "\n"
|
421 | 420 | ]
|
422 | 421 | }
|
|
460 | 459 | "name": "stderr",
|
461 | 460 | "output_type": "stream",
|
462 | 461 | "text": [
|
463 |
| - "[2025-03-11 22:05:13.881][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n", |
464 |
| - "[2025-03-11 22:05:13.884][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n", |
465 |
| - "[2025-03-11 22:05:13.885][ INFO](notebook) - Completed creating .json files for all embedding files.\n" |
| 462 | + "[2025-03-11 22:16:53.638][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n", |
| 463 | + "[2025-03-11 22:16:53.640][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n", |
| 464 | + "[2025-03-11 22:16:53.641][ INFO](notebook) - Completed creating .json files for all embedding files.\n" |
466 | 465 | ]
|
467 | 466 | }
|
468 | 467 | ],
|
|
539 | 538 | "name": "stderr",
|
540 | 539 | "output_type": "stream",
|
541 | 540 | "text": [
|
542 |
| - "[2025-03-11 22:05:13.892][ INFO](notebook) - Training the model...\n" |
| 541 | + "[2025-03-11 22:16:53.646][ INFO](notebook) - Training the model...\n" |
543 | 542 | ]
|
544 | 543 | },
|
545 | 544 | {
|
546 | 545 | "name": "stdout",
|
547 | 546 | "output_type": "stream",
|
548 | 547 | "text": [
|
549 | 548 | "\n",
|
550 |
| - "[2025-03-11 22:05:24.419][ INFO](training) - Using cuda:0 of 1\n", |
551 |
| - "[2025-03-11 22:05:24.419][ INFO](training) - [config] ckpt_folder -> ./temp_work_dir/./models.\n", |
552 |
| - "[2025-03-11 22:05:24.419][ INFO](training) - [config] data_root -> ./temp_work_dir/./embeddings.\n", |
553 |
| - "[2025-03-11 22:05:24.419][ INFO](training) - [config] data_list -> ./temp_work_dir/sim_datalist.json.\n", |
554 |
| - "[2025-03-11 22:05:24.419][ INFO](training) - [config] lr -> 0.0001.\n", |
555 |
| - "[2025-03-11 22:05:24.419][ INFO](training) - [config] num_epochs -> 2.\n", |
556 |
| - "[2025-03-11 22:05:24.419][ INFO](training) - [config] num_train_timesteps -> 1000.\n", |
557 |
| - "[2025-03-11 22:05:24.420][ INFO](training) - num_files_train: 2\n", |
558 |
| - "[2025-03-11 22:05:26.152][ INFO](training) - Training from scratch.\n", |
559 |
| - "[2025-03-11 22:05:26.539][ INFO](training) - Scaling factor set to 1.159977912902832.\n", |
560 |
| - "[2025-03-11 22:05:26.539][ INFO](training) - scale_factor -> 1.159977912902832.\n", |
561 |
| - "[2025-03-11 22:05:26.542][ INFO](training) - torch.set_float32_matmul_precision -> highest.\n", |
562 |
| - "[2025-03-11 22:05:26.542][ INFO](training) - Epoch 1, lr 0.0001.\n", |
563 |
| - "[2025-03-11 22:05:28.578][ INFO](training) - [2025-03-11 22:05:28] epoch 1, iter 1/2, loss: 0.7974, lr: 0.000100000000.\n", |
564 |
| - "[2025-03-11 22:05:28.719][ INFO](training) - [2025-03-11 22:05:28] epoch 1, iter 2/2, loss: 0.7943, lr: 0.000056250000.\n", |
565 |
| - "[2025-03-11 22:05:28.762][ INFO](training) - epoch 1 average loss: 0.7958.\n", |
566 |
| - "[2025-03-11 22:05:30.615][ INFO](training) - Epoch 2, lr 2.5e-05.\n", |
567 |
| - "[2025-03-11 22:05:31.002][ INFO](training) - [2025-03-11 22:05:31] epoch 2, iter 1/2, loss: 0.7898, lr: 0.000025000000.\n", |
568 |
| - "[2025-03-11 22:05:31.105][ INFO](training) - [2025-03-11 22:05:31] epoch 2, iter 2/2, loss: 0.7886, lr: 0.000006250000.\n", |
569 |
| - "[2025-03-11 22:05:31.168][ INFO](training) - epoch 2 average loss: 0.7892.\n", |
| 549 | + "[2025-03-11 22:17:02.004][ INFO](training) - Using cuda:0 of 1\n", |
| 550 | + "[2025-03-11 22:17:02.004][ INFO](training) - [config] ckpt_folder -> ./temp_work_dir/./models.\n", |
| 551 | + "[2025-03-11 22:17:02.004][ INFO](training) - [config] data_root -> ./temp_work_dir/./embeddings.\n", |
| 552 | + "[2025-03-11 22:17:02.004][ INFO](training) - [config] data_list -> ./temp_work_dir/sim_datalist.json.\n", |
| 553 | + "[2025-03-11 22:17:02.004][ INFO](training) - [config] lr -> 0.0001.\n", |
| 554 | + "[2025-03-11 22:17:02.004][ INFO](training) - [config] num_epochs -> 2.\n", |
| 555 | + "[2025-03-11 22:17:02.004][ INFO](training) - [config] num_train_timesteps -> 1000.\n", |
| 556 | + "[2025-03-11 22:17:02.005][ INFO](training) - num_files_train: 2\n", |
| 557 | + "[2025-03-11 22:17:03.887][ INFO](training) - Training from scratch.\n", |
| 558 | + "[2025-03-11 22:17:04.338][ INFO](training) - Scaling factor set to 1.159977912902832.\n", |
| 559 | + "[2025-03-11 22:17:04.339][ INFO](training) - scale_factor -> 1.159977912902832.\n", |
| 560 | + "[2025-03-11 22:17:04.341][ INFO](training) - torch.set_float32_matmul_precision -> highest.\n", |
| 561 | + "[2025-03-11 22:17:04.341][ INFO](training) - Epoch 1, lr 0.0001.\n", |
| 562 | + "[2025-03-11 22:17:05.278][ INFO](training) - [2025-03-11 22:17:05] epoch 1, iter 1/2, loss: 0.7973, lr: 0.000100000000.\n", |
| 563 | + "[2025-03-11 22:17:05.673][ INFO](training) - [2025-03-11 22:17:05] epoch 1, iter 2/2, loss: 0.7969, lr: 0.000056250000.\n", |
| 564 | + "[2025-03-11 22:17:05.718][ INFO](training) - epoch 1 average loss: 0.7971.\n", |
| 565 | + "[2025-03-11 22:17:07.383][ INFO](training) - Epoch 2, lr 2.5e-05.\n", |
| 566 | + "[2025-03-11 22:17:07.777][ INFO](training) - [2025-03-11 22:17:07] epoch 2, iter 1/2, loss: 0.7932, lr: 0.000025000000.\n", |
| 567 | + "[2025-03-11 22:17:07.881][ INFO](training) - [2025-03-11 22:17:07] epoch 2, iter 2/2, loss: 0.7904, lr: 0.000006250000.\n", |
| 568 | + "[2025-03-11 22:17:07.942][ INFO](training) - epoch 2 average loss: 0.7918.\n", |
570 | 569 | "\n"
|
571 | 570 | ]
|
572 | 571 | }
|
|
612 | 611 | "name": "stderr",
|
613 | 612 | "output_type": "stream",
|
614 | 613 | "text": [
|
615 |
| - "[2025-03-11 22:05:35.033][ INFO](notebook) - Running inference...\n", |
616 |
| - "[2025-03-11 22:05:50.259][ INFO](notebook) - Completed all steps.\n" |
| 614 | + "[2025-03-11 22:17:11.993][ INFO](notebook) - Running inference...\n", |
| 615 | + "[2025-03-11 22:17:27.730][ INFO](notebook) - Completed all steps.\n" |
617 | 616 | ]
|
618 | 617 | },
|
619 | 618 | {
|
620 | 619 | "name": "stdout",
|
621 | 620 | "output_type": "stream",
|
622 | 621 | "text": [
|
623 | 622 | "\n",
|
624 |
| - "[2025-03-11 22:05:43.502][ INFO](inference) - Using cuda:0 of 1 with random seed: 7854\n", |
625 |
| - "[2025-03-11 22:05:43.502][ INFO](inference) - [config] ckpt_filepath -> ./temp_work_dir/./models/diff_unet_ckpt.pt.\n", |
626 |
| - "[2025-03-11 22:05:43.502][ INFO](inference) - [config] random_seed -> 7854.\n", |
627 |
| - "[2025-03-11 22:05:43.502][ INFO](inference) - [config] output_prefix -> unet_3d.\n", |
628 |
| - "[2025-03-11 22:05:43.502][ INFO](inference) - [config] output_size -> (256, 256, 128).\n", |
629 |
| - "[2025-03-11 22:05:43.502][ INFO](inference) - [config] out_spacing -> (1.0, 1.0, 0.75).\n", |
630 |
| - "[2025-03-11 22:05:43.502][ INFO](root) - `controllable_anatomy_size` is not provided.\n", |
631 |
| - "[2025-03-11 22:05:45.793][ INFO](inference) - checkpoints ./temp_work_dir/./models/diff_unet_ckpt.pt loaded.\n", |
632 |
| - "[2025-03-11 22:05:45.795][ INFO](inference) - scale_factor -> 1.159977912902832.\n", |
633 |
| - "[2025-03-11 22:05:45.796][ INFO](inference) - num_downsample_level -> 4, divisor -> 4.\n", |
634 |
| - "[2025-03-11 22:05:45.798][ INFO](inference) - noise: cuda:0, torch.float32, <class 'torch.Tensor'>\n", |
| 623 | + "[2025-03-11 22:17:20.465][ INFO](inference) - Using cuda:0 of 1 with random seed: 23141\n", |
| 624 | + "[2025-03-11 22:17:20.466][ INFO](inference) - [config] ckpt_filepath -> ./temp_work_dir/./models/diff_unet_ckpt.pt.\n", |
| 625 | + "[2025-03-11 22:17:20.466][ INFO](inference) - [config] random_seed -> 23141.\n", |
| 626 | + "[2025-03-11 22:17:20.466][ INFO](inference) - [config] output_prefix -> unet_3d.\n", |
| 627 | + "[2025-03-11 22:17:20.466][ INFO](inference) - [config] output_size -> (256, 256, 128).\n", |
| 628 | + "[2025-03-11 22:17:20.466][ INFO](inference) - [config] out_spacing -> (1.0, 1.0, 0.75).\n", |
| 629 | + "[2025-03-11 22:17:20.466][ INFO](root) - `controllable_anatomy_size` is not provided.\n", |
| 630 | + "[2025-03-11 22:17:23.065][ INFO](inference) - checkpoints ./temp_work_dir/./models/diff_unet_ckpt.pt loaded.\n", |
| 631 | + "[2025-03-11 22:17:23.067][ INFO](inference) - scale_factor -> 1.159977912902832.\n", |
| 632 | + "[2025-03-11 22:17:23.068][ INFO](inference) - num_downsample_level -> 4, divisor -> 4.\n", |
| 633 | + "[2025-03-11 22:17:23.070][ INFO](inference) - noise: cuda:0, torch.float32, <class 'torch.Tensor'>\n", |
635 | 634 | "\n",
|
636 | 635 | " 0%| | 0/10 [00:00<?, ?it/s]\n",
|
637 |
| - " 10%|█ | 1/10 [00:00<00:05, 1.78it/s]\n", |
638 |
| - " 60%|██████ | 6/10 [00:00<00:00, 11.19it/s]\n", |
639 |
| - "100%|██████████| 10/10 [00:00<00:00, 12.88it/s]\n", |
640 |
| - "[2025-03-11 22:05:48.356][ INFO](inference) - Saved ./temp_work_dir/./predictions/unet_3d_seed7854_size256x256x128_spacing1.00x1.00x0.75_20250311220547_rank0.nii.gz.\n", |
| 636 | + " 10%|█ | 1/10 [00:00<00:07, 1.24it/s]\n", |
| 637 | + " 60%|██████ | 6/10 [00:00<00:00, 8.37it/s]\n", |
| 638 | + "100%|██████████| 10/10 [00:01<00:00, 9.78it/s]\n", |
| 639 | + "[2025-03-11 22:17:25.828][ INFO](inference) - Saved ./temp_work_dir/./predictions/unet_3d_seed23141_size256x256x128_spacing1.00x1.00x0.75_20250311221725_rank0.nii.gz.\n", |
641 | 640 | "\n"
|
642 | 641 | ]
|
643 | 642 | }
|
|
0 commit comments