Skip to content

Commit 08fd961

Browse files
committed
Update to bring Flowchain in for method chaining
1 parent 55a6ed2 commit 08fd961

5 files changed

Lines changed: 22 additions & 42 deletions

File tree

gato/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
from gato.config import GatoConfig
22
from gato.models import Gato
3+
from flowchain import enable_tensor_chaining
4+
5+
enable_tensor_chaining()

gato/models/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ def call(self, inputs, training=None, mask=None):
3636

3737
ones = tf.ones((input_ids.shape[0], 1, self.config.layer_width), dtype=tf.float32)
3838
image_embed = self.image_embedding((input_ids, (row_pos, col_pos)), training=training)
39-
image_embed *= tf.matmul(encoding[..., 0], ones, transpose_a=True) # image patch masking
39+
image_embed *= encoding[..., 0].transpose().matmul(ones) # image patch masking
4040

4141
# continuous value takes from first value of input_ids
4242
continuous_embed = self.continuous_encoding(input_ids[..., 0])
4343
continuous_embed = self.discrete_embedding(continuous_embed)
44-
continuous_embed *= tf.matmul(encoding[..., 1], ones, transpose_a=True) # continuous value masking
44+
continuous_embed *= encoding[..., 1].transpose().matmul(ones) # continuous value masking
4545

4646
discrete_embed = self.discrete_embedding(input_ids[..., 0])
47-
discrete_embed *= tf.matmul(encoding[..., 2], ones, transpose_a=True) # discrete value masking
47+
discrete_embed *= encoding[..., 2].transpose().matmul(ones) # discrete value masking
4848

4949
# Appendix C.3. Position Encodings > Local Observation Position Encodings
5050
# add local observation position encodings
@@ -101,7 +101,7 @@ def call(self, inputs, training=None, mask=None):
101101
patch_size = self.config.img_patch_size
102102
depth = self.config.input_dim // (patch_size * patch_size)
103103

104-
x = tf.reshape(input_ids, (-1, input_ids.shape[1], patch_size, patch_size, depth))
104+
x = input_ids.reshape((-1, input_ids.shape[1], patch_size, patch_size, depth))
105105
x = self.residual_embedding(x)
106106
x = self.pos_encoding((x, (row_pos, col_pos)))
107107
return x

gato/models/embedding.py

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,13 @@
77

88
def _randomized_positions(from_v, to_v):
99
pos = tf.random.uniform(from_v.shape, minval=0, maxval=1, dtype=tf.float32)
10-
pos = pos * tf.cast(to_v - from_v, dtype=tf.float32)
11-
pos = tf.cast(pos, dtype=tf.int32)
12-
return pos
10+
pos = pos * (to_v - from_v).cast(tf.float32)
11+
return pos.cast(tf.int32)
1312

1413

1514
def _rounded_mean_positions(from_v, to_v):
16-
pos = tf.cast(from_v + to_v, tf.float32)
17-
pos = pos / 2
18-
pos = tf.round(pos)
19-
return pos
20-
21-
22-
def _broadcast(row_pos, col_pos, row_ones, col_ones):
23-
# broadcast (5,) to (20,) with column-axis
24-
row_pos = tf.expand_dims(row_pos, 1)
25-
row_pos = tf.matmul(row_pos, col_ones, transpose_b=True)
26-
row_pos = tf.reshape(row_pos, (-1,))
27-
row_pos = tf.stop_gradient(row_pos)
28-
29-
# broadcast (4,) to (20,) with row-axis
30-
col_pos = tf.expand_dims(col_pos, 1)
31-
col_pos = tf.matmul(row_ones, col_pos, transpose_b=True)
32-
col_pos = tf.reshape(col_pos, (-1,))
33-
col_pos = tf.stop_gradient(col_pos)
34-
35-
return row_pos, col_pos
15+
pos = (from_v + to_v).cast(tf.float32) / 2.
16+
return pos.round()
3617

3718

