-
Notifications
You must be signed in to change notification settings - Fork 305
Expand file tree
/
Copy pathtest_registry.py
More file actions
144 lines (109 loc) · 4.19 KB
/
Copy pathtest_registry.py
File metadata and controls
144 lines (109 loc) · 4.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from olive.data.constants import (
DataComponentType,
DefaultDataContainer,
)
from olive.data.registry import Registry
class TestRegistryRegister:
def test_register_dataset_component(self):
# setup & execute
@Registry.register(DataComponentType.LOAD_DATASET, name="test_dataset_reg")
def my_dataset():
return "dataset"
# assert
result = Registry.get_load_dataset_component("test_dataset_reg")
assert result is my_dataset
def test_register_pre_process_component(self):
# setup & execute
@Registry.register_pre_process(name="test_pre_process_reg")
def my_pre_process(data):
return data
# assert
result = Registry.get_pre_process_component("test_pre_process_reg")
assert result is my_pre_process
def test_register_post_process_component(self):
# setup & execute
@Registry.register_post_process(name="test_post_process_reg")
def my_post_process(data):
return data
# assert
result = Registry.get_post_process_component("test_post_process_reg")
assert result is my_post_process
def test_register_dataloader_component(self):
# setup & execute
@Registry.register_dataloader(name="test_dataloader_reg")
def my_dataloader(data):
return data
# assert
result = Registry.get_dataloader_component("test_dataloader_reg")
assert result is my_dataloader
def test_register_case_insensitive(self):
# setup & execute
@Registry.register(DataComponentType.LOAD_DATASET, name="CaseSensitiveTest_Reg")
def my_func():
pass
# assert
result = Registry.get_load_dataset_component("casesensitivetest_reg")
assert result is my_func
def test_register_uses_class_name_when_no_name(self):
# setup & execute
@Registry.register(DataComponentType.LOAD_DATASET)
def unique_named_test_func_reg():
pass
# assert
result = Registry.get_load_dataset_component("unique_named_test_func_reg")
assert result is unique_named_test_func_reg
class TestRegistryGet:
def test_get_component(self):
# setup
@Registry.register(DataComponentType.LOAD_DATASET, name="test_get_comp_reg")
def my_func():
pass
# execute
result = Registry.get_component(DataComponentType.LOAD_DATASET.value, "test_get_comp_reg")
# assert
assert result is my_func
def test_get_by_subtype(self):
# setup
@Registry.register(DataComponentType.LOAD_DATASET, name="test_get_subtype_reg")
def my_func():
pass
# execute
result = Registry.get(DataComponentType.LOAD_DATASET.value, "test_get_subtype_reg")
# assert
assert result is my_func
class TestRegistryDefaultComponents:
def test_get_default_load_dataset(self):
# execute
result = Registry.get_default_load_dataset_component()
# assert
assert result is not None
def test_get_default_pre_process(self):
# execute
result = Registry.get_default_pre_process_component()
# assert
assert result is not None
def test_get_default_post_process(self):
# execute
result = Registry.get_default_post_process_component()
# assert
assert result is not None
def test_get_default_dataloader(self):
# execute
result = Registry.get_default_dataloader_component()
# assert
assert result is not None
class TestRegistryContainer:
def test_get_container_default(self):
# execute
result = Registry.get_container(None)
# assert
assert result is not None
def test_get_container_by_name(self):
# execute
result = Registry.get_container(DefaultDataContainer.DATA_CONTAINER.value)
# assert
assert result is not None