Skip to content

Commit a90cbd0

Browse files
committed
reformat
Signed-off-by: Can-Zhao <[email protected]>
1 parent 81ea271 commit a90cbd0

File tree

1 file changed

+66
-67
lines changed

1 file changed

+66
-67
lines changed

generation/maisi/maisi_diff_unet_training_tutorial.ipynb

+66-67
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
"\n",
2929
"In this notebook, we detail the procedure for training a 3D latent diffusion model to generate high-dimensional 3D medical images. Due to the potential for out-of-memory issues on most GPUs when generating large images (e.g., those with dimensions of 512 x 512 x 512 or greater), we have structured the training process into two primary steps: 1) generating image embeddings and 2) training 3D latent diffusion models. The subsequent sections will demonstrate the entire process using a simulated dataset.\n",
3030
"\n",
31-
"`[Release Note (March 2025)]:` We are excited to announce the new MAISI Version `'maisi-rflow'`. Compared with the previous version `'maisi-ddpm'`, it accelerated latent diffusion model inference by 33x. Please see the detailed difference in the following section."
31+
"`[Release Note (March 2025)]:` We are excited to announce the new MAISI Version `'maisi3d-rflow'`. Compared with the previous version `'maisi3d-ddpm'`, it accelerated latent diffusion model inference by 33x. Please see the detailed difference in the following section."
3232
]
3333
},
3434
{
@@ -38,10 +38,10 @@
3838
"source": [
3939
"## Set up the MAISI version\n",
4040
"\n",
41-
"Choose between `'maisi-ddpm'` and `'maisi-rflow'`. The differences are:\n",
42-
"- The maisi version `'maisi-ddpm'` uses basic noise scheduler DDPM. `'maisi-rflow'` uses Rectified Flow scheduler, can be 33 times faster during inference.\n",
43-
"- The maisi version `'maisi-ddpm'` requires training images to be labeled with body region (`\"top_region_index\"` and `\"bottom_region_index\"`), while `'maisi-rflow'` does not have such requirement. In other words, it is easier to prepare training data for `'maisi-rflow'`.\n",
44-
"- For the released model weights, `'maisi-rflow'` can generate images with better quality for head region and small output volumes, and comparable quality for other cases compared with `'maisi-ddpm'`."
41+
"Choose between `'maisi3d-ddpm'` and `'maisi3d-rflow'`. The differences are:\n",
42+
"- The maisi version `'maisi3d-ddpm'` uses basic noise scheduler DDPM. `'maisi3d-rflow'` uses Rectified Flow scheduler, can be 33 times faster during inference.\n",
43+
"- The maisi version `'maisi3d-ddpm'` requires training images to be labeled with body region (`\"top_region_index\"` and `\"bottom_region_index\"`), while `'maisi3d-rflow'` does not have such requirement. In other words, it is easier to prepare training data for `'maisi3d-rflow'`.\n",
44+
"- For the released model weights, `'maisi3d-rflow'` can generate images with better quality for head region and small output volumes, and comparable quality for other cases compared with `'maisi3d-ddpm'`."
4545
]
4646
},
4747
{
@@ -51,8 +51,8 @@
5151
"metadata": {},
5252
"outputs": [],
5353
"source": [
54-
"maisi_version = \"maisi-ddpm\"\n",
55-
"assert maisi_version in [\"maisi-ddpm\", \"maisi-rflow\"]"
54+
"maisi_version = \"maisi3d-ddpm\"\n",
55+
"assert maisi_version in [\"maisi3d-ddpm\", \"maisi3d-rflow\"]"
5656
]
5757
},
5858
{
@@ -131,13 +131,12 @@
131131
"import numpy as np\n",
132132
"import nibabel as nib\n",
133133
"import subprocess\n",
134+
"from IPython.display import Image, display\n",
134135
"\n",
135136
"from monai.apps import download_url\n",
136137
"from monai.data import create_test_image_3d\n",
137138
"from monai.config import print_config\n",
138139
"\n",
139-
"from IPython.display import Image, display\n",
140-
"\n",
141140
"from scripts.diff_model_setting import setup_logging\n",
142141
"\n",
143142
"print_config()\n",
@@ -152,10 +151,10 @@
152151
"source": [
153152
"## Set up the MAISI version\n",
154153
"\n",
155-
"Choose between `'maisi-ddpm'` and `'maisi-rflow'`. The differences are:\n",
156-
"- The maisi version `'maisi-ddpm'` uses basic noise scheduler DDPM. `'maisi-rflow'` uses Rectified Flow scheduler, can be 33 times faster during inference.\n",
157-
"- The maisi version `'maisi-ddpm'` requires training images to be labeled with body region (`\"top_region_index\"` and `\"bottom_region_index\"`), while `'maisi-rflow'` does not have such requirement. In other words, it is easier to prepare training data for `'maisi-rflow'`.\n",
158-
"- For the released model weights, `'maisi-rflow'` can generate images with better quality for head region and small output volumes, and comparable quality for other cases compared with `'maisi-ddpm'`."
154+
"Choose between `'maisi3d-ddpm'` and `'maisi3d-rflow'`. The differences are:\n",
155+
"- The maisi version `'maisi3d-ddpm'` uses basic noise scheduler DDPM. `'maisi3d-rflow'` uses Rectified Flow scheduler, can be 33 times faster during inference.\n",
156+
"- The maisi version `'maisi3d-ddpm'` requires training images to be labeled with body region (`\"top_region_index\"` and `\"bottom_region_index\"`), while `'maisi3d-rflow'` does not have such requirement. In other words, it is easier to prepare training data for `'maisi3d-rflow'`.\n",
157+
"- For the released model weights, `'maisi3d-rflow'` can generate images with better quality for head region and small output volumes, and comparable quality for other cases compared with `'maisi3d-ddpm'`."
159158
]
160159
},
161160
{
@@ -165,8 +164,8 @@
165164
"metadata": {},
166165
"outputs": [],
167166
"source": [
168-
"maisi_version = \"maisi-ddpm\"\n",
169-
"assert maisi_version in [\"maisi-ddpm\", \"maisi-rflow\"]"
167+
"maisi_version = \"maisi3d-ddpm\"\n",
168+
"assert maisi_version in [\"maisi3d-ddpm\", \"maisi3d-rflow\"]"
170169
]
171170
},
172171
{
@@ -213,7 +212,7 @@
213212
"name": "stderr",
214213
"output_type": "stream",
215214
"text": [
216-
"[2025-03-11 22:05:02.952][ INFO](notebook) - Generated simulated images.\n"
215+
"[2025-03-11 22:16:41.000][ INFO](notebook) - Generated simulated images.\n"
217216
]
218217
}
219218
],
@@ -260,22 +259,22 @@
260259
"name": "stderr",
261260
"output_type": "stream",
262261
"text": [
263-
"[2025-03-11 22:05:02.966][ INFO](notebook) - files and folders under work_dir: ['predictions', 'config_maisi.json', 'models', 'sim_dataroot', 'config_maisi_diff_model.json', 'embeddings', 'environment_maisi_diff_model.json', 'sim_datalist.json'].\n",
264-
"[2025-03-11 22:05:02.966][ INFO](notebook) - number of GPUs: 1.\n"
262+
"[2025-03-11 22:16:41.012][ INFO](notebook) - files and folders under work_dir: ['predictions', 'config_maisi.json', 'models', 'sim_dataroot', 'config_maisi_diff_model.json', 'embeddings', 'environment_maisi_diff_model.json', 'sim_datalist.json'].\n",
263+
"[2025-03-11 22:16:41.012][ INFO](notebook) - number of GPUs: 1.\n"
265264
]
266265
}
267266
],
268267
"source": [
269268
"env_config_path = \"./configs/environment_maisi_diff_model.json\"\n",
270269
"model_config_path = \"./configs/config_maisi_diff_model.json\"\n",
271-
"if maisi_version == \"maisi-ddpm\":\n",
272-
" model_def_path = \"./configs/config_maisi-ddpm.json\"\n",
270+
"if maisi_version == \"maisi3d-ddpm\":\n",
271+
" model_def_path = \"./configs/config_maisi3d-ddpm.json\"\n",
273272
" include_body_region = True\n",
274-
"elif maisi_version == \"maisi-rflow\":\n",
275-
" model_def_path = \"./configs/config_maisi-rflow.json\"\n",
273+
"elif maisi_version == \"maisi3d-rflow\":\n",
274+
" model_def_path = \"./configs/config_maisi3d-rflow.json\"\n",
276275
" include_body_region = False\n",
277276
"else:\n",
278-
" raise ValueError(f\"maisi_version has to be chosen from ['maisi-ddpm', 'maisi-rflow'], yet got {maisi_version}.\")\n",
277+
" raise ValueError(f\"maisi_version has to be chosen from ['maisi3d-ddpm', 'maisi3d-rflow'], yet got {maisi_version}.\")\n",
279278
"\n",
280279
"# Load environment configuration, model configuration and model definition\n",
281280
"with open(env_config_path, \"r\") as f:\n",
@@ -407,16 +406,16 @@
407406
"name": "stderr",
408407
"output_type": "stream",
409408
"text": [
410-
"[2025-03-11 22:05:02.977][ INFO](notebook) - Creating training data...\n"
409+
"[2025-03-11 22:16:41.021][ INFO](notebook) - Creating training data...\n"
411410
]
412411
},
413412
{
414413
"name": "stdout",
415414
"output_type": "stream",
416415
"text": [
417416
"\n",
418-
"[2025-03-11 22:05:10.881][ INFO](creating training data) - Using device cuda:0\n",
419-
"[2025-03-11 22:05:11.686][ INFO](creating training data) - filenames_raw: ['tr_image_001.nii.gz', 'tr_image_002.nii.gz']\n",
417+
"[2025-03-11 22:16:50.396][ INFO](creating training data) - Using device cuda:0\n",
418+
"[2025-03-11 22:16:51.402][ INFO](creating training data) - filenames_raw: ['tr_image_001.nii.gz', 'tr_image_002.nii.gz']\n",
420419
"\n"
421420
]
422421
}
@@ -460,9 +459,9 @@
460459
"name": "stderr",
461460
"output_type": "stream",
462461
"text": [
463-
"[2025-03-11 22:05:13.881][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n",
464-
"[2025-03-11 22:05:13.884][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n",
465-
"[2025-03-11 22:05:13.885][ INFO](notebook) - Completed creating .json files for all embedding files.\n"
462+
"[2025-03-11 22:16:53.638][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n",
463+
"[2025-03-11 22:16:53.640][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n",
464+
"[2025-03-11 22:16:53.641][ INFO](notebook) - Completed creating .json files for all embedding files.\n"
466465
]
467466
}
468467
],
@@ -539,34 +538,34 @@
539538
"name": "stderr",
540539
"output_type": "stream",
541540
"text": [
542-
"[2025-03-11 22:05:13.892][ INFO](notebook) - Training the model...\n"
541+
"[2025-03-11 22:16:53.646][ INFO](notebook) - Training the model...\n"
543542
]
544543
},
545544
{
546545
"name": "stdout",
547546
"output_type": "stream",
548547
"text": [
549548
"\n",
550-
"[2025-03-11 22:05:24.419][ INFO](training) - Using cuda:0 of 1\n",
551-
"[2025-03-11 22:05:24.419][ INFO](training) - [config] ckpt_folder -> ./temp_work_dir/./models.\n",
552-
"[2025-03-11 22:05:24.419][ INFO](training) - [config] data_root -> ./temp_work_dir/./embeddings.\n",
553-
"[2025-03-11 22:05:24.419][ INFO](training) - [config] data_list -> ./temp_work_dir/sim_datalist.json.\n",
554-
"[2025-03-11 22:05:24.419][ INFO](training) - [config] lr -> 0.0001.\n",
555-
"[2025-03-11 22:05:24.419][ INFO](training) - [config] num_epochs -> 2.\n",
556-
"[2025-03-11 22:05:24.419][ INFO](training) - [config] num_train_timesteps -> 1000.\n",
557-
"[2025-03-11 22:05:24.420][ INFO](training) - num_files_train: 2\n",
558-
"[2025-03-11 22:05:26.152][ INFO](training) - Training from scratch.\n",
559-
"[2025-03-11 22:05:26.539][ INFO](training) - Scaling factor set to 1.159977912902832.\n",
560-
"[2025-03-11 22:05:26.539][ INFO](training) - scale_factor -> 1.159977912902832.\n",
561-
"[2025-03-11 22:05:26.542][ INFO](training) - torch.set_float32_matmul_precision -> highest.\n",
562-
"[2025-03-11 22:05:26.542][ INFO](training) - Epoch 1, lr 0.0001.\n",
563-
"[2025-03-11 22:05:28.578][ INFO](training) - [2025-03-11 22:05:28] epoch 1, iter 1/2, loss: 0.7974, lr: 0.000100000000.\n",
564-
"[2025-03-11 22:05:28.719][ INFO](training) - [2025-03-11 22:05:28] epoch 1, iter 2/2, loss: 0.7943, lr: 0.000056250000.\n",
565-
"[2025-03-11 22:05:28.762][ INFO](training) - epoch 1 average loss: 0.7958.\n",
566-
"[2025-03-11 22:05:30.615][ INFO](training) - Epoch 2, lr 2.5e-05.\n",
567-
"[2025-03-11 22:05:31.002][ INFO](training) - [2025-03-11 22:05:31] epoch 2, iter 1/2, loss: 0.7898, lr: 0.000025000000.\n",
568-
"[2025-03-11 22:05:31.105][ INFO](training) - [2025-03-11 22:05:31] epoch 2, iter 2/2, loss: 0.7886, lr: 0.000006250000.\n",
569-
"[2025-03-11 22:05:31.168][ INFO](training) - epoch 2 average loss: 0.7892.\n",
549+
"[2025-03-11 22:17:02.004][ INFO](training) - Using cuda:0 of 1\n",
550+
"[2025-03-11 22:17:02.004][ INFO](training) - [config] ckpt_folder -> ./temp_work_dir/./models.\n",
551+
"[2025-03-11 22:17:02.004][ INFO](training) - [config] data_root -> ./temp_work_dir/./embeddings.\n",
552+
"[2025-03-11 22:17:02.004][ INFO](training) - [config] data_list -> ./temp_work_dir/sim_datalist.json.\n",
553+
"[2025-03-11 22:17:02.004][ INFO](training) - [config] lr -> 0.0001.\n",
554+
"[2025-03-11 22:17:02.004][ INFO](training) - [config] num_epochs -> 2.\n",
555+
"[2025-03-11 22:17:02.004][ INFO](training) - [config] num_train_timesteps -> 1000.\n",
556+
"[2025-03-11 22:17:02.005][ INFO](training) - num_files_train: 2\n",
557+
"[2025-03-11 22:17:03.887][ INFO](training) - Training from scratch.\n",
558+
"[2025-03-11 22:17:04.338][ INFO](training) - Scaling factor set to 1.159977912902832.\n",
559+
"[2025-03-11 22:17:04.339][ INFO](training) - scale_factor -> 1.159977912902832.\n",
560+
"[2025-03-11 22:17:04.341][ INFO](training) - torch.set_float32_matmul_precision -> highest.\n",
561+
"[2025-03-11 22:17:04.341][ INFO](training) - Epoch 1, lr 0.0001.\n",
562+
"[2025-03-11 22:17:05.278][ INFO](training) - [2025-03-11 22:17:05] epoch 1, iter 1/2, loss: 0.7973, lr: 0.000100000000.\n",
563+
"[2025-03-11 22:17:05.673][ INFO](training) - [2025-03-11 22:17:05] epoch 1, iter 2/2, loss: 0.7969, lr: 0.000056250000.\n",
564+
"[2025-03-11 22:17:05.718][ INFO](training) - epoch 1 average loss: 0.7971.\n",
565+
"[2025-03-11 22:17:07.383][ INFO](training) - Epoch 2, lr 2.5e-05.\n",
566+
"[2025-03-11 22:17:07.777][ INFO](training) - [2025-03-11 22:17:07] epoch 2, iter 1/2, loss: 0.7932, lr: 0.000025000000.\n",
567+
"[2025-03-11 22:17:07.881][ INFO](training) - [2025-03-11 22:17:07] epoch 2, iter 2/2, loss: 0.7904, lr: 0.000006250000.\n",
568+
"[2025-03-11 22:17:07.942][ INFO](training) - epoch 2 average loss: 0.7918.\n",
570569
"\n"
571570
]
572571
}
@@ -612,32 +611,32 @@
612611
"name": "stderr",
613612
"output_type": "stream",
614613
"text": [
615-
"[2025-03-11 22:05:35.033][ INFO](notebook) - Running inference...\n",
616-
"[2025-03-11 22:05:50.259][ INFO](notebook) - Completed all steps.\n"
614+
"[2025-03-11 22:17:11.993][ INFO](notebook) - Running inference...\n",
615+
"[2025-03-11 22:17:27.730][ INFO](notebook) - Completed all steps.\n"
617616
]
618617
},
619618
{
620619
"name": "stdout",
621620
"output_type": "stream",
622621
"text": [
623622
"\n",
624-
"[2025-03-11 22:05:43.502][ INFO](inference) - Using cuda:0 of 1 with random seed: 7854\n",
625-
"[2025-03-11 22:05:43.502][ INFO](inference) - [config] ckpt_filepath -> ./temp_work_dir/./models/diff_unet_ckpt.pt.\n",
626-
"[2025-03-11 22:05:43.502][ INFO](inference) - [config] random_seed -> 7854.\n",
627-
"[2025-03-11 22:05:43.502][ INFO](inference) - [config] output_prefix -> unet_3d.\n",
628-
"[2025-03-11 22:05:43.502][ INFO](inference) - [config] output_size -> (256, 256, 128).\n",
629-
"[2025-03-11 22:05:43.502][ INFO](inference) - [config] out_spacing -> (1.0, 1.0, 0.75).\n",
630-
"[2025-03-11 22:05:43.502][ INFO](root) - `controllable_anatomy_size` is not provided.\n",
631-
"[2025-03-11 22:05:45.793][ INFO](inference) - checkpoints ./temp_work_dir/./models/diff_unet_ckpt.pt loaded.\n",
632-
"[2025-03-11 22:05:45.795][ INFO](inference) - scale_factor -> 1.159977912902832.\n",
633-
"[2025-03-11 22:05:45.796][ INFO](inference) - num_downsample_level -> 4, divisor -> 4.\n",
634-
"[2025-03-11 22:05:45.798][ INFO](inference) - noise: cuda:0, torch.float32, <class 'torch.Tensor'>\n",
623+
"[2025-03-11 22:17:20.465][ INFO](inference) - Using cuda:0 of 1 with random seed: 23141\n",
624+
"[2025-03-11 22:17:20.466][ INFO](inference) - [config] ckpt_filepath -> ./temp_work_dir/./models/diff_unet_ckpt.pt.\n",
625+
"[2025-03-11 22:17:20.466][ INFO](inference) - [config] random_seed -> 23141.\n",
626+
"[2025-03-11 22:17:20.466][ INFO](inference) - [config] output_prefix -> unet_3d.\n",
627+
"[2025-03-11 22:17:20.466][ INFO](inference) - [config] output_size -> (256, 256, 128).\n",
628+
"[2025-03-11 22:17:20.466][ INFO](inference) - [config] out_spacing -> (1.0, 1.0, 0.75).\n",
629+
"[2025-03-11 22:17:20.466][ INFO](root) - `controllable_anatomy_size` is not provided.\n",
630+
"[2025-03-11 22:17:23.065][ INFO](inference) - checkpoints ./temp_work_dir/./models/diff_unet_ckpt.pt loaded.\n",
631+
"[2025-03-11 22:17:23.067][ INFO](inference) - scale_factor -> 1.159977912902832.\n",
632+
"[2025-03-11 22:17:23.068][ INFO](inference) - num_downsample_level -> 4, divisor -> 4.\n",
633+
"[2025-03-11 22:17:23.070][ INFO](inference) - noise: cuda:0, torch.float32, <class 'torch.Tensor'>\n",
635634
"\n",
636635
" 0%| | 0/10 [00:00<?, ?it/s]\n",
637-
" 10%|█ | 1/10 [00:00<00:05, 1.78it/s]\n",
638-
" 60%|██████ | 6/10 [00:00<00:00, 11.19it/s]\n",
639-
"100%|██████████| 10/10 [00:00<00:00, 12.88it/s]\n",
640-
"[2025-03-11 22:05:48.356][ INFO](inference) - Saved ./temp_work_dir/./predictions/unet_3d_seed7854_size256x256x128_spacing1.00x1.00x0.75_20250311220547_rank0.nii.gz.\n",
636+
" 10%|█ | 1/10 [00:00<00:07, 1.24it/s]\n",
637+
" 60%|██████ | 6/10 [00:00<00:00, 8.37it/s]\n",
638+
"100%|██████████| 10/10 [00:01<00:00, 9.78it/s]\n",
639+
"[2025-03-11 22:17:25.828][ INFO](inference) - Saved ./temp_work_dir/./predictions/unet_3d_seed23141_size256x256x128_spacing1.00x1.00x0.75_20250311221725_rank0.nii.gz.\n",
641640
"\n"
642641
]
643642
}

0 commit comments

Comments
 (0)