Skip to content

Commit eb631d1

Browse files
committed
Update diff_model_train and make it compartitble with previous DDPM. Tsted with DDPM, not with rflow in this PR
Signed-off-by: Can-Zhao <[email protected]>
1 parent 4a40380 commit eb631d1

File tree

5 files changed

+412
-39
lines changed

5 files changed

+412
-39
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
{
2+
"spatial_dims": 3,
3+
"image_channels": 1,
4+
"latent_channels": 4,
5+
"mask_generation_latent_shape": [
6+
4,
7+
64,
8+
64,
9+
64
10+
],
11+
"autoencoder_def": {
12+
"_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi",
13+
"spatial_dims": "@spatial_dims",
14+
"in_channels": "@image_channels",
15+
"out_channels": "@image_channels",
16+
"latent_channels": "@latent_channels",
17+
"num_channels": [
18+
64,
19+
128,
20+
256
21+
],
22+
"num_res_blocks": [2,2,2],
23+
"norm_num_groups": 32,
24+
"norm_eps": 1e-06,
25+
"attention_levels": [
26+
false,
27+
false,
28+
false
29+
],
30+
"with_encoder_nonlocal_attn": false,
31+
"with_decoder_nonlocal_attn": false,
32+
"use_checkpointing": false,
33+
"use_convtranspose": false,
34+
"norm_float16": true,
35+
"num_splits": 8,
36+
"dim_split": 1
37+
},
38+
"diffusion_unet_def": {
39+
"_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi",
40+
"spatial_dims": "@spatial_dims",
41+
"in_channels": "@latent_channels",
42+
"out_channels": "@latent_channels",
43+
"num_channels": [
44+
64,
45+
128,
46+
256,
47+
512
48+
],
49+
"attention_levels": [
50+
false,
51+
false,
52+
true,
53+
true
54+
],
55+
"num_head_channels": [
56+
0,
57+
0,
58+
32,
59+
32
60+
],
61+
"num_res_blocks": 2,
62+
"use_flash_attention": true,
63+
"include_top_region_index_input": true,
64+
"include_bottom_region_index_input": true,
65+
"include_spacing_input": true
66+
},
67+
"controlnet_def": {
68+
"_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi",
69+
"spatial_dims": "@spatial_dims",
70+
"in_channels": "@latent_channels",
71+
"num_channels": [
72+
64,
73+
128,
74+
256,
75+
512
76+
],
77+
"attention_levels": [
78+
false,
79+
false,
80+
true,
81+
true
82+
],
83+
"num_head_channels": [
84+
0,
85+
0,
86+
32,
87+
32
88+
],
89+
"num_res_blocks": 2,
90+
"use_flash_attention": true,
91+
"conditioning_embedding_in_channels": 8,
92+
"conditioning_embedding_num_channels": [8, 32, 64]
93+
},
94+
"mask_generation_autoencoder_def": {
95+
"_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi",
96+
"spatial_dims": "@spatial_dims",
97+
"in_channels": 8,
98+
"out_channels": 125,
99+
"latent_channels": "@latent_channels",
100+
"num_channels": [
101+
32,
102+
64,
103+
128
104+
],
105+
"num_res_blocks": [1, 2, 2],
106+
"norm_num_groups": 32,
107+
"norm_eps": 1e-06,
108+
"attention_levels": [
109+
false,
110+
false,
111+
false
112+
],
113+
"with_encoder_nonlocal_attn": false,
114+
"with_decoder_nonlocal_attn": false,
115+
"use_flash_attention": false,
116+
"use_checkpointing": true,
117+
"use_convtranspose": true,
118+
"norm_float16": true,
119+
"num_splits": 8,
120+
"dim_split": 1
121+
},
122+
"mask_generation_diffusion_def": {
123+
"_target_": "monai.networks.nets.diffusion_model_unet.DiffusionModelUNet",
124+
"spatial_dims": "@spatial_dims",
125+
"in_channels": "@latent_channels",
126+
"out_channels": "@latent_channels",
127+
"channels":[64, 128, 256, 512],
128+
"attention_levels":[false, false, true, true],
129+
"num_head_channels":[0, 0, 32, 32],
130+
"num_res_blocks": 2,
131+
"use_flash_attention": true,
132+
"with_conditioning": true,
133+
"upcast_attention": true,
134+
"cross_attention_dim": 10
135+
},
136+
"mask_generation_scale_factor": 1.0055984258651733,
137+
"noise_scheduler": {
138+
"_target_": "monai.networks.schedulers.ddpm.DDPMScheduler",
139+
"num_train_timesteps": 1000,
140+
"beta_start": 0.0015,
141+
"beta_end": 0.0195,
142+
"schedule": "scaled_linear_beta",
143+
"clip_sample": false
144+
},
145+
"mask_generation_noise_scheduler": {
146+
"_target_": "monai.networks.schedulers.ddpm.DDPMScheduler",
147+
"num_train_timesteps": 1000,
148+
"beta_start": 0.0015,
149+
"beta_end": 0.0195,
150+
"schedule": "scaled_linear_beta",
151+
"clip_sample": false
152+
}
153+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
{
2+
"spatial_dims": 3,
3+
"image_channels": 1,
4+
"latent_channels": 4,
5+
"mask_generation_latent_shape": [
6+
4,
7+
64,
8+
64,
9+
64
10+
],
11+
"autoencoder_def": {
12+
"_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi",
13+
"spatial_dims": "@spatial_dims",
14+
"in_channels": "@image_channels",
15+
"out_channels": "@image_channels",
16+
"latent_channels": "@latent_channels",
17+
"num_channels": [
18+
64,
19+
128,
20+
256
21+
],
22+
"num_res_blocks": [2,2,2],
23+
"norm_num_groups": 32,
24+
"norm_eps": 1e-06,
25+
"attention_levels": [
26+
false,
27+
false,
28+
false
29+
],
30+
"with_encoder_nonlocal_attn": false,
31+
"with_decoder_nonlocal_attn": false,
32+
"use_checkpointing": false,
33+
"use_convtranspose": false,
34+
"norm_float16": true,
35+
"num_splits": 4,
36+
"dim_split": 1
37+
},
38+
"diffusion_unet_def": {
39+
"_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi",
40+
"spatial_dims": "@spatial_dims",
41+
"in_channels": "@latent_channels",
42+
"out_channels": "@latent_channels",
43+
"num_channels": [64, 128, 256, 512],
44+
"attention_levels": [
45+
false,
46+
false,
47+
true,
48+
true
49+
],
50+
"num_head_channels": [
51+
0,
52+
0,
53+
32,
54+
32
55+
],
56+
"num_res_blocks": 2,
57+
"use_flash_attention": true,
58+
"include_top_region_index_input": false,
59+
"include_bottom_region_index_input": false,
60+
"include_spacing_input": true,
61+
"num_class_embeds": 128,
62+
"resblock_updown": true,
63+
"include_fc": true
64+
},
65+
"controlnet_def": {
66+
"_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi",
67+
"spatial_dims": "@spatial_dims",
68+
"in_channels": "@latent_channels",
69+
"num_channels": [64, 128, 256, 512],
70+
"attention_levels": [
71+
false,
72+
false,
73+
true,
74+
true
75+
],
76+
"num_head_channels": [
77+
0,
78+
0,
79+
32,
80+
32
81+
],
82+
"num_res_blocks": 2,
83+
"use_flash_attention": true,
84+
"conditioning_embedding_in_channels": 8,
85+
"conditioning_embedding_num_channels": [8, 32, 64],
86+
"num_class_embeds": 128,
87+
"resblock_updown": true,
88+
"include_fc": true
89+
},
90+
"mask_generation_autoencoder_def": {
91+
"_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi",
92+
"spatial_dims": "@spatial_dims",
93+
"in_channels": 8,
94+
"out_channels": 125,
95+
"latent_channels": "@latent_channels",
96+
"num_channels": [
97+
32,
98+
64,
99+
128
100+
],
101+
"num_res_blocks": [1, 2, 2],
102+
"norm_num_groups": 32,
103+
"norm_eps": 1e-06,
104+
"attention_levels": [
105+
false,
106+
false,
107+
false
108+
],
109+
"with_encoder_nonlocal_attn": false,
110+
"with_decoder_nonlocal_attn": false,
111+
"use_flash_attention": false,
112+
"use_checkpointing": true,
113+
"use_convtranspose": true,
114+
"norm_float16": true,
115+
"num_splits": 8,
116+
"dim_split": 1
117+
},
118+
"mask_generation_diffusion_def": {
119+
"_target_": "monai.networks.nets.diffusion_model_unet.DiffusionModelUNet",
120+
"spatial_dims": "@spatial_dims",
121+
"in_channels": "@latent_channels",
122+
"out_channels": "@latent_channels",
123+
"channels":[64, 128, 256, 512],
124+
"attention_levels":[false, false, true, true],
125+
"num_head_channels":[0, 0, 32, 32],
126+
"num_res_blocks": 2,
127+
"use_flash_attention": true,
128+
"with_conditioning": true,
129+
"upcast_attention": true,
130+
"cross_attention_dim": 10
131+
},
132+
"mask_generation_scale_factor": 1.0055984258651733,
133+
"noise_scheduler": {
134+
"_target_": "monai.networks.schedulers.rectified_flow.RFlowScheduler",
135+
"num_train_timesteps": 1000,
136+
"use_discrete_timesteps": false,
137+
"use_timestep_transform": true,
138+
"sample_method": "uniform",
139+
"scale":1.4
140+
},
141+
"mask_generation_noise_scheduler": {
142+
"_target_": "monai.networks.schedulers.ddpm.DDPMScheduler",
143+
"num_train_timesteps": 1000,
144+
"beta_start": 0.0015,
145+
"beta_end": 0.0195,
146+
"schedule": "scaled_linear_beta",
147+
"clip_sample": false
148+
}
149+
}

0 commit comments

Comments
 (0)