forked from modular/modular
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathembedding.py
More file actions
267 lines (222 loc) · 8.75 KB
/
embedding.py
File metadata and controls
267 lines (222 loc) · 8.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Optional
from max.dtype import DType
from max.graph import (
BufferValue,
DeviceRef,
TensorValue,
TensorValueLike,
Weight,
ops,
)
from max.graph.quantization import QuantizationEncoding
from max.nn.comm.allreduce import Allreduce
from .layer import Layer, Module
@dataclass
class EmbeddingV1(Layer):
"""A lookup table for embedding integer indices into dense vectors.
Deprecated: Use `Embedding` instead.
"""
weights: TensorValueLike
device: DeviceRef
def __call__(self, indices: TensorValueLike) -> TensorValue:
self.weights = TensorValue(self.weights).to(self.device)
indices = TensorValue(indices).to(self.device)
result = ops.gather(self.weights, indices, axis=0)
if (
isinstance(self.weights, Weight)
and self.weights.quantization_encoding is not None
):
result = ops.dequantize(self.weights.quantization_encoding, result)
return result
class Embedding(Module):
"""
A lookup table for embedding integer indices into dense vectors.
This layer maps each integer index to a dense vector of fixed size.
Embedding weights are stored on the CPU but are moved to the specified
device during the model init phase.
Example:
.. code-block:: python
embedding_layer = Embedding(
vocab_size=1000,
hidden_dim=256,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="embeddings",
)
# Token indices of shape: [batch, ..., num_indices].
token_indices: TensorValueLike
embeddings = embedding_layer(token_indices)
"""
weight: Weight
"""The embedding weight matrix stored on the CPU.
Model init moves weights to the device specified in :obj:`device`."""
device: DeviceRef
"""The device on which embedding lookup is performed."""
def __init__(
self,
vocab_size: int,
hidden_dim: int,
dtype: DType,
device: DeviceRef,
quantization_encoding: Optional[QuantizationEncoding] = None,
name: Optional[str] = None,
) -> None:
"""Initializes the embedding layer with the given arguments.
Args:
vocab_size: The number of unique items in the vocabulary.
Indices must be in the range ``[0, vocab_size)``.
hidden_dim: The dimensionality of each embedding vector.
dtype: The data type of the embedding weights.
device: The device where embedding lookups are executed.
Model init transfers the initially CPU-resident weights to this
device.
name: The name identifier for the embedding weight matrix.
"""
super().__init__()
self.device = device
self.weight = Weight(
name or "weight",
dtype,
shape=(vocab_size, hidden_dim),
device=device,
quantization_encoding=quantization_encoding,
)
def __call__(self, indices: TensorValueLike) -> TensorValue:
"""Embeds the input indices by looking up corresponding vectors.
Args:
indices: A tensor of integer indices to look up.
Each index must be in the range ``[0, vocab_size)``.
Returns:
A tensor containing the embeddings corresponding to the input
indices.
The result resides on the device specified in :obj:`device`.
"""
result = ops.gather(
TensorValue(self.weight),
indices,
axis=0,
)
if self.weight.quantization_encoding is not None:
result = ops.dequantize(self.weight.quantization_encoding, result)
return result
class VocabParallelEmbedding(Module):
"""
A lookup table for embedding integer indices into dense vectors.
This layer works like `nn.Embedding` except the embedding table is sharded
on the vocabulary dimension across all devices.
Example:
.. code-block:: python
embedding_layer = VocabParallelEmbedding(
vocab_size=1000,
hidden_dim=256,
dtype=DType.float32,
device=[DeviceRef.GPU(0), DeviceRef.GPU(1)],
name="embeddings",
)
# Token indices of shape: [batch, ..., num_indices].
token_indices: TensorValueLike
embeddings = embedding_layer(token_indices)
"""
def __init__(
self,
vocab_size: int,
hidden_dim: int,
dtype: DType,
devices: list[DeviceRef],
quantization_encoding: Optional[QuantizationEncoding] = None,
name: Optional[str] = None,
):
"""
Args:
vocab_size: The number of unique items in the vocabulary.
Indices must be in the range ``[0, vocab_size)``.
hidden_dim: The dimensionality of each embedding vector.
dtype: The data type of the embedding weights.
devices: The devices where embedding lookups are executed.
Model init transfers the initially CPU-resident weights to this
device.
name: The name identifier for the embedding weight matrix.
"""
super().__init__()
self.vocab_size = vocab_size
self.devices = devices
self.num_devices = len(self.devices)
self.shard_size = math.ceil(self.vocab_size / self.num_devices)
# The weight is loaded in with a single op, then copied to each device
# in __call__.
self.weight = Weight(
name or "weight",
dtype,
shape=(vocab_size, hidden_dim),
device=DeviceRef.CPU(),
quantization_encoding=quantization_encoding,
)
self.allreduce = Allreduce(num_accelerators=self.num_devices)
def __call__(
self, indices: TensorValueLike, signal_buffers: list[BufferValue]
) -> list[TensorValue]:
"""Embeds the input indices by looking up corresponding vectors.
Args:
indices: A tensor of integer indices to look up.
Each index must be in the range ``[0, vocab_size)``.
signal_buffers: Buffers for peer-to-peer communication in allreduce.
Returns:
A tensor containing the embeddings corresponding to the input
indices.
The result resides on the device specified in :obj:`device`.
"""
# Shard the weight onto each device.
input = TensorValue(indices)
outputs = [
self._per_device_call(input, n) for n in range(self.num_devices)
]
return self.allreduce(outputs, signal_buffers)
def _per_device_call(
self, indices: TensorValue, device_idx: int
) -> TensorValue:
"""Computes the embeddings for the input indices, for a single device."""
# Copy a shard from the embedding weights onto the device.
device = self.devices[device_idx]
vocab_start_index = self.shard_size * device_idx
vocab_end_index = min(
self.shard_size * (device_idx + 1), self.vocab_size
)
embedding_shard = self.weight[vocab_start_index:vocab_end_index].to(
device
)
indices = indices.to(device)
# Process indices so that all tokens are between 0 and the shard size.
# Set up mask so that the 1=tokens within range, 0=tokens out of range.
input_mask = ops.logical_and(
indices >= vocab_start_index, indices < vocab_end_index
)
# Tokens that are out of this range are masked out.
indices -= vocab_start_index
# Apply mask to avoid searching for out-of-bound tokens
indices *= input_mask
result = ops.gather(
embedding_shard,
indices,
axis=0,
)
if self.weight.quantization_encoding is not None:
result = ops.dequantize(self.weight.quantization_encoding, result)
result *= ops.cast(
ops.unsqueeze(input_mask, 1), result.dtype
) # Apply input mask again
return result