@@ -643,6 +643,18 @@ class MappingSpec:
643
643
dropout : float
644
644
645
645
646
+ def make_layer_factory (spec , mapping ):
647
+ if isinstance (spec .self_attn , GlobalAttentionSpec ):
648
+ return lambda _ : GlobalTransformerLayer (spec .width , spec .d_ff , spec .self_attn .d_head , mapping .width , dropout = spec .dropout )
649
+ elif isinstance (spec .self_attn , NeighborhoodAttentionSpec ):
650
+ return lambda _ : NeighborhoodTransformerLayer (spec .width , spec .d_ff , spec .self_attn .d_head , mapping .width , spec .self_attn .kernel_size , dropout = spec .dropout )
651
+ elif isinstance (spec .self_attn , ShiftedWindowAttentionSpec ):
652
+ return lambda i : ShiftedWindowTransformerLayer (spec .width , spec .d_ff , spec .self_attn .d_head , mapping .width , spec .self_attn .window_size , i , dropout = spec .dropout )
653
+ elif isinstance (spec .self_attn , NoAttentionSpec ):
654
+ return lambda _ : NoAttentionTransformerLayer (spec .width , spec .d_ff , mapping .width , dropout = spec .dropout )
655
+ raise ValueError (f"unsupported self attention spec { spec .self_attn } " )
656
+
657
+
646
658
# Model class
647
659
648
660
class ImageTransformerDenoiserModelV2 (nn .Module ):
@@ -662,16 +674,7 @@ def __init__(self, levels, mapping, in_channels, out_channels, patch_size, num_c
662
674
663
675
self .down_levels , self .up_levels = nn .ModuleList (), nn .ModuleList ()
664
676
for i , spec in enumerate (levels ):
665
- if isinstance (spec .self_attn , GlobalAttentionSpec ):
666
- layer_factory = lambda _ : GlobalTransformerLayer (spec .width , spec .d_ff , spec .self_attn .d_head , mapping .width , dropout = spec .dropout )
667
- elif isinstance (spec .self_attn , NeighborhoodAttentionSpec ):
668
- layer_factory = lambda _ : NeighborhoodTransformerLayer (spec .width , spec .d_ff , spec .self_attn .d_head , mapping .width , spec .self_attn .kernel_size , dropout = spec .dropout )
669
- elif isinstance (spec .self_attn , ShiftedWindowAttentionSpec ):
670
- layer_factory = lambda i : ShiftedWindowTransformerLayer (spec .width , spec .d_ff , spec .self_attn .d_head , mapping .width , spec .self_attn .window_size , i , dropout = spec .dropout )
671
- elif isinstance (spec .self_attn , NoAttentionSpec ):
672
- layer_factory = lambda _ : NoAttentionTransformerLayer (spec .width , spec .d_ff , mapping .width , dropout = spec .dropout )
673
- else :
674
- raise ValueError (f"unsupported self attention spec { spec .self_attn } " )
677
+ layer_factory = self .make_layer_factory (spec , mapping )
675
678
676
679
if i < len (levels ) - 1 :
677
680
self .down_levels .append (Level ([layer_factory (i ) for i in range (spec .depth )]))
0 commit comments