Skip to content

Commit 1b12478

Browse files
authored
Merge branch 'kijai:main' into main
2 parents 424573d + 6ee278a commit 1b12478

File tree

6 files changed

+277
-146
lines changed

6 files changed

+277
-146
lines changed

nodes/image_nodes.py

Lines changed: 141 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -2468,14 +2468,15 @@ def INPUT_TYPES(s):
24682468
"width": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
24692469
"height": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
24702470
"upscale_method": (s.upscale_methods,),
2471-
"keep_proportion": (["stretch", "resize", "pad", "pad_edge", "pad_edge_pixel", "crop", "pillarbox_blur"], { "default": False }),
2471+
"keep_proportion": (["stretch", "resize", "pad", "pad_edge", "pad_edge_pixel", "crop", "pillarbox_blur", "total_pixels"], { "default": False }),
24722472
"pad_color": ("STRING", { "default": "0, 0, 0", "tooltip": "Color to use for padding."}),
24732473
"crop_position": (["center", "top", "bottom", "left", "right"], { "default": "center" }),
24742474
"divisible_by": ("INT", { "default": 2, "min": 0, "max": 512, "step": 1, }),
24752475
},
24762476
"optional" : {
24772477
"mask": ("MASK",),
24782478
"device": (["cpu", "gpu"],),
2479+
#"per_batch": ("INT", { "default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, "tooltip": "Process images in sub-batches to reduce memory usage. 0 disables sub-batching."}),
24792480
},
24802481
"hidden": {
24812482
"unique_id": "UNIQUE_ID",
@@ -2494,7 +2495,7 @@ def INPUT_TYPES(s):
24942495
highest dimension.
24952496
"""
24962497

2497-
def resize(self, image, width, height, keep_proportion, upscale_method, divisible_by, pad_color, crop_position, unique_id, device="cpu", mask=None):
2498+
def resize(self, image, width, height, keep_proportion, upscale_method, divisible_by, pad_color, crop_position, unique_id, device="cpu", mask=None, per_batch=64):
24982499
B, H, W, C = image.shape
24992500

25002501
if device == "gpu":
@@ -2504,15 +2505,23 @@ def resize(self, image, width, height, keep_proportion, upscale_method, divisibl
25042505
else:
25052506
device = torch.device("cpu")
25062507

2507-
if width == 0:
2508-
width = W
2509-
if height == 0:
2510-
height = H
2511-
25122508
pillarbox_blur = keep_proportion == "pillarbox_blur"
2513-
if keep_proportion == "resize" or keep_proportion.startswith("pad") or pillarbox_blur:
2509+
2510+
# Initialize padding variables
2511+
pad_left = pad_right = pad_top = pad_bottom = 0
2512+
2513+
if keep_proportion in ["resize", "total_pixels"] or keep_proportion.startswith("pad") or pillarbox_blur:
2514+
if keep_proportion == "total_pixels":
2515+
total_pixels = width * height
2516+
aspect_ratio = W / H
2517+
new_height = int(math.sqrt(total_pixels / aspect_ratio))
2518+
new_width = int(math.sqrt(total_pixels * aspect_ratio))
2519+
25142520
# If one of the dimensions is zero, calculate it to maintain the aspect ratio
2515-
if width == 0 and height != 0:
2521+
elif width == 0 and height == 0:
2522+
new_width = W
2523+
new_height = H
2524+
elif width == 0 and height != 0:
25162525
ratio = height / H
25172526
new_width = round(W * ratio)
25182527
new_height = height
@@ -2528,7 +2537,6 @@ def resize(self, image, width, height, keep_proportion, upscale_method, divisibl
25282537
new_width = width
25292538
new_height = height
25302539

2531-
pad_left = pad_right = pad_top = pad_bottom = 0
25322540
if keep_proportion.startswith("pad") or pillarbox_blur:
25332541
# Calculate padding based on position
25342542
if crop_position == "center":
@@ -2559,76 +2567,136 @@ def resize(self, image, width, height, keep_proportion, upscale_method, divisibl
25592567

25602568
width = new_width
25612569
height = new_height
2570+
else:
2571+
if width == 0:
2572+
width = W
2573+
if height == 0:
2574+
height = H
25622575

25632576
if divisible_by > 1:
25642577
width = width - (width % divisible_by)
25652578
height = height - (height % divisible_by)
25662579

2567-
out_image = image.clone().to(device)
2568-
if mask is not None:
2569-
out_mask = mask.clone().to(device)
2570-
else:
2571-
out_mask = None
2572-
2573-
# Crop logic
2574-
if keep_proportion == "crop":
2575-
old_width = W
2576-
old_height = H
2577-
old_aspect = old_width / old_height
2578-
new_aspect = width / height
2579-
if old_aspect > new_aspect:
2580-
crop_w = round(old_height * new_aspect)
2581-
crop_h = old_height
2582-
else:
2583-
crop_w = old_width
2584-
crop_h = round(old_width / new_aspect)
2585-
if crop_position == "center":
2586-
x = (old_width - crop_w) // 2
2587-
y = (old_height - crop_h) // 2
2588-
elif crop_position == "top":
2589-
x = (old_width - crop_w) // 2
2590-
y = 0
2591-
elif crop_position == "bottom":
2592-
x = (old_width - crop_w) // 2
2593-
y = old_height - crop_h
2594-
elif crop_position == "left":
2595-
x = 0
2596-
y = (old_height - crop_h) // 2
2597-
elif crop_position == "right":
2598-
x = old_width - crop_w
2599-
y = (old_height - crop_h) // 2
2600-
out_image = out_image.narrow(-2, x, crop_w).narrow(-3, y, crop_h)
2601-
if mask is not None:
2602-
out_mask = out_mask.narrow(-1, x, crop_w).narrow(-2, y, crop_h)
2580+
# Preflight estimate (log-only when batching is active)
2581+
if per_batch != 0 and B > per_batch:
2582+
try:
2583+
bytes_per_elem = image.element_size() # typically 4 for float32
2584+
est_total_bytes = B * height * width * C * bytes_per_elem
2585+
est_mb = est_total_bytes / (1024 * 1024)
2586+
msg = f"<tr><td>Resize v2</td><td>estimated output ~{est_mb:.2f} MB; batching {per_batch}/{B}</td></tr>"
2587+
if unique_id and PromptServer is not None:
2588+
try:
2589+
PromptServer.instance.send_progress_text(msg, unique_id)
2590+
except:
2591+
pass
2592+
else:
2593+
print(f"[ImageResizeKJv2] estimated output ~{est_mb:.2f} MB; batching {per_batch}/{B}")
2594+
except:
2595+
pass
26032596

2604-
out_image = common_upscale(out_image.movedim(-1,1), width, height, upscale_method, crop="disabled").movedim(1,-1)
2605-
if mask is not None:
2606-
if upscale_method == "lanczos":
2607-
out_mask = common_upscale(out_mask.unsqueeze(1).repeat(1, 3, 1, 1), width, height, upscale_method, crop="disabled").movedim(1,-1)[:, :, :, 0]
2608-
else:
2609-
out_mask = common_upscale(out_mask.unsqueeze(1), width, height, upscale_method, crop="disabled").squeeze(1)
2597+
def _process_subbatch(in_image, in_mask, pad_left, pad_right, pad_top, pad_bottom):
2598+
# Avoid unnecessary clones; only move if needed
2599+
out_image = in_image if in_image.device == device else in_image.to(device)
2600+
out_mask = None if in_mask is None else (in_mask if in_mask.device == device else in_mask.to(device))
2601+
2602+
# Crop logic
2603+
if keep_proportion == "crop":
2604+
old_height = out_image.shape[-3]
2605+
old_width = out_image.shape[-2]
2606+
old_aspect = old_width / old_height
2607+
new_aspect = width / height
2608+
if old_aspect > new_aspect:
2609+
crop_w = round(old_height * new_aspect)
2610+
crop_h = old_height
2611+
else:
2612+
crop_w = old_width
2613+
crop_h = round(old_width / new_aspect)
2614+
if crop_position == "center":
2615+
x = (old_width - crop_w) // 2
2616+
y = (old_height - crop_h) // 2
2617+
elif crop_position == "top":
2618+
x = (old_width - crop_w) // 2
2619+
y = 0
2620+
elif crop_position == "bottom":
2621+
x = (old_width - crop_w) // 2
2622+
y = old_height - crop_h
2623+
elif crop_position == "left":
2624+
x = 0
2625+
y = (old_height - crop_h) // 2
2626+
elif crop_position == "right":
2627+
x = old_width - crop_w
2628+
y = (old_height - crop_h) // 2
2629+
out_image = out_image.narrow(-2, x, crop_w).narrow(-3, y, crop_h)
2630+
if out_mask is not None:
2631+
out_mask = out_mask.narrow(-1, x, crop_w).narrow(-2, y, crop_h)
2632+
2633+
out_image = common_upscale(out_image.movedim(-1,1), width, height, upscale_method, crop="disabled").movedim(1,-1)
2634+
if out_mask is not None:
2635+
if upscale_method == "lanczos":
2636+
out_mask = common_upscale(out_mask.unsqueeze(1).repeat(1, 3, 1, 1), width, height, upscale_method, crop="disabled").movedim(1,-1)[:, :, :, 0]
2637+
else:
2638+
out_mask = common_upscale(out_mask.unsqueeze(1), width, height, upscale_method, crop="disabled").squeeze(1)
2639+
2640+
# Pad logic
2641+
if (keep_proportion.startswith("pad") or pillarbox_blur) and (pad_left > 0 or pad_right > 0 or pad_top > 0 or pad_bottom > 0):
2642+
padded_width = width + pad_left + pad_right
2643+
padded_height = height + pad_top + pad_bottom
2644+
if divisible_by > 1:
2645+
width_remainder = padded_width % divisible_by
2646+
height_remainder = padded_height % divisible_by
2647+
if width_remainder > 0:
2648+
extra_width = divisible_by - width_remainder
2649+
pad_right += extra_width
2650+
if height_remainder > 0:
2651+
extra_height = divisible_by - height_remainder
2652+
pad_bottom += extra_height
2653+
2654+
pad_mode = (
2655+
"pillarbox_blur" if pillarbox_blur else
2656+
"edge" if keep_proportion == "pad_edge" else
2657+
"edge_pixel" if keep_proportion == "pad_edge_pixel" else
2658+
"color"
2659+
)
2660+
out_image, out_mask = ImagePadKJ.pad(self, out_image, pad_left, pad_right, pad_top, pad_bottom, 0, pad_color, pad_mode, mask=out_mask)
26102661

2611-
# Pad logic
2612-
if (keep_proportion.startswith("pad") or pillarbox_blur) and (pad_left > 0 or pad_right > 0 or pad_top > 0 or pad_bottom > 0):
2613-
padded_width = width + pad_left + pad_right
2614-
padded_height = height + pad_top + pad_bottom
2615-
if divisible_by > 1:
2616-
width_remainder = padded_width % divisible_by
2617-
height_remainder = padded_height % divisible_by
2618-
if width_remainder > 0:
2619-
extra_width = divisible_by - width_remainder
2620-
pad_right += extra_width
2621-
if height_remainder > 0:
2622-
extra_height = divisible_by - height_remainder
2623-
pad_bottom += extra_height
2624-
2625-
pad_mode = (
2626-
"pillarbox_blur" if pillarbox_blur else
2627-
"edge" if keep_proportion == "pad_edge" else
2628-
"edge_pixel" if keep_proportion == "pad_edge_pixel" else
2629-
"color"
2630-
)
2631-
out_image, out_mask = ImagePadKJ.pad(self, out_image, pad_left, pad_right, pad_top, pad_bottom, 0, pad_color, pad_mode, mask=out_mask)
2662+
return out_image, out_mask
2663+
2664+
# If batching disabled (per_batch==0) or batch fits, process whole batch
2665+
if per_batch == 0 or B <= per_batch:
2666+
out_image, out_mask = _process_subbatch(image, mask, pad_left, pad_right, pad_top, pad_bottom)
2667+
else:
2668+
chunks = []
2669+
mask_chunks = [] if mask is not None else None
2670+
total_batches = (B + per_batch - 1) // per_batch
2671+
current_batch = 0
2672+
for start_idx in range(0, B, per_batch):
2673+
current_batch += 1
2674+
end_idx = min(start_idx + per_batch, B)
2675+
sub_img = image[start_idx:end_idx]
2676+
sub_mask = mask[start_idx:end_idx] if mask is not None else None
2677+
sub_out_img, sub_out_mask = _process_subbatch(sub_img, sub_mask, pad_left, pad_right, pad_top, pad_bottom)
2678+
chunks.append(sub_out_img.cpu())
2679+
if mask is not None:
2680+
mask_chunks.append(sub_out_mask.cpu() if sub_out_mask is not None else None)
2681+
# Per-batch progress update
2682+
if unique_id and PromptServer is not None:
2683+
try:
2684+
PromptServer.instance.send_progress_text(
2685+
f"<tr><td>Resize v2</td><td>batch {current_batch}/{total_batches} · images {end_idx}/{B}</td></tr>",
2686+
unique_id
2687+
)
2688+
except:
2689+
pass
2690+
else:
2691+
try:
2692+
print(f"[ImageResizeKJv2] batch {current_batch}/{total_batches} · images {end_idx}/{B}")
2693+
except:
2694+
pass
2695+
out_image = torch.cat(chunks, dim=0)
2696+
if mask is not None and any(m is not None for m in mask_chunks):
2697+
out_mask = torch.cat([m for m in mask_chunks if m is not None], dim=0)
2698+
else:
2699+
out_mask = None
26322700

26332701
# Progress UI
26342702
if unique_id and PromptServer is not None:

nodes/lora_nodes.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,27 @@ def extract_lora(diff, key, rank, algorithm, lora_type, lowrank_iters=7, adaptiv
4848
s_cum = torch.cumsum(S, dim=0)
4949
min_cum_sum = adaptive_param * torch.sum(S)
5050
lora_rank = torch.sum(s_cum < min_cum_sum).item()
51-
print(f"{key} Extracted LoRA rank: {lora_rank}")
51+
elif lora_type == "adaptive_fro":
52+
S_squared = S.pow(2)
53+
S_fro_sq = float(torch.sum(S_squared))
54+
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
55+
lora_rank = int(torch.searchsorted(sum_S_squared, adaptive_param**2)) + 1
56+
lora_rank = max(1, min(lora_rank, len(S)))
57+
else:
58+
pass # Will print after capping
59+
60+
# Cap adaptive rank by the specified max rank
61+
lora_rank = min(lora_rank, rank)
62+
63+
# Calculate and print actual fro percentage retained after capping
64+
if lora_type == "adaptive_fro":
65+
S_squared = S.pow(2)
66+
s_fro = torch.sqrt(torch.sum(S_squared))
67+
s_red_fro = torch.sqrt(torch.sum(S_squared[:lora_rank]))
68+
fro_percent = float(s_red_fro / s_fro)
69+
print(f"{key} Extracted LoRA rank: {lora_rank}, Frobenius retained: {fro_percent:.1%}")
70+
else:
71+
print(f"{key} Extracted LoRA rank: {lora_rank}")
5272
else:
5373
lora_rank = rank
5474

@@ -141,13 +161,13 @@ def INPUT_TYPES(s):
141161
"finetuned_model": ("MODEL",),
142162
"original_model": ("MODEL",),
143163
"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
144-
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
145-
"lora_type": (["standard", "full", "adaptive_ratio", "adaptive_quantile", "adaptive_energy"],),
164+
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1, "tooltip": "The rank to use for standard LoRA, or maximum rank limit for adaptive methods."}),
165+
"lora_type": (["standard", "full", "adaptive_ratio", "adaptive_quantile", "adaptive_energy", "adaptive_fro"],),
146166
"algorithm": (["svd_linalg", "svd_lowrank"], {"default": "svd_linalg", "tooltip": "SVD algorithm to use, svd_lowrank is faster but less accurate."}),
147167
"lowrank_iters": ("INT", {"default": 7, "min": 1, "max": 100, "step": 1, "tooltip": "The number of subspace iterations for lowrank SVD algorithm."}),
148168
"output_dtype": (["fp16", "bf16", "fp32"], {"default": "fp16"}),
149169
"bias_diff": ("BOOLEAN", {"default": True}),
150-
"adaptive_param": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "For ratio mode, this is the ratio of the maximum singular value. For quantile mode, this is the quantile of the singular values."}),
170+
"adaptive_param": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "For ratio mode, this is the ratio of the maximum singular value. For quantile mode, this is the quantile of the singular values. For fro mode, this is the Frobenius norm retention ratio."}),
151171
"clamp_quantile": ("BOOLEAN", {"default": True}),
152172
},
153173

@@ -520,7 +540,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
520540
fro_retained = param_dict["fro_retained"]
521541
if not np.isnan(fro_retained):
522542
fro_list.append(float(fro_retained))
523-
log_str = f"{block_down_name:75} | sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
543+
log_str = f"{block_down_name:75} | sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}, new dim: {param_dict['new_rank']}"
524544
tqdm.write(log_str)
525545
verbose_str += log_str
526546

0 commit comments

Comments
 (0)