|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +from torch import nn |
| 7 | + |
| 8 | +from torchtune.models.t5._encoder import ( |
| 9 | + T5Encoder, |
| 10 | + T5EncoderLayer, |
| 11 | + T5EncoderSelfAttention, |
| 12 | +) |
| 13 | +from torchtune.modules.feed_forward import FeedForward |
| 14 | +from torchtune.modules.rms_norm import RMSNorm |
| 15 | + |
| 16 | + |
| 17 | +def t5_encoder( |
| 18 | + embed_dim: int, |
| 19 | + mlp_dim: int, |
| 20 | + num_heads: int, |
| 21 | + head_dim: int, |
| 22 | + num_layers: int, |
| 23 | + rel_pos_num_buckets: int, |
| 24 | + rel_pos_max_dist: int, |
| 25 | + vocab_size: int, |
| 26 | + norm_eps: float, |
| 27 | + max_seq_len: int, |
| 28 | +): |
| 29 | + """ |
| 30 | + Builder for the T5 encoder. |
| 31 | +
|
| 32 | + T5 paper: https://arxiv.org/abs/1910.10683 |
| 33 | +
|
| 34 | + Args: |
| 35 | + embed_dim (int): The model dimension. |
| 36 | + mlp_dim (int): The inner dimension of the feed forward layers. |
| 37 | + num_heads (int): The number of attention heads. |
| 38 | + head_dim (int): The dimension of the attention heads (should equal `embed_dim // num_heads`) |
| 39 | + num_layers (int): Number of encoder layers. |
| 40 | + rel_pos_num_buckets (int): Number of discrete buckets to divide the relative positions into. |
| 41 | + See: :class:`~torchtune.models.t5._encoder.T5EncoderRelativePositionBias` |
| 42 | + rel_pos_max_dist (int): Maximum distance for relative positions. |
| 43 | + Distances beyond this are grouped into the last bucket. |
| 44 | + See: :class:`~torchtune.models.t5._encoder.T5EncoderRelativePositionBias` |
| 45 | + vocab_size (int): Vocab size of the model's tokenizer. |
| 46 | + norm_eps (float): Small value added to denominator for numerical stability. |
| 47 | + max_seq_len (int): The maximum sequence length (context length) of the model. |
| 48 | +
|
| 49 | + Returns: |
| 50 | + T5Encoder |
| 51 | + """ |
| 52 | + token_embedding = nn.Embedding(vocab_size, embed_dim) |
| 53 | + |
| 54 | + attn = T5EncoderSelfAttention( |
| 55 | + embed_dim=embed_dim, |
| 56 | + num_heads=num_heads, |
| 57 | + head_dim=head_dim, |
| 58 | + q_proj=nn.Linear(embed_dim, embed_dim, bias=False), |
| 59 | + k_proj=nn.Linear(embed_dim, embed_dim, bias=False), |
| 60 | + v_proj=nn.Linear(embed_dim, embed_dim, bias=False), |
| 61 | + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), |
| 62 | + ) |
| 63 | + |
| 64 | + mlp = FeedForward( |
| 65 | + gate_proj=nn.Linear(embed_dim, mlp_dim, bias=False), |
| 66 | + down_proj=nn.Linear(mlp_dim, embed_dim, bias=False), |
| 67 | + up_proj=nn.Linear(embed_dim, mlp_dim, bias=False), |
| 68 | + activation=nn.GELU(), |
| 69 | + ) |
| 70 | + |
| 71 | + layer = T5EncoderLayer( |
| 72 | + attn=attn, |
| 73 | + mlp=mlp, |
| 74 | + sa_norm=RMSNorm(embed_dim, eps=norm_eps), |
| 75 | + mlp_norm=RMSNorm(embed_dim, eps=norm_eps), |
| 76 | + ) |
| 77 | + |
| 78 | + final_norm = RMSNorm(embed_dim, eps=norm_eps) |
| 79 | + |
| 80 | + return T5Encoder( |
| 81 | + token_embedding=token_embedding, |
| 82 | + layer=layer, |
| 83 | + final_norm=final_norm, |
| 84 | + num_layers=num_layers, |
| 85 | + num_heads=num_heads, |
| 86 | + rel_pos_num_buckets=rel_pos_num_buckets, |
| 87 | + rel_pos_max_dist=rel_pos_max_dist, |
| 88 | + max_seq_len=max_seq_len, |
| 89 | + ) |
0 commit comments