Skip to content

Commit b9270fc

Browse files
Better logging of loss for soft-masked vs. not
1 parent 8d37bdd commit b9270fc

16 files changed

+247
-157
lines changed

configs/data/default.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ num_workers: 4
1111
pin_memory: true
1212

1313
# Soft masking for genomic soft-masked regions (lowercase nucleotides)
14-
soft_masked_loss_weight_train: 0.01 # Low weight for soft-masked regions during training
15-
soft_masked_loss_weight_eval: 0.0 # No weight for soft-masked regions during eval
14+
soft_masked_weight: 0.01 # Loss weight for soft-masked regions in main training loss
1615

1716
# Data augmentation
1817
data_augmentation: true # Reverse complement augmentation (training only)

configs/experiment/clm_transformer_small.yaml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ logger:
1111
tags: ["debug", "clm", "transformer", "small"]
1212

1313
trainer:
14-
max_steps: 100
15-
log_every_n_steps: 10
16-
val_check_interval: 10
17-
limit_val_batches: 2
14+
max_steps: 300
15+
log_every_n_steps: 100
16+
val_check_interval: 100
17+
limit_val_batches: 10
1818
check_val_every_n_epoch: null
1919

2020
model:
@@ -33,5 +33,6 @@ data:
3333
_target_: glm_experiments.data.lm_datamodule.CLMDataModule
3434
batch_size: 8
3535
per_device_batch_size: 8
36+
soft_masked_weight: 0.5
3637

3738
compile: false

configs/model/bert_bytenet_small.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
_target_: glm_experiments.models.lm_lit_module.MLMLitModule
22

3+
soft_masked_weight: ${data.soft_masked_weight}
4+
35
net:
46
_target_: glm_experiments.models.components.lm.MLM
57
embedder:

configs/model/clm_transformer_base.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
_target_: glm_experiments.models.lm_lit_module.CLMLitModule
22

3+
soft_masked_weight: ${data.soft_masked_weight}
4+
35
net:
46
_target_: glm_experiments.models.components.lm.CLM
57
embedder:

configs/model/clm_transformer_small.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
_target_: glm_experiments.models.lm_lit_module.CLMLitModule
22

3+
soft_masked_weight: ${data.soft_masked_weight}
4+
35
net:
46
_target_: glm_experiments.models.components.lm.CLM
57
embedder:

configs/model/gpn_animal_promoter.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
_target_: glm_experiments.models.lm_lit_module.MLMLitModule
22

3+
soft_masked_weight: ${data.soft_masked_weight}
4+
35
net:
46
_target_: glm_experiments.models.components.lm.MLM
57
embedder:

configs/model/mlm_transformer_base.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
_target_: glm_experiments.models.lm_lit_module.MLMLitModule
22

3+
soft_masked_weight: ${data.soft_masked_weight}
4+
35
net:
46
_target_: glm_experiments.models.components.lm.MLM
57
embedder:

configs/model/mlm_transformer_small.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
_target_: glm_experiments.models.lm_lit_module.MLMLitModule
22

3+
soft_masked_weight: ${data.soft_masked_weight}
4+
35
net:
46
_target_: glm_experiments.models.components.lm.MLM
57
embedder:

glm_experiments/data/lm_datamodule.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,7 @@ class LMDataModule(LightningDataModule):
125125
per_device_batch_size: Batch size per device (what fits in GPU memory)
126126
num_workers: Number of workers for data loading
127127
pin_memory: Whether to pin memory for faster GPU transfer
128-
soft_masked_loss_weight_train: Loss weight for soft-masked regions during training
129-
soft_masked_loss_weight_eval: Loss weight for soft-masked regions during evaluation
128+
soft_masked_weight: Loss weight for soft-masked regions (not used in data module)
130129
data_augmentation: Whether to apply reverse complement augmentation (training only)
131130
max_val_lm_samples: Maximum number of samples for LM validation (None = unlimited)
132131
seed: Random seed for reproducibility
@@ -140,8 +139,7 @@ def __init__(
140139
per_device_batch_size: int = 256, # Batch size that fits in GPU memory
141140
num_workers: int = 8,
142141
pin_memory: bool = True,
143-
soft_masked_loss_weight_train: float = 0.01,
144-
soft_masked_loss_weight_eval: float = 0.0,
142+
soft_masked_weight: float = 0.01,
145143
data_augmentation: bool = True,
146144
max_val_lm_samples: int | None = None,
147145
seed: int = 42,
@@ -256,16 +254,15 @@ def tokenize(seq: list[str]) -> list[list[int]]:
256254
return_special_tokens_mask=False,
257255
)["input_ids"]
258256

