Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/claude-code-review.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,3 @@ jobs:
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
# or https://docs.claude.com/en/docs/claude-code/cli-reference for available options
claude_args: '--allowed-tools "Bash(gh issue view:*),Bash(gh search:*),Bash(gh issue list:*),Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*)"'

1 change: 0 additions & 1 deletion .github/workflows/claude.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,3 @@ jobs:
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
# or https://docs.claude.com/en/docs/claude-code/cli-reference for available options
# claude_args: '--allowed-tools Bash(gh pr:*)'

37 changes: 37 additions & 0 deletions configs/experiment/dlm_transformer_small.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# @package _global_

# Short training run with small Transformer encoder for quick testing DLM objective

defaults:
- override /model: dlm_transformer_small

logger:
wandb:
name: debug-dlm-transformer-small
tags: ["debug", "dlm"]

trainer:
max_steps: 100
log_every_n_steps: 10
val_check_interval: 10
limit_val_batches: 2
check_val_every_n_epoch: null

model:
net:
embedder:
d_model: 32
encoder:
n_layers: 2
scheduler:
_target_: transformers.get_cosine_schedule_with_warmup
_partial_: true
num_warmup_steps: 10
num_training_steps: ${trainer.max_steps}

data:
_target_: glm_experiments.data.lm_datamodule.DLMDataModule
batch_size: 8
per_device_batch_size: 8

compile: false
39 changes: 39 additions & 0 deletions configs/model/dlm_transformer_base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
_target_: glm_experiments.models.lm_lit_module.DLMLitModule

soft_masked_weight: ${data.soft_masked_weight}
evals: ${data.evals}

net:
_target_: glm_experiments.models.components.lm.DLM
embedder:
_target_: glm_experiments.models.components.transformer.Embedding
vocab_size: 7
d_model: 768 # Standard BERT-base size
encoder:
_target_: glm_experiments.models.components.transformer.Transformer
hidden_size: ${..embedder.d_model} # 768
n_layers: 12 # CS336 default
num_heads: 12 # 12 heads → d_head = 64
# d_ff: auto-computed as floor(768 * 8/3 / 64) * 64 = 2048
rope_theta: 10000.0
is_causal: false # Bidirectional attention like MLM
layer_norm:
_target_: torch.nn.RMSNorm
normalized_shape: ${..embedder.d_model}
decoder:
_target_: glm_experiments.models.components.transformer.Linear
d_in: ${..embedder.d_model}
d_out: ${..embedder.vocab_size}

optimizer:
_target_: torch.optim.AdamW
_partial_: true
lr: 0.001 # CS336 default
weight_decay: 0.1 # CS336 default
betas: [0.9, 0.98] # CS336 default (beta1, beta2)
eps: 1.0e-9 # CS336 default

scheduler:
_target_: transformers.get_constant_schedule_with_warmup
_partial_: true
num_warmup_steps: 1000 # More warmup for larger model
39 changes: 39 additions & 0 deletions configs/model/dlm_transformer_small.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
_target_: glm_experiments.models.lm_lit_module.DLMLitModule

soft_masked_weight: ${data.soft_masked_weight}
evals: ${data.evals}

net:
_target_: glm_experiments.models.components.lm.DLM
embedder:
_target_: glm_experiments.models.components.transformer.Embedding
vocab_size: 7
d_model: 128
encoder:
_target_: glm_experiments.models.components.transformer.Transformer
hidden_size: ${..embedder.d_model} # 128
n_layers: 6 # Fewer layers for fast iteration
num_heads: 8 # 8 heads → d_head = 16
# d_ff: auto-computed as floor(128 * 8/3 / 64) * 64 = 320
rope_theta: 10000.0
is_causal: false # Bidirectional for DLM (like MLM)
layer_norm:
_target_: torch.nn.RMSNorm # Use RMSNorm to match Transformer
normalized_shape: ${..embedder.d_model}
decoder:
_target_: glm_experiments.models.components.transformer.Linear
d_in: ${..embedder.d_model}
d_out: ${..embedder.vocab_size}

