Skip to content

Commit f93d2fb

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 9dcaa42 commit f93d2fb

File tree

5 files changed

+13
-11
lines changed

5 files changed

+13
-11
lines changed

generation/maisi/scripts/diff_model_infer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def diff_model_infer(
291291
output_size,
292292
divisor,
293293
logger,
294-
include_body_region = include_body_region,
294+
include_body_region=include_body_region,
295295
)
296296

297297
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")

generation/maisi/scripts/diff_model_train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def train_one_epoch(
288288
"noise scheduler prediction type has to be chosen from ",
289289
f"[{DDPMPredictionType.EPSILON},{DDPMPredictionType.SAMPLE},{DDPMPredictionType.V_PREDICTION}]",
290290
)
291-
291+
292292
loss = loss_pt(model_output.float(), model_gt.float())
293293

294294
if amp:

generation/maisi/scripts/infer_controlnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def main():
188188
noise_factor=1.0,
189189
num_inference_steps=args.controlnet_infer["num_inference_steps"],
190190
autoencoder_sliding_window_infer_size=args.controlnet_infer["autoencoder_sliding_window_infer_size"],
191-
autoencoder_sliding_window_infer_overlap=args.controlnet_infer["autoencoder_sliding_window_infer_overlap"]
191+
autoencoder_sliding_window_infer_overlap=args.controlnet_infer["autoencoder_sliding_window_infer_overlap"],
192192
)
193193
# save image/label pairs
194194
labels = decollate_batch(batch)[0]["label"]

generation/maisi/scripts/sample.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,13 @@ def ldm_conditional_sample_one_image(
182182
noise_scheduler,
183183
scale_factor,
184184
device,
185-
combine_label_or,
185+
combine_label_or,
186186
spacing_tensor,
187187
latent_shape,
188188
output_size,
189189
noise_factor,
190-
top_region_index_tensor = None,
191-
bottom_region_index_tensor = None,
190+
top_region_index_tensor=None,
191+
bottom_region_index_tensor=None,
192192
num_inference_steps=1000,
193193
autoencoder_sliding_window_infer_size=[96, 96, 96],
194194
autoencoder_sliding_window_infer_overlap=0.6667,
@@ -203,7 +203,7 @@ def ldm_conditional_sample_one_image(
203203
noise_scheduler: The noise scheduler for the diffusion process.
204204
scale_factor (float): Scaling factor for the latent space.
205205
device (torch.device): The device to run the computation on.
206-
combine_label_or (torch.Tensor): The combined label tensor.
206+
combine_label_or (torch.Tensor): The combined label tensor.
207207
spacing_tensor (torch.Tensor): Tensor specifying the spacing.
208208
latent_shape (tuple): The shape of the latent space.
209209
output_size (tuple): The desired output size of the image.

generation/maisi/scripts/train_controlnet.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def main():
5151
"--training-config",
5252
default="./configs/config_maisi_controlnet_train.json",
5353
help="config json file that stores training hyper-parameters",
54-
)
54+
)
5555
parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node")
5656
parser.add_argument(
5757
"--include_body_region",
@@ -199,7 +199,9 @@ def main():
199199
if isinstance(noise_scheduler, RFlowScheduler):
200200
timesteps = noise_scheduler.sample_timesteps(images)
201201
else:
202-
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (images.shape[0],), device=images.device).long()
202+
timesteps = torch.randint(
203+
0, noise_scheduler.num_train_timesteps, (images.shape[0],), device=images.device
204+
).long()
203205

204206
# create noisy latent
205207
noisy_latent = noise_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps)
@@ -241,7 +243,7 @@ def main():
241243
"noise scheduler prediction type has to be chosen from ",
242244
f"[{DDPMPredictionType.EPSILON},{DDPMPredictionType.SAMPLE},{DDPMPredictionType.V_PREDICTION}]",
243245
)
244-
246+
245247
if weighted_loss > 1.0:
246248
weights = torch.ones_like(images).to(images.device)
247249
roi = torch.zeros([noise_shape[0]] + [1] + noise_shape[2:]).to(images.device)
@@ -253,7 +255,7 @@ def main():
253255
loss = (F.l1_loss(noise_pred.float(), model_gt.float(), reduction="none") * weights).mean()
254256
else:
255257
loss = F.l1_loss(model_output.float(), model_gt.float())
256-
258+
257259
scaler.scale(loss).backward()
258260
scaler.step(optimizer)
259261
scaler.update()

0 commit comments

Comments
 (0)