@@ -1086,6 +1086,7 @@ def INPUT_TYPES(s):
10861086 "optional" : {
10871087 "mask" : ("MASK" , ),
10881088 "scheduler" : (["FlowMatchEulerDiscreteScheduler" , "ConsistencyFlowMatchEulerDiscreteScheduler" ],),
1089+ "force_offload" : ("BOOLEAN" , {"default" : True , "tooltip" : "Offloads the model to the offload device once the process is done." }),
10891090 }
10901091 }
10911092
@@ -1094,7 +1095,8 @@ def INPUT_TYPES(s):
10941095 FUNCTION = "process"
10951096 CATEGORY = "Hunyuan3DWrapper"
10961097
1097- def process (self , pipeline , image , steps , guidance_scale , seed , mask = None , front = None , back = None , left = None , right = None , scheduler = "FlowMatchEulerDiscreteScheduler" ):
1098+ def process (self , pipeline , image , steps , guidance_scale , seed , mask = None , front = None , back = None , left = None , right = None ,
1099+ scheduler = "FlowMatchEulerDiscreteScheduler" , force_offload = True ):
10981100
10991101 mm .unload_all_models ()
11001102 mm .soft_empty_cache ()
@@ -1136,8 +1138,9 @@ def process(self, pipeline, image, steps, guidance_scale, seed, mask=None, front
11361138 torch .cuda .reset_peak_memory_stats (device )
11371139 except :
11381140 pass
1139-
1140- pipeline .to (offload_device )
1141+
1142+ if not force_offload :
1143+ pipeline .to (offload_device )
11411144
11421145 return (latents , )
11431146
@@ -1254,6 +1257,8 @@ def INPUT_TYPES(s):
12541257 },
12551258 "optional" : {
12561259 "enable_flash_vdm" : ("BOOLEAN" , {"default" : True }),
1260+ "force_offload" : ("BOOLEAN" , {"default" : True , "tooltip" : "Offloads the model to the offload device once the process is done." }),
1261+
12571262 }
12581263 }
12591264
@@ -1262,7 +1267,7 @@ def INPUT_TYPES(s):
12621267 FUNCTION = "process"
12631268 CATEGORY = "Hunyuan3DWrapper"
12641269
1265- def process (self , vae , latents , box_v , octree_resolution , mc_level , num_chunks , mc_algo , enable_flash_vdm = True ):
1270+ def process (self , vae , latents , box_v , octree_resolution , mc_level , num_chunks , mc_algo , enable_flash_vdm = True , force_offload = True ):
12661271 device = mm .get_torch_device ()
12671272 offload_device = mm .unet_offload_device ()
12681273
@@ -1283,7 +1288,8 @@ def process(self, vae, latents, box_v, octree_resolution, mc_level, num_chunks,
12831288 octree_resolution = octree_resolution ,
12841289 mc_algo = mc_algo ,
12851290 )[0 ]
1286- vae .to (offload_device )
1291+ if force_offload :
1292+ vae .to (offload_device )
12871293
12881294 outputs .mesh_f = outputs .mesh_f [:, ::- 1 ]
12891295 mesh_output = Trimesh .Trimesh (outputs .mesh_v , outputs .mesh_f )
0 commit comments