259-
def transform_batch(examples: dict, soft_masked_weight: float, data_aug: bool) -> dict:
257+
def transform_batch(examples: dict, data_aug: bool) -> dict:
260258
"""Transform a batch of examples.
261259
262260
Args:
263261
examples: Batch of examples with 'seq' field
264-
soft_masked_weight: Loss weight for lowercase nucleotides
265262
data_aug: Whether to apply reverse complement augmentation
266263
267264
Returns:
268-
Dictionary with input_ids, labels, and loss_weight (all tensors)
265+
Dictionary with input_ids, labels, and soft_masked (all tensors)
269266
"""
270267
seq = examples["seq"]
271268

@@ -276,19 +273,19 @@ def transform_batch(examples: dict, soft_masked_weight: float, data_aug: bool) -
276273
# Tokenize
277274
input_ids = torch.tensor(tokenize(seq), dtype=torch.int8)
278275

279-
# Create loss weights (lower weight for soft-masked lowercase regions)
280-
loss_weight = torch.ones(input_ids.shape, dtype=torch.float16)
276+
# Create soft_masked boolean tensor (True for lowercase nucleotides)
277+
soft_masked = torch.zeros(input_ids.shape, dtype=torch.bool)
281278
for i, s in enumerate(seq):
282279
lowercase_mask = np.array([c.islower() for c in s])
283-
loss_weight[i][lowercase_mask] = soft_masked_weight
280+
soft_masked[i][lowercase_mask] = True
284281

285282
# Apply objective-specific label creation (MLM vs CLM)
286283
input_ids, labels = self.apply_labels(input_ids)
287284

288285
return {
289286
"input_ids": input_ids,
290287
"labels": labels,
291-
"loss_weight": loss_weight,
288+
"soft_masked": soft_masked,
292289
}
293290

