Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions bergson/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
from torch import Tensor
from torch.utils.hooks import RemovableHandle
from transformers.pytorch_utils import Conv1D

from .math import reshape_to_nearest_square
from .utils import assert_type, create_projection_matrix
Expand Down Expand Up @@ -323,7 +324,7 @@
target_modules: set[str] | None = None
"""
List of parameter names to collect gradients for. Should consist only of weight
matrices in `nn.Linear` modules. If `None`, the gradients for all weight matrices
matrices in `nn.Linear` or `Conv1D` modules. If `None`, the gradients for all weight matrices
will be collected.
"""

Expand All @@ -340,7 +341,7 @@

# Before we add any hooks, we need to peek at what modules we need to track.
for name, layer in self.model.named_modules():
if not isinstance(layer, nn.Linear):
if not isinstance(layer, (nn.Linear, Conv1D)):
continue

if self.target_modules is not None and name not in self.target_modules:
Expand Down Expand Up @@ -442,15 +443,21 @@
# to save memory, rather than waiting until the backward pass.
p = self.processor.projection_dim
if p is not None and not isinstance(norm, AdamNormalizer):
i = module.in_features
# Handle both Linear and Conv1D modules
if isinstance(module, nn.Linear):
i = module.in_features
elif isinstance(module, Conv1D):
i = module.nx
else:
raise ValueError(f"Unsupported module type: {type(module)}")
x = x @ self.projection(name, p, i, "right", x.device, x.dtype).T # type: ignore

module._inputs = x

def _process_grad(self, module: nn.Module, _, grad_out):
"""Process the incoming gradient wrt the output of the module."""
# Sanity checks
assert isinstance(module, nn.Linear), "Expected a Linear module"
assert isinstance(module, (nn.Linear, Conv1D)), "Expected a Linear or Conv1D module"
G = grad_out[0] # [N, S, O]
I = module._inputs # [N, S, I/q]

Expand All @@ -465,7 +472,7 @@
module._inputs,
module.out_features,
)
module.out_features = head_size

Check failure on line 475 in bergson/gradients.py

View workflow job for this annotation

GitHub Actions / build

Argument of type "int" cannot be assigned to parameter "value" of type "Tensor | Module" in function "__setattr__"   Type "int" is not assignable to type "Tensor | Module"     "int" is not assignable to "Tensor"     "int" is not assignable to "Module" (reportArgumentType)
for h in range(num_heads):
module._name = self.get_head_name(name, h) # type: ignore
module._inputs = module_inputs
Expand All @@ -482,7 +489,7 @@
raise e

self._process_grad(module, None, (head_G,))
module._name, module._inputs, module.out_features = (

Check failure on line 492 in bergson/gradients.py

View workflow job for this annotation

GitHub Actions / build

Argument of type "int | Tensor | Module" cannot be assigned to parameter "value" of type "Tensor | Module" in function "__setattr__" (reportArgumentType)

Check failure on line 492 in bergson/gradients.py

View workflow job for this annotation

GitHub Actions / build

Cannot assign to attribute "out_features" for class "Linear"   Expression of type "int | Tensor | Module" cannot be assigned to attribute "out_features" of class "Linear"     Attribute "__set__" is unknown     Type "int | Tensor | Module" is not assignable to type "int"       "Module" is not assignable to "int" (reportAttributeAccessIssue)
module_name,
module_inputs,
module_out_features,
Expand All @@ -491,7 +498,13 @@
return

p = self.processor.projection_dim
o, i = module.out_features, module.in_features
# Handle both Linear and Conv1D modules
if isinstance(module, nn.Linear):
o, i = module.out_features, module.in_features
elif isinstance(module, Conv1D):
o, i = module.nf, module.nx
else:
raise ValueError(f"Unsupported module type: {type(module)}")

# Pre-scale G by the Adafactor row statistics
norm = self.processor.normalizers.get(name)
Expand Down
46 changes: 44 additions & 2 deletions tests/test_build.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from bergson.data import load_gradients

try:
import torch
Expand Down Expand Up @@ -28,14 +29,14 @@
from bergson.data import tokenize


def test_disk_build(tmp_path: Path):
def test_disk_build_linear(tmp_path: Path):
run_path = tmp_path / "example_with_heads"
run_path.mkdir(parents=True, exist_ok=True)

config = IndexConfig(
run_path=str(run_path),
model="RonenEldan/TinyStories-1M",
data=DataConfig(dataset="RonenEldan/TinyStories"),
data=DataConfig(dataset="NeelNanda/pile-10k", truncation=True),
head_cfgs={
"h.0.attn.attention.out_proj": HeadConfig(
num_heads=16, head_size=4, head_dim=2
Expand Down Expand Up @@ -68,3 +69,44 @@ def test_disk_build(tmp_path: Path):
)

assert any(run_path.iterdir()), "Expected artifacts in the temp run_path"


def test_disk_build_conv1d(tmp_path: Path):
run_path = tmp_path / "example_with_heads"
run_path.mkdir(parents=True, exist_ok=True)

config = IndexConfig(
run_path=str(run_path),
model="openai-community/gpt2",
data=DataConfig(dataset="NeelNanda/pile-10k", truncation=True),
)

model = AutoModelForCausalLM.from_pretrained(
config.model, trust_remote_code=True, use_safetensors=True
)
tokenizer = AutoTokenizer.from_pretrained(config.model)
data = load_dataset(config.data.dataset, split="train")
data = data.select(range(8)) # type: ignore

processor = GradientProcessor(projection_dim=config.projection_dim)

data = data.map(
tokenize,
batched=True,
fn_kwargs=dict(args=config.data, tokenizer=tokenizer),
remove_columns=data.column_names,
)

collect_gradients(
model=model,
data=data,
processor=processor,
path=config.run_path,
head_cfgs=config.head_cfgs,
)

assert any(run_path.iterdir()), "Expected artifacts in the temp run_path"

index = load_gradients(str(run_path))
assert len(modules := index.dtype.names) != 0
assert len(first_column := index[modules[0]]) != 0
16 changes: 14 additions & 2 deletions tests/test_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ def closure(name: str, g: torch.Tensor):
for name, collected_grad in collected_grads.items():
layer = model.get_submodule(name)

o, i = layer.out_features, layer.in_features
# Handle both Linear and Conv1D modules
if hasattr(layer, 'out_features') and hasattr(layer, 'in_features'):
o, i = layer.out_features, layer.in_features
elif hasattr(layer, 'nf') and hasattr(layer, 'nx'):
o, i = layer.nf, layer.nx
else:
raise ValueError(f"Unsupported layer type: {type(layer)}")
g = layer.weight.grad
assert g is not None

Expand Down Expand Up @@ -78,7 +84,13 @@ def closure(name: str, g: torch.Tensor):
for name, collected_grad in collected_grads.items():
layer = model.get_submodule(name)

o, i = layer.out_features, layer.in_features
# Handle both Linear and Conv1D modules
if hasattr(layer, 'out_features') and hasattr(layer, 'in_features'):
o, i = layer.out_features, layer.in_features
elif hasattr(layer, 'nf') and hasattr(layer, 'nx'):
o, i = layer.nf, layer.nx
else:
raise ValueError(f"Unsupported layer type: {type(layer)}")
g = layer.weight.grad
assert g is not None

Expand Down
Loading