Skip to content

Commit 36dc02d

Browse files
committed
Fix up loop-binding issues in ImageTransformerV2
1 parent c3f6608 commit 36dc02d

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

Diff for: k_diffusion/models/image_transformer_v2.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,18 @@ class MappingSpec:
643643
dropout: float
644644

645645

646+
def make_layer_factory(spec, mapping):
647+
if isinstance(spec.self_attn, GlobalAttentionSpec):
648+
return lambda _: GlobalTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, dropout=spec.dropout)
649+
elif isinstance(spec.self_attn, NeighborhoodAttentionSpec):
650+
return lambda _: NeighborhoodTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.kernel_size, dropout=spec.dropout)
651+
elif isinstance(spec.self_attn, ShiftedWindowAttentionSpec):
652+
return lambda i: ShiftedWindowTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.window_size, i, dropout=spec.dropout)
653+
elif isinstance(spec.self_attn, NoAttentionSpec):
654+
return lambda _: NoAttentionTransformerLayer(spec.width, spec.d_ff, mapping.width, dropout=spec.dropout)
655+
raise ValueError(f"unsupported self attention spec {spec.self_attn}")
656+
657+
646658
# Model class
647659

648660
class ImageTransformerDenoiserModelV2(nn.Module):
@@ -662,16 +674,7 @@ def __init__(self, levels, mapping, in_channels, out_channels, patch_size, num_c
662674

663675
self.down_levels, self.up_levels = nn.ModuleList(), nn.ModuleList()
664676
for i, spec in enumerate(levels):
665-
if isinstance(spec.self_attn, GlobalAttentionSpec):
666-
layer_factory = lambda _: GlobalTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, dropout=spec.dropout)
667-
elif isinstance(spec.self_attn, NeighborhoodAttentionSpec):
668-
layer_factory = lambda _: NeighborhoodTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.kernel_size, dropout=spec.dropout)
669-
elif isinstance(spec.self_attn, ShiftedWindowAttentionSpec):
670-
layer_factory = lambda i: ShiftedWindowTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.window_size, i, dropout=spec.dropout)
671-
elif isinstance(spec.self_attn, NoAttentionSpec):
672-
layer_factory = lambda _: NoAttentionTransformerLayer(spec.width, spec.d_ff, mapping.width, dropout=spec.dropout)
673-
else:
674-
raise ValueError(f"unsupported self attention spec {spec.self_attn}")
677+
layer_factory = self.make_layer_factory(spec, mapping)
675678

676679
if i < len(levels) - 1:
677680
self.down_levels.append(Level([layer_factory(i) for i in range(spec.depth)]))

0 commit comments

Comments
 (0)