diff --git a/runs/test_build_cache.npy b/runs/test_build_cache.npy new file mode 100644 index 00000000..ef9b3f35 Binary files /dev/null and b/runs/test_build_cache.npy differ diff --git a/tests/conftest.py b/tests/conftest.py index 0640d373..f69c9d59 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,15 @@ import pytest +import torch from datasets import Dataset from transformers import AutoConfig, AutoModelForCausalLM @pytest.fixture def model(): - """Create a small test model.""" + """Randomly initialize a small test model.""" + torch.manual_seed(42) + torch.cuda.manual_seed(42) + config = AutoConfig.from_pretrained("trl-internal-testing/tiny-Phi3ForCausalLM") return AutoModelForCausalLM.from_config(config) diff --git a/tests/test_build.py b/tests/test_build.py index 3ca98dce..d1cfe16f 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -1,22 +1,8 @@ -import pytest - -from bergson.data import load_gradients - -try: - import torch - - HAS_CUDA = torch.cuda.is_available() -except Exception: - HAS_CUDA = False - -if not HAS_CUDA: - pytest.skip( - "Skipping GPU-only tests: no CUDA/NVIDIA driver available.", - allow_module_level=True, - ) - from pathlib import Path +import numpy as np +import pytest +import torch from transformers import AutoModelForCausalLM from bergson import ( @@ -24,6 +10,29 @@ GradientProcessor, collect_gradients, ) +from bergson.data import load_gradients + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_build_consistency(tmp_path: Path, model, dataset): + collect_gradients( + model=model, + data=dataset, + processor=GradientProcessor(), + path=str(tmp_path), + skip_preconditioners=True, + ) + index = load_gradients(str(tmp_path)) + + # Regenerate cache + cache_path = Path("runs/test_build_cache.npy") + if not cache_path.exists(): + np.save(cache_path, index[index.dtype.names[0]][0]) + cached_item_grad = np.load(cache_path) + + first_module_grad = index[index.dtype.names[0]][0] + + assert np.allclose(first_module_grad, cached_item_grad, atol=1e-6) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")