3819
class PatchPositionEncoding(layers.Layer):
@@ -57,7 +38,7 @@ def __init__(self,
5738
self.col_embedding = layers.Embedding(self.discretize_depth, self.embedding_dim, name='col_embedding')
5839

5940
def _discretize(self, pos):
60-
return tf.round(pos * self.discretize_depth)
41+
return (pos * self.discretize_depth).round()
6142

6243
def _discretize_interval(self, interval):
6344
pos_from, pos_to = interval
@@ -83,12 +64,9 @@ def call(self, inputs, *args, **kwargs):
8364
row_pos = _rounded_mean_positions(row_pos_from, row_pos_to)
8465
col_pos = _rounded_mean_positions(col_pos_from, col_pos_to)
8566

86-
col_pos = tf.cast(col_pos, dtype=tf.int32)
87-
row_pos = tf.cast(row_pos, dtype=tf.int32)
88-
8967
# > Once row and column position encoding are retrieved from the embedding table,
9068
# > they are added onto the token embedding produced by the resnet embedding function.
91-
return input_ids + self.row_embedding(row_pos) + self.col_embedding(col_pos)
69+
return input_ids + self.row_embedding(row_pos.cast(tf.int32)) + self.col_embedding(col_pos.cast(tf.int32))
9270

9371
def get_config(self):
9472
config = super(PatchPositionEncoding, self).get_config()
@@ -127,10 +105,10 @@ def call(self, inputs, *args, **kwargs):
127105

128106
residual = self.conv_proj(self.gn_proj(x))
129107

130-
x = tf.nn.gelu(self.gn1(x))
108+
x = self.gn1(x).gelu()
131109
x = self.conv1(x)
132110

133-
x = tf.nn.gelu(self.gn2(x))
111+
x = self.gn2(x).gelu()
134112
x = self.conv2(x)
135113

136114
return x + residual
@@ -185,7 +163,7 @@ def call(self, inputs, *args, **kwargs):
185163
x = block(x)
186164
if self.conv_proj is not None:
187165
x = self.conv_proj(x)
188-
x = tf.reshape(x, shape=(-1, inputs.shape[1], self.config.layer_width))
166+
x = x.reshape((-1, inputs.shape[1], self.config.layer_width))
189167
return x
190168

191169
def get_config(self):
@@ -222,8 +200,7 @@ def call(self, inputs, *args, **kwargs):
222200
embed = self.embedding(obs_pos)
223201

224202
ones = tf.ones((embed.shape[0], 1, self.config.layer_width), dtype=tf.float32)
225-
obs_mask = tf.cast(obs_mask, dtype=tf.float32)
226-
obs_mask = tf.matmul(obs_mask, ones, transpose_a=True)
203+
obs_mask = obs_mask.cast(tf.float32).transpose().matmul(ones)
227204
return embed * obs_mask
228205

229206
def get_config(self):

gato/models/tokenizers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77

88
def mu_law_encode(x, mu=100, m=256):
99
# Appendix B. Agent Data Tokenization Details
10-
sign = tf.math.sign(x)
11-
numerator = tf.math.log(tf.abs(x) * mu + 1.0)
10+
numerator = tf.math.log(x.abs() * mu + 1.0)
1211
denominator = tf.math.log(m * mu + 1.0)
13-
return (numerator / denominator) * sign
12+
return (numerator / denominator) * x.sign()
1413

1514

1615
def tokenize_continuous_values(x, mu=100, m=256, bins=1024, shift=None):
@@ -21,7 +20,7 @@ def tokenize_continuous_values(x, mu=100, m=256, bins=1024, shift=None):
2120
# > We use 1024 bins and shift the resulting integers
2221
# > so they are not overlapping with the ones used for text tokens.
2322
c = (c + 1) * (bins / 2)
24-
c = tf.cast(c, tf.int32)
23+
c = c.cast(tf.int32)
2524
if shift is not None:
2625
c += shift
2726
return c

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name='gato-tf',
5-
version='0.0.2',
5+
version='0.0.3',
66
description='Unofficial Gato: A Generalist Agent',
77
url='https://github.com/OrigamiDream/gato.git',
88
author='OrigamiDream',
@@ -11,6 +11,7 @@
1111
packages=find_packages(),
1212
install_requires=[
1313
'tensorflow>=2.11',
14+
'flowchain>=0.0.4'
1415
],
1516
keywords=[
1617
'deep learning',

0 commit comments

Comments
 (0)