-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlm.py
More file actions
251 lines (197 loc) · 8.31 KB
/
lm.py
File metadata and controls
251 lines (197 loc) · 8.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
"""Language model base class with MLM and CLM subclasses."""
import torch
import torch.nn as nn
import torch.nn.functional as F
class LM(nn.Module):
"""Base language model class.
Architecture: input_ids → embedder → encoder → layer_norm → decoder → logits
Args:
embedder: Token embedding layer
encoder: Encoder module (e.g., ByteNet or Transformer)
layer_norm: Layer normalization module
decoder: Output projection layer
"""
def __init__(
self,
embedder: nn.Module,
encoder: nn.Module,
layer_norm: nn.Module,
decoder: nn.Module,
):
super().__init__()
self.embedder = embedder
self.encoder = encoder
self.layer_norm = layer_norm
self.decoder = decoder
def get_logits(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Compute logits from input token IDs.
Args:
input_ids: Input token IDs of shape (batch, seq_len), int8 or long
Returns:
Logits of shape (batch, seq_len, vocab_size)
"""
# Convert int8 to long for embedding lookup
input_ids = input_ids.long()
# Embed
x = self.embedder(input_ids) # (batch, seq_len, hidden_dim)
# Encode
x = self.encoder(x) # (batch, seq_len, hidden_dim)
# Layer norm
x = self.layer_norm(x) # (batch, seq_len, hidden_dim)
# Decode to vocabulary
logits = self.decoder(x) # (batch, seq_len, vocab_size)
return logits
def prepare_for_loss(
self,
logits: torch.Tensor,
labels: torch.Tensor,
soft_masked: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Prepare logits, labels, and soft_masked for loss computation.
Override in subclasses to implement MLM vs CLM-specific slicing/filtering.
Args:
logits: Logits of shape (batch, seq_len, vocab_size)
labels: Target labels of shape (batch, seq_len)
soft_masked: Boolean mask of shape (batch, seq_len)
Returns:
Tuple of (logits, labels, soft_masked) ready for loss computation
"""
raise NotImplementedError("Subclasses must implement prepare_for_loss")
def compute_loss(
self,
logits: torch.Tensor,
labels: torch.Tensor,
soft_masked: torch.Tensor,
soft_masked_weight: float,
) -> dict[str, torch.Tensor]:
"""Compute weighted cross-entropy loss with three variants.
Computes three loss values:
1. loss_full: All tokens weighted equally (baseline)
2. loss_non_soft_masked: Only non-soft-masked tokens
3. loss: Training loss with soft_masked_weight applied
Args:
logits: Logits (1D or 2D)
labels: Target labels (1D)
soft_masked: Boolean mask (1D), True for soft-masked positions
soft_masked_weight: Weight for soft-masked positions in training loss
Returns:
Dictionary with keys: loss, loss_full, loss_non_soft_masked
"""
# Single cross-entropy computation (efficient)
loss_per_token = F.cross_entropy(logits, labels, reduction="none")
# Create three weight masks
weight_full = torch.ones_like(loss_per_token)
weight_non_soft_masked = (~soft_masked).float()
weight_training = torch.where(soft_masked, soft_masked_weight, 1.0)
# Compute normalized losses
def normalize_and_sum(loss: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
weight_sum = weight.sum()
if weight_sum > 0:
return (loss * weight / weight_sum).sum()
else:
# Handle edge case: no tokens with weight
return torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
return {
"loss": normalize_and_sum(loss_per_token, weight_training),
"loss_full": normalize_and_sum(loss_per_token, weight_full),
"loss_non_soft_masked": normalize_and_sum(loss_per_token, weight_non_soft_masked),
}
def forward(
self,
input_ids: torch.Tensor,
labels: torch.Tensor,
soft_masked: torch.Tensor,
soft_masked_weight: float,
) -> dict[str, torch.Tensor]:
"""Forward pass with loss calculation.
Args:
input_ids: Input token IDs of shape (batch, seq_len), int8 or long
labels: True token IDs of shape (batch, seq_len)
soft_masked: Boolean mask of shape (batch, seq_len), True for soft-masked positions
soft_masked_weight: Weight for soft-masked positions in training loss
Returns:
Dictionary with loss components (loss, loss_full, loss_non_soft_masked)
"""
logits = self.get_logits(input_ids)
logits, labels, soft_masked = self.prepare_for_loss(logits, labels, soft_masked)
return self.compute_loss(logits, labels, soft_masked, soft_masked_weight)
class GeneralMaskedLM(LM):
"""Base class for bidirectional masked language models (MLM, DLM).
Subclasses differ only in their masking strategy (applied in data module).
Both filter to masked positions (labels != -100) for loss computation.
"""
def prepare_for_loss(
self,
logits: torch.Tensor,
labels: torch.Tensor,
soft_masked: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Filter to masked positions only (labels != -100).
Shared by both MLM and DLM since both use masking-based training.
Args:
logits: Logits of shape (batch, seq_len, vocab_size)
labels: Target labels of shape (batch, seq_len), -100 for ignored positions
soft_masked: Boolean mask of shape (batch, seq_len)
Returns:
Filtered (logits, labels, soft_masked) for masked positions only
"""
# Reshape to 1D
logits = logits.view(-1, logits.size(-1))
labels = labels.view(-1).long()
soft_masked = soft_masked.view(-1)
# Filter to masked positions (labels != -100)
mask = labels != -100
logits = logits[mask]
labels = labels[mask]
soft_masked = soft_masked[mask]
return logits, labels, soft_masked
class MLM(GeneralMaskedLM):
"""Masked language model with BERT-style masking.
Uses fixed 15% masking with token replacement (80% [MASK], 10% random, 10% unchanged).
Masking logic is implemented in MLMDataModule.apply_labels().
"""
pass # All logic inherited from GeneralMaskedLM
class DLM(GeneralMaskedLM):
"""Diffusion language model with variable masking ratio.
Uses per-sequence random masking ratio r ~ Uniform(0, 1).
No token replacement (100% [MASK]).
Masking logic is implemented in DLMDataModule.apply_labels().
"""
pass # All logic inherited from GeneralMaskedLM
class CLM(LM):
"""Causal language model (autoregressive).
Predicts next token at all positions using causal attention.
"""
def __init__(
self,
embedder: nn.Module,
encoder: nn.Module,
layer_norm: nn.Module,
decoder: nn.Module,
):
super().__init__(embedder, encoder, layer_norm, decoder)
# Validate encoder supports causal attention
if not getattr(encoder, "is_causal", True):
raise ValueError(
f"CLM requires causal encoder (is_causal=True). "
f"Got {type(encoder).__name__} with is_causal={encoder.is_causal}"
)
def prepare_for_loss(
self,
logits: torch.Tensor,
labels: torch.Tensor,
soft_masked: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Slice for next-token prediction.
Args:
logits: Logits of shape (batch, seq_len, vocab_size)
labels: Target labels of shape (batch, seq_len) (same as input_ids)
soft_masked: Boolean mask of shape (batch, seq_len)
Returns:
Sliced (logits, labels, soft_masked) for next-token prediction
"""
# Slice: logits[:, :-1] predicts labels[:, 1:]
logits = logits[:, :-1].reshape(-1, logits.size(-1))
labels = labels[:, 1:].reshape(-1).long()
soft_masked = soft_masked[:, 1:].reshape(-1)
return logits, labels, soft_masked