Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
610 changes: 610 additions & 0 deletions test/common/test_config_utils.py

Large diffs are not rendered by default.

144 changes: 144 additions & 0 deletions test/data_container/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from olive.data.constants import (
DataComponentType,
DefaultDataContainer,
)
from olive.data.registry import Registry


class TestRegistryRegister:
def test_register_dataset_component(self):
# setup & execute
@Registry.register(DataComponentType.LOAD_DATASET, name="test_dataset_reg")
def my_dataset():
return "dataset"

# assert
result = Registry.get_load_dataset_component("test_dataset_reg")
assert result is my_dataset

def test_register_pre_process_component(self):
# setup & execute
@Registry.register_pre_process(name="test_pre_process_reg")
def my_pre_process(data):
return data

# assert
result = Registry.get_pre_process_component("test_pre_process_reg")
assert result is my_pre_process

def test_register_post_process_component(self):
# setup & execute
@Registry.register_post_process(name="test_post_process_reg")
def my_post_process(data):
return data

# assert
result = Registry.get_post_process_component("test_post_process_reg")
assert result is my_post_process

def test_register_dataloader_component(self):
# setup & execute
@Registry.register_dataloader(name="test_dataloader_reg")
def my_dataloader(data):
return data

# assert
result = Registry.get_dataloader_component("test_dataloader_reg")
assert result is my_dataloader

def test_register_case_insensitive(self):
# setup & execute
@Registry.register(DataComponentType.LOAD_DATASET, name="CaseSensitiveTest_Reg")
def my_func():
pass

# assert
result = Registry.get_load_dataset_component("casesensitivetest_reg")
assert result is my_func

def test_register_uses_class_name_when_no_name(self):
# setup & execute
@Registry.register(DataComponentType.LOAD_DATASET)
def unique_named_test_func_reg():
pass

# assert
result = Registry.get_load_dataset_component("unique_named_test_func_reg")
assert result is unique_named_test_func_reg


class TestRegistryGet:
def test_get_component(self):
# setup
@Registry.register(DataComponentType.LOAD_DATASET, name="test_get_comp_reg")
def my_func():
pass

# execute
result = Registry.get_component(DataComponentType.LOAD_DATASET.value, "test_get_comp_reg")

# assert
assert result is my_func

def test_get_by_subtype(self):
# setup
@Registry.register(DataComponentType.LOAD_DATASET, name="test_get_subtype_reg")
def my_func():
pass

# execute
result = Registry.get(DataComponentType.LOAD_DATASET.value, "test_get_subtype_reg")

# assert
assert result is my_func


class TestRegistryDefaultComponents:
def test_get_default_load_dataset(self):
# execute
result = Registry.get_default_load_dataset_component()

# assert
assert result is not None

def test_get_default_pre_process(self):
# execute
result = Registry.get_default_pre_process_component()

# assert
assert result is not None

def test_get_default_post_process(self):
# execute
result = Registry.get_default_post_process_component()

# assert
assert result is not None

def test_get_default_dataloader(self):
# execute
result = Registry.get_default_dataloader_component()

# assert
assert result is not None


class TestRegistryContainer:
def test_get_container_default(self):
# execute
result = Registry.get_container(None)

# assert
assert result is not None

def test_get_container_by_name(self):
# execute
result = Registry.get_container(DefaultDataContainer.DATA_CONTAINER.value)

# assert
assert result is not None
216 changes: 216 additions & 0 deletions test/evaluator/test_metric_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import pytest
from pydantic import ValidationError

from olive.evaluator.metric_config import (
LatencyMetricConfig,
MetricGoal,
SizeOnDiskMetricConfig,
ThroughputMetricConfig,
get_user_config_class,
)


class TestLatencyMetricConfig:
def test_defaults(self):
# execute
config = LatencyMetricConfig()

# assert
assert config.warmup_num == 10
assert config.repeat_test_num == 20
assert config.sleep_num == 0

def test_custom_values(self):
# execute
config = LatencyMetricConfig(warmup_num=5, repeat_test_num=100, sleep_num=2)

# assert
assert config.warmup_num == 5
assert config.repeat_test_num == 100
assert config.sleep_num == 2


class TestThroughputMetricConfig:
def test_defaults(self):
# execute
config = ThroughputMetricConfig()

# assert
assert config.warmup_num == 10
assert config.repeat_test_num == 20
assert config.sleep_num == 0

