@@ -303,7 +303,7 @@ def INPUT_TYPES(s):
303303 RETURN_TYPES = ("MODEL" ,)
304304 FUNCTION = "patch"
305305
306- CATEGORY = "KJNodes/experimental "
306+ CATEGORY = "KJNodes/torchcompile "
307307 EXPERIMENTAL = True
308308
309309 def parse_blocks (self , blocks_str ):
@@ -378,7 +378,7 @@ def INPUT_TYPES(s):
378378 RETURN_TYPES = ("MODEL" ,)
379379 FUNCTION = "patch"
380380
381- CATEGORY = "KJNodes/experimental "
381+ CATEGORY = "KJNodes/torchcompile "
382382 EXPERIMENTAL = True
383383
384384 def patch (self , model , backend , fullgraph , mode , dynamic , dynamo_cache_size_limit , compile_single_blocks , compile_double_blocks , compile_txt_in , compile_vector_in , compile_final_layer ):
@@ -415,6 +415,51 @@ def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limi
415415 except :
416416 raise RuntimeError ("Failed to compile model" )
417417 return (m , )
418+
419+ class TorchCompileModelWanVideo :
420+ def __init__ (self ):
421+ self ._compiled = False
422+
423+ @classmethod
424+ def INPUT_TYPES (s ):
425+ return {
426+ "required" : {
427+ "model" : ("MODEL" ,),
428+ "backend" : (["inductor" ,"cudagraphs" ], {"default" : "inductor" }),
429+ "fullgraph" : ("BOOLEAN" , {"default" : False , "tooltip" : "Enable full graph mode" }),
430+ "mode" : (["default" , "max-autotune" , "max-autotune-no-cudagraphs" , "reduce-overhead" ], {"default" : "default" }),
431+ "dynamic" : ("BOOLEAN" , {"default" : False , "tooltip" : "Enable dynamic mode" }),
432+ "dynamo_cache_size_limit" : ("INT" , {"default" : 64 , "min" : 0 , "max" : 1024 , "step" : 1 , "tooltip" : "torch._dynamo.config.cache_size_limit" }),
433+ "compile_transformer_blocks" : ("BOOLEAN" , {"default" : True , "tooltip" : "Compile all transformer blocks" }),
434+ },
435+ }
436+ RETURN_TYPES = ("MODEL" ,)
437+ FUNCTION = "patch"
438+
439+ CATEGORY = "KJNodes/torchcompile"
440+ EXPERIMENTAL = True
441+
442+ def patch (self , model , backend , fullgraph , mode , dynamic , dynamo_cache_size_limit , compile_transformer_blocks ):
443+ m = model .clone ()
444+ diffusion_model = m .get_model_object ("diffusion_model" )
445+ torch ._dynamo .config .cache_size_limit = dynamo_cache_size_limit
446+ if not self ._compiled :
447+ try :
448+ if compile_transformer_blocks :
449+ for i , block in enumerate (diffusion_model .blocks ):
450+ compiled_block = torch .compile (block , fullgraph = fullgraph , dynamic = dynamic , backend = backend , mode = mode )
451+ m .add_object_patch (f"diffusion_model.blocks.{ i } " , compiled_block )
452+ self ._compiled = True
453+ compile_settings = {
454+ "backend" : backend ,
455+ "mode" : mode ,
456+ "fullgraph" : fullgraph ,
457+ "dynamic" : dynamic ,
458+ }
459+ setattr (m .model , "compile_settings" , compile_settings )
460+ except :
461+ raise RuntimeError ("Failed to compile model" )
462+ return (m , )
418463
419464class TorchCompileVAE :
420465 def __init__ (self ):
@@ -434,7 +479,7 @@ def INPUT_TYPES(s):
434479 RETURN_TYPES = ("VAE" ,)
435480 FUNCTION = "compile"
436481
437- CATEGORY = "KJNodes/experimental "
482+ CATEGORY = "KJNodes/torchcompile "
438483 EXPERIMENTAL = True
439484
440485 def compile (self , vae , backend , mode , fullgraph , compile_encoder , compile_decoder ):
@@ -495,7 +540,7 @@ def INPUT_TYPES(s):
495540 RETURN_TYPES = ("CONTROL_NET" ,)
496541 FUNCTION = "compile"
497542
498- CATEGORY = "KJNodes/experimental "
543+ CATEGORY = "KJNodes/torchcompile "
499544 EXPERIMENTAL = True
500545
501546 def compile (self , controlnet , backend , mode , fullgraph ):
@@ -528,7 +573,7 @@ def INPUT_TYPES(s):
528573 RETURN_TYPES = ("MODEL" ,)
529574 FUNCTION = "patch"
530575
531- CATEGORY = "KJNodes/experimental "
576+ CATEGORY = "KJNodes/torchcompile "
532577 EXPERIMENTAL = True
533578
534579 def patch (self , model , backend , mode , fullgraph , dynamic ):
@@ -571,7 +616,7 @@ def INPUT_TYPES(s):
571616 RETURN_TYPES = ("MODEL" ,)
572617 FUNCTION = "patch"
573618
574- CATEGORY = "KJNodes/experimental "
619+ CATEGORY = "KJNodes/torchcompile "
575620 EXPERIMENTAL = True
576621
577622 def patch (self , model , backend , mode , fullgraph , dynamic , dynamo_cache_size_limit ):
0 commit comments