Skip to content

Commit 7fd3b4e

Browse files
fix
Created using spr 1.3.4
2 parents 44683dd + d78b9c3 commit 7fd3b4e

1 file changed

Lines changed: 10 additions & 0 deletions

File tree

gematria/model/python/model_blocks.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,13 @@ def call(self, layer_inputs: tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
7878
residual_part = self._linear_transformation(residual_part)
7979

8080
return tf.math.add(output_part, residual_part, name=self.name)
81+
82+
83+
class CastLayer(tf_keras.layers.Layer):
84+
85+
def __init__(self, dtype, **kwargs):
86+
super().__init__(**kwargs)
87+
self._dtype = dtype
88+
89+
def call(self, input_tensor):
90+
return tf.cast(input_tensor, dtype=self._dtype)

0 commit comments

Comments
 (0)