Skip to content

Commit 9af6814

Browse files
authored
Fix typos and bugs (#1963)
Fixes some typo in `scripts/train_controlnet.py` and `scripts/diff_model_train.py`. --------- Signed-off-by: MaybeRichard <[email protected]>
1 parent 223d8ff commit 9af6814

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

generation/maisi/scripts/diff_model_train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -480,13 +480,13 @@ def diff_model_train(
480480
parser.add_argument(
481481
"--env_config",
482482
type=str,
483-
default="./configs/environment_maisi_diff_model_train.json",
483+
default="./configs/environment_maisi_diff_model.json",
484484
help="Path to environment configuration file",
485485
)
486486
parser.add_argument(
487487
"--model_config",
488488
type=str,
489-
default="./configs/config_maisi_diff_model_train.json",
489+
default="./configs/config_maisi_diff_model.json",
490490
help="Path to model training/inference configuration",
491491
)
492492
parser.add_argument(

generation/maisi/scripts/train_controlnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def main():
4343
parser.add_argument(
4444
"-c",
4545
"--config-file",
46-
default="./configs/config_maisi.json",
46+
default="./configs/config_maisi-ddpm.json",
4747
help="config json file that stores network hyper-parameters",
4848
)
4949
parser.add_argument(
@@ -269,7 +269,7 @@ def main():
269269
for label in weighted_loss_label:
270270
roi[interpolate_label == label] = 1
271271
weights[roi.repeat(1, images.shape[1], 1, 1, 1) == 1] = weighted_loss
272-
loss = (F.l1_loss(noise_pred.float(), model_gt.float(), reduction="none") * weights).mean()
272+
loss = (F.l1_loss(model_output.float(), model_gt.float(), reduction="none") * weights).mean()
273273
else:
274274
loss = F.l1_loss(model_output.float(), model_gt.float())
275275

0 commit comments

Comments
 (0)