diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9be5464 --- /dev/null +++ b/.gitignore @@ -0,0 +1,114 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +Pipfile.lock + +# Poetry +poetry.lock + +# PEP 582 +__pypackages__/ + +# Celery +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# IDEs +.idea/ +.vscode/ +*.swp +*.swo +*~ +.DS_Store + +# Claude settings +.claude/* \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..dca8a27 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,116 @@ +[tool.poetry] +name = "lightweight_mmm" +version = "0.1.9" +description = "Package for Media-Mix-Modelling" +authors = ["Google LLC "] +license = "Apache-2.0" +readme = "README.md" +homepage = "https://github.com/google/lightweight_mmm" +repository = "https://github.com/google/lightweight_mmm" +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Topic :: Scientific/Engineering :: Mathematics", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12" +] +packages = [{include = "lightweight_mmm"}] + +[tool.poetry.dependencies] +python = "^3.8" +absl-py = "*" +arviz = ">=0.11.2" +immutabledict = ">=2.0.0" +jax = ">=0.3.18" +jaxlib = ">=0.3.18" +matplotlib = "==3.6.1" +numpy = ">=1.21.0" +numpyro = ">=0.9.2" +pandas = ">=1.1.5" +scipy = "*" +seaborn = "==0.11.1" +scikit-learn = "*" +statsmodels = ">=0.13.0" +tensorflow = ">=2.7.2" + +[tool.poetry.group.dev.dependencies] +pytest = "^7.4.0" +pytest-cov = "^4.1.0" +pytest-mock = "^3.11.1" +pytest-xdist = "^3.3.1" + +[tool.poetry.scripts] +test = "pytest:main" +tests = "pytest:main" + +[tool.pytest.ini_options] +minversion = "7.0" +testpaths = ["tests"] +addopts = [ + "-ra", + "--strict-markers", + "--cov=lightweight_mmm", + "--cov-report=term-missing:skip-covered", + "--cov-report=html", + "--cov-report=xml", + "--cov-fail-under=80", + "-vv", + "--tb=short", + "--maxfail=3" +] +python_files = ["test_*.py", "*_test.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +markers = [ + "unit: marks tests as unit tests (fast, isolated)", + "integration: marks tests as integration tests (may be slower)", + "slow: marks tests as slow (deselect with '-m \"not slow\"')" +] +filterwarnings = [ + "ignore::DeprecationWarning", + "ignore::PendingDeprecationWarning" +] + +[tool.coverage.run] +source = ["lightweight_mmm"] +branch = true +omit = [ + "*/tests/*", + "*/__pycache__/*", + "*/conftest.py", + "*/setup.py", + "*/.venv/*", + "*/venv/*" +] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "if self.debug:", + "if __name__ == .__main__.:", + "raise AssertionError", + "raise NotImplementedError", + "if 0:", + "if False:", + "class .*\\bProtocol\\):", + "@(abc\\.)?abstractmethod" +] +show_missing = true +precision = 2 +fail_under = 80 + +[tool.coverage.html] +directory = "htmlcov" + +[tool.coverage.xml] +output = "coverage.xml" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1c17851 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,119 @@ +"""Shared pytest fixtures and configuration for lightweight_mmm tests.""" + +import os +import tempfile +from pathlib import Path +from typing import Generator + +import numpy as np +import pandas as pd +import pytest + + +@pytest.fixture +def temp_dir() -> Generator[Path, None, None]: + """Create a temporary directory for test files.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + +@pytest.fixture +def mock_data(): + """Create mock data for testing MMM models.""" + np.random.seed(42) + n_time_periods = 52 + n_media_channels = 3 + n_geos = 2 + + # Generate synthetic data + data = { + 'date': pd.date_range('2023-01-01', periods=n_time_periods, freq='W'), + 'sales': np.random.poisson(1000, n_time_periods) + np.random.normal(0, 50, n_time_periods), + } + + # Add media spend data + for i in range(n_media_channels): + data[f'media_{i}'] = np.random.exponential(1000, n_time_periods) + + # Add geo data + for i in range(n_geos): + data[f'geo_{i}_sales'] = np.random.poisson(500, n_time_periods) + + return pd.DataFrame(data) + + +@pytest.fixture +def mock_config(): + """Create a mock configuration dictionary.""" + return { + 'n_media_channels': 3, + 'n_geos': 2, + 'model_type': 'adstock', + 'priors': { + 'intercept': {'mean': 0, 'std': 1}, + 'coef_media': {'mean': 0, 'std': 0.1}, + }, + 'hyperparameters': { + 'learning_rate': 0.001, + 'n_iterations': 1000, + 'batch_size': 32, + } + } + + +@pytest.fixture +def sample_media_data(): + """Generate sample media spend data.""" + np.random.seed(123) + return np.random.rand(52, 3) * 10000 # 52 weeks, 3 channels + + +@pytest.fixture +def sample_target_data(): + """Generate sample target (sales) data.""" + np.random.seed(123) + base_sales = 10000 + trend = np.linspace(0, 1000, 52) + seasonality = 500 * np.sin(np.linspace(0, 4 * np.pi, 52)) + noise = np.random.normal(0, 200, 52) + return base_sales + trend + seasonality + noise + + +@pytest.fixture(autouse=True) +def reset_random_seed(): + """Reset random seeds before each test for reproducibility.""" + np.random.seed(42) + import random + random.seed(42) + + # Reset JAX random seed if JAX is available + try: + import jax + jax.random.PRNGKey(42) + except ImportError: + pass + + +@pytest.fixture +def mock_model_params(): + """Create mock model parameters.""" + return { + 'intercept': np.array([1000.0]), + 'coef_media': np.array([0.1, 0.2, 0.15]), + 'coef_trend': np.array([10.0]), + 'saturation_parameters': { + 'alphas': np.array([2.0, 1.5, 2.5]), + 'betas': np.array([0.5, 0.6, 0.4]) + }, + 'adstock_parameters': { + 'convolve_window': 3, + 'decay_rates': np.array([0.3, 0.4, 0.35]) + } + } + + +@pytest.fixture +def capture_logs(caplog): + """Fixture to capture and assert log messages.""" + with caplog.at_level('DEBUG'): + yield caplog \ No newline at end of file diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_setup_validation.py b/tests/test_setup_validation.py new file mode 100644 index 0000000..8370398 --- /dev/null +++ b/tests/test_setup_validation.py @@ -0,0 +1,86 @@ +"""Validation tests to verify testing infrastructure is set up correctly.""" + +import pytest + + +class TestInfrastructureSetup: + """Test class to validate the testing infrastructure.""" + + @pytest.mark.unit + def test_pytest_installed(self): + """Verify pytest is available.""" + import pytest + assert pytest.__version__ + + @pytest.mark.unit + def test_coverage_installed(self): + """Verify pytest-cov is available.""" + import pytest_cov + assert pytest_cov + + @pytest.mark.unit + def test_mock_installed(self): + """Verify pytest-mock is available.""" + import pytest_mock + assert pytest_mock + + @pytest.mark.unit + def test_fixtures_available(self, temp_dir, mock_data, mock_config): + """Verify custom fixtures are available and working.""" + # Test temp_dir fixture + assert temp_dir.exists() + assert temp_dir.is_dir() + + # Test mock_data fixture + assert mock_data is not None + assert len(mock_data) == 52 # 52 weeks of data + assert 'sales' in mock_data.columns + assert 'media_0' in mock_data.columns + + # Test mock_config fixture + assert isinstance(mock_config, dict) + assert 'n_media_channels' in mock_config + assert mock_config['n_media_channels'] == 3 + + @pytest.mark.unit + def test_markers_defined(self, request): + """Verify custom markers are defined.""" + markers = request.config.getini('markers') + marker_names = [m.split(':')[0].strip() for m in markers] + assert 'unit' in marker_names + assert 'integration' in marker_names + assert 'slow' in marker_names + + @pytest.mark.integration + def test_integration_marker(self): + """Test that integration marker works.""" + assert True + + @pytest.mark.slow + def test_slow_marker(self): + """Test that slow marker works.""" + import time + time.sleep(0.1) # Simulate slow test + assert True + + def test_project_structure(self): + """Verify the project structure is accessible.""" + from pathlib import Path + + project_root = Path(__file__).parent.parent + assert project_root.exists() + assert (project_root / 'lightweight_mmm').exists() + assert (project_root / 'pyproject.toml').exists() + assert (project_root / 'tests').exists() + assert (project_root / 'tests' / 'conftest.py').exists() + + +def test_basic_assertion(): + """Basic test to ensure pytest runs.""" + assert 1 + 1 == 2 + + +def test_fixture_usage(sample_media_data, sample_target_data): + """Test that fixtures from conftest.py are accessible.""" + assert sample_media_data.shape == (52, 3) + assert len(sample_target_data) == 52 \ No newline at end of file diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29