|
11 | 11 | from modelgauge.load_plugins import load_plugins
|
12 | 12 | from modelgauge.locales import EN_US # see "workaround" below
|
13 | 13 | from modelgauge.prompt import SUTOptions, TextPrompt
|
| 14 | +from modelgauge.prompt_sets import demo_prompt_set_url |
14 | 15 | from modelgauge.record_init import InitializationRecord
|
15 | 16 | from modelgauge.sut import PromptResponseSUT, SUTResponse
|
16 | 17 | from modelgauge.sut_capabilities import AcceptsTextPrompt
|
|
24 | 25 |
|
25 | 26 | # Ensure all the plugins are available during testing.
|
26 | 27 | load_plugins()
|
27 |
| -# Some tests need to download a file from modellab, which requires a real auth token |
| 28 | + |
28 | 29 | _FAKE_SECRETS = fake_all_secrets()
|
29 | 30 |
|
30 | 31 |
|
31 | 32 | def ensure_public_dependencies(dependencies):
|
| 33 | + """Some tests are defined with dependencies that require an auth token to download them. |
| 34 | + In this test context, we substitute public files instead.""" |
32 | 35 | for k, d in dependencies.items():
|
33 | 36 | if isinstance(d, WebData):
|
34 |
| - if "practice_prompt" in d.source_url or "heldback_prompt" in d.source_url: |
35 |
| - new_dependency = WebData( |
36 |
| - source_url=re.sub("(practice|heldback|official)_prompt", "demo_prompt", d.source_url) |
37 |
| - ) |
38 |
| - dependencies[k] = new_dependency |
| 37 | + new_dependency = WebData(source_url=demo_prompt_set_url(d.source_url), headers=None) |
| 38 | + dependencies[k] = new_dependency |
39 | 39 | return dependencies
|
40 | 40 |
|
41 | 41 |
|
@@ -63,26 +63,24 @@ def test_all_tests_construct_and_record_init(test_name):
|
63 | 63 | @flaky
|
64 | 64 | @pytest.mark.parametrize("test_name", [key for key, _ in TESTS.items() if key not in TOO_SLOW])
|
65 | 65 | def test_all_tests_make_test_items(tmp_path, test_name, shared_run_dir):
|
66 |
| - # test = TESTS.make_instance(test_name, secrets=_FAKE_SECRETS) |
67 |
| - |
68 |
| - # # TODO remove when localized files are handled better |
69 |
| - # # workaround |
70 |
| - # if isinstance(test, BaseSafeTestVersion1) and test.locale != EN_US: |
71 |
| - # return |
72 |
| - |
73 |
| - # if isinstance(test, PromptResponseTest): |
74 |
| - # test_data_path = os.path.join(shared_run_dir, test.__class__.__name__) |
75 |
| - # dependencies = ensure_public_dependencies(test.get_dependencies()) |
76 |
| - |
77 |
| - # dependency_helper = FromSourceDependencyHelper( |
78 |
| - # test_data_path, |
79 |
| - # dependencies, |
80 |
| - # required_versions={}, |
81 |
| - # ) |
82 |
| - |
83 |
| - # test_items = test.make_test_items(dependency_helper) |
84 |
| - # assert len(test_items) > 0 |
85 |
| - pass # to silence the broken test until we fix it properly |
| 66 | + test = TESTS.make_instance(test_name, secrets=_FAKE_SECRETS) |
| 67 | + |
| 68 | + # TODO remove when localized files are handled better |
| 69 | + # workaround |
| 70 | + if isinstance(test, BaseSafeTestVersion1) and test.locale != EN_US: |
| 71 | + return |
| 72 | + |
| 73 | + if isinstance(test, PromptResponseTest): |
| 74 | + test_data_path = os.path.join(shared_run_dir, test.__class__.__name__) |
| 75 | + dependencies = ensure_public_dependencies(test.get_dependencies()) |
| 76 | + dependency_helper = FromSourceDependencyHelper( |
| 77 | + test_data_path, |
| 78 | + dependencies, |
| 79 | + required_versions={}, |
| 80 | + ) |
| 81 | + |
| 82 | + test_items = test.make_test_items(dependency_helper) |
| 83 | + assert len(test_items) > 0 |
86 | 84 |
|
87 | 85 |
|
88 | 86 | @pytest.mark.parametrize("sut_name", [key for key, _ in SUTS.items()])
|
|
0 commit comments