Skip to content

Commit e891bf8

Browse files
[𝘀𝗽𝗿] changes introduced through rebase
Created using spr 1.3.4 [skip ci]
1 parent 9a66d59 commit e891bf8

1 file changed

Lines changed: 10 additions & 0 deletions

File tree

gematria/model/python/model_blocks.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,13 @@ def call(self, layer_inputs: tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
7878
residual_part = self._linear_transformation(residual_part)
7979

8080
return tf.math.add(output_part, residual_part, name=self.name)
81+
82+
83+
class CastLayer(tf_keras.layers.Layer):
84+
85+
def __init__(self, dtype, **kwargs):
86+
super().__init__(**kwargs)
87+
self._dtype = dtype
88+
89+
def call(self, input_tensor):
90+
return tf.cast(input_tensor, dtype=self._dtype)

0 commit comments

Comments
 (0)