294291
# Load raw dataset with streaming
@@ -301,7 +298,6 @@ def transform_batch(examples: dict, soft_masked_weight: float, data_aug: bool) -
301298
train_dataset = train_dataset.map(
302299
lambda ex: transform_batch(
303300
ex,
304-
soft_masked_weight=self.hparams.soft_masked_loss_weight_train,
305301
data_aug=self.hparams.data_augmentation,
306302
),
307303
batched=True,
@@ -322,7 +318,6 @@ def transform_batch(examples: dict, soft_masked_weight: float, data_aug: bool) -
322318
val_dataset = val_dataset.map(
323319
lambda ex: transform_batch(
324320
ex,
325-
soft_masked_weight=self.hparams.soft_masked_loss_weight_eval,
326321
data_aug=False,
327322
),
328323
batched=True,

glm_experiments/models/components/lm.py

Lines changed: 56 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -60,64 +60,89 @@ def prepare_for_loss(
6060
self,
6161
logits: torch.Tensor,
6262
labels: torch.Tensor,
63-
loss_weight: torch.Tensor,
63+
soft_masked: torch.Tensor,
6464
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
65-
"""Prepare logits, labels, and weights for loss computation.
65+
"""Prepare logits, labels, and soft_masked for loss computation.
6666
6767
Override in subclasses to implement MLM vs CLM-specific slicing/filtering.
6868
6969
Args:
7070
logits: Logits of shape (batch, seq_len, vocab_size)
7171
labels: Target labels of shape (batch, seq_len)
72-
loss_weight: Loss weights of shape (batch, seq_len)
72+
soft_masked: Boolean mask of shape (batch, seq_len)
7373
7474
Returns:
75-
Tuple of (logits, labels, loss_weight) ready for loss computation
75+
Tuple of (logits, labels, soft_masked) ready for loss computation
7676
"""
7777
raise NotImplementedError("Subclasses must implement prepare_for_loss")
7878

7979
def compute_loss(
8080
self,
8181
logits: torch.Tensor,
8282
labels: torch.Tensor,
83-
loss_weight: torch.Tensor,
84-
) -> torch.Tensor:
85-
"""Compute weighted cross-entropy loss.
83+
soft_masked: torch.Tensor,
84+
soft_masked_weight: float,
85+
) -> dict[str, torch.Tensor]:
86+
"""Compute weighted cross-entropy loss with three variants.
8687
87-
Shared loss computation logic for MLM and CLM.
88+
Computes three loss values:
89+
1. loss_full: All tokens weighted equally (baseline)
90+
2. loss_non_soft_masked: Only non-soft-masked tokens
91+
3. loss: Training loss with soft_masked_weight applied
8892
8993
Args:
9094
logits: Logits (1D or 2D)
9195
labels: Target labels (1D)
92-
loss_weight: Loss weights (1D)
96+
soft_masked: Boolean mask (1D), True for soft-masked positions
97+
soft_masked_weight: Weight for soft-masked positions in training loss
9398
9499
Returns:
95-
Scalar loss value
100+
Dictionary with keys: loss, loss_full, loss_non_soft_masked
96101
"""
97-
loss = F.cross_entropy(logits, labels, reduction="none")
98-
loss = (loss * loss_weight / loss_weight.sum()).sum()
99-
return loss
102+
# Single cross-entropy computation (efficient)
103+
loss_per_token = F.cross_entropy(logits, labels, reduction="none")
104+
105+
# Create three weight masks
106+
weight_full = torch.ones_like(loss_per_token)
107+
weight_non_soft_masked = (~soft_masked).float()
108+
weight_training = torch.where(soft_masked, soft_masked_weight, 1.0)
109+
110+
# Compute normalized losses
111+
def normalize_and_sum(loss: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
112+
weight_sum = weight.sum()
113+
if weight_sum > 0:
114+
return (loss * weight / weight_sum).sum()
115+
else:
116+
# Handle edge case: no tokens with weight
117+
return torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
118+
119+
return {
120+
"loss": normalize_and_sum(loss_per_token, weight_training),
121+
"loss_full": normalize_and_sum(loss_per_token, weight_full),
122+
"loss_non_soft_masked": normalize_and_sum(loss_per_token, weight_non_soft_masked),
123+
}
100124

101125
def forward(
102126
self,
103127
input_ids: torch.Tensor,
104128
labels: torch.Tensor,
105-
loss_weight: torch.Tensor,
106-
) -> torch.Tensor:
129+
soft_masked: torch.Tensor,
130+
soft_masked_weight: float,
131+
) -> dict[str, torch.Tensor]:
107132
"""Forward pass with loss calculation.
108133
109134
Args:
110135
input_ids: Input token IDs of shape (batch, seq_len), int8 or long
111136
labels: True token IDs of shape (batch, seq_len)
112-
loss_weight: Per-token loss weights of shape (batch, seq_len)
137+
soft_masked: Boolean mask of shape (batch, seq_len), True for soft-masked positions
138+
soft_masked_weight: Weight for soft-masked positions in training loss
113139
114140
Returns:
115-
Weighted cross-entropy loss (scalar)
141+
Dictionary with loss components (loss, loss_full, loss_non_soft_masked)
116142
"""
117143
logits = self.get_logits(input_ids)
118-
logits, labels, loss_weight = self.prepare_for_loss(logits, labels, loss_weight)
119-
loss = self.compute_loss(logits, labels, loss_weight)
120-
return loss
144+
logits, labels, soft_masked = self.prepare_for_loss(logits, labels, soft_masked)
145+
return self.compute_loss(logits, labels, soft_masked, soft_masked_weight)
121146

122147

123148
class MLM(LM):
@@ -130,30 +155,30 @@ def prepare_for_loss(
130155
self,
131156
logits: torch.Tensor,
132157
labels: torch.Tensor,
133-
loss_weight: torch.Tensor,
158+
soft_masked: torch.Tensor,
134159
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
135160
"""Filter to masked positions only.
136161
137162
Args:
138163
logits: Logits of shape (batch, seq_len, vocab_size)
139164
labels: Target labels of shape (batch, seq_len), -100 for ignored positions
140-
loss_weight: Loss weights of shape (batch, seq_len)
165+
soft_masked: Boolean mask of shape (batch, seq_len)
141166
142167
Returns:
143-
Filtered (logits, labels, loss_weight) for masked positions only
168+
Filtered (logits, labels, soft_masked) for masked positions only
144169
"""
145170
# Reshape to 1D
146171
logits = logits.view(-1, logits.size(-1))
147172
labels = labels.view(-1).long()
148-
loss_weight = loss_weight.view(-1)
173+
soft_masked = soft_masked.view(-1)
149174

150175
# Filter to masked positions (labels != -100)
151176
mask = labels != -100
152177
logits = logits[mask]
153178
labels = labels[mask]
154-
loss_weight = loss_weight[mask]
179+
soft_masked = soft_masked[mask]
155180

156-
return logits, labels, loss_weight
181+
return logits, labels, soft_masked
157182

158183

159184
class CLM(LM):
@@ -182,21 +207,21 @@ def prepare_for_loss(
182207
self,
183208
logits: torch.Tensor,
184209
labels: torch.Tensor,
185-
loss_weight: torch.Tensor,
210+
soft_masked: torch.Tensor,
186211
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
187212
"""Slice for next-token prediction.
188213
189214
Args:
190215
logits: Logits of shape (batch, seq_len, vocab_size)
191216
labels: Target labels of shape (batch, seq_len) (same as input_ids)
192-
loss_weight: Loss weights of shape (batch, seq_len)
217+
soft_masked: Boolean mask of shape (batch, seq_len)
193218
194219
Returns:
195-
Sliced (logits, labels, loss_weight) for next-token prediction
220+
Sliced (logits, labels, soft_masked) for next-token prediction
196221
"""
197222
# Slice: logits[:, :-1] predicts labels[:, 1:]
198223
logits = logits[:, :-1].reshape(-1, logits.size(-1))
199224
labels = labels[:, 1:].reshape(-1).long()
200-
loss_weight = loss_weight[:, 1:].reshape(-1)
225+
soft_masked = soft_masked[:, 1:].reshape(-1)
201226

202-
return logits, labels, loss_weight
227+
return logits, labels, soft_masked

0 commit comments

Comments
 (0)