diff --git a/bergson/gradients.py b/bergson/gradients.py index 0abd43d3..00754d8d 100644 --- a/bergson/gradients.py +++ b/bergson/gradients.py @@ -9,6 +9,7 @@ import torch.nn as nn from torch import Tensor from torch.utils.hooks import RemovableHandle +from transformers.pytorch_utils import Conv1D as HFConv1D from .math import reshape_to_nearest_square from .utils import assert_type, create_projection_matrix @@ -297,6 +298,32 @@ class HeadConfig(NamedTuple): head_dim: int """The dimension along which heads are tiled.""" +class LayerAdapter(): + supported_modules = (nn.Linear, HFConv1D, nn.Conv1d, nn.Conv2d, nn.Conv3d) + + @staticmethod + def in_attr(layer: nn.Module) -> str: + match layer: + case nn.Linear(): + return "in_features" + case HFConv1D(): + return 'nx' + case nn.Conv1d() | nn.Conv2d() | nn.Conv3d(): + return 'in_channels' + case _: + raise ValueError(f"Unsupported layer type: {type(layer)}") + + @staticmethod + def out_attr(layer: nn.Module) -> str: + match layer: + case nn.Linear(): + return "out_features" + case HFConv1D(): + return 'nf' + case nn.Conv1d() | nn.Conv2d() | nn.Conv3d(): + return 'out_channels' + case _: + raise ValueError(f"Unsupported layer type: {type(layer)}") @dataclass class GradientCollector(ContextDecorator): @@ -323,7 +350,7 @@ class GradientCollector(ContextDecorator): target_modules: set[str] | None = None """ List of parameter names to collect gradients for. Should consist only of weight - matrices in `nn.Linear` modules. If `None`, the gradients for all weight matrices + matrices in modules supported by LayerAdapter. If `None`, the gradients for all weight matrices will be collected. """ @@ -340,7 +367,7 @@ def __post_init__(self): # Before we add any hooks, we need to peek at what modules we need to track. for name, layer in self.model.named_modules(): - if not isinstance(layer, nn.Linear): + if not isinstance(layer, LayerAdapter.supported_modules): continue if self.target_modules is not None and name not in self.target_modules: @@ -442,7 +469,7 @@ def _save_input(self, module: nn.Module, inp: tuple, _): # to save memory, rather than waiting until the backward pass. p = self.processor.projection_dim if p is not None and not isinstance(norm, AdamNormalizer): - i = module.in_features + i = getattr(module, LayerAdapter.in_attr(module)) x = x @ self.projection(name, p, i, "right", x.device, x.dtype).T # type: ignore module._inputs = x @@ -450,7 +477,7 @@ def _save_input(self, module: nn.Module, inp: tuple, _): def _process_grad(self, module: nn.Module, _, grad_out): """Process the incoming gradient wrt the output of the module.""" # Sanity checks - assert isinstance(module, nn.Linear), "Expected a Linear module" + assert isinstance(module, LayerAdapter.supported_modules), f"Expected a module of type {LayerAdapter.supported_modules}, got {type(module)}" G = grad_out[0] # [N, S, O] I = module._inputs # [N, S, I/q] @@ -463,9 +490,9 @@ def _process_grad(self, module: nn.Module, _, grad_out): module_name, module_inputs, module_out_features = ( module._name, module._inputs, - module.out_features, + getattr(module, LayerAdapter.out_attr(module)), ) - module.out_features = head_size + setattr(module, LayerAdapter.out_attr(module), head_size) for h in range(num_heads): module._name = self.get_head_name(name, h) # type: ignore module._inputs = module_inputs @@ -482,16 +509,17 @@ def _process_grad(self, module: nn.Module, _, grad_out): raise e self._process_grad(module, None, (head_G,)) - module._name, module._inputs, module.out_features = ( + module._name, module._inputs = ( module_name, - module_inputs, - module_out_features, + module_inputs ) + setattr(module, LayerAdapter.out_attr(module), module_out_features) return p = self.processor.projection_dim - o, i = module.out_features, module.in_features + i = getattr(module, LayerAdapter.in_attr(module)) + o = getattr(module, LayerAdapter.out_attr(module)) # Pre-scale G by the Adafactor row statistics norm = self.processor.normalizers.get(name) diff --git a/tests/test_build.py b/tests/test_build.py index c517bdd7..f8f9e722 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -1,4 +1,5 @@ import pytest +from bergson.data import load_gradients try: import torch @@ -28,14 +29,14 @@ from bergson.data import tokenize -def test_disk_build(tmp_path: Path): +def test_disk_build_linear(tmp_path: Path): run_path = tmp_path / "example_with_heads" run_path.mkdir(parents=True, exist_ok=True) config = IndexConfig( run_path=str(run_path), model="RonenEldan/TinyStories-1M", - data=DataConfig(dataset="RonenEldan/TinyStories"), + data=DataConfig(dataset="NeelNanda/pile-10k", truncation=True), head_cfgs={ "h.0.attn.attention.out_proj": HeadConfig( num_heads=16, head_size=4, head_dim=2 @@ -68,3 +69,44 @@ def test_disk_build(tmp_path: Path): ) assert any(run_path.iterdir()), "Expected artifacts in the temp run_path" + + +def test_disk_build_conv1d(tmp_path: Path): + run_path = tmp_path / "example_with_heads" + run_path.mkdir(parents=True, exist_ok=True) + + config = IndexConfig( + run_path=str(run_path), + model="openai-community/gpt2", + data=DataConfig(dataset="NeelNanda/pile-10k", truncation=True), + ) + + model = AutoModelForCausalLM.from_pretrained( + config.model, trust_remote_code=True, use_safetensors=True + ) + tokenizer = AutoTokenizer.from_pretrained(config.model) + data = load_dataset(config.data.dataset, split="train") + data = data.select(range(8)) # type: ignore + + processor = GradientProcessor(projection_dim=config.projection_dim) + + data = data.map( + tokenize, + batched=True, + fn_kwargs=dict(args=config.data, tokenizer=tokenizer), + remove_columns=data.column_names, + ) + + collect_gradients( + model=model, + data=data, + processor=processor, + path=config.run_path, + head_cfgs=config.head_cfgs, + ) + + assert any(run_path.iterdir()), "Expected artifacts in the temp run_path" + + index = load_gradients(str(run_path)) + assert len(modules := index.dtype.names) != 0 + assert len(first_column := index[modules[0]]) != 0 \ No newline at end of file diff --git a/tests/test_gradients.py b/tests/test_gradients.py index 5bac36a8..44f3f298 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -9,6 +9,7 @@ AdamNormalizer, GradientCollector, GradientProcessor, + LayerAdapter, ) @@ -43,7 +44,8 @@ def closure(name: str, g: torch.Tensor): for name, collected_grad in collected_grads.items(): layer = model.get_submodule(name) - o, i = layer.out_features, layer.in_features + i = getattr(layer, LayerAdapter.in_attr(layer)) + o = getattr(layer, LayerAdapter.out_attr(layer)) g = layer.weight.grad assert g is not None @@ -77,8 +79,8 @@ def closure(name: str, g: torch.Tensor): for name, collected_grad in collected_grads.items(): layer = model.get_submodule(name) - - o, i = layer.out_features, layer.in_features + i = getattr(layer, LayerAdapter.in_attr(layer)) + o = getattr(layer, LayerAdapter.out_attr(layer)) g = layer.weight.grad assert g is not None