@@ -51,7 +51,7 @@ def main():
51
51
"--training-config" ,
52
52
default = "./configs/config_maisi_controlnet_train.json" ,
53
53
help = "config json file that stores training hyper-parameters" ,
54
- )
54
+ )
55
55
parser .add_argument ("-g" , "--gpus" , default = 1 , type = int , help = "number of gpus per node" )
56
56
parser .add_argument (
57
57
"--include_body_region" ,
@@ -199,7 +199,9 @@ def main():
199
199
if isinstance (noise_scheduler , RFlowScheduler ):
200
200
timesteps = noise_scheduler .sample_timesteps (images )
201
201
else :
202
- timesteps = torch .randint (0 , noise_scheduler .num_train_timesteps , (images .shape [0 ],), device = images .device ).long ()
202
+ timesteps = torch .randint (
203
+ 0 , noise_scheduler .num_train_timesteps , (images .shape [0 ],), device = images .device
204
+ ).long ()
203
205
204
206
# create noisy latent
205
207
noisy_latent = noise_scheduler .add_noise (original_samples = images , noise = noise , timesteps = timesteps )
@@ -241,7 +243,7 @@ def main():
241
243
"noise scheduler prediction type has to be chosen from " ,
242
244
f"[{ DDPMPredictionType .EPSILON } ,{ DDPMPredictionType .SAMPLE } ,{ DDPMPredictionType .V_PREDICTION } ]" ,
243
245
)
244
-
246
+
245
247
if weighted_loss > 1.0 :
246
248
weights = torch .ones_like (images ).to (images .device )
247
249
roi = torch .zeros ([noise_shape [0 ]] + [1 ] + noise_shape [2 :]).to (images .device )
@@ -253,7 +255,7 @@ def main():
253
255
loss = (F .l1_loss (noise_pred .float (), model_gt .float (), reduction = "none" ) * weights ).mean ()
254
256
else :
255
257
loss = F .l1_loss (model_output .float (), model_gt .float ())
256
-
258
+
257
259
scaler .scale (loss ).backward ()
258
260
scaler .step (optimizer )
259
261
scaler .update ()
0 commit comments