Skip to content

Commit cc7dcff

Browse files
No public description
PiperOrigin-RevId: 711491914
1 parent 1bdb87d commit cc7dcff

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

Diff for: official/nlp/modeling/layers/block_sparse_attention.py

+11
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,17 @@ def __init__(
8484
"sigmoid_attn_bias must be specified for sigmoid attn."
8585
)
8686

87+
def get_config(self):
88+
config = super().get_config()
89+
config.update({
90+
"src_block_size": self._src_block_size,
91+
"tgt_block_size": self._tgt_block_size,
92+
"use_sigmoid_attn": self._use_sigmoid_attn,
93+
"sigmoid_attn_bias": self._sigmoid_attn_bias,
94+
"num_kv_heads": self._num_kv_heads,
95+
})
96+
return config
97+
8798
def _build_from_signature(self, query, value, key=None):
8899
# pytype: disable=attribute-error
89100
super()._build_from_signature(query, value, key)

Diff for: official/nlp/modeling/layers/multi_query_attention.py

+5
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ def __init__(self, num_kv_heads=None, **kwargs):
9393
self._num_heads % self._num_kv_heads == 0
9494
), "num_kv_heads needs to divide num_heads exactly."
9595

96+
def get_config(self):
97+
config = super().get_config()
98+
config.update({"num_kv_heads": self._num_kv_heads})
99+
return config
100+
96101
def _build_from_signature(
97102
self,
98103
query: Union[tf.Tensor, tf.TensorShape],

0 commit comments

Comments
 (0)