Skip to content

Commit 151177d

Browse files
committed
add modality as input, make inference notebook excutable for ddpm and rflow
Signed-off-by: Can-Zhao <[email protected]>
1 parent 2dc8039 commit 151177d

15 files changed

+821
-990
lines changed

generation/maisi/configs/config_infer.json

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"body_region": ["abdomen"],
44
"anatomy_list": ["liver","hepatic tumor"],
55
"controllable_anatomy_size": [],
6-
"num_inference_steps": 1000,
6+
"num_inference_steps": 30,
77
"mask_generation_num_inference_steps": 1000,
88
"output_size": [
99
256,
@@ -23,5 +23,6 @@
2323
"diffusion_unet": "$@diffusion_unet_def",
2424
"autoencoder": "$@autoencoder_def",
2525
"mask_generation_autoencoder": "$@mask_generation_autoencoder_def",
26-
"mask_generation_diffusion": "$@mask_generation_diffusion_def"
26+
"mask_generation_diffusion": "$@mask_generation_diffusion_def",
27+
"modality": 1
2728
}

generation/maisi/configs/config_maisi3d-ddpm.json

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"spatial_dims": 3,
33
"image_channels": 1,
44
"latent_channels": 4,
5+
"include_body_region": true,
56
"mask_generation_latent_shape": [
67
4,
78
64,
@@ -60,8 +61,8 @@
6061
],
6162
"num_res_blocks": 2,
6263
"use_flash_attention": true,
63-
"include_top_region_index_input": true,
64-
"include_bottom_region_index_input": true,
64+
"include_top_region_index_input": "@include_body_region",
65+
"include_bottom_region_index_input": "@include_body_region",
6566
"include_spacing_input": true
6667
},
6768
"controlnet_def": {

generation/maisi/configs/config_maisi3d-rflow.json

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"spatial_dims": 3,
33
"image_channels": 1,
44
"latent_channels": 4,
5+
"include_body_region": false,
56
"mask_generation_latent_shape": [
67
4,
78
64,
@@ -55,8 +56,8 @@
5556
],
5657
"num_res_blocks": 2,
5758
"use_flash_attention": true,
58-
"include_top_region_index_input": false,
59-
"include_bottom_region_index_input": false,
59+
"include_top_region_index_input": "@include_body_region",
60+
"include_bottom_region_index_input": "@include_body_region",
6061
"include_spacing_input": true,
6162
"num_class_embeds": 128,
6263
"resblock_updown": true,

generation/maisi/configs/config_maisi_controlnet_train.json

+4-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
"weighted_loss": 100
1010
},
1111
"controlnet_infer": {
12-
"num_inference_steps": 1000,
13-
"autoencoder_sliding_window_infer_size": [96, 96, 96]
12+
"num_inference_steps": 10,
13+
"autoencoder_sliding_window_infer_size": [80, 80, 80],
14+
"autoencoder_sliding_window_infer_overlap": 0.4,
15+
"modality": 1
1416
}
1517
}

generation/maisi/configs/config_maisi_diff_model.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
0
3030
],
3131
"random_seed": 0,
32-
"num_inference_steps": 10
32+
"num_inference_steps": 10,
33+
"modality": 1
3334
}
3435
}

generation/maisi/maisi_inference_tutorial.ipynb

+161-244
Large diffs are not rendered by default.

generation/maisi/maisi_train_controlnet_tutorial.ipynb

