Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 38 additions & 10 deletions bergson/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
from torch import Tensor
from torch.utils.hooks import RemovableHandle
from transformers.pytorch_utils import Conv1D as HFConv1D

from .math import reshape_to_nearest_square
from .utils import assert_type, create_projection_matrix
Expand Down Expand Up @@ -297,6 +298,32 @@ class HeadConfig(NamedTuple):
head_dim: int
"""The dimension along which heads are tiled."""

class LayerAdapter():
supported_modules = (nn.Linear, HFConv1D, nn.Conv1d, nn.Conv2d, nn.Conv3d)

@staticmethod
def in_attr(layer: nn.Module) -> str:
match layer:
case nn.Linear():
return "in_features"
case HFConv1D():
return 'nx'
case nn.Conv1d() | nn.Conv2d() | nn.Conv3d():
return 'in_channels'
case _:
raise ValueError(f"Unsupported layer type: {type(layer)}")

@staticmethod
def out_attr(layer: nn.Module) -> str:
match layer:
case nn.Linear():
return "out_features"
case HFConv1D():
return 'nf'
case nn.Conv1d() | nn.Conv2d() | nn.Conv3d():
return 'out_channels'
case _:
raise ValueError(f"Unsupported layer type: {type(layer)}")

@dataclass
class GradientCollector(ContextDecorator):
Expand All @@ -323,7 +350,7 @@ class GradientCollector(ContextDecorator):
target_modules: set[str] | None = None
"""
List of parameter names to collect gradients for. Should consist only of weight
matrices in `nn.Linear` modules. If `None`, the gradients for all weight matrices
matrices in modules supported by LayerAdapter. If `None`, the gradients for all weight matrices
will be collected.
"""

Expand All @@ -340,7 +367,7 @@ def __post_init__(self):

# Before we add any hooks, we need to peek at what modules we need to track.
for name, layer in self.model.named_modules():
if not isinstance(layer, nn.Linear):
if not isinstance(layer, LayerAdapter.supported_modules):
continue

if self.target_modules is not None and name not in self.target_modules:
Expand Down Expand Up @@ -442,15 +469,15 @@ def _save_input(self, module: nn.Module, inp: tuple, _):
# to save memory, rather than waiting until the backward pass.
p = self.processor.projection_dim
if p is not None and not isinstance(norm, AdamNormalizer):
i = module.in_features
i = getattr(module, LayerAdapter.in_attr(module))
x = x @ self.projection(name, p, i, "right", x.device, x.dtype).T # type: ignore

module._inputs = x

def _process_grad(self, module: nn.Module, _, grad_out):
"""Process the incoming gradient wrt the output of the module."""
# Sanity checks
assert isinstance(module, nn.Linear), "Expected a Linear module"
assert isinstance(module, LayerAdapter.supported_modules), f"Expected a module of type {LayerAdapter.supported_modules}, got {type(module)}"
G = grad_out[0] # [N, S, O]
I = module._inputs # [N, S, I/q]

Expand All @@ -463,9 +490,9 @@ def _process_grad(self, module: nn.Module, _, grad_out):
module_name, module_inputs, module_out_features = (
module._name,
module._inputs,
module.out_features,
getattr(module, LayerAdapter.out_attr(module)),
)
module.out_features = head_size
setattr(module, LayerAdapter.out_attr(module), head_size)
for h in range(num_heads):
module._name = self.get_head_name(name, h) # type: ignore
module._inputs = module_inputs
Expand All @@ -482,16 +509,17 @@ def _process_grad(self, module: nn.Module, _, grad_out):
raise e

self._process_grad(module, None, (head_G,))
module._name, module._inputs, module.out_features = (
module._name, module._inputs = (
module_name,
module_inputs,
module_out_features,
module_inputs
)
setattr(module, LayerAdapter.out_attr(module), module_out_features)

return

p = self.processor.projection_dim
o, i = module.out_features, module.in_features
i = getattr(module, LayerAdapter.in_attr(module))
o = getattr(module, LayerAdapter.out_attr(module))

# Pre-scale G by the Adafactor row statistics
norm = self.processor.normalizers.get(name)
Expand Down
46 changes: 44 additions & 2 deletions tests/test_build.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from bergson.data import load_gradients

try:
import torch
Expand Down Expand Up @@ -28,14 +29,14 @@
from bergson.data import tokenize


def test_disk_build(tmp_path: Path):
def test_disk_build_linear(tmp_path: Path):
run_path = tmp_path / "example_with_heads"
run_path.mkdir(parents=True, exist_ok=True)

config = IndexConfig(
run_path=str(run_path),
model="RonenEldan/TinyStories-1M",
data=DataConfig(dataset="RonenEldan/TinyStories"),
data=DataConfig(dataset="NeelNanda/pile-10k", truncation=True),
head_cfgs={
"h.0.attn.attention.out_proj": HeadConfig(
num_heads=16, head_size=4, head_dim=2
Expand Down Expand Up @@ -68,3 +69,44 @@ def test_disk_build(tmp_path: Path):
)

assert any(run_path.iterdir()), "Expected artifacts in the temp run_path"


def test_disk_build_conv1d(tmp_path: Path):
run_path = tmp_path / "example_with_heads"
run_path.mkdir(parents=True, exist_ok=True)

config = IndexConfig(
run_path=str(run_path),
model="openai-community/gpt2",
data=DataConfig(dataset="NeelNanda/pile-10k", truncation=True),
)

model = AutoModelForCausalLM.from_pretrained(
config.model, trust_remote_code=True, use_safetensors=True
)
tokenizer = AutoTokenizer.from_pretrained(config.model)
data = load_dataset(config.data.dataset, split="train")
data = data.select(range(8)) # type: ignore

processor = GradientProcessor(projection_dim=config.projection_dim)

data = data.map(
tokenize,
batched=True,
fn_kwargs=dict(args=config.data, tokenizer=tokenizer),
remove_columns=data.column_names,
)

collect_gradients(
model=model,
data=data,
processor=processor,
path=config.run_path,
head_cfgs=config.head_cfgs,
)

assert any(run_path.iterdir()), "Expected artifacts in the temp run_path"

index = load_gradients(str(run_path))
assert len(modules := index.dtype.names) != 0
assert len(first_column := index[modules[0]]) != 0
8 changes: 5 additions & 3 deletions tests/test_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AdamNormalizer,
GradientCollector,
GradientProcessor,
LayerAdapter,
)


Expand Down Expand Up @@ -43,7 +44,8 @@ def closure(name: str, g: torch.Tensor):
for name, collected_grad in collected_grads.items():
layer = model.get_submodule(name)

o, i = layer.out_features, layer.in_features
i = getattr(layer, LayerAdapter.in_attr(layer))
o = getattr(layer, LayerAdapter.out_attr(layer))
g = layer.weight.grad
assert g is not None

Expand Down Expand Up @@ -77,8 +79,8 @@ def closure(name: str, g: torch.Tensor):

for name, collected_grad in collected_grads.items():
layer = model.get_submodule(name)

o, i = layer.out_features, layer.in_features
i = getattr(layer, LayerAdapter.in_attr(layer))
o = getattr(layer, LayerAdapter.out_attr(layer))
g = layer.weight.grad
assert g is not None

Expand Down