Skip to content

Commit c41e844

Browse files
authored
Add Bloom Model (#1382)
* Add Bloom Model * Add Backbone test and some fixes * Add BloomBackbone to keras_nlp.models * Fix a typo in layer naming * Remove self.built = True * Revert "Remove self.built = True" This reverts commit 889f204. * Add built=True to MLP layer * Add Checkpoint conversion script * Change LayerNorm name * Fix typo * Fix getting HF model output * Add and to allclose function in checkpoint conversion script * Remove allclose check * Add doc for bloom * Write batch size instead of _ * Rename out_dense to output_dense * Rename out_dense to output_dense * Format to 80 chars and remove unnecessery check * Remove exporting BloomDecoder * Add intermediate_dim Arg * Format the code * Remove unnecessery comment * Use keras gelu * Remove MLP layer and implement it inside BloomDecoder * Split q k v heads * Remove shapes comments * Revert "Split q k v heads" This reverts commit 2d03d2c. * Revert "Revert "Split q k v heads"" This reverts commit 531b1ff. * Revert "Remove shapes comments" This reverts commit 2eeb5f4. * Add bias axes * Add bias axes to the correct axes * Update conversion script for splitting q,k,v * format the code * Rename _dropout -> _dropout_layer * use clone initializer instead of paasing str name * Serialize kernal & bais initializers * Format the code * Add alibi_bias_max to _build_alibi_tensor function * Format the code * Lowercase vairiable names
1 parent f89bf90 commit c41e844

File tree

7 files changed

+829
-0
lines changed

7 files changed

+829
-0
lines changed

keras_nlp/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
)
3636
from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor
3737
from keras_nlp.models.bert.bert_tokenizer import BertTokenizer
38+
from keras_nlp.models.bloom.bloom_backbone import BloomBackbone
3839
from keras_nlp.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone
3940
from keras_nlp.models.deberta_v3.deberta_v3_classifier import (
4041
DebertaV3Classifier,

keras_nlp/models/bloom/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import math
15+
16+
from keras_nlp.backend import keras
17+
from keras_nlp.backend import ops
18+
from keras_nlp.utils.keras_utils import clone_initializer
19+
20+
21+
class BloomAttention(keras.layers.Layer):
22+
def __init__(
23+
self,
24+
num_heads,
25+
dropout=0.0,
26+
kernel_initializer="glorot_uniform",
27+
bias_initializer="zeros",
28+
**kwargs,
29+
):
30+
super().__init__(**kwargs)
31+
self.num_heads = num_heads
32+
self.dropout = dropout
33+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
34+
self.bias_initializer = keras.initializers.get(bias_initializer)
35+
36+
def build(self, inputs_shape):
37+
batch_size, seq_length, hidden_dim = inputs_shape
38+
39+
self.head_dim = hidden_dim // self.num_heads
40+
41+
# Layer-wise attention scaling
42+
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
43+
44+
self._query_dense = keras.layers.EinsumDense(
45+
equation="btm,mnh->btnh",
46+
output_shape=(None, self.num_heads, self.head_dim),
47+
bias_axes="nh",
48+
kernel_initializer=clone_initializer(self.kernel_initializer),
49+
bias_initializer=clone_initializer(self.bias_initializer),
50+
dtype=self.dtype_policy,
51+
name="query_dense",
52+
)
53+
self._query_dense.build(inputs_shape)
54+
55+
self._key_dense = keras.layers.EinsumDense(
56+
equation="bsm,mnh->bsnh",
57+
output_shape=(None, self.num_heads, self.head_dim),
58+
bias_axes="nh",
59+
kernel_initializer=clone_initializer(self.kernel_initializer),
60+
bias_initializer=clone_initializer(self.bias_initializer),
61+
dtype=self.dtype_policy,
62+
name="key_dense",
63+
)
64+
self._key_dense.build(inputs_shape)
65+
66+
self._value_dense = keras.layers.EinsumDense(
67+
equation="bsm,mnh->bsnh",
68+
output_shape=(None, self.num_heads, self.head_dim),
69+
bias_axes="nh",
70+
kernel_initializer=clone_initializer(self.kernel_initializer),
71+
bias_initializer=clone_initializer(self.bias_initializer),
72+
dtype=self.dtype_policy,
73+
name="value_dense",
74+
)
75+
self._value_dense.build(inputs_shape)
76+
77+
self._output_dense = keras.layers.Dense(
78+
hidden_dim,
79+
kernel_initializer=clone_initializer(self.kernel_initializer),
80+
bias_initializer=clone_initializer(self.bias_initializer),
81+
dtype=self.dtype_policy,
82+
name="output_dense",
83+
)
84+
self._output_dense.build(inputs_shape)
85+
86+
self._dropout_layer = keras.layers.Dropout(
87+
rate=self.dropout, dtype=self.dtype_policy, name="dropout"
88+
)
89+
self._softmax = keras.layers.Softmax(
90+
dtype=self.dtype_policy, name="softmax"
91+
)
92+
93+
self.built = True
94+
95+
@staticmethod
96+
def _build_alibi_tensor(num_heads, seq_length, alibi_bias_max=8):
97+
# this function is adopted from fairseq
98+
# https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
99+
def get_slopes(n):
100+
def get_slopes_power_of_2(n):
101+
start = 2 ** (
102+
-(2 ** -(math.log2(n) - math.log2(alibi_bias_max)))
103+
)
104+
ratio = start
105+
return [start * ratio**i for i in range(n)]
106+
107+
if math.log2(n).is_integer():
108+
return get_slopes_power_of_2(n)
109+
else:
110+
closest_power_of_2 = 2 ** math.floor(math.log2(n))
111+
return (
112+
get_slopes_power_of_2(closest_power_of_2)
113+
+ get_slopes(2 * closest_power_of_2)[0::2][
114+
: n - closest_power_of_2
115+
]
116+
)
117+
118+
slopes = ops.convert_to_tensor(get_slopes(num_heads), dtype=float)
119+
slopes = ops.expand_dims(slopes, 1)
120+
121+
alibi = slopes * ops.expand_dims(ops.arange(seq_length, dtype=float), 0)
122+
alibi = ops.expand_dims(alibi, 1)
123+
alibi = ops.expand_dims(alibi, 0)
124+
125+
return alibi
126+
127+
def call(
128+
self,
129+
hidden_states,
130+
attention_mask=None,
131+
cache=None,
132+
cache_update_index=None,
133+
):
134+
batch_size, seq_length, hidden_dim = ops.shape(hidden_states)
135+
136+
query = self._query_dense(hidden_states)
137+
key = self._key_dense(hidden_states)
138+
value = self._value_dense(hidden_states)
139+
140+
if cache is not None:
141+
key_cache = cache[:, 0, ...]
142+
value_cache = cache[:, 1, ...]
143+
if cache_update_index is None:
144+
key = key_cache
145+
value = value_cache
146+
else:
147+
start = [0, cache_update_index, 0, 0]
148+
key = ops.slice_update(key_cache, start, key)
149+
value = ops.slice_update(value_cache, start, value)
150+
cache = ops.stack((key, value), axis=1)
151+
else:
152+
if cache_update_index is not None:
153+
raise ValueError(
154+
"`cache_update_index` should not be set if `cache` is "
155+
f"`None`. Received: cache={cache}, "
156+
f"cache_update_index={cache_update_index}"
157+
)
158+
159+
# query (batch_size, num_heads, query_length, head_dim)
160+
query = ops.transpose(query, [0, 2, 1, 3])
161+
# value (batch_size, num_heads, kv_length, head_dim)
162+
value = ops.transpose(value, [0, 2, 1, 3])
163+
# key (batch_size, num_heads, head_dim, kv_length)
164+
key = ops.transpose(key, [0, 2, 3, 1])
165+
166+
alibi = self._build_alibi_tensor(
167+
num_heads=self.num_heads, seq_length=seq_length
168+
)
169+
170+
scores = (
171+
ops.matmul(query, key) * self.inv_norm_factor + alibi
172+
) # [batch_size, num_heads, query_length, kv_length]
173+
174+
scores = self._softmax(scores, ops.expand_dims(attention_mask, 1))
175+
176+
scores = self._dropout_layer(scores)
177+
178+
attention_output = ops.matmul(
179+
scores, value
180+
) # [batch_size, num_heads, query_length, head_dim]
181+
182+
attention_output = ops.transpose(
183+
attention_output, [0, 2, 1, 3]
184+
) # [batch_size, query_length, num_heads, head_dim]
185+
attention_output = ops.reshape(
186+
attention_output,
187+
[batch_size, seq_length, self.num_heads * self.head_dim],
188+
) # [batch_size, query_length, hidden_dim]
189+
190+
attention_output = self._output_dense(attention_output)
191+
attention_output = self._dropout_layer(attention_output)
192+
193+
if cache is not None:
194+
return attention_output, cache
195+
196+
return attention_output
197+
198+
def get_config(self):
199+
config = super().get_config()
200+
config.update(
201+
{
202+
"num_heads": self.num_heads,
203+
"dropout": self.dropout,
204+
"kernel_initializer": keras.initializers.serialize(
205+
self.kernel_initializer
206+
),
207+
"bias_initializer": keras.initializers.serialize(
208+
self.bias_initializer
209+
),
210+
}
211+
)
212+
return config
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from keras_nlp.api_export import keras_nlp_export
15+
from keras_nlp.backend import keras
16+
from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding
17+
from keras_nlp.models.backbone import Backbone
18+
from keras_nlp.models.bloom.bloom_decoder import BloomDecoder
19+
20+
21+
def _bloom_kernel_initializer(stddev=0.02):
22+
return keras.initializers.RandomNormal(stddev=stddev)
23+
24+
25+
@keras_nlp_export("keras_nlp.models.BloomBackbone")
26+
class BloomBackbone(Backbone):
27+
"""A Bloom decoder network.
28+
29+
This network implements a Transformer-based decoder network, BigScience
30+
Language Open-science Open-access Multilingual (BLOOM), as descriped in
31+
["BLOOM: A 176B-Parameter Open-Access Multilingual Language Model"](https://arxiv.org/pdf/2211.05100.pdf).
32+
33+
The default constructor gives a fully customizable, randomly initialized
34+
Bloom model with any number of layers, heads, and embedding dimensions. To
35+
load preset architectures and weights, use the `from_preset()` constructor.
36+
37+
Disclaimer: Pre-trained models are provided on an "as is" basis, without
38+
warranties or conditions of any kind.
39+
40+
Args:
41+
vocabulary_size: int. The size of the token vocabulary.
42+
num_layers: int. The number of transformer layers.
43+
num_heads: int. The number of attention heads for each transformer.
44+
The hidden size must be divisible by the number of attention heads.
45+
hidden_dim: int. The dimensionality of the embeddings and hidden states.
46+
intermediate_dim: int. The output dimension of the first Dense layer in
47+
the MLP network of each transformer.
48+
dropout: float. Dropout probability for the Transformer decoder.
49+
layer_norm_epsilon: float. Epsilon for the layer normalization layers in
50+
the transformer decoder.
51+
max_sequence_length: int. The maximum sequence length that this decoder
52+
can consume.
53+
54+
Examples:
55+
```python
56+
input_data = {
57+
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
58+
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
59+
}
60+
61+
# Randomly initialized BLOOM decoder with a custom config.
62+
model = keras_nlp.models.BloomBackbone(
63+
vocabulary_size=10,
64+
num_layers=2,
65+
num_heads=2,
66+
hidden_dim=32,
67+
intermediate_dim=32*4,
68+
dropout=0.0,
69+
layer_norm_epsilon=1e-5,
70+
max_sequence_length=128,
71+
)
72+
model(input_data)
73+
```
74+
75+
"""
76+
77+
def __init__(
78+
self,
79+
vocabulary_size,
80+
num_layers,
81+
num_heads,
82+
hidden_dim,
83+
intermediate_dim,
84+
dropout=0.0,
85+
layer_norm_epsilon=1e-5,
86+
max_sequence_length=512,
87+
**kwargs,
88+
):
89+
token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids")
90+
padding_mask = keras.Input(
91+
shape=(None,), dtype="int32", name="padding_mask"
92+
)
93+
94+
# Embed tokens
95+
token_embedding_layer = ReversibleEmbedding(
96+
input_dim=vocabulary_size,
97+
output_dim=hidden_dim,
98+
embeddings_initializer=_bloom_kernel_initializer(stddev=0.02),
99+
tie_weights=False,
100+
name="token_embedding",
101+
)
102+
token_embedding = token_embedding_layer(token_ids)
103+
104+
x = keras.layers.LayerNormalization(
105+
epsilon=layer_norm_epsilon, name="token_embedding_layernorm"
106+
)(token_embedding)
107+
108+
for i in range(num_layers):
109+
x = BloomDecoder(
110+
num_heads=num_heads,
111+
intermediate_dim=intermediate_dim,
112+
dropout=dropout,
113+
layer_norm_epsilon=layer_norm_epsilon,
114+
name=f"transformer_layer_{i}",
115+
)(x, decoder_padding_mask=padding_mask)
116+
117+
sequence_output = keras.layers.LayerNormalization(
118+
epsilon=layer_norm_epsilon, name="final_layernorm"
119+
)(x)
120+
121+
super().__init__(
122+
inputs={
123+
"token_ids": token_ids,
124+
"padding_mask": padding_mask,
125+
},
126+
outputs=sequence_output,
127+
**kwargs,
128+
)
129+
self.vocabulary_size = vocabulary_size
130+
self.num_layers = num_layers
131+
self.num_heads = num_heads
132+
self.hidden_dim = hidden_dim
133+
self.intermediate_dim = intermediate_dim
134+
self.dropout = dropout
135+
self.layer_norm_epsilon = layer_norm_epsilon
136+
self.max_sequence_length = max_sequence_length
137+
self.token_embedding = token_embedding_layer
138+
139+
def get_config(self):
140+
config = super().get_config()
141+
config.update(
142+
{
143+
"vocabulary_size": self.vocabulary_size,
144+
"num_layers": self.num_layers,
145+
"num_heads": self.num_heads,
146+
"hidden_dim": self.hidden_dim,
147+
"intermediate_dim": self.intermediate_dim,
148+
"dropout": self.dropout,
149+
"layer_norm_epsilon": self.layer_norm_epsilon,
150+
"max_sequence_length": self.max_sequence_length,
151+
}
152+
)
153+
return config

0 commit comments

Comments
 (0)