optimizer:
_target_: torch.optim.AdamW
_partial_: true
lr: 0.001 # CS336 default
weight_decay: 0.1 # CS336 default
betas: [0.9, 0.98] # CS336 default (beta1, beta2)
eps: 1.0e-9 # CS336 default

scheduler:
_target_: transformers.get_constant_schedule_with_warmup
_partial_: true
num_warmup_steps: 100
4 changes: 2 additions & 2 deletions experiments/training_data/shard_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@
import argparse
from pathlib import Path

import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download
from tqdm import tqdm
import numpy as np


def download_split(repo_id: str, split: str, cache_dir: Path) -> Path:
"""Download a single split from HuggingFace."""
filename = f"data/{split}/{split}.jsonl.zst"
return Path(
hf_hub_download(
hf_hub_download( # nosec B615
repo_id=repo_id,
filename=filename,
repo_type="dataset",
Expand Down
4 changes: 1 addition & 3 deletions experiments/training_data/upload_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@


def main():
parser = argparse.ArgumentParser(
description="Upload sharded dataset to HuggingFace Hub"
)
parser = argparse.ArgumentParser(description="Upload sharded dataset to HuggingFace Hub")
parser.add_argument(
"--input-dir",
type=Path,
Expand Down
9 changes: 5 additions & 4 deletions glm_experiments/data/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,15 @@ def load_eval_dataset(
split: Dataset split to load (default: "test")
window_size: Size of the window around variants (must be even)
cache_dir: Directory to cache transformed dataset
objective: Training objective ("mlm" or "clm") - determines transform function
objective: Training objective ("mlm", "dlm", or "clm") - determines transform function
data_dir: Directory for genome downloads (default: "data")
label_column: Name of the label column to preserve (default: "label")

Returns:
Transformed dataset with columns: input_ids, pos, ref, alt, {label_column}

Raises:
ValueError: If filter_name not in EVAL_FILTERS or objective not mlm/clm
ValueError: If filter_name not in EVAL_FILTERS or objective not mlm/dlm/clm
"""
from datasets import load_from_disk

Expand Down Expand Up @@ -154,12 +154,13 @@ def load_eval_dataset(
original_columns = dataset.column_names

# Select transform function based on objective
if objective == "mlm":
if objective in ("mlm", "dlm"):
# Both MLM and DLM use the same bidirectional masking transform
transform_func = transform_llr_mlm
elif objective == "clm":
transform_func = transform_llr_clm
else:
raise ValueError(f"Unknown objective: {objective}. Must be 'mlm' or 'clm'.")
raise ValueError(f"Unknown objective: {objective}. Must be 'mlm', 'dlm', or 'clm'.")

transform_fn = partial(
transform_func,
Expand Down
68 changes: 68 additions & 0 deletions glm_experiments/data/lm_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,45 @@ def apply_clm_labels(input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tenso
return input_ids, labels


def apply_dlm_masking(
input_ids: torch.Tensor,
mask_token_id: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply diffusion masking to input tokens.

For each sequence, samples a random masking ratio r ~ Uniform(0, 1),
then masks each token with probability r. Unlike BERT/MLM, there is
no token replacement (100% of selected tokens become [MASK]).

Args:
input_ids: Token IDs of shape (batch_size, seq_len)
mask_token_id: Token ID for [MASK]

Returns:
Tuple of (masked_input_ids, labels) both as int8.
Labels has -100 for non-masked positions (standard PyTorch ignore_index).
"""
input_ids = input_ids.clone().to(torch.int8)
labels = input_ids.clone()

batch_size, seq_len = input_ids.shape

# Sample masking ratio r ~ Uniform(0, 1) for each sequence
masking_ratios = torch.rand(batch_size, 1) # Shape: (batch_size, 1)

# Create probability matrix: each sequence has its own masking ratio
probability_matrix = masking_ratios.expand(batch_size, seq_len) # (batch_size, seq_len)

# Select tokens for masking based on per-sequence ratio
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # Standard PyTorch ignore_index

# Replace ALL masked tokens with [MASK] (no random replacement)
input_ids[masked_indices] = mask_token_id

return input_ids, labels


class LMDataModule(LightningDataModule):
"""Base DataModule for DNA language modeling.

Expand Down Expand Up @@ -442,6 +481,35 @@ def apply_labels(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Ten
)


class DLMDataModule(LMDataModule):
"""DataModule for Diffusion Language Modeling.

Uses per-sequence variable masking ratio r ~ Uniform(0, 1).
No token replacement (100% [MASK]).

Args:
**kwargs: Arguments passed to LMDataModule
"""

def get_objective(self) -> str:
"""Return the objective type for DLM."""
return "dlm"

def apply_labels(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply DLM masking to create labels.

Args:
input_ids: Tokenized input IDs of shape (batch_size, seq_len)

Returns:
Tuple of (masked_input_ids, labels) with -100 for non-masked positions
"""
return apply_dlm_masking(
input_ids,
mask_token_id=self.tokenizer.mask_token_id,
)


class CLMDataModule(LMDataModule):
"""DataModule for Causal Language Modeling.

Expand Down
32 changes: 28 additions & 4 deletions glm_experiments/models/components/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,11 @@ def forward(
return self.compute_loss(logits, labels, soft_masked, soft_masked_weight)


class MLM(LM):
"""Masked language model (bidirectional).
class GeneralMaskedLM(LM):
"""Base class for bidirectional masked language models (MLM, DLM).

Predicts tokens only at masked positions (labels != -100).
Subclasses differ only in their masking strategy (applied in data module).
Both filter to masked positions (labels != -100) for loss computation.
"""

def prepare_for_loss(
Expand All @@ -157,7 +158,9 @@ def prepare_for_loss(
labels: torch.Tensor,
soft_masked: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Filter to masked positions only.
"""Filter to masked positions only (labels != -100).

Shared by both MLM and DLM since both use masking-based training.

Args:
logits: Logits of shape (batch, seq_len, vocab_size)
Expand All @@ -181,6 +184,27 @@ def prepare_for_loss(
return logits, labels, soft_masked


class MLM(GeneralMaskedLM):
"""Masked language model with BERT-style masking.

Uses fixed 15% masking with token replacement (80% [MASK], 10% random, 10% unchanged).
Masking logic is implemented in MLMDataModule.apply_labels().
"""

pass # All logic inherited from GeneralMaskedLM


class DLM(GeneralMaskedLM):
"""Diffusion language model with variable masking ratio.

Uses per-sequence random masking ratio r ~ Uniform(0, 1).
No token replacement (100% [MASK]).
Masking logic is implemented in DLMDataModule.apply_labels().
"""

pass # All logic inherited from GeneralMaskedLM


class CLM(LM):
"""Causal language model (autoregressive).

Expand Down
41 changes: 41 additions & 0 deletions glm_experiments/models/lm_lit_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,47 @@ def _compute_raw_llr(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
)


class DLMLitModule(LMLitModule):
"""Lightning module for Diffusion Language Modeling.

Uses same adapter and VEP scoring as MLM since both are bidirectional
masked models. Only difference is the masking strategy during training
(handled by DLMDataModule).

Args:
net: DLM model
optimizer: Optimizer partial function
scheduler: Scheduler partial function
"""

def create_adapter(self, net: nn.Module) -> nn.Module:
"""Create MaskedLMAdapter for biofoundation scoring."""
return MaskedLMAdapter(net)

def get_loss_name(self) -> str:
"""Return DLM loss metric name."""
return "dlm_loss"

def _compute_raw_llr(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute raw DLM LLR scores (no transformation).

Uses same MLM scoring since DLM is also a bidirectional masked model.

Args:
batch: Batch with keys {input_ids, pos, ref, alt, label}

Returns:
Raw LLR scores (higher = more likely under model)
"""
return compute_llr_mlm(
model=self.adapter,
input_ids=batch["input_ids"],
pos=batch["pos"],
ref=batch["ref"],
alt=batch["alt"],
)


class CLMLitModule(LMLitModule):
"""Lightning module for Causal Language Modeling.

Expand Down
Loading