Skip to content

Commit 24c8573

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent eb631d1 commit 24c8573

File tree

3 files changed

+55
-32
lines changed

3 files changed

+55
-32
lines changed

generation/maisi/scripts/diff_model_infer.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def run_inference(
149149
if isinstance(noise_scheduler, RFlowScheduler):
150150
noise_scheduler.set_timesteps(
151151
num_inference_steps=args.diffusion_unet_inference["num_inference_steps"],
152-
input_img_size_numel=torch.prod(torch.tensor(noise.shape[-3:]))
152+
input_img_size_numel=torch.prod(torch.tensor(noise.shape[-3:])),
153153
)
154154
else:
155155
noise_scheduler.set_timesteps(num_inference_steps=args.diffusion_unet_inference["num_inference_steps"])
@@ -161,9 +161,9 @@ def run_inference(
161161
all_timesteps = noise_scheduler.timesteps
162162
all_next_timesteps = torch.cat((all_timesteps[1:], torch.tensor([0], dtype=all_timesteps.dtype)))
163163
progress_bar = tqdm(
164-
zip(all_timesteps, all_next_timesteps),
165-
total=min(len(all_timesteps), len(all_next_timesteps)),
166-
)
164+
zip(all_timesteps, all_next_timesteps),
165+
total=min(len(all_timesteps), len(all_next_timesteps)),
166+
)
167167
with torch.amp.autocast("cuda", enabled=True):
168168
for t, next_t in progress_bar:
169169
model_output = unet(
@@ -178,7 +178,6 @@ def run_inference(
178178
else:
179179
image, _ = noise_scheduler.step(model_output, t, image, next_t) # type: ignore
180180

181-
182181
inferer = SlidingWindowInferer(
183182
roi_size=(
184183
min(output_size[0] // divisor // 4 * 3, 96),
@@ -228,7 +227,9 @@ def save_image(
228227

229228

230229
@torch.inference_mode()
231-
def diff_model_infer(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int, include_body_region: bool = False ) -> None:
230+
def diff_model_infer(
231+
env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int, include_body_region: bool = False
232+
) -> None:
232233
"""
233234
Main function to run the diffusion model inference.
234235
@@ -335,7 +336,14 @@ def diff_model_infer(env_config_path: str, model_config_path: str, model_def_pat
335336
default=1,
336337
help="Number of GPUs to use for distributed inference",
337338
)
338-
parser.add_argument("--include_body_region", dest="include_body_region", action="store_true", help="Whether to include body region in data")
339+
parser.add_argument(
340+
"--include_body_region",
341+
dest="include_body_region",
342+
action="store_true",
343+
help="Whether to include body region in data",
344+
)
339345

340346
args = parser.parse_args()
341-
diff_model_infer(args.env_config, args.model_config, args.model_def, args.num_gpus, include_body_region=args.include_body_region)
347+
diff_model_infer(
348+
args.env_config, args.model_config, args.model_def, args.num_gpus, include_body_region=args.include_body_region
349+
)

generation/maisi/scripts/diff_model_train.py

+36-21
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ def load_filenames(data_list_path: str) -> list:
5151

5252

5353
def prepare_data(
54-
train_files: list,
55-
device: torch.device,
56-
cache_rate: float,
57-
num_workers: int = 2,
58-
batch_size: int = 1,
59-
include_body_region: bool = False
54+
train_files: list,
55+
device: torch.device,
56+
cache_rate: float,
57+
num_workers: int = 2,
58+
batch_size: int = 1,
59+
include_body_region: bool = False,
6060
) -> DataLoader:
6161
"""
6262
Prepare training data.
@@ -78,11 +78,11 @@ def _load_data_from_file(file_path, key):
7878
return torch.FloatTensor(json.load(f)[key])
7979

8080
train_transforms_list = [
81-
monai.transforms.LoadImaged(keys=["image"]),
82-
monai.transforms.EnsureChannelFirstd(keys=["image"]),
83-
monai.transforms.Lambdad(keys="spacing", func=lambda x: _load_data_from_file(x, "spacing")),
84-
monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2),
85-
]
81+
monai.transforms.LoadImaged(keys=["image"]),
82+
monai.transforms.EnsureChannelFirstd(keys=["image"]),
83+
monai.transforms.Lambdad(keys="spacing", func=lambda x: _load_data_from_file(x, "spacing")),
84+
monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2),
85+
]
8686
if include_body_region:
8787
train_transforms_list += [
8888
monai.transforms.Lambdad(
@@ -202,7 +202,7 @@ def train_one_epoch(
202202
logger: logging.Logger,
203203
local_rank: int,
204204
amp: bool = True,
205-
include_body_region: bool = False
205+
include_body_region: bool = False,
206206
) -> torch.Tensor:
207207
"""
208208
Train the model for one epoch.
@@ -284,9 +284,10 @@ def train_one_epoch(
284284
# predict velocity
285285
loss = loss_pt(model_output.float(), (images - noise).float())
286286
else:
287-
raise ValueError("noise scheduler prediction type has to be chosen from ",
288-
f"[{DDPMPredictionType.EPSILON},{DDPMPredictionType.SAMPLE},{DDPMPredictionType.V_PREDICTION}]"
289-
)
287+
raise ValueError(
288+
"noise scheduler prediction type has to be chosen from ",
289+
f"[{DDPMPredictionType.EPSILON},{DDPMPredictionType.SAMPLE},{DDPMPredictionType.V_PREDICTION}]",
290+
)
290291

291292
if amp:
292293
scaler.scale(loss).backward()
@@ -349,7 +350,12 @@ def save_checkpoint(
349350

350351

351352
def diff_model_train(
352-
env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int, amp: bool = True, include_body_region: bool = False
353+
env_config_path: str,
354+
model_config_path: str,
355+
model_def_path: str,
356+
num_gpus: int,
357+
amp: bool = True,
358+
include_body_region: bool = False,
353359
) -> None:
354360
"""
355361
Main function to train a diffusion model.
@@ -400,9 +406,11 @@ def diff_model_train(
400406
)[local_rank]
401407

402408
train_loader = prepare_data(
403-
train_files, device, args.diffusion_unet_train["cache_rate"],
409+
train_files,
410+
device,
411+
args.diffusion_unet_train["cache_rate"],
404412
batch_size=args.diffusion_unet_train["batch_size"],
405-
include_body_region = include_body_region
413+
include_body_region=include_body_region,
406414
)
407415

408416
unet = load_unet(args, device, logger)
@@ -438,7 +446,7 @@ def diff_model_train(
438446
logger,
439447
local_rank,
440448
amp=amp,
441-
include_body_region=include_body_region
449+
include_body_region=include_body_region,
442450
)
443451

444452
loss_torch = loss_torch.tolist()
@@ -479,7 +487,14 @@ def diff_model_train(
479487
)
480488
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use for training")
481489
parser.add_argument("--no_amp", dest="amp", action="store_false", help="Disable automatic mixed precision training")
482-
parser.add_argument("--include_body_region", dest="include_body_region", action="store_true", help="Whether to include body region in data")
490+
parser.add_argument(
491+
"--include_body_region",
492+
dest="include_body_region",
493+
action="store_true",
494+
help="Whether to include body region in data",
495+
)
483496

484497
args = parser.parse_args()
485-
diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus, args.amp, args.include_body_region)
498+
diff_model_train(
499+
args.env_config, args.model_config, args.model_def, args.num_gpus, args.amp, args.include_body_region
500+
)

generation/maisi/scripts/utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -712,14 +712,14 @@ def dynamic_infer(inferer, model, images):
712712
# Extract the spatial dimensions from the images tensor (H, W, D)
713713
spatial_dims = images.shape[2:]
714714
orig_roi = inferer.roi_size
715-
715+
716716
# Check that roi has the same number of dimensions as spatial_dims
717717
if len(orig_roi) != len(spatial_dims):
718718
raise ValueError(f"ROI length ({len(orig_roi)}) does not match spatial dimensions ({len(spatial_dims)}).")
719-
719+
720720
# Iterate and adjust each ROI dimension
721721
adjusted_roi = [min(roi_dim, img_dim) for roi_dim, img_dim in zip(orig_roi, spatial_dims)]
722722
inferer.roi_size = adjusted_roi
723723
output = inferer(network=model, inputs=images)
724724
inferer.roi_size = orig_roi
725-
return output
725+
return output

0 commit comments

Comments
 (0)