Skip to content

Commit 1df6836

Browse files
Replace to_linen function with NNX module
1 parent 6d2a65b commit 1df6836

File tree

2 files changed

+131
-158
lines changed

2 files changed

+131
-158
lines changed

MaxText/layers/gpt3.py

Lines changed: 76 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
from MaxText import max_logging
3333
from MaxText.common_types import Config, DType, AxisNames, BATCH, LENGTH, EMBED, HEAD, D_KV, Array, MODEL_MODE_TRAIN
3434
from MaxText.layers import initializers, nnx_wrappers
35-
from MaxText.layers.linears import mlp_block
35+
from MaxText.layers.linears import DenseGeneral, MlpBlock
3636
from MaxText.layers import models
3737
from MaxText.layers import quantizations
38-
from MaxText.layers.attentions import KVQuant, attention_op_as_linen, dense_general
38+
from MaxText.layers.attentions import AttentionOp, KVQuant
3939
from MaxText.layers.initializers import Initializer, NdInitializer, nd_dense_init
4040
from MaxText.layers.quantizations import AqtQuantization as Quant
4141

@@ -166,7 +166,7 @@ class Gpt3MultiHeadAttention(nnx.Module):
166166
head_dim: dimension of each head.
167167
max_target_length: maximum length of output
168168
max_prefill_predict_length: size of the maximum prefill
169-
self.mesh: device self.mesh
169+
mesh: device self.mesh
170170
dtype: the dtype of the computation.
171171
dropout_rate: dropout rate
172172
kernel_init: initializer for the kernel of the Dense layers.
@@ -180,30 +180,32 @@ class Gpt3MultiHeadAttention(nnx.Module):
180180
"""
181181

182182
def __init__(
183-
self,
184-
config: Config,
185-
num_heads: int,
186-
head_dim: int,
187-
max_target_length: int,
188-
max_prefill_predict_length: int,
189-
mesh: Mesh,
190-
attention_kernel: str,
191-
dtype: DType = jnp.float32,
192-
weight_dtype: DType = jnp.float32,
193-
dropout_rate: float = 0.0,
194-
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal"),
195-
float32_qk_product: bool = False, # computes logits in float32 for stability.
196-
float32_logits: bool = True, # cast logits in float32 for stability.
197-
fused_qkv: bool = True,
198-
quant: Optional[Quant] = None,
199-
kv_quant: Optional[KVQuant] = None,
200-
use_bias: bool = True,
201-
input_axis_names: AxisNames = (BATCH, LENGTH, EMBED),
202-
query_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
203-
key_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
204-
value_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
205-
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
206-
**kwargs: Any
183+
self,
184+
config: Config,
185+
num_heads: int,
186+
feature_dim: Array,
187+
head_dim: int,
188+
max_target_length: int,
189+
max_prefill_predict_length: int,
190+
mesh: Mesh,
191+
attention_kernel: str,
192+
dtype: DType = jnp.float32,
193+
weight_dtype: DType = jnp.float32,
194+
dropout_rate: float = 0.0,
195+
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal"),
196+
float32_qk_product: bool = False, # computes logits in float32 for stability.
197+
float32_logits: bool = True, # cast logits in float32 for stability.
198+
fused_qkv: bool = True,
199+
quant: Optional[Quant] = None,
200+
kv_quant: Optional[KVQuant] = None,
201+
use_bias: bool = True,
202+
input_axis_names: AxisNames = (BATCH, LENGTH, EMBED),
203+
query_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
204+
key_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
205+
value_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
206+
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
207+
rngs: Optional[nnx.Rngs] = None,
208+
**kwargs: Any,
207209
):
208210
self.config = config
209211
self.num_heads = num_heads
@@ -227,59 +229,48 @@ def __init__(
227229
self.key_axis_names = key_axis_names
228230
self.value_axis_names = value_axis_names
229231
self.out_axis_names = out_axis_names
230-
231-
def qkv_projection(self, inputs: Array, proj_name: str):
232-
"""Fused QKV projection"""
233-
234-
qkv_proj = dense_general(
235-
inputs_shape=inputs.shape,
232+
self.rngs = rngs if rngs is not None else kwargs.get("rngs", nnx.Rngs(0))
233+
print(f'feature_dim: {feature_dim}')
234+
self.qkv_projection_layer = self.create_projection_layer(feature_dim, ("embed", "qkv", "heads", "kv"))
235+
self.q_projection_layer = self.create_projection_layer(feature_dim, ("embed", "heads", "kv"))
236+
self.k_projection_layer = self.create_projection_layer(feature_dim, ("embed", "heads", "kv"))
237+
self.v_projection_layer = self.create_projection_layer(feature_dim, ("embed", "heads", "kv"))
238+
self.out_projection_layer = self.create_projection_layer(feature_dim, ("heads", "kv", "embed"))
239+
240+
def create_projection_layer(self, input_shape: Array, kernel_axes: str):
241+
return DenseGeneral(
242+
in_features_shape=input_shape,
236243
out_features_shape=(3, self.num_heads, self.head_dim),
237244
axis=-1,
238245
kernel_init=self.kernel_init,
239-
kernel_axes=("embed", "qkv", "heads", "kv"),
246+
kernel_axes=kernel_axes,
240247
dtype=self.dtype,
241248
weight_dtype=self.weight_dtype,
242-
name=proj_name,
243249
quant=self.quant,
244250
use_bias=self.use_bias,
245251
matmul_precision=self.config.matmul_precision,
246-
)(inputs)
252+
rngs=self.rngs,
253+
)
254+
255+
def qkv_projection(self, projection_layer: Any, inputs: Array, proj_name: str):
256+
"""Fused QKV projection"""
257+
print(f'qkv_projection in_features_shape: {inputs.shape}')
258+
259+
qkv_proj = self.qkv_projection_layer(inputs)
260+
247261
qkv_proj = checkpoint_name(qkv_proj, "qkv_proj")
248262
query, key, value = qkv_proj[:, :, 0, ...], qkv_proj[:, :, 1, ...], qkv_proj[:, :, 2, ...]
249263
return query, key, value
250264

251-
def projection(self, inputs: Array, proj_name: str) -> Array:
265+
def projection(self, projection_layer: Any, inputs: Array, proj_name: str) -> Array:
252266
"""individual projection for one of q, k and v."""
253-
proj = dense_general(
254-
inputs_shape=inputs.shape,
255-
out_features_shape=(self.num_heads, self.head_dim),
256-
axis=-1,
257-
kernel_init=self.kernel_init,
258-
kernel_axes=("embed", "heads", "kv"),
259-
dtype=self.dtype,
260-
weight_dtype=self.weight_dtype,
261-
name=proj_name,
262-
quant=self.quant,
263-
use_bias=self.use_bias,
264-
matmul_precision=self.config.matmul_precision,
265-
)(inputs)
267+
proj = projection_layer(inputs)
266268
return proj
267269

268-
def out_projection(self, output_dim: int, out: Array) -> Array:
270+
def out_projection(self, projection_layer: Any, output_dim: int, out: Array) -> Array:
269271
"""output projection"""
270-
out_proj = dense_general(
271-
inputs_shape=out.shape,
272-
out_features_shape=output_dim,
273-
axis=(-2, -1),
274-
kernel_init=self.kernel_init,
275-
kernel_axes=("heads", "kv", "embed"),
276-
dtype=self.dtype,
277-
weight_dtype=self.weight_dtype,
278-
name="out",
279-
quant=self.quant,
280-
use_bias=self.use_bias,
281-
matmul_precision=self.config.matmul_precision,
282-
)(out)
272+
273+
out_proj = projection_layer(out)
283274
return out_proj
284275

285276
def __call__(
@@ -292,11 +283,12 @@ def __call__(
292283
):
293284
inputs_q = nn.with_logical_constraint(inputs_q, self.input_axis_names)
294285
if self.fused_qkv:
295-
query, key, value = self.qkv_projection(inputs_q, proj_name="qkv_proj")
286+
print(f'inputs_q size: {inputs_q.shape}')
287+
query, key, value = self.qkv_projection(self.qkv_projection_layer, inputs_q, proj_name="qkv_proj")
296288
else:
297-
query = self.projection(inputs_q, proj_name="query")
298-
key = self.projection(inputs_q, proj_name="key")
299-
value = self.projection(inputs_q, proj_name="value")
289+
query = self.projection(self.q_projection_layer, inputs_q, proj_name="query")
290+
key = self.projection(self.k_projection_layer, inputs_q, proj_name="key")
291+
value = self.projection(self.v_projection_layer, inputs_q, proj_name="value")
300292

301293
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
302294
query /= depth_scaling
@@ -309,7 +301,7 @@ def __call__(
309301
value = nn.with_logical_constraint(value, self.value_axis_names)
310302
value = checkpoint_name(value, "value_proj")
311303

312-
attention_op = attention_op_as_linen(
304+
attention_op = AttentionOp(
313305
config=self.config,
314306
mesh=self.mesh,
315307
attention_kernel=self.attention_kernel,
@@ -328,7 +320,7 @@ def __call__(
328320
out = nn.with_logical_constraint(out, self.out_axis_names)
329321

330322
# apply output projection, output dim is set to the input dim.
331-
out = self.out_projection(inputs_q.shape[-1], out)
323+
out = self.out_projection(self.out_projection_layer, inputs_q.shape[-1], out)
332324
out = checkpoint_name(out, "out_proj")
333325
return out
334326

@@ -340,16 +332,14 @@ def __call__(
340332

341333
class Gpt3DecoderLayer(nnx.Module):
342334
"""Transformer decoder layer that attends to the encoder."""
335+
343336
def __init__(
344-
self,
345-
config: models.Config,
346-
mesh: Mesh,
347-
quant: Optional[Quant] = None,
348-
**kwargs: Any
337+
self, config: models.Config, mesh: Mesh, quant: Optional[Quant] = None, rngs: Optional[nnx.Rngs] = None, **kwargs: Any
349338
):
350339
self.config = config
351340
self.mesh = mesh
352341
self.quant = quant
342+
self.rngs = rngs if rngs is not None else kwargs.get("rngs", nnx.Rngs(0))
353343

354344
def __call__(
355345
self,
@@ -364,14 +354,14 @@ def __call__(
364354
):
365355
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
366356
inputs = checkpoint_name(inputs, "decoder_layer_input")
367-
lnx = gpt3_layer_norm(
357+
lnx = Gpt3LayerNorm(
368358
num_features=inputs.shape[-1],
369359
dtype=self.config.dtype,
370-
name="pre_self_attention_norm",
371360
kernel_axes=("norm",),
372361
epsilon=self.config.normalization_layer_epsilon,
373362
reductions_in_fp32=False,
374363
use_bias=True,
364+
rngs=self.rngs,
375365
)(inputs)
376366

377367
lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
@@ -380,10 +370,15 @@ def __call__(
380370
assert (
381371
self.config.num_query_heads == self.config.num_kv_heads
382372
), f"{self.config.num_query_heads=} should be the same as {self.config.num_kv_heads=} in gpt3"
373+
374+
# Todo: make the axis names parameters
375+
lnx_sharding = nn.with_logical_constraint(lnx, (BATCH, LENGTH, EMBED))
376+
383377
attention_layer = Gpt3MultiHeadAttention(
384378
config=self.config,
385379
num_heads=self.config.num_query_heads,
386380
dtype=self.config.dtype,
381+
feature_dim=lnx_sharding.shape,
387382
weight_dtype=self.config.weight_dtype,
388383
head_dim=self.config.head_dim,
389384
max_target_length=self.config.max_target_length,
@@ -410,18 +405,18 @@ def __call__(
410405
attention_lnx += inputs
411406

412407
# MLP block.
413-
mlp_lnx = mlp_block(
408+
mlp_lnx = MlpBlock(
414409
in_features=attention_lnx.shape[-1],
415410
intermediate_dim=self.config.mlp_dim,
416411
activations=self.config.mlp_activations,
417412
intermediate_dropout_rate=self.config.dropout_rate,
418413
dtype=self.config.dtype,
419414
weight_dtype=self.config.weight_dtype,
420-
name="mlp",
421415
use_bias=True,
422416
use_pre_norm=True,
423417
config=self.config,
424418
quant=self.quant,
419+
rngs=self.rngs,
425420
)(attention_lnx, deterministic=deterministic)
426421
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
427422

@@ -448,9 +443,10 @@ def __call__(
448443
else:
449444
return layer_output
450445

446+
451447
def gpt3_decoder_layer_class() -> nn.Module:
452448
"""Creates a Gemma3DecoderLayer Linen module."""
453449
return nnx_wrappers.to_linen_class(
454450
Gpt3DecoderLayer,
455451
metadata_fn=initializers.variable_to_logically_partitioned,
456-
)
452+
)

0 commit comments

Comments
 (0)