Skip to content

Commit ca44960

Browse files
authored
Merge pull request #21 from Rishit-dagli/Rishit-dagli-patch-1
Fix error with tf.function
2 parents 4d3b9b0 + c97c3b4 commit ca44960

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

perceiver/perceiver.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ def __init__(
7676
# self.existing_layers = get_latent_attn()(self.existing_layers)
7777
# self.existing_layers = get_latent_ff()(self.existing_layers)
7878

79+
self.existing_layers = tf.keras.Sequential(self.existing_layers)
80+
7981
self.to_logits = tf.keras.Sequential(
8082
[
8183
tf.keras.layers.LayerNormalization(axis=-1),
@@ -103,7 +105,7 @@ def call(self, data, mask=None):
103105

104106
x = repeat(self.latents, "n d -> b n d", b=b)
105107

106-
x = tf.keras.Sequential(self.existing_layers)(x)
108+
x = self.existing_layers(x)
107109

108110
x = tf.math.reduce_mean(x, axis=-2)
109111
return self.to_logits(x)

perceiver/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.1.1"
1+
__version__ = "0.1.2"

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
setup(
99
name="perceiver",
10-
version="0.1.1",
10+
version="0.1.2",
1111
description="Implement of Perceiver, General Perception with Iterative Attention in TensorFlow",
1212
packages=["perceiver"],
1313
long_description=long_description,

0 commit comments

Comments
 (0)