def test_custom_values(self):
# execute
config = ThroughputMetricConfig(warmup_num=3, repeat_test_num=50, sleep_num=1)

# assert
assert config.warmup_num == 3
assert config.repeat_test_num == 50
assert config.sleep_num == 1


class TestSizeOnDiskMetricConfig:
def test_creation(self):
# execute
config = SizeOnDiskMetricConfig()

# assert
assert isinstance(config, SizeOnDiskMetricConfig)


class TestMetricGoal:
def test_threshold_type(self):
# execute
goal = MetricGoal(type="threshold", value=0.9)

# assert
assert goal.type == "threshold"
assert goal.value == 0.9

def test_min_improvement_type(self):
# execute
goal = MetricGoal(type="min-improvement", value=0.05)

# assert
assert goal.type == "min-improvement"
assert goal.value == 0.05

def test_max_degradation_type(self):
# execute
goal = MetricGoal(type="max-degradation", value=0.1)

# assert
assert goal.type == "max-degradation"
assert goal.value == 0.1

def test_percent_min_improvement_type(self):
# execute
goal = MetricGoal(type="percent-min-improvement", value=5.0)

# assert
assert goal.type == "percent-min-improvement"

def test_percent_max_degradation_type(self):
# execute
goal = MetricGoal(type="percent-max-degradation", value=10.0)

# assert
assert goal.type == "percent-max-degradation"

def test_invalid_type_raises(self):
# execute & assert
with pytest.raises(ValidationError, match="Metric goal type must be one of"):
MetricGoal(type="invalid_type", value=0.5)

def test_negative_value_for_min_improvement_raises(self):
# execute & assert
with pytest.raises(ValidationError, match="Value must be nonnegative"):
MetricGoal(type="min-improvement", value=-0.5)

def test_negative_value_for_max_degradation_raises(self):
# execute & assert
with pytest.raises(ValidationError, match="Value must be nonnegative"):
MetricGoal(type="max-degradation", value=-0.1)

def test_negative_value_for_percent_min_improvement_raises(self):
# execute & assert
with pytest.raises(ValidationError, match="Value must be nonnegative"):
MetricGoal(type="percent-min-improvement", value=-5.0)

def test_negative_value_for_percent_max_degradation_raises(self):
# execute & assert
with pytest.raises(ValidationError, match="Value must be nonnegative"):
MetricGoal(type="percent-max-degradation", value=-10.0)

def test_threshold_allows_negative_value(self):
# execute
goal = MetricGoal(type="threshold", value=-1.0)

# assert
assert goal.value == -1.0

def test_has_regression_goal_min_improvement(self):
# setup
goal = MetricGoal(type="min-improvement", value=0.05)

# execute
result = goal.has_regression_goal()

# assert
assert result is False

def test_has_regression_goal_percent_min_improvement(self):
# setup
goal = MetricGoal(type="percent-min-improvement", value=5.0)

# execute
result = goal.has_regression_goal()

# assert
assert result is False

def test_has_regression_goal_max_degradation_positive(self):
# setup
goal = MetricGoal(type="max-degradation", value=0.1)

# execute
result = goal.has_regression_goal()

# assert
assert result is True

def test_has_regression_goal_max_degradation_zero(self):
# setup
goal = MetricGoal(type="max-degradation", value=0.0)

# execute
result = goal.has_regression_goal()

# assert
assert result is False

def test_has_regression_goal_percent_max_degradation_positive(self):
# setup
goal = MetricGoal(type="percent-max-degradation", value=10.0)

# execute
result = goal.has_regression_goal()

# assert
assert result is True

def test_has_regression_goal_threshold(self):
# setup
goal = MetricGoal(type="threshold", value=0.9)

# execute
result = goal.has_regression_goal()

# assert
assert result is False


class TestGetUserConfigClass:
def test_custom_metric_type(self):
# execute
cls = get_user_config_class("custom")
instance = cls()

# assert
assert hasattr(instance, "user_script")
assert hasattr(instance, "evaluate_func")

def test_unknown_metric_type(self):
# execute
cls = get_user_config_class("latency")
instance = cls()

# assert
assert hasattr(instance, "user_script")
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test name test_unknown_metric_type is misleading because it calls get_user_config_class("latency"), which is a known metric type. Please rename the test to reflect what is actually being validated (for example, that non-"custom" metric types still include the common user-config fields).

Copilot uses AI. Check for mistakes.
assert hasattr(instance, "inference_settings")
Loading
Loading