@@ -2468,14 +2468,15 @@ def INPUT_TYPES(s):
24682468 "width" : ("INT" , { "default" : 512 , "min" : 0 , "max" : MAX_RESOLUTION , "step" : 1 , }),
24692469 "height" : ("INT" , { "default" : 512 , "min" : 0 , "max" : MAX_RESOLUTION , "step" : 1 , }),
24702470 "upscale_method" : (s .upscale_methods ,),
2471- "keep_proportion" : (["stretch" , "resize" , "pad" , "pad_edge" , "pad_edge_pixel" , "crop" , "pillarbox_blur" ], { "default" : False }),
2471+ "keep_proportion" : (["stretch" , "resize" , "pad" , "pad_edge" , "pad_edge_pixel" , "crop" , "pillarbox_blur" , "total_pixels" ], { "default" : False }),
24722472 "pad_color" : ("STRING" , { "default" : "0, 0, 0" , "tooltip" : "Color to use for padding." }),
24732473 "crop_position" : (["center" , "top" , "bottom" , "left" , "right" ], { "default" : "center" }),
24742474 "divisible_by" : ("INT" , { "default" : 2 , "min" : 0 , "max" : 512 , "step" : 1 , }),
24752475 },
24762476 "optional" : {
24772477 "mask" : ("MASK" ,),
24782478 "device" : (["cpu" , "gpu" ],),
2479+ #"per_batch": ("INT", { "default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, "tooltip": "Process images in sub-batches to reduce memory usage. 0 disables sub-batching."}),
24792480 },
24802481 "hidden" : {
24812482 "unique_id" : "UNIQUE_ID" ,
@@ -2494,7 +2495,7 @@ def INPUT_TYPES(s):
24942495highest dimension.
24952496"""
24962497
2497- def resize (self , image , width , height , keep_proportion , upscale_method , divisible_by , pad_color , crop_position , unique_id , device = "cpu" , mask = None ):
2498+ def resize (self , image , width , height , keep_proportion , upscale_method , divisible_by , pad_color , crop_position , unique_id , device = "cpu" , mask = None , per_batch = 64 ):
24982499 B , H , W , C = image .shape
24992500
25002501 if device == "gpu" :
@@ -2504,15 +2505,23 @@ def resize(self, image, width, height, keep_proportion, upscale_method, divisibl
25042505 else :
25052506 device = torch .device ("cpu" )
25062507
2507- if width == 0 :
2508- width = W
2509- if height == 0 :
2510- height = H
2511-
25122508 pillarbox_blur = keep_proportion == "pillarbox_blur"
2513- if keep_proportion == "resize" or keep_proportion .startswith ("pad" ) or pillarbox_blur :
2509+
2510+ # Initialize padding variables
2511+ pad_left = pad_right = pad_top = pad_bottom = 0
2512+
2513+ if keep_proportion in ["resize" , "total_pixels" ] or keep_proportion .startswith ("pad" ) or pillarbox_blur :
2514+ if keep_proportion == "total_pixels" :
2515+ total_pixels = width * height
2516+ aspect_ratio = W / H
2517+ new_height = int (math .sqrt (total_pixels / aspect_ratio ))
2518+ new_width = int (math .sqrt (total_pixels * aspect_ratio ))
2519+
25142520 # If one of the dimensions is zero, calculate it to maintain the aspect ratio
2515- if width == 0 and height != 0 :
2521+ elif width == 0 and height == 0 :
2522+ new_width = W
2523+ new_height = H
2524+ elif width == 0 and height != 0 :
25162525 ratio = height / H
25172526 new_width = round (W * ratio )
25182527 new_height = height
@@ -2528,7 +2537,6 @@ def resize(self, image, width, height, keep_proportion, upscale_method, divisibl
25282537 new_width = width
25292538 new_height = height
25302539
2531- pad_left = pad_right = pad_top = pad_bottom = 0
25322540 if keep_proportion .startswith ("pad" ) or pillarbox_blur :
25332541 # Calculate padding based on position
25342542 if crop_position == "center" :
@@ -2559,76 +2567,136 @@ def resize(self, image, width, height, keep_proportion, upscale_method, divisibl
25592567
25602568 width = new_width
25612569 height = new_height
2570+ else :
2571+ if width == 0 :
2572+ width = W
2573+ if height == 0 :
2574+ height = H
25622575
25632576 if divisible_by > 1 :
25642577 width = width - (width % divisible_by )
25652578 height = height - (height % divisible_by )
25662579
2567- out_image = image .clone ().to (device )
2568- if mask is not None :
2569- out_mask = mask .clone ().to (device )
2570- else :
2571- out_mask = None
2572-
2573- # Crop logic
2574- if keep_proportion == "crop" :
2575- old_width = W
2576- old_height = H
2577- old_aspect = old_width / old_height
2578- new_aspect = width / height
2579- if old_aspect > new_aspect :
2580- crop_w = round (old_height * new_aspect )
2581- crop_h = old_height
2582- else :
2583- crop_w = old_width
2584- crop_h = round (old_width / new_aspect )
2585- if crop_position == "center" :
2586- x = (old_width - crop_w ) // 2
2587- y = (old_height - crop_h ) // 2
2588- elif crop_position == "top" :
2589- x = (old_width - crop_w ) // 2
2590- y = 0
2591- elif crop_position == "bottom" :
2592- x = (old_width - crop_w ) // 2
2593- y = old_height - crop_h
2594- elif crop_position == "left" :
2595- x = 0
2596- y = (old_height - crop_h ) // 2
2597- elif crop_position == "right" :
2598- x = old_width - crop_w
2599- y = (old_height - crop_h ) // 2
2600- out_image = out_image .narrow (- 2 , x , crop_w ).narrow (- 3 , y , crop_h )
2601- if mask is not None :
2602- out_mask = out_mask .narrow (- 1 , x , crop_w ).narrow (- 2 , y , crop_h )
2580+ # Preflight estimate (log-only when batching is active)
2581+ if per_batch != 0 and B > per_batch :
2582+ try :
2583+ bytes_per_elem = image .element_size () # typically 4 for float32
2584+ est_total_bytes = B * height * width * C * bytes_per_elem
2585+ est_mb = est_total_bytes / (1024 * 1024 )
2586+ msg = f"<tr><td>Resize v2</td><td>estimated output ~{ est_mb :.2f} MB; batching { per_batch } /{ B } </td></tr>"
2587+ if unique_id and PromptServer is not None :
2588+ try :
2589+ PromptServer .instance .send_progress_text (msg , unique_id )
2590+ except :
2591+ pass
2592+ else :
2593+ print (f"[ImageResizeKJv2] estimated output ~{ est_mb :.2f} MB; batching { per_batch } /{ B } " )
2594+ except :
2595+ pass
26032596
2604- out_image = common_upscale (out_image .movedim (- 1 ,1 ), width , height , upscale_method , crop = "disabled" ).movedim (1 ,- 1 )
2605- if mask is not None :
2606- if upscale_method == "lanczos" :
2607- out_mask = common_upscale (out_mask .unsqueeze (1 ).repeat (1 , 3 , 1 , 1 ), width , height , upscale_method , crop = "disabled" ).movedim (1 ,- 1 )[:, :, :, 0 ]
2608- else :
2609- out_mask = common_upscale (out_mask .unsqueeze (1 ), width , height , upscale_method , crop = "disabled" ).squeeze (1 )
2597+ def _process_subbatch (in_image , in_mask , pad_left , pad_right , pad_top , pad_bottom ):
2598+ # Avoid unnecessary clones; only move if needed
2599+ out_image = in_image if in_image .device == device else in_image .to (device )
2600+ out_mask = None if in_mask is None else (in_mask if in_mask .device == device else in_mask .to (device ))
2601+
2602+ # Crop logic
2603+ if keep_proportion == "crop" :
2604+ old_height = out_image .shape [- 3 ]
2605+ old_width = out_image .shape [- 2 ]
2606+ old_aspect = old_width / old_height
2607+ new_aspect = width / height
2608+ if old_aspect > new_aspect :
2609+ crop_w = round (old_height * new_aspect )
2610+ crop_h = old_height
2611+ else :
2612+ crop_w = old_width
2613+ crop_h = round (old_width / new_aspect )
2614+ if crop_position == "center" :
2615+ x = (old_width - crop_w ) // 2
2616+ y = (old_height - crop_h ) // 2
2617+ elif crop_position == "top" :
2618+ x = (old_width - crop_w ) // 2
2619+ y = 0
2620+ elif crop_position == "bottom" :
2621+ x = (old_width - crop_w ) // 2
2622+ y = old_height - crop_h
2623+ elif crop_position == "left" :
2624+ x = 0
2625+ y = (old_height - crop_h ) // 2
2626+ elif crop_position == "right" :
2627+ x = old_width - crop_w
2628+ y = (old_height - crop_h ) // 2
2629+ out_image = out_image .narrow (- 2 , x , crop_w ).narrow (- 3 , y , crop_h )
2630+ if out_mask is not None :
2631+ out_mask = out_mask .narrow (- 1 , x , crop_w ).narrow (- 2 , y , crop_h )
2632+
2633+ out_image = common_upscale (out_image .movedim (- 1 ,1 ), width , height , upscale_method , crop = "disabled" ).movedim (1 ,- 1 )
2634+ if out_mask is not None :
2635+ if upscale_method == "lanczos" :
2636+ out_mask = common_upscale (out_mask .unsqueeze (1 ).repeat (1 , 3 , 1 , 1 ), width , height , upscale_method , crop = "disabled" ).movedim (1 ,- 1 )[:, :, :, 0 ]
2637+ else :
2638+ out_mask = common_upscale (out_mask .unsqueeze (1 ), width , height , upscale_method , crop = "disabled" ).squeeze (1 )
2639+
2640+ # Pad logic
2641+ if (keep_proportion .startswith ("pad" ) or pillarbox_blur ) and (pad_left > 0 or pad_right > 0 or pad_top > 0 or pad_bottom > 0 ):
2642+ padded_width = width + pad_left + pad_right
2643+ padded_height = height + pad_top + pad_bottom
2644+ if divisible_by > 1 :
2645+ width_remainder = padded_width % divisible_by
2646+ height_remainder = padded_height % divisible_by
2647+ if width_remainder > 0 :
2648+ extra_width = divisible_by - width_remainder
2649+ pad_right += extra_width
2650+ if height_remainder > 0 :
2651+ extra_height = divisible_by - height_remainder
2652+ pad_bottom += extra_height
2653+
2654+ pad_mode = (
2655+ "pillarbox_blur" if pillarbox_blur else
2656+ "edge" if keep_proportion == "pad_edge" else
2657+ "edge_pixel" if keep_proportion == "pad_edge_pixel" else
2658+ "color"
2659+ )
2660+ out_image , out_mask = ImagePadKJ .pad (self , out_image , pad_left , pad_right , pad_top , pad_bottom , 0 , pad_color , pad_mode , mask = out_mask )
26102661
2611- # Pad logic
2612- if (keep_proportion .startswith ("pad" ) or pillarbox_blur ) and (pad_left > 0 or pad_right > 0 or pad_top > 0 or pad_bottom > 0 ):
2613- padded_width = width + pad_left + pad_right
2614- padded_height = height + pad_top + pad_bottom
2615- if divisible_by > 1 :
2616- width_remainder = padded_width % divisible_by
2617- height_remainder = padded_height % divisible_by
2618- if width_remainder > 0 :
2619- extra_width = divisible_by - width_remainder
2620- pad_right += extra_width
2621- if height_remainder > 0 :
2622- extra_height = divisible_by - height_remainder
2623- pad_bottom += extra_height
2624-
2625- pad_mode = (
2626- "pillarbox_blur" if pillarbox_blur else
2627- "edge" if keep_proportion == "pad_edge" else
2628- "edge_pixel" if keep_proportion == "pad_edge_pixel" else
2629- "color"
2630- )
2631- out_image , out_mask = ImagePadKJ .pad (self , out_image , pad_left , pad_right , pad_top , pad_bottom , 0 , pad_color , pad_mode , mask = out_mask )
2662+ return out_image , out_mask
2663+
2664+ # If batching disabled (per_batch==0) or batch fits, process whole batch
2665+ if per_batch == 0 or B <= per_batch :
2666+ out_image , out_mask = _process_subbatch (image , mask , pad_left , pad_right , pad_top , pad_bottom )
2667+ else :
2668+ chunks = []
2669+ mask_chunks = [] if mask is not None else None
2670+ total_batches = (B + per_batch - 1 ) // per_batch
2671+ current_batch = 0
2672+ for start_idx in range (0 , B , per_batch ):
2673+ current_batch += 1
2674+ end_idx = min (start_idx + per_batch , B )
2675+ sub_img = image [start_idx :end_idx ]
2676+ sub_mask = mask [start_idx :end_idx ] if mask is not None else None
2677+ sub_out_img , sub_out_mask = _process_subbatch (sub_img , sub_mask , pad_left , pad_right , pad_top , pad_bottom )
2678+ chunks .append (sub_out_img .cpu ())
2679+ if mask is not None :
2680+ mask_chunks .append (sub_out_mask .cpu () if sub_out_mask is not None else None )
2681+ # Per-batch progress update
2682+ if unique_id and PromptServer is not None :
2683+ try :
2684+ PromptServer .instance .send_progress_text (
2685+ f"<tr><td>Resize v2</td><td>batch { current_batch } /{ total_batches } · images { end_idx } /{ B } </td></tr>" ,
2686+ unique_id
2687+ )
2688+ except :
2689+ pass
2690+ else :
2691+ try :
2692+ print (f"[ImageResizeKJv2] batch { current_batch } /{ total_batches } · images { end_idx } /{ B } " )
2693+ except :
2694+ pass
2695+ out_image = torch .cat (chunks , dim = 0 )
2696+ if mask is not None and any (m is not None for m in mask_chunks ):
2697+ out_mask = torch .cat ([m for m in mask_chunks if m is not None ], dim = 0 )
2698+ else :
2699+ out_mask = None
26322700
26332701 # Progress UI
26342702 if unique_id and PromptServer is not None :
0 commit comments