Skip to content

Commit f448e9e

Browse files
committed
reformat
Signed-off-by: Can-Zhao <[email protected]>
2 parents 4187565 + 9be20ed commit f448e9e

File tree

1 file changed

+87
-66
lines changed

1 file changed

+87
-66
lines changed

generation/maisi/maisi_diff_unet_training_tutorial.ipynb

+87-66
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,30 @@
3131
"`[Release Note (March 2025)]:` We are excited to announce the new MAISI Version `'maisi-rflow'`. Compared with the previous version `'maisi-ddpm'`, it accelerated latent diffusion model inference by 33x. Please see the detailed difference in the following section."
3232
]
3333
},
34+
{
35+
"cell_type": "markdown",
36+
"id": "aa51792d",
37+
"metadata": {},
38+
"source": [
39+
"## Set up the MAISI version\n",
40+
"\n",
41+
"Choose between `'maisi-ddpm'` and `'maisi-rflow'`. The differences are:\n",
42+
"- The maisi version `'maisi-ddpm'` uses basic noise scheduler DDPM. `'maisi-rflow'` uses Rectified Flow scheduler, can be 33 times faster during inference.\n",
43+
"- The maisi version `'maisi-ddpm'` requires training images to be labeled with body region (`\"top_region_index\"` and `\"bottom_region_index\"`), while `'maisi-rflow'` does not have such requirement. In other words, it is easier to prepare training data for `'maisi-rflow'`.\n",
44+
"- For the released model weights, `'maisi-rflow'` can generate images with better quality for head region and small output volumes, and comparable quality for other cases compared with `'maisi-ddpm'`."
45+
]
46+
},
47+
{
48+
"cell_type": "code",
49+
"execution_count": 1,
50+
"id": "828b9ece-7759-40e8-ac4c-6467c3399701",
51+
"metadata": {},
52+
"outputs": [],
53+
"source": [
54+
"maisi_version = \"maisi-ddpm\"\n",
55+
"assert maisi_version in [\"maisi-ddpm\", \"maisi-rflow\"]"
56+
]
57+
},
3458
{
3559
"cell_type": "markdown",
3660
"id": "c9ecfb90",
@@ -41,7 +65,7 @@
4165
},
4266
{
4367
"cell_type": "code",
44-
"execution_count": 1,
68+
"execution_count": 2,
4569
"id": "58cbde9b",
4670
"metadata": {},
4771
"outputs": [],
@@ -59,7 +83,7 @@
5983
},
6084
{
6185
"cell_type": "code",
62-
"execution_count": 13,
86+
"execution_count": 3,
6387
"id": "e3bf0346",
6488
"metadata": {},
6589
"outputs": [
@@ -136,8 +160,8 @@
136160
},
137161
{
138162
"cell_type": "code",
139-
"execution_count": 3,
140-
"id": "828b9ece-7759-40e8-ac4c-6467c3399701",
163+
"execution_count": 4,
164+
"id": "31684f74",
141165
"metadata": {},
142166
"outputs": [],
143167
"source": [
@@ -159,7 +183,7 @@
159183
},
160184
{
161185
"cell_type": "code",
162-
"execution_count": 4,
186+
"execution_count": 5,
163187
"id": "fc32a7fe",
164188
"metadata": {},
165189
"outputs": [],
@@ -181,15 +205,15 @@
181205
},
182206
{
183207
"cell_type": "code",
184-
"execution_count": 5,
208+
"execution_count": 6,
185209
"id": "1b199078",
186210
"metadata": {},
187211
"outputs": [
188212
{
189213
"name": "stderr",
190214
"output_type": "stream",
191215
"text": [
192-
"[2025-03-11 21:46:47.184][ INFO](notebook) - Generated simulated images.\n"
216+
"[2025-03-11 22:05:02.952][ INFO](notebook) - Generated simulated images.\n"
193217
]
194218
}
195219
],
@@ -228,26 +252,26 @@
228252
},
229253
{
230254
"cell_type": "code",
231-
"execution_count": 6,
255+
"execution_count": 7,
232256
"id": "6c7b434c",
233257
"metadata": {},
234258
"outputs": [
235259
{
236260
"name": "stderr",
237261
"output_type": "stream",
238262
"text": [
239-
"[2025-03-11 21:46:47.199][ INFO](notebook) - files and folders under work_dir: ['predictions', 'config_maisi.json', 'models', 'sim_dataroot', 'config_maisi_diff_model.json', 'embeddings', 'environment_maisi_diff_model.json', 'sim_datalist.json'].\n",
240-
"[2025-03-11 21:46:47.199][ INFO](notebook) - number of GPUs: 1.\n"
263+
"[2025-03-11 22:05:02.966][ INFO](notebook) - files and folders under work_dir: ['predictions', 'config_maisi.json', 'models', 'sim_dataroot', 'config_maisi_diff_model.json', 'embeddings', 'environment_maisi_diff_model.json', 'sim_datalist.json'].\n",
264+
"[2025-03-11 22:05:02.966][ INFO](notebook) - number of GPUs: 1.\n"
241265
]
242266
}
243267
],
244268
"source": [
245269
"env_config_path = \"./configs/environment_maisi_diff_model.json\"\n",
246270
"model_config_path = \"./configs/config_maisi_diff_model.json\"\n",
247-
"if maisi_version == 'maisi-ddpm':\n",
271+
"if maisi_version == \"maisi-ddpm\":\n",
248272
" model_def_path = \"./configs/config_maisi-ddpm.json\"\n",
249273
" include_body_region = True\n",
250-
"elif maisi_version == 'maisi-rflow':\n",
274+
"elif maisi_version == \"maisi-rflow\":\n",
251275
" model_def_path = \"./configs/config_maisi-rflow.json\"\n",
252276
" include_body_region = False\n",
253277
"else:\n",
@@ -315,7 +339,7 @@
315339
},
316340
{
317341
"cell_type": "code",
318-
"execution_count": 7,
342+
"execution_count": 8,
319343
"id": "95ea6972",
320344
"metadata": {},
321345
"outputs": [],
@@ -375,24 +399,24 @@
375399
},
376400
{
377401
"cell_type": "code",
378-
"execution_count": 8,
402+
"execution_count": 9,
379403
"id": "f45ea863",
380404
"metadata": {},
381405
"outputs": [
382406
{
383407
"name": "stderr",
384408
"output_type": "stream",
385409
"text": [
386-
"[2025-03-11 21:46:47.210][ INFO](notebook) - Creating training data...\n"
410+
"[2025-03-11 22:05:02.977][ INFO](notebook) - Creating training data...\n"
387411
]
388412
},
389413
{
390414
"name": "stdout",
391415
"output_type": "stream",
392416
"text": [
393417
"\n",
394-
"[2025-03-11 21:46:57.369][ INFO](creating training data) - Using device cuda:0\n",
395-
"[2025-03-11 21:46:58.170][ INFO](creating training data) - filenames_raw: ['tr_image_001.nii.gz', 'tr_image_002.nii.gz']\n",
418+
"[2025-03-11 22:05:10.881][ INFO](creating training data) - Using device cuda:0\n",
419+
"[2025-03-11 22:05:11.686][ INFO](creating training data) - filenames_raw: ['tr_image_001.nii.gz', 'tr_image_002.nii.gz']\n",
396420
"\n"
397421
]
398422
}
@@ -428,17 +452,17 @@
428452
},
429453
{
430454
"cell_type": "code",
431-
"execution_count": 9,
455+
"execution_count": 10,
432456
"id": "0221a658",
433457
"metadata": {},
434458
"outputs": [
435459
{
436460
"name": "stderr",
437461
"output_type": "stream",
438462
"text": [
439-
"[2025-03-11 21:47:00.412][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n",
440-
"[2025-03-11 21:47:00.414][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n",
441-
"[2025-03-11 21:47:00.415][ INFO](notebook) - Completed creating .json files for all embedding files.\n"
463+
"[2025-03-11 22:05:13.881][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n",
464+
"[2025-03-11 22:05:13.884][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n",
465+
"[2025-03-11 22:05:13.885][ INFO](notebook) - Completed creating .json files for all embedding files.\n"
442466
]
443467
}
444468
],
@@ -467,10 +491,7 @@
467491
" spacing = [float(_item) for _item in spacing]\n",
468492
"\n",
469493
" # Create the dictionary with the specified keys and values\n",
470-
" data = {\n",
471-
" \"dim\": dimensions,\n",
472-
" \"spacing\": spacing\n",
473-
" }\n",
494+
" data = {\"dim\": dimensions, \"spacing\": spacing}\n",
474495
" if include_body_region:\n",
475496
" # The region can be selected from one of four regions from top to bottom.\n",
476497
" # [1,0,0,0] is the head and neck, [0,1,0,0] is the chest region, [0,0,1,0]\n",
@@ -510,42 +531,42 @@
510531
},
511532
{
512533
"cell_type": "code",
513-
"execution_count": 10,
534+
"execution_count": 11,
514535
"id": "ade6389d",
515536
"metadata": {},
516537
"outputs": [
517538
{
518539
"name": "stderr",
519540
"output_type": "stream",
520541
"text": [
521-
"[2025-03-11 21:47:00.420][ INFO](notebook) - Training the model...\n"
542+
"[2025-03-11 22:05:13.892][ INFO](notebook) - Training the model...\n"
522543
]
523544
},
524545
{
525546
"name": "stdout",
526547
"output_type": "stream",
527548
"text": [
528549
"\n",
529-
"[2025-03-11 21:47:09.081][ INFO](training) - Using cuda:0 of 1\n",
530-
"[2025-03-11 21:47:09.081][ INFO](training) - [config] ckpt_folder -> ./temp_work_dir/./models.\n",
531-
"[2025-03-11 21:47:09.081][ INFO](training) - [config] data_root -> ./temp_work_dir/./embeddings.\n",
532-
"[2025-03-11 21:47:09.081][ INFO](training) - [config] data_list -> ./temp_work_dir/sim_datalist.json.\n",
533-
"[2025-03-11 21:47:09.081][ INFO](training) - [config] lr -> 0.0001.\n",
534-
"[2025-03-11 21:47:09.081][ INFO](training) - [config] num_epochs -> 2.\n",
535-
"[2025-03-11 21:47:09.081][ INFO](training) - [config] num_train_timesteps -> 1000.\n",
536-
"[2025-03-11 21:47:09.081][ INFO](training) - num_files_train: 2\n",
537-
"[2025-03-11 21:47:10.815][ INFO](training) - Training from scratch.\n",
538-
"[2025-03-11 21:47:11.273][ INFO](training) - Scaling factor set to 1.159977912902832.\n",
539-
"[2025-03-11 21:47:11.273][ INFO](training) - scale_factor -> 1.159977912902832.\n",
540-
"[2025-03-11 21:47:11.276][ INFO](training) - torch.set_float32_matmul_precision -> highest.\n",
541-
"[2025-03-11 21:47:11.276][ INFO](training) - Epoch 1, lr 0.0001.\n",
542-
"[2025-03-11 21:47:12.253][ INFO](training) - [2025-03-11 21:47:12] epoch 1, iter 1/2, loss: 0.7979, lr: 0.000100000000.\n",
543-
"[2025-03-11 21:47:12.535][ INFO](training) - [2025-03-11 21:47:12] epoch 1, iter 2/2, loss: 0.7931, lr: 0.000056250000.\n",
544-
"[2025-03-11 21:47:12.572][ INFO](training) - epoch 1 average loss: 0.7955.\n",
545-
"[2025-03-11 21:47:14.031][ INFO](training) - Epoch 2, lr 2.5e-05.\n",
546-
"[2025-03-11 21:47:14.420][ INFO](training) - [2025-03-11 21:47:14] epoch 2, iter 1/2, loss: 0.7883, lr: 0.000025000000.\n",
547-
"[2025-03-11 21:47:14.517][ INFO](training) - [2025-03-11 21:47:14] epoch 2, iter 2/2, loss: 0.7893, lr: 0.000006250000.\n",
548-
"[2025-03-11 21:47:14.594][ INFO](training) - epoch 2 average loss: 0.7888.\n",
550+
"[2025-03-11 22:05:24.419][ INFO](training) - Using cuda:0 of 1\n",
551+
"[2025-03-11 22:05:24.419][ INFO](training) - [config] ckpt_folder -> ./temp_work_dir/./models.\n",
552+
"[2025-03-11 22:05:24.419][ INFO](training) - [config] data_root -> ./temp_work_dir/./embeddings.\n",
553+
"[2025-03-11 22:05:24.419][ INFO](training) - [config] data_list -> ./temp_work_dir/sim_datalist.json.\n",
554+
"[2025-03-11 22:05:24.419][ INFO](training) - [config] lr -> 0.0001.\n",
555+
"[2025-03-11 22:05:24.419][ INFO](training) - [config] num_epochs -> 2.\n",
556+
"[2025-03-11 22:05:24.419][ INFO](training) - [config] num_train_timesteps -> 1000.\n",
557+
"[2025-03-11 22:05:24.420][ INFO](training) - num_files_train: 2\n",
558+
"[2025-03-11 22:05:26.152][ INFO](training) - Training from scratch.\n",
559+
"[2025-03-11 22:05:26.539][ INFO](training) - Scaling factor set to 1.159977912902832.\n",
560+
"[2025-03-11 22:05:26.539][ INFO](training) - scale_factor -> 1.159977912902832.\n",
561+
"[2025-03-11 22:05:26.542][ INFO](training) - torch.set_float32_matmul_precision -> highest.\n",
562+
"[2025-03-11 22:05:26.542][ INFO](training) - Epoch 1, lr 0.0001.\n",
563+
"[2025-03-11 22:05:28.578][ INFO](training) - [2025-03-11 22:05:28] epoch 1, iter 1/2, loss: 0.7974, lr: 0.000100000000.\n",
564+
"[2025-03-11 22:05:28.719][ INFO](training) - [2025-03-11 22:05:28] epoch 1, iter 2/2, loss: 0.7943, lr: 0.000056250000.\n",
565+
"[2025-03-11 22:05:28.762][ INFO](training) - epoch 1 average loss: 0.7958.\n",
566+
"[2025-03-11 22:05:30.615][ INFO](training) - Epoch 2, lr 2.5e-05.\n",
567+
"[2025-03-11 22:05:31.002][ INFO](training) - [2025-03-11 22:05:31] epoch 2, iter 1/2, loss: 0.7898, lr: 0.000025000000.\n",
568+
"[2025-03-11 22:05:31.105][ INFO](training) - [2025-03-11 22:05:31] epoch 2, iter 2/2, loss: 0.7886, lr: 0.000006250000.\n",
569+
"[2025-03-11 22:05:31.168][ INFO](training) - epoch 2 average loss: 0.7892.\n",
549570
"\n"
550571
]
551572
}
@@ -583,40 +604,40 @@
583604
},
584605
{
585606
"cell_type": "code",
586-
"execution_count": 11,
607+
"execution_count": 12,
587608
"id": "1626526d",
588609
"metadata": {},
589610
"outputs": [
590611
{
591612
"name": "stderr",
592613
"output_type": "stream",
593614
"text": [
594-
"[2025-03-11 21:47:18.262][ INFO](notebook) - Running inference...\n",
595-
"[2025-03-11 21:47:35.148][ INFO](notebook) - Completed all steps.\n"
615+
"[2025-03-11 22:05:35.033][ INFO](notebook) - Running inference...\n",
616+
"[2025-03-11 22:05:50.259][ INFO](notebook) - Completed all steps.\n"
596617
]
597618
},
598619
{
599620
"name": "stdout",
600621
"output_type": "stream",
601622
"text": [
602623
"\n",
603-
"[2025-03-11 21:47:27.859][ INFO](inference) - Using cuda:0 of 1 with random seed: 99760\n",
604-
"[2025-03-11 21:47:27.859][ INFO](inference) - [config] ckpt_filepath -> ./temp_work_dir/./models/diff_unet_ckpt.pt.\n",
605-
"[2025-03-11 21:47:27.860][ INFO](inference) - [config] random_seed -> 99760.\n",
606-
"[2025-03-11 21:47:27.860][ INFO](inference) - [config] output_prefix -> unet_3d.\n",
607-
"[2025-03-11 21:47:27.860][ INFO](inference) - [config] output_size -> (256, 256, 128).\n",
608-
"[2025-03-11 21:47:27.860][ INFO](inference) - [config] out_spacing -> (1.0, 1.0, 0.75).\n",
609-
"[2025-03-11 21:47:27.860][ INFO](root) - `controllable_anatomy_size` is not provided.\n",
610-
"[2025-03-11 21:47:30.510][ INFO](inference) - checkpoints ./temp_work_dir/./models/diff_unet_ckpt.pt loaded.\n",
611-
"[2025-03-11 21:47:30.512][ INFO](inference) - scale_factor -> 1.159977912902832.\n",
612-
"[2025-03-11 21:47:30.512][ INFO](inference) - num_downsample_level -> 4, divisor -> 4.\n",
613-
"[2025-03-11 21:47:30.514][ INFO](inference) - noise: cuda:0, torch.float32, <class 'torch.Tensor'>\n",
624+
"[2025-03-11 22:05:43.502][ INFO](inference) - Using cuda:0 of 1 with random seed: 7854\n",
625+
"[2025-03-11 22:05:43.502][ INFO](inference) - [config] ckpt_filepath -> ./temp_work_dir/./models/diff_unet_ckpt.pt.\n",
626+
"[2025-03-11 22:05:43.502][ INFO](inference) - [config] random_seed -> 7854.\n",
627+
"[2025-03-11 22:05:43.502][ INFO](inference) - [config] output_prefix -> unet_3d.\n",
628+
"[2025-03-11 22:05:43.502][ INFO](inference) - [config] output_size -> (256, 256, 128).\n",
629+
"[2025-03-11 22:05:43.502][ INFO](inference) - [config] out_spacing -> (1.0, 1.0, 0.75).\n",
630+
"[2025-03-11 22:05:43.502][ INFO](root) - `controllable_anatomy_size` is not provided.\n",
631+
"[2025-03-11 22:05:45.793][ INFO](inference) - checkpoints ./temp_work_dir/./models/diff_unet_ckpt.pt loaded.\n",
632+
"[2025-03-11 22:05:45.795][ INFO](inference) - scale_factor -> 1.159977912902832.\n",
633+
"[2025-03-11 22:05:45.796][ INFO](inference) - num_downsample_level -> 4, divisor -> 4.\n",
634+
"[2025-03-11 22:05:45.798][ INFO](inference) - noise: cuda:0, torch.float32, <class 'torch.Tensor'>\n",
614635
"\n",
615636
" 0%| | 0/10 [00:00<?, ?it/s]\n",
616-
" 10%|█ | 1/10 [00:00<00:07, 1.24it/s]\n",
617-
" 60%|██████ | 6/10 [00:00<00:00, 8.39it/s]\n",
618-
"100%|██████████| 10/10 [00:01<00:00, 9.80it/s]\n",
619-
"[2025-03-11 21:47:33.116][ INFO](inference) - Saved ./temp_work_dir/./predictions/unet_3d_seed99760_size256x256x128_spacing1.00x1.00x0.75_20250311214732_rank0.nii.gz.\n",
637+
" 10%|█ | 1/10 [00:00<00:05, 1.78it/s]\n",
638+
" 60%|██████ | 6/10 [00:00<00:00, 11.19it/s]\n",
639+
"100%|██████████| 10/10 [00:00<00:00, 12.88it/s]\n",
640+
"[2025-03-11 22:05:48.356][ INFO](inference) - Saved ./temp_work_dir/./predictions/unet_3d_seed7854_size256x256x128_spacing1.00x1.00x0.75_20250311220547_rank0.nii.gz.\n",
620641
"\n"
621642
]
622643
}
@@ -654,7 +675,7 @@
654675
},
655676
{
656677
"cell_type": "code",
657-
"execution_count": 12,
678+
"execution_count": 13,
658679
"id": "0d8a344d",
659680
"metadata": {},
660681
"outputs": [

0 commit comments

Comments
 (0)