+48-56
Original file line numberDiff line numberDiff line change
@@ -141,20 +141,22 @@
141141
"name": "stderr",
142142
"output_type": "stream",
143143
"text": [
144-
"[2025-03-11 23:38:43.304][ INFO](notebook) - Using MAISI version maisi3d-ddpm. Will need body region as data input.\n"
144+
"[2025-03-12 22:27:22.838][ INFO](notebook) - MAISI version is maisi3d-ddpm, whether to use body_region is True\n"
145145
]
146146
}
147147
],
148148
"source": [
149149
"maisi_version = \"maisi3d-ddpm\"\n",
150150
"if maisi_version == \"maisi3d-ddpm\":\n",
151-
" include_body_region = True\n",
152-
" logger.info(\"Using MAISI version maisi3d-ddpm. Will need body region as data input.\")\n",
151+
" model_def_path = \"./configs/config_maisi3d-ddpm.json\"\n",
153152
"elif maisi_version == \"maisi3d-rflow\":\n",
154-
" include_body_region = False\n",
155-
" logger.info(\"Using MAISI version maisi3d-rflow. Does not need body region as data input.\")\n",
153+
" model_def_path = \"./configs/config_maisi3d-rflow.json\"\n",
156154
"else:\n",
157-
" raise ValueError(f\"maisi_version has to be chosen from ['maisi3d-ddpm', 'maisi3d-rflow'], yet got {maisi_version}.\")"
155+
" raise ValueError(f\"maisi_version has to be chosen from ['maisi3d-ddpm', 'maisi3d-rflow'], yet got {maisi_version}.\")\n",
156+
"with open(model_def_path, \"r\") as f:\n",
157+
" model_def = json.load(f)\n",
158+
"include_body_region = model_def[\"include_body_region\"]\n",
159+
"logger.info(f\"MAISI version is {maisi_version}, whether to use body_region is {include_body_region}\")"
158160
]
159161
},
160162
{
@@ -246,9 +248,9 @@
246248
"name": "stderr",
247249
"output_type": "stream",
248250
"text": [
249-
"[2025-03-11 23:38:45.473][ INFO](notebook) - Generated simulated images.\n",
250-
"[2025-03-11 23:38:45.474][ INFO](notebook) - img_emb shape: (64, 64, 32, 4)\n",
251-
"[2025-03-11 23:38:45.475][ INFO](notebook) - label shape: (256, 256, 128)\n"
251+
"[2025-03-12 22:27:25.046][ INFO](notebook) - Generated simulated images.\n",
252+
"[2025-03-12 22:27:25.047][ INFO](notebook) - img_emb shape: (64, 64, 32, 4)\n",
253+
"[2025-03-12 22:27:25.048][ INFO](notebook) - label shape: (256, 256, 128)\n"
252254
]
253255
}
254256
],
@@ -320,20 +322,14 @@
320322
"name": "stderr",
321323
"output_type": "stream",
322324
"text": [
323-
"[2025-03-11 23:38:45.489][ INFO](notebook) - files and folders under work_dir: ['config_maisi.json', 'models', 'config_maisi_controlnet_train.json', 'outputs', 'sim_dataroot', 'environment_maisi_controlnet_train.json', 'sim_datalist.json'].\n",
324-
"[2025-03-11 23:38:45.490][ INFO](notebook) - number of GPUs: 1.\n"
325+
"[2025-03-12 22:27:25.062][ INFO](notebook) - files and folders under work_dir: ['config_maisi.json', 'models', 'config_maisi_controlnet_train.json', 'outputs', 'sim_dataroot', 'environment_maisi_controlnet_train.json', 'sim_datalist.json'].\n",
326+
"[2025-03-12 22:27:25.063][ INFO](notebook) - number of GPUs: 1.\n"
325327
]
326328
}
327329
],
328330
"source": [
329331
"env_config_path = \"./configs/environment_maisi_controlnet_train.json\"\n",
330332
"train_config_path = \"./configs/config_maisi_controlnet_train.json\"\n",
331-
"if maisi_version == \"maisi3d-ddpm\":\n",
332-
" model_def_path = \"./configs/config_maisi3d-ddpm.json\"\n",
333-
"elif maisi_version == \"maisi3d-rflow\":\n",
334-
" model_def_path = \"./configs/config_maisi3d-rflow.json\"\n",
335-
"else:\n",
336-
" raise ValueError(f\"maisi_version has to be chosen from ['maisi3d-ddpm', 'maisi3d-rflow'], yet got {maisi_version}.\")\n",
337333
"\n",
338334
"# Load environment configuration, model configuration and model definition\n",
339335
"with open(env_config_path, \"r\") as f:\n",
@@ -472,29 +468,29 @@
472468
"name": "stderr",
473469
"output_type": "stream",
474470
"text": [
475-
"[2025-03-11 23:38:45.501][ INFO](notebook) - Training the model...\n"
471+
"[2025-03-12 22:27:25.074][ INFO](notebook) - Training the model...\n"
476472
]
477473
},
478474
{
479475
"name": "stdout",
480476
"output_type": "stream",
481477
"text": [
482-
"[2025-03-11 23:38:54.835][ INFO](maisi.controlnet.training) - Number of GPUs: 8\n",
483-
"[2025-03-11 23:38:54.835][ INFO](maisi.controlnet.training) - World_size: 1\n",
484-
"[2025-03-11 23:38:56.401][ INFO](maisi.controlnet.training) - trained diffusion model is not loaded.\n",
485-
"[2025-03-11 23:38:56.401][ INFO](maisi.controlnet.training) - set scale_factor -> 1.0.\n",
486-
"2025-03-11 23:38:56,899 - INFO - 'dst' model updated: 158 of 206 variables.\n",
487-
"[2025-03-11 23:38:56.903][ INFO](maisi.controlnet.training) - train controlnet model from scratch.\n",
488-
"[2025-03-11 23:38:56.925][ INFO](maisi.controlnet.training) - total number of training steps: 4.0.\n",
489-
"[2025-03-11 23:38:58.871][ INFO](maisi.controlnet.training) -\n",
490-
"[Epoch 1/2] [Batch 1/2] [LR: 0.00000563] [loss: 0.7972] ETA: 0:00:01.944427\n",
491-
"[2025-03-11 23:38:59.018][ INFO](maisi.controlnet.training) -\n",
492-
"[Epoch 1/2] [Batch 2/2] [LR: 0.00000250] [loss: 0.7981] ETA: 0:00:00\n",
493-
"[2025-03-11 23:38:59.775][ INFO](maisi.controlnet.training) - best loss -> 0.7976870536804199.\n",
494-
"[2025-03-11 23:39:00.998][ INFO](maisi.controlnet.training) -\n",
495-
"[Epoch 2/2] [Batch 1/2] [LR: 0.00000063] [loss: 0.7971] ETA: 0:00:01.979231\n",
496-
"[2025-03-11 23:39:01.129][ INFO](maisi.controlnet.training) -\n",
497-
"[Epoch 2/2] [Batch 2/2] [LR: 0.00000000] [loss: 0.7994] ETA: 0:00:00\n",
478+
"[2025-03-12 22:27:33.707][ INFO](maisi.controlnet.training) - Number of GPUs: 8\n",
479+
"[2025-03-12 22:27:33.708][ INFO](maisi.controlnet.training) - World_size: 1\n",
480+
"[2025-03-12 22:27:35.410][ INFO](maisi.controlnet.training) - trained diffusion model is not loaded.\n",
481+
"[2025-03-12 22:27:35.410][ INFO](maisi.controlnet.training) - set scale_factor -> 1.0.\n",
482+
"2025-03-12 22:27:35,902 - INFO - 'dst' model updated: 158 of 206 variables.\n",
483+
"[2025-03-12 22:27:35.907][ INFO](maisi.controlnet.training) - train controlnet model from scratch.\n",
484+
"[2025-03-12 22:27:35.930][ INFO](maisi.controlnet.training) - total number of training steps: 4.0.\n",
485+
"[2025-03-12 22:27:38.006][ INFO](maisi.controlnet.training) -\n",
486+
"[Epoch 1/2] [Batch 1/2] [LR: 0.00000563] [loss: 0.7976] ETA: 0:00:02.073507\n",
487+
"[2025-03-12 22:27:38.147][ INFO](maisi.controlnet.training) -\n",
488+
"[Epoch 1/2] [Batch 2/2] [LR: 0.00000250] [loss: 0.7985] ETA: 0:00:00\n",
489+
"[2025-03-12 22:27:38.683][ INFO](maisi.controlnet.training) - best loss -> 0.7980280518531799.\n",
490+
"[2025-03-12 22:27:39.955][ INFO](maisi.controlnet.training) -\n",
491+
"[Epoch 2/2] [Batch 1/2] [LR: 0.00000063] [loss: 0.7992] ETA: 0:00:01.807460\n",
492+
"[2025-03-12 22:27:40.086][ INFO](maisi.controlnet.training) -\n",
493+
"[Epoch 2/2] [Batch 2/2] [LR: 0.00000000] [loss: 0.7980] ETA: 0:00:00\n",
498494
"\n"
499495
]
500496
}
@@ -512,8 +508,6 @@
512508
" \"--training-config\",\n",
513509
" train_config_filepath,\n",
514510
"]\n",
515-
"if include_body_region:\n",
516-
" module_args.append(\"--include_body_region\")\n",
517511
"\n",
518512
"run_torchrun(module, module_args, num_gpus=num_gpus)"
519513
]
@@ -539,32 +533,32 @@
539533
"name": "stderr",
540534
"output_type": "stream",
541535
"text": [
542-
"[2025-03-11 23:39:03.635][ INFO](notebook) - Inference...\n"
536+
"[2025-03-12 22:27:42.632][ INFO](notebook) - Inference...\n"
543537
]
544538
},
545539
{
546540
"name": "stdout",
547541
"output_type": "stream",
548542
"text": [
549-
"[2025-03-11 23:39:13.628][ INFO](maisi.controlnet.infer) - Number of GPUs: 8\n",
550-
"[2025-03-11 23:39:13.628][ INFO](maisi.controlnet.infer) - World_size: 1\n",
551-
"[2025-03-11 23:39:14.205][ INFO](maisi.controlnet.infer) - trained autoencoder model is not loaded.\n",
552-
"[2025-03-11 23:39:15.418][ INFO](maisi.controlnet.infer) - trained diffusion model is not loaded.\n",
553-
"[2025-03-11 23:39:15.418][ INFO](maisi.controlnet.infer) - set scale_factor -> 1.0.\n",
554-
"2025-03-11 23:39:15,917 - INFO - 'dst' model updated: 158 of 206 variables.\n",
555-
"[2025-03-11 23:39:15.922][ INFO](maisi.controlnet.infer) - trained controlnet is not loaded.\n",
556-
"[2025-03-11 23:39:16.582][ INFO](root) - `controllable_anatomy_size` is not provided.\n",
557-
"[2025-03-11 23:39:16.584][ INFO](root) - ---- Start generating latent features... ----\n",
558-
"[2025-03-11 23:39:17.178][ INFO](root) - ---- Latent features generation time: 0.5939664840698242 seconds ----\n",
559-
"[2025-03-11 23:39:17.180][ INFO](root) - ---- Start decoding latent features into images... ----\n",
560-
"[2025-03-11 23:39:18.003][ INFO](root) - ---- Image decoding time: 0.8231167793273926 seconds ----\n",
561-
"2025-03-11 23:39:18,299 INFO image_writer.py:197 - writing: temp_work_dir_controlnet_train_demo/outputs/sample_20250311_233918_283950_image.nii.gz\n",
562-
"2025-03-11 23:39:18,649 INFO image_writer.py:197 - writing: temp_work_dir_controlnet_train_demo/outputs/sample_20250311_233918_283950_label.nii.gz\n",
543+
"[2025-03-12 22:27:53.399][ INFO](maisi.controlnet.infer) - Number of GPUs: 8\n",
544+
"[2025-03-12 22:27:53.400][ INFO](maisi.controlnet.infer) - World_size: 1\n",
545+
"[2025-03-12 22:27:54.101][ INFO](maisi.controlnet.infer) - trained autoencoder model is not loaded.\n",
546+
"[2025-03-12 22:27:55.286][ INFO](maisi.controlnet.infer) - trained diffusion model is not loaded.\n",
547+
"[2025-03-12 22:27:55.286][ INFO](maisi.controlnet.infer) - set scale_factor -> 1.0.\n",
548+
"2025-03-12 22:27:55,756 - INFO - 'dst' model updated: 158 of 206 variables.\n",
549+
"[2025-03-12 22:27:55.761][ INFO](maisi.controlnet.infer) - trained controlnet is not loaded.\n",
550+
"[2025-03-12 22:27:56.340][ INFO](root) - `controllable_anatomy_size` is not provided.\n",
551+
"[2025-03-12 22:27:56.344][ INFO](root) - ---- Start generating latent features... ----\n",
552+
"[2025-03-12 22:27:58.065][ INFO](root) - ---- Latent features generation time: 1.7215001583099365 seconds ----\n",
553+
"[2025-03-12 22:27:58.066][ INFO](root) - ---- Start decoding latent features into images... ----\n",
554+
"[2025-03-12 22:27:58.838][ INFO](root) - ---- Image decoding time: 0.7712326049804688 seconds ----\n",
555+
"2025-03-12 22:27:59,142 INFO image_writer.py:197 - writing: temp_work_dir_controlnet_train_demo/outputs/sample_20250312_222759_124463_image.nii.gz\n",
556+
"2025-03-12 22:27:59,487 INFO image_writer.py:197 - writing: temp_work_dir_controlnet_train_demo/outputs/sample_20250312_222759_124463_label.nii.gz\n",
563557
"\n",
564558
"\n",
565-
" 0%| | 0/1 [00:00<?, ?it/s]\n",
566-
"100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.92it/s]\n",
567-
"100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.92it/s]\n",
559+
" 0%| | 0/1 [00:00<?, ?it/s]\n",
560+
"100%|██████████| 1/1 [00:01<00:00, 1.62s/it]\n",
561+
"100%|██████████| 1/1 [00:01<00:00, 1.62s/it]\n",
568562
"\n"
569563
]
570564
}
@@ -582,8 +576,6 @@
582576
" \"--training-config\",\n",
583577
" train_config_filepath,\n",
584578
"]\n",
585-
"if include_body_region:\n",
586-
" module_args.append(\"--include_body_region\")\n",
587579
"\n",
588580
"run_torchrun(module, module_args, num_gpus=num_gpus)"
589581
]

0 commit comments

Comments
 (0)