Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .github/workflows/run-unit-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
name: "Run Unit Tests"

on:
pull_request:

jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: ["ubuntu-22.04"]
python-version: ["3.10"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
make first-time-setup
- name: Run unit tests
run: |
make test
4 changes: 2 additions & 2 deletions aiperf/common/config/user_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ class UserConfig(BaseConfig):
),
group=Groups.ENDPOINT,
),
] = EndPointConfig()
]

endpoint: Annotated[
EndPointConfig,
Field(
description="Endpoint configuration",
),
]
] = EndPointConfig()

input: Annotated[
InputConfig,
Expand Down
4 changes: 2 additions & 2 deletions tests/config/test_base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

from enum import Enum

from pydantic import BaseModel
from pydantic.fields import FieldInfo

from aiperf.common.config import BaseConfig
from aiperf.common.models import AIPerfBaseModel


class NestedConfig(BaseModel):
class NestedConfig(AIPerfBaseModel):
field1: str
field2: int

Expand Down
35 changes: 27 additions & 8 deletions tests/data_exporters/test_exporter_manager.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from aiperf.common.config import EndPointConfig, OutputConfig, UserConfig
from aiperf.common.enums import EndpointType
from aiperf.common.enums.data_exporter_enums import DataExporterType
from aiperf.common.models import MetricResult
from aiperf.common.models.record_models import ProfileResults
from aiperf.data_exporter.exporter_manager import ExporterManager
Expand Down Expand Up @@ -47,12 +48,27 @@ class TestExporterManager:
async def test_export(
self, endpoint_config, output_config, sample_records, mock_user_config
):
mock_exporter_instance = MagicMock()
mock_exporter_class = MagicMock(return_value=mock_exporter_instance)
exporter_types = [
DataExporterType.CONSOLE_ERROR,
DataExporterType.CONSOLE,
DataExporterType.JSON,
]
mock_exporter_instances = []
mock_exporter_classes = {}

with patch(
"aiperf.common.factories.DataExporterFactory.get_all_classes",
return_value=[mock_exporter_class],
for exporter_type in exporter_types:
instance = MagicMock()
instance.export = AsyncMock()
mock_class = MagicMock(return_value=instance)
mock_exporter_classes[exporter_type] = mock_class
mock_exporter_instances.append(instance)

with patch.object(
__import__(
"aiperf.common.factories", fromlist=["DataExporterFactory"]
).DataExporterFactory,
"_registry",
mock_exporter_classes,
):
manager = ExporterManager(
results=ProfileResults(
Expand All @@ -66,5 +82,8 @@ async def test_export(
input_config=mock_user_config,
)
await manager.export_all()
mock_exporter_class.assert_called_once()
mock_exporter_instance.export.assert_called_once()
for mock_class, mock_instance in zip(
mock_exporter_classes.values(), mock_exporter_instances, strict=False
):
mock_class.assert_called_once()
mock_instance.export.assert_awaited_once()
40 changes: 35 additions & 5 deletions tests/data_exporters/test_json_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,42 @@ def sample_records(self):
def mock_user_config(self):
return UserConfig(model_names=["test-model"])

@pytest.fixture
def mock_results(self, sample_records):
class MockResults:
def __init__(self, metrics):
self.metrics = metrics
self.start_ns = None
self.end_ns = None

@property
def records(self):
return self.metrics

@property
def has_results(self):
return bool(self.metrics)

@property
def was_cancelled(self):
return False

@property
def error_summary(self):
return []

return MockResults(sample_records)

@pytest.mark.asyncio
async def test_json_exporter_creates_expected_json(
self, sample_records, mock_user_config
self, mock_results, mock_user_config
):
with tempfile.TemporaryDirectory() as temp_dir:
output_dir = Path(temp_dir)
mock_user_config.output.artifact_directory = output_dir

exporter_config = ExporterConfig(
results=sample_records,
results=mock_results,
input_config=mock_user_config,
)

Expand All @@ -63,9 +89,13 @@ async def test_json_exporter_creates_expected_json(
with open(expected_file) as f:
data = json.load(f)

assert "Test Metric" in data
assert data["Test Metric"]["unit"] == "ms"
assert data["Test Metric"]["avg"] == 123.0
assert "records" in data
records = data["records"]
assert isinstance(records, dict)
assert len(records) == 1
assert "Test Metric" in records
assert records["Test Metric"]["unit"] == "ms"
assert records["Test Metric"]["avg"] == 123.0

assert "input_config" in data
assert isinstance(data["input_config"], dict)
Expand Down
3 changes: 2 additions & 1 deletion tests/logging/test_logging_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,11 @@ def test_end_to_end_logging(self, caplog, method_name, level, should_be_logged):
assert f"Test message {level}" not in caplog.text

def test_lazy_evaluation(self, caplog):
"""Test that lazy evaluation prevents expensive operations."""
"""Test that lazy evaluation prevents expensive operations for logs at the wrong level."""
caplog.set_level(_INFO)

mock_class = MockClass()
mock_class.logger.set_level(_INFO)
expensive_operation_called = False

def expensive_operation():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
Tests for the ProcessingStatsStreamer class.
"""

from unittest.mock import AsyncMock

import pytest

from aiperf.common.messages.inference_messages import ParsedInferenceResultsMessage
Expand Down Expand Up @@ -66,6 +68,8 @@ async def test_all_records_received(

streamer.final_request_count = 10
streamer.processing_stats.total_expected_requests = 10
if hasattr(streamer, "pub_client"):
streamer.pub_client.publish = AsyncMock()

for _ in range(10):
await streamer.stream_record(sample_record.record)
Expand All @@ -88,6 +92,8 @@ async def test_report_records_task(
streamer.processing_stats.processed = 0
streamer.processing_stats.errors = 0
streamer.processing_stats.total_expected_requests = 10
if hasattr(streamer, "pub_client"):
streamer.pub_client.publish = AsyncMock()

for _ in range(10):
await streamer.stream_record(sample_record.record)
Expand Down
31 changes: 22 additions & 9 deletions tests/services/records_manager/test_streaming_post_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
Tests for the streaming post processor base class.
"""

import asyncio

import pytest

from aiperf.common.enums import StreamingPostProcessorType
from aiperf.common.enums.timing_enums import CreditPhase
from aiperf.common.factories import StreamingPostProcessorFactory
from aiperf.common.messages.inference_messages import ParsedInferenceResultsMessage
Expand All @@ -14,7 +17,16 @@
BaseStreamingPostProcessor,
)
from aiperf.services.records_manager.records_manager import RecordsManager
from tests.utils.async_test_utils import async_fixture


@pytest.fixture(autouse=True)
def patch_streaming_post_processor_factory():
StreamingPostProcessorFactory._registry.clear()
StreamingPostProcessorFactory.register(StreamingPostProcessorType.JSONL)(
MockStreamingPostProcessor
)
yield
StreamingPostProcessorFactory._registry.clear()


class MockStreamingPostProcessor(BaseStreamingPostProcessor):
Expand Down Expand Up @@ -45,18 +57,19 @@ async def test_basic_streaming_functionality(
records_manager: RecordsManager,
sample_record: ParsedInferenceResultsMessage,
):
# Clear the registry to avoid conflicts with other tests
StreamingPostProcessorFactory._registry.clear()
StreamingPostProcessorFactory.register("test")(MockStreamingPostProcessor)

records_manager = await async_fixture(records_manager)
await records_manager._initialize_streaming_post_processors()
proc = records_manager.streaming_post_processors[0]
proc = next(
p
for p in records_manager.streaming_post_processors
if isinstance(p, MockStreamingPostProcessor)
)
# Test hack: manually start the background processing task
# This is necessary because the test does not go through the full lifecycle of RecordsManager
# and its streaming post processors.
proc._task = asyncio.create_task(proc._stream_records_task())
assert proc.service_id == records_manager.service_id
assert proc.records_queue.maxsize == 100_000
assert len(proc.processed_records) == 0
assert proc.stream_record_call_count == 0
await proc.wait_for_start()

for _ in range(10):
await records_manager._on_parsed_inference_results(
Expand Down
6 changes: 3 additions & 3 deletions tests/services/test_inference_result_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def mock_tokenizer():
def sample_turn():
return Turn(
role="user",
text=[
Text(content=["Hello world", "Test case"]),
Text(content=["Another input", "Final message"]),
texts=[
Text(contents=["Hello world", "Test case"]),
Text(contents=["Another input", "Final message"]),
],
)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_background_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ async def _run_task(self):
@pytest.mark.asyncio
async def test_background_task(time_traveler: TimeTraveler):
task_class = ExampleTaskClass()

assert not task_class.running, "Task should not be running before starting"
await task_class.initialize()
await task_class.start()
for _ in range(3): # yield a few times to ensure the task got scheduled
await time_traveler.yield_to_event_loop()
async with task_class.lock:
assert task_class.running, "Task should be running after starting"

await task_class.stop()
for _ in range(3): # yield a few times to ensure the task got scheduled
await time_traveler.yield_to_event_loop()
async with task_class.lock:
assert not task_class.running, "Task should not be running after stopping"