diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 000000000..dee8fe086 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,138 @@ +name: Tests + +on: + push: + branches: [ main, dev/v0d4, feature/* ] + pull_request: + branches: [ main, dev/v0d4 ] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.9' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install black isort pytest + + - name: Run linting + run: | + cd test + python run_suite.py lint + + unit-tests: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.9', '3.10', '3.11'] + + steps: + - uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 0 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install pytest pytest-cov + python -m pip install -e . + + - name: Run unit tests + run: | + cd test + python run_suite.py unit + + integration-tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.9' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install pytest pytest-mock + python -m pip install -e . + + - name: Run integration tests + run: | + cd test + python run_suite.py integration + + throughput-tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.9' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install pytest + # Install minimal dependencies for throughput testing + python -m pip install loguru time + + - name: Run throughput tests + run: | + cd test + python run_suite.py throughput + + test-coverage: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.9' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install pytest pytest-cov + python -m pip install -e . + + - name: Run tests with coverage + run: | + cd test + python -m pytest --cov=../lmms_eval --cov-report=xml --cov-report=html + + - name: Upload coverage reports + uses: codecov/codecov-action@v3 + with: + file: ./test/coverage.xml + fail_ci_if_error: false \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 7ed891f1d..7d6a28d2d 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,13 @@ dependencies = [ ] [project.optional-dependencies] +test = [ + "pytest>=7.0.0", + "pytest-cov>=4.0.0", + "pytest-mock>=3.0.0", + "pytest-xdist>=3.0.0", + "coverage>=6.0.0", +] audio = [ "more-itertools", "editdistance", diff --git a/test/README.md b/test/README.md new file mode 100644 index 000000000..8335cef98 --- /dev/null +++ b/test/README.md @@ -0,0 +1,178 @@ +# Testing Framework for lmms-eval + +This directory contains the test suite for lmms-eval, designed for CI/CD integration and comprehensive testing of the codebase. + +## Structure + +``` +test/ +├── __init__.py # Test package initialization +├── conftest.py # pytest fixtures and configuration +├── requirements-test.txt # Testing dependencies +├── run_suite.py # Test suite runner +├── test_api_components.py # Core API component tests +├── test_chat_models.py # Chat model integration tests +├── test_throughput_metrics.py # Original throughput demo script +└── test_throughput_metrics_unit.py # Unit tests for throughput metrics +``` + +## Test Categories + +### Unit Tests +- **test_throughput_metrics_unit.py**: Tests for TPOT and inference speed calculations +- **test_api_components.py**: Tests for core API components (Instance, registries, metrics) + +### Integration Tests +- **test_chat_models.py**: Integration tests for chat models with throughput metrics + +### Throughput Tests +- **test_throughput_metrics.py**: Demo script showing throughput calculations +- **test_throughput_metrics_unit.py**: Comprehensive unit tests for timing logic + +## Running Tests + +### Using the Test Runner +```bash +# Run all tests +python test/run_suite.py all + +# Run specific test suites +python test/run_suite.py unit +python test/run_suite.py integration +python test/run_suite.py throughput +python test/run_suite.py lint +``` + +### Using pytest Directly +```bash +# Install test dependencies +pip install -r test/requirements-test.txt + +# Run all tests +pytest test/ + +# Run specific test files +pytest test/test_throughput_metrics_unit.py -v + +# Run with coverage +pytest test/ --cov=lmms_eval --cov-report=html +``` + +### Using unittest +```bash +# Run individual test files +python test/test_throughput_metrics_unit.py +python test/test_api_components.py +``` + +## CI/CD Integration + +### GitHub Actions +The test suite is integrated with GitHub Actions through `.github/workflows/test.yml`: + +- **Lint Check**: Runs black and isort formatting checks +- **Unit Tests**: Runs on Python 3.9, 3.10, 3.11 +- **Integration Tests**: Tests model integration with mocks +- **Throughput Tests**: Validates throughput metric calculations +- **Coverage**: Generates test coverage reports + +### Pre-commit Hooks +Tests are automatically run through pre-commit hooks: +```bash +pre-commit install +pre-commit run --all-files +``` + +## Test Design Principles + +### 1. Fast Unit Tests +- Mock external dependencies (models, APIs) +- Test core logic without heavy I/O +- Focus on edge cases and error handling + +### 2. Comprehensive Integration Tests +- Test real component interactions +- Use minimal mocking for integration points +- Validate end-to-end workflows + +### 3. Throughput-Specific Tests +- Validate TPOT formula: `(e2e_latency - TTFT) / (num_output_tokens - 1)` +- Test inference speed calculation: `1 / TPOT` +- Verify timing measurement accuracy +- Test batch processing scenarios + +### 4. Maintainable Test Code +- Use fixtures for common test data +- Clear test names describing what's being tested +- Comprehensive error message assertions +- Clean separation between test categories + +## Adding New Tests + +### For New Features +1. Add unit tests in appropriate `test_*.py` file +2. Add integration tests if feature involves multiple components +3. Update `run_suite.py` if new test categories are needed +4. Update CI workflow if special setup is required + +### For Throughput Metrics +1. Add calculation tests to `test_throughput_metrics_unit.py` +2. Add integration tests to `test_chat_models.py` +3. Ensure timing accuracy tests cover edge cases + +### Test Naming Convention +- Test files: `test_.py` +- Test classes: `Test` +- Test methods: `test_` + +## Dependencies + +### Core Testing +- `pytest`: Test framework +- `pytest-cov`: Coverage reporting +- `pytest-mock`: Mocking utilities + +### Code Quality +- `black`: Code formatting +- `isort`: Import sorting +- `coverage`: Coverage analysis + +### Optional +- `torch`: For model-related tests +- `transformers`: For HuggingFace model tests +- `openai`: For API model tests + +## Best Practices + +### Writing Tests +- Keep tests focused on single behaviors +- Use descriptive assertions with clear error messages +- Mock external dependencies appropriately +- Test both success and failure cases + +### Performance Testing +- Use timing measurements for throughput validation +- Allow reasonable variance in timing tests +- Test edge cases (zero tokens, single token, large batches) + +### CI/CD Considerations +- Tests should be deterministic and reliable +- Avoid network dependencies in CI +- Use matrix testing for multiple Python versions +- Generate coverage reports for code quality tracking + +## Troubleshooting + +### Common Issues +1. **Import Errors**: Ensure lmms-eval is installed with `pip install -e .` +2. **Missing Dependencies**: Install test requirements with `pip install -r test/requirements-test.txt` +3. **Timing Test Failures**: Check system load; timing tests may be sensitive to CPU usage + +### Debug Mode +```bash +# Run tests with detailed output +pytest test/ -v -s + +# Run specific test with pdb debugging +pytest test/test_throughput_metrics_unit.py::TestThroughputMetrics::test_tpot_calculation -v -s --pdb +``` \ No newline at end of file diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 000000000..789db3fd9 --- /dev/null +++ b/test/__init__.py @@ -0,0 +1,3 @@ +""" +Test suite for lmms-eval +""" diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 000000000..1a86b8561 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,51 @@ +""" +pytest configuration and fixtures for lmms-eval tests +""" + +import os +import tempfile +from unittest.mock import Mock, patch + +import pytest + + +@pytest.fixture +def mock_model(): + """Mock model for testing without actual model loading""" + mock = Mock() + mock.generate.return_value = "test response" + mock.tokenizer = Mock() + mock.tokenizer.encode.return_value = [1, 2, 3, 4, 5] + mock.tokenizer.decode.return_value = "test response" + return mock + + +@pytest.fixture +def temp_cache_dir(): + """Temporary directory for cache files""" + with tempfile.TemporaryDirectory() as temp_dir: + yield temp_dir + + +@pytest.fixture +def mock_task_dict(): + """Mock task dictionary for testing""" + return { + "test_task": { + "test": [ + { + "question": "What is 2+2?", + "answer": "4", + "image": None, + "doc_id": 0, + } + ] + } + } + + +@pytest.fixture +def mock_eval_logger(): + """Mock evaluation logger""" + with patch("lmms_eval.api.model.eval_logger") as mock_logger: + yield mock_logger diff --git a/test/requirements-test.txt b/test/requirements-test.txt new file mode 100644 index 000000000..19354606f --- /dev/null +++ b/test/requirements-test.txt @@ -0,0 +1,14 @@ +# Testing dependencies for lmms-eval +pytest>=7.0.0 +pytest-cov>=4.0.0 +pytest-mock>=3.0.0 +pytest-xdist>=3.0.0 # For parallel test execution +black>=24.1.0 +isort>=5.13.2 +loguru +coverage>=6.0.0 + +# Optional dependencies for specific test scenarios +torch>=2.0.0 # For model testing +transformers>=4.20.0 # For model testing +openai>=1.0.0 # For API testing \ No newline at end of file diff --git a/test/run_suite.py b/test/run_suite.py new file mode 100755 index 000000000..06689606a --- /dev/null +++ b/test/run_suite.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +""" +Test suite runner for lmms-eval CI/CD + +This script runs different test suites based on the provided argument: +- unit: Run unit tests only +- integration: Run integration tests only +- all: Run all tests +- throughput: Run throughput-specific tests only +""" +import argparse +import subprocess +import sys +import unittest +from pathlib import Path + + +def run_command(cmd, description): + """Run a command and handle errors""" + print(f"\n{'='*60}") + print(f"Running: {description}") + print(f"Command: {' '.join(cmd)}") + print(f"{'='*60}") + + try: + result = subprocess.run(cmd, check=True, capture_output=True, text=True) + print(result.stdout) + if result.stderr: + print("STDERR:", result.stderr) + return True + except subprocess.CalledProcessError as e: + print(f"ERROR: {description} failed") + print(f"Return code: {e.returncode}") + print(f"STDOUT: {e.stdout}") + print(f"STDERR: {e.stderr}") + return False + + +def run_unit_tests(): + """Run unit tests""" + test_files = [ + "test_throughput_metrics_unit.py", + "test_api_components.py", + ] + + success = True + for test_file in test_files: + if not run_command(["python", "-m", "pytest", test_file, "-v"], f"Unit tests: {test_file}"): + success = False + + return success + + +def run_integration_tests(): + """Run integration tests""" + test_files = [ + "test_chat_models.py", + ] + + success = True + for test_file in test_files: + if not run_command(["python", "-m", "pytest", test_file, "-v"], f"Integration tests: {test_file}"): + success = False + + return success + + +def run_throughput_tests(): + """Run throughput-specific tests""" + test_files = [ + "test_throughput_metrics_unit.py", + "test_throughput_metrics.py", + ] + + success = True + for test_file in test_files: + if Path(test_file).exists(): + if not run_command(["python", test_file], f"Throughput tests: {test_file}"): + success = False + + return success + + +def run_linting(): + """Run code linting""" + commands = [ + (["python", "-m", "black", "--check", ".", "--line-length", "240"], "Black formatting check"), + (["python", "-m", "isort", "--check-only", "."], "Import sorting check"), + ] + + success = True + for cmd, description in commands: + if not run_command(cmd, description): + success = False + + return success + + +def main(): + parser = argparse.ArgumentParser(description="Run lmms-eval test suite") + parser.add_argument("suite", choices=["unit", "integration", "all", "throughput", "lint"], help="Test suite to run") + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") + + args = parser.parse_args() + + # Change to test directory + test_dir = Path(__file__).parent + import os + + os.chdir(test_dir) + + print(f"Running {args.suite} test suite...") + print(f"Working directory: {test_dir}") + + success = True + + if args.suite == "unit": + success = run_unit_tests() + elif args.suite == "integration": + success = run_integration_tests() + elif args.suite == "throughput": + success = run_throughput_tests() + elif args.suite == "lint": + success = run_linting() + elif args.suite == "all": + print("Running all test suites...") + success &= run_linting() + success &= run_unit_tests() + success &= run_integration_tests() + success &= run_throughput_tests() + + if success: + print(f"\n✅ All {args.suite} tests passed!") + sys.exit(0) + else: + print(f"\n❌ Some {args.suite} tests failed!") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/test/test_api_components.py b/test/test_api_components.py new file mode 100644 index 000000000..cfe21dd36 --- /dev/null +++ b/test/test_api_components.py @@ -0,0 +1,185 @@ +""" +Unit tests for core API components +""" + +import unittest +from unittest.mock import Mock, patch + +import pytest + + +class TestAPIComponents(unittest.TestCase): + """Test core API components""" + + def test_instance_creation(self): + """Test Instance class creation and properties""" + from lmms_eval.api.instance import Instance + + # Test basic instance creation + instance = Instance( + request_type="generate_until", + arguments=("test context", {"max_tokens": 100}), + idx=0, + metadata={"task": "test_task", "doc_id": 0, "repeats": 1}, + ) + + self.assertEqual(instance.request_type, "generate_until") + self.assertEqual(instance.idx, 0) + self.assertEqual(instance.task_name, "test_task") + self.assertEqual(instance.doc_id, 0) + self.assertEqual(instance.repeats, 1) + + def test_model_registry(self): + """Test model registration functionality""" + from lmms_eval.api.registry import register_model + + # Test that decorator works + @register_model("test_model") + class TestModel: + pass + + # Verify model was registered + from lmms_eval.api.registry import MODEL_REGISTRY + + self.assertIn("test_model", MODEL_REGISTRY) + self.assertEqual(MODEL_REGISTRY["test_model"], TestModel) + + def test_metrics_registry(self): + """Test metrics registration functionality""" + from lmms_eval.api.registry import register_metric + + # Test metric registration + @register_metric( + metric="test_metric", + higher_is_better=True, + output_type="generate_until", + aggregation="mean", + ) + def test_metric_fn(items): + return items + + # Verify metric was registered + from lmms_eval.api.registry import METRIC_REGISTRY + + self.assertIn("test_metric", METRIC_REGISTRY) + + def test_base_model_interface(self): + """Test base model interface""" + from lmms_eval.api.model import lmms + + # Create mock model + class MockModel(lmms): + def loglikelihood(self, requests): + return [(0.5, True) for _ in requests] + + def generate_until(self, requests): + return ["test response" for _ in requests] + + def generate_until_multi_round(self, requests): + return ["test response" for _ in requests] + + model = MockModel() + + # Test interface methods exist + self.assertTrue(hasattr(model, "loglikelihood")) + self.assertTrue(hasattr(model, "generate_until")) + self.assertTrue(hasattr(model, "generate_until_multi_round")) + + # Test properties + self.assertEqual(model.rank, 0) + self.assertEqual(model.world_size, 1) + self.assertTrue(model.is_simple) + + def test_caching_functionality(self): + """Test model caching functionality""" + from lmms_eval.api.model import CacheHook, hash_args + + # Test hash function + test_args = ("test", {"param": "value"}) + hash1 = hash_args("method", test_args) + hash2 = hash_args("method", test_args) + hash3 = hash_args("method", ("different", {"param": "value"})) + + self.assertEqual(hash1, hash2) # Same inputs should hash the same + self.assertNotEqual(hash1, hash3) # Different inputs should hash differently + + # Test cache hook + cache_hook = CacheHook(None) + # Should not crash when dbdict is None + cache_hook.add_partial("test_method", test_args, "result") + + @patch("lmms_eval.api.metrics.eval_logger") + def test_metrics_calculation(self, mock_logger): + """Test metrics calculation functions""" + from lmms_eval.api.metrics import exact_match_hf_evaluate, mean, median + + # Test mean calculation + test_values = [1, 2, 3, 4, 5] + result = mean(test_values) + self.assertEqual(result, 3.0) + + # Test median calculation + result = median(test_values) + self.assertEqual(result, 3) + + # Test exact match + predictions = ["hello", "world", "test"] + references = ["hello", "world", "different"] + result = exact_match_hf_evaluate(predictions, references) + expected_accuracy = 2 / 3 # 2 out of 3 matches + self.assertAlmostEqual(result["exact_match"], expected_accuracy, places=3) + + def test_utility_functions(self): + """Test utility functions""" + from lmms_eval.utils import hash_string, simple_parse_args_string + + # Test argument parsing + arg_string = "param1=value1,param2=value2,param3=123" + parsed = simple_parse_args_string(arg_string) + + expected = {"param1": "value1", "param2": "value2", "param3": "123"} + self.assertEqual(parsed, expected) + + # Test hash function + test_string = "test string for hashing" + hash1 = hash_string(test_string) + hash2 = hash_string(test_string) + hash3 = hash_string("different string") + + self.assertEqual(hash1, hash2) + self.assertNotEqual(hash1, hash3) + self.assertIsInstance(hash1, str) + + +class TestTaskManagement(unittest.TestCase): + """Test task management functionality""" + + def test_task_creation(self): + """Test basic task creation and properties""" + # This would require more complex setup with actual task files + # For now, test that the task module can be imported + try: + from lmms_eval.api.task import Task + from lmms_eval.tasks import TaskManager + + # Basic import test + self.assertTrue(True, "Task modules imported successfully") + except ImportError as e: + self.fail(f"Failed to import task modules: {e}") + + def test_task_registry(self): + """Test task registration functionality""" + from lmms_eval.api.registry import register_task + + # Test task registration decorator + @register_task("test_task") + class TestTask: + pass + + from lmms_eval.api.registry import TASK_REGISTRY + + self.assertIn("test_task", TASK_REGISTRY) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_chat_models.py b/test/test_chat_models.py new file mode 100644 index 000000000..c17d5c148 --- /dev/null +++ b/test/test_chat_models.py @@ -0,0 +1,183 @@ +""" +Integration tests for chat models with throughput metrics +""" + +import time +import unittest +from unittest.mock import MagicMock, Mock, patch + +import pytest + + +class TestChatModelThroughput(unittest.TestCase): + """Test throughput metrics integration in chat models""" + + def setUp(self): + """Set up test fixtures""" + self.mock_logger = Mock() + + @patch("lmms_eval.models.chat.openai_compatible.eval_logger") + @patch("lmms_eval.models.chat.openai_compatible.OpenAI") + def test_openai_compatible_metrics(self, mock_openai, mock_logger): + """Test OpenAI compatible model throughput metrics""" + # Mock OpenAI response + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message.content = "Test response" + mock_response.usage = Mock() + mock_response.usage.completion_tokens = 10 + mock_response.usage.prompt_tokens = 5 + + mock_client = Mock() + mock_client.chat.completions.create.return_value = mock_response + mock_openai.return_value = mock_client + + # Import after mocking + from lmms_eval.models.chat.openai_compatible import OpenAICompatible + + model = OpenAICompatible(model_version="test-model") + model.client = mock_client + + # Create mock request + mock_request = Mock() + mock_request.args = ( + "test context", + lambda x: [{"role": "user", "content": "test"}], + {"max_new_tokens": 100}, + 0, + "test_task", + "test", + ) + + # Test generate_until + result = model.generate_until([mock_request]) + + # Verify metrics logging was called + mock_logger.info.assert_called() + log_calls = [call.args[0] for call in mock_logger.info.call_args_list] + + # Check that throughput metrics were logged + metrics_logged = any("Inference metrics" in call for call in log_calls) + self.assertTrue(metrics_logged, "Throughput metrics should be logged") + + def test_timing_integration(self): + """Test that timing measurements are integrated properly""" + + class MockModel: + def __init__(self): + self.generate_call_count = 0 + + def generate_with_timing(self): + """Simulate model generation with timing""" + self.generate_call_count += 1 + start_time = time.time() + time.sleep(0.01) # Simulate processing + end_time = time.time() + + e2e_latency = end_time - start_time + output_tokens = 25 + ttft = e2e_latency * 0.1 + + if output_tokens > 1: + tpot = (e2e_latency - ttft) / (output_tokens - 1) + inference_speed = 1 / tpot if tpot > 0 else 0 + else: + tpot = e2e_latency + inference_speed = 0 + + return { + "e2e_latency": e2e_latency, + "tpot": tpot, + "inference_speed": inference_speed, + "output_tokens": output_tokens, + } + + mock_model = MockModel() + result = mock_model.generate_with_timing() + + # Verify metrics are calculated + self.assertIn("e2e_latency", result) + self.assertIn("tpot", result) + self.assertIn("inference_speed", result) + self.assertIn("output_tokens", result) + + # Verify reasonable values + self.assertGreater(result["e2e_latency"], 0) + self.assertGreater(result["tpot"], 0) + self.assertGreater(result["inference_speed"], 0) + self.assertEqual(result["output_tokens"], 25) + + def test_batch_processing_metrics(self): + """Test batch processing throughput metrics""" + + def calculate_batch_metrics(batch_responses, e2e_latency): + """Calculate metrics for a batch of responses""" + total_tokens = sum(len(response.split()) for response in batch_responses) + batch_size = len(batch_responses) + + if batch_size > 0: + avg_tokens_per_response = total_tokens / batch_size + avg_latency_per_response = e2e_latency / batch_size + + ttft_estimate = avg_latency_per_response * 0.1 + + if avg_tokens_per_response > 1: + tpot = (avg_latency_per_response - ttft_estimate) / (avg_tokens_per_response - 1) + inference_speed = 1 / tpot if tpot > 0 else 0 + else: + tpot = avg_latency_per_response + inference_speed = 0 + + return { + "total_tokens": total_tokens, + "avg_tpot": tpot, + "avg_speed": inference_speed, + "batch_size": batch_size, + } + return {} + + # Test with sample batch + batch_responses = [ + "This is a test response with several words", + "Another response that is slightly longer than the first", + "Short response", + ] + e2e_latency = 1.5 + + metrics = calculate_batch_metrics(batch_responses, e2e_latency) + + # Verify batch metrics + self.assertEqual(metrics["batch_size"], 3) + self.assertGreater(metrics["total_tokens"], 0) + self.assertGreater(metrics["avg_tpot"], 0) + self.assertGreater(metrics["avg_speed"], 0) + + @patch("time.time") + def test_timing_accuracy(self, mock_time): + """Test timing measurement accuracy with controlled time""" + # Mock time to return predictable values + mock_time.side_effect = [0.0, 1.0] # 1 second elapsed + + start_time = time.time() + end_time = time.time() + e2e_latency = end_time - start_time + + self.assertEqual(e2e_latency, 1.0) + + # Test TPOT calculation with known timing + output_tokens = 20 + ttft = 0.1 + + if output_tokens > 1: + tpot = (e2e_latency - ttft) / (output_tokens - 1) + inference_speed = 1 / tpot + + expected_tpot = (1.0 - 0.1) / (20 - 1) # 0.047 + expected_speed = 1 / expected_tpot # 21.11 + + self.assertAlmostEqual(tpot, expected_tpot, places=3) + self.assertAlmostEqual(inference_speed, expected_speed, places=1) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_throughput_metrics_unit.py b/test/test_throughput_metrics_unit.py new file mode 100644 index 000000000..408945b6f --- /dev/null +++ b/test/test_throughput_metrics_unit.py @@ -0,0 +1,142 @@ +""" +Unit tests for inference throughput metrics implementation +""" + +import time +import unittest +from unittest.mock import Mock, patch + +import pytest + + +class TestThroughputMetrics(unittest.TestCase): + """Test cases for TPOT and inference speed calculations""" + + def test_tpot_calculation(self): + """Test TPOT calculation with known values""" + e2e_latency = 2.5 # seconds + ttft = 0.25 # seconds + num_output_tokens = 50 + + # Calculate TPOT using the implemented formula + if num_output_tokens > 1: + tpot = (e2e_latency - ttft) / (num_output_tokens - 1) + inference_speed = 1 / tpot if tpot > 0 else 0 + else: + tpot = e2e_latency + inference_speed = 0 + + expected_tpot = (2.5 - 0.25) / (50 - 1) # 0.0459 + expected_speed = 1 / expected_tpot # 21.78 + + self.assertAlmostEqual(tpot, expected_tpot, places=4) + self.assertAlmostEqual(inference_speed, expected_speed, places=1) + + def test_tpot_single_token(self): + """Test TPOT calculation with single token output""" + e2e_latency = 1.0 + ttft = 0.1 + num_output_tokens = 1 + + if num_output_tokens > 1: + tpot = (e2e_latency - ttft) / (num_output_tokens - 1) + inference_speed = 1 / tpot if tpot > 0 else 0 + else: + tpot = e2e_latency + inference_speed = 0 + + self.assertEqual(tpot, e2e_latency) + self.assertEqual(inference_speed, 0) + + def test_tpot_zero_tokens(self): + """Test TPOT calculation with zero tokens""" + e2e_latency = 1.0 + ttft = 0.1 + num_output_tokens = 0 + + if num_output_tokens > 1: + tpot = (e2e_latency - ttft) / (num_output_tokens - 1) + inference_speed = 1 / tpot if tpot > 0 else 0 + else: + tpot = e2e_latency + inference_speed = 0 + + self.assertEqual(tpot, e2e_latency) + self.assertEqual(inference_speed, 0) + + def test_ttft_estimation(self): + """Test TTFT estimation logic""" + e2e_latency = 2.0 + batch_size = 4 + + # Estimate TTFT as 10% of total time for batch processing + ttft_estimate = e2e_latency * 0.1 / batch_size + + expected_ttft = 2.0 * 0.1 / 4 # 0.05 + self.assertEqual(ttft_estimate, expected_ttft) + + def test_batch_metrics_calculation(self): + """Test batch-level metrics calculation""" + e2e_latency = 3.0 + generated_tokens = [10, 15, 20, 25] # tokens per response + batch_size = len(generated_tokens) + + total_tokens = sum(generated_tokens) + avg_tokens_per_response = total_tokens / batch_size + avg_latency_per_response = e2e_latency / batch_size + + # Test calculations + self.assertEqual(total_tokens, 70) + self.assertEqual(avg_tokens_per_response, 17.5) + self.assertEqual(avg_latency_per_response, 0.75) + + # Test TPOT calculation for batch + ttft_estimate = avg_latency_per_response * 0.1 + if avg_tokens_per_response > 1: + tpot = (avg_latency_per_response - ttft_estimate) / (avg_tokens_per_response - 1) + inference_speed = 1 / tpot if tpot > 0 else 0 + + expected_tpot = (0.75 - 0.075) / (17.5 - 1) # 0.0409 + expected_speed = 1 / expected_tpot # 24.45 + + self.assertAlmostEqual(tpot, expected_tpot, places=4) + self.assertAlmostEqual(inference_speed, expected_speed, places=1) + + +class TestTimingMeasurement(unittest.TestCase): + """Test cases for timing measurement accuracy""" + + def test_timing_precision(self): + """Test that timing measurement is reasonably accurate""" + sleep_duration = 0.1 # 100ms + + start_time = time.time() + time.sleep(sleep_duration) + end_time = time.time() + + measured_duration = end_time - start_time + + # Allow for some variance in timing (±20ms) + self.assertGreaterEqual(measured_duration, sleep_duration - 0.02) + self.assertLessEqual(measured_duration, sleep_duration + 0.02) + + def test_zero_latency_handling(self): + """Test handling of edge cases with zero latency""" + e2e_latency = 0.0 + num_output_tokens = 10 + + # Should not crash with zero latency + ttft = e2e_latency * 0.1 + if num_output_tokens > 1: + tpot = (e2e_latency - ttft) / (num_output_tokens - 1) + inference_speed = 1 / tpot if tpot > 0 else 0 + else: + tpot = e2e_latency + inference_speed = 0 + + self.assertEqual(tpot, 0.0) + self.assertEqual(inference_speed, 0) + + +if __name__ == "__main__": + unittest.main()