Skip to content

Commit ca3da34

Browse files
committed
handle other types being passed to attn_types keyword arg
1 parent 5e2eff6 commit ca3da34

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

dalle_pytorch/transformer.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ def exists(val):
1717
def default(val, d):
1818
return val if exists(val) else d
1919

20-
def cast_tuple(val, depth):
20+
def cast_tuple(val, depth = 1):
21+
if isinstance(val, list):
22+
val = tuple(val)
2123
return val if isinstance(val, tuple) else (val,) * depth
2224

2325
# classes
@@ -72,7 +74,9 @@ def __init__(
7274
super().__init__()
7375
layers = nn.ModuleList([])
7476
sparse_layer = cast_tuple(sparse_attn, depth)
77+
7578
attn_types = default(attn_types, ('full',))
79+
attn_types = cast_tuple(attn_types)
7680
attn_type_layer = islice(cycle(attn_types), depth)
7781

7882
for _, sparse_attn, attn_type in zip(range(depth), sparse_layer, attn_type_layer):

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'dalle-pytorch',
55
packages = find_packages(),
6-
version = '0.0.56',
6+
version = '0.0.58',
77
license='MIT',
88
description = 'DALL-E - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)