1
1
import os
2
+ import re
2
3
3
4
import pytest
4
5
from flaky import flaky # type: ignore
5
6
from modelgauge .base_test import PromptResponseTest
6
7
from modelgauge .caching import SqlDictCache
7
8
from modelgauge .config import load_secrets_from_config
8
9
from modelgauge .dependency_helper import FromSourceDependencyHelper
10
+ from modelgauge .external_data import WebData
9
11
from modelgauge .load_plugins import load_plugins
10
12
from modelgauge .locales import EN_US # see "workaround" below
11
13
from modelgauge .prompt import SUTOptions , TextPrompt
26
28
_FAKE_SECRETS = fake_all_secrets ()
27
29
28
30
29
- @pytest .mark .parametrize ("test_name" , [key for key , _ in TESTS .items ()])
30
- def test_all_tests_construct_and_record_init (test_name ):
31
- test = TESTS .make_instance (test_name , secrets = _FAKE_SECRETS )
32
- assert hasattr (test , "initialization_record" ), "Test is probably missing @modelgauge_test() decorator."
33
- assert isinstance (test .initialization_record , InitializationRecord )
31
+ # Some tests require such large downloads / complex processing
32
+ # that we don't want to do that even on expensive_tests.
33
+ # If your Test is timing out, consider adding it here.
34
+ TOO_SLOW = {"real_toxicity_prompts" , "bbq" }
35
+
36
+
37
+ def ensure_public_dependencies (dependencies ):
38
+ for k , d in dependencies .items ():
39
+ if isinstance (d , WebData ):
40
+ if "practice_prompt" in d .source_url or "heldback_prompt" in d .source_url :
41
+ new_dependency = WebData (
42
+ source_url = re .sub ("(practice|heldback|official)_prompt" , "demo_prompt" , d .source_url )
43
+ )
44
+ dependencies [k ] = new_dependency
45
+ return dependencies
34
46
35
47
36
48
@pytest .fixture (scope = "session" )
@@ -39,17 +51,18 @@ def shared_run_dir(tmp_path_factory):
39
51
return tmp_path_factory .mktemp ("run_data" )
40
52
41
53
42
- # Some tests require such large downloads / complex processing
43
- # that we don't want to do that even on expensive_tests.
44
- # If your Test is timing out, consider adding it here.
45
- TOO_SLOW = {"real_toxicity_prompts" , "bbq" }
54
+ @pytest .mark .parametrize ("test_name" , [key for key , _ in TESTS .items ()])
55
+ def test_all_tests_construct_and_record_init (test_name ):
56
+ test = TESTS .make_instance (test_name , secrets = _FAKE_SECRETS )
57
+ assert hasattr (test , "initialization_record" ), "Test is probably missing @modelgauge_test() decorator."
58
+ assert isinstance (test .initialization_record , InitializationRecord )
46
59
47
60
48
61
@expensive_tests
49
62
@pytest .mark .timeout (30 )
50
63
@flaky
51
64
@pytest .mark .parametrize ("test_name" , [key for key , _ in TESTS .items () if key not in TOO_SLOW ])
52
- def test_all_tests_make_test_items (test_name , shared_run_dir ):
65
+ def test_all_tests_make_test_items (tmp_path , test_name , shared_run_dir ):
53
66
test = TESTS .make_instance (test_name , secrets = _FAKE_SECRETS )
54
67
55
68
# TODO remove when localized files are handled better
@@ -59,9 +72,11 @@ def test_all_tests_make_test_items(test_name, shared_run_dir):
59
72
60
73
if isinstance (test , PromptResponseTest ):
61
74
test_data_path = os .path .join (shared_run_dir , test .__class__ .__name__ )
75
+ dependencies = ensure_public_dependencies (test .get_dependencies ())
76
+
62
77
dependency_helper = FromSourceDependencyHelper (
63
78
test_data_path ,
64
- test . get_dependencies () ,
79
+ dependencies ,
65
80
required_versions = {},
66
81
)
67
82
0 commit comments