Skip to content

Commit b97c0eb

Browse files
committed
use public prompt sets instead of prompt sets that require auth
1 parent 8ba3e6b commit b97c0eb

File tree

1 file changed

+26
-11
lines changed

1 file changed

+26
-11
lines changed

plugins/validation_tests/test_object_creation.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import os
2+
import re
23

34
import pytest
45
from flaky import flaky # type: ignore
56
from modelgauge.base_test import PromptResponseTest
67
from modelgauge.caching import SqlDictCache
78
from modelgauge.config import load_secrets_from_config
89
from modelgauge.dependency_helper import FromSourceDependencyHelper
10+
from modelgauge.external_data import WebData
911
from modelgauge.load_plugins import load_plugins
1012
from modelgauge.locales import EN_US # see "workaround" below
1113
from modelgauge.prompt import SUTOptions, TextPrompt
@@ -26,11 +28,21 @@
2628
_FAKE_SECRETS = fake_all_secrets()
2729

2830

29-
@pytest.mark.parametrize("test_name", [key for key, _ in TESTS.items()])
30-
def test_all_tests_construct_and_record_init(test_name):
31-
test = TESTS.make_instance(test_name, secrets=_FAKE_SECRETS)
32-
assert hasattr(test, "initialization_record"), "Test is probably missing @modelgauge_test() decorator."
33-
assert isinstance(test.initialization_record, InitializationRecord)
31+
# Some tests require such large downloads / complex processing
32+
# that we don't want to do that even on expensive_tests.
33+
# If your Test is timing out, consider adding it here.
34+
TOO_SLOW = {"real_toxicity_prompts", "bbq"}
35+
36+
37+
def ensure_public_dependencies(dependencies):
38+
for k, d in dependencies.items():
39+
if isinstance(d, WebData):
40+
if "practice_prompt" in d.source_url or "heldback_prompt" in d.source_url:
41+
new_dependency = WebData(
42+
source_url=re.sub("(practice|heldback|official)_prompt", "demo_prompt", d.source_url)
43+
)
44+
dependencies[k] = new_dependency
45+
return dependencies
3446

3547

3648
@pytest.fixture(scope="session")
@@ -39,17 +51,18 @@ def shared_run_dir(tmp_path_factory):
3951
return tmp_path_factory.mktemp("run_data")
4052

4153

42-
# Some tests require such large downloads / complex processing
43-
# that we don't want to do that even on expensive_tests.
44-
# If your Test is timing out, consider adding it here.
45-
TOO_SLOW = {"real_toxicity_prompts", "bbq"}
54+
@pytest.mark.parametrize("test_name", [key for key, _ in TESTS.items()])
55+
def test_all_tests_construct_and_record_init(test_name):
56+
test = TESTS.make_instance(test_name, secrets=_FAKE_SECRETS)
57+
assert hasattr(test, "initialization_record"), "Test is probably missing @modelgauge_test() decorator."
58+
assert isinstance(test.initialization_record, InitializationRecord)
4659

4760

4861
@expensive_tests
4962
@pytest.mark.timeout(30)
5063
@flaky
5164
@pytest.mark.parametrize("test_name", [key for key, _ in TESTS.items() if key not in TOO_SLOW])
52-
def test_all_tests_make_test_items(test_name, shared_run_dir):
65+
def test_all_tests_make_test_items(tmp_path, test_name, shared_run_dir):
5366
test = TESTS.make_instance(test_name, secrets=_FAKE_SECRETS)
5467

5568
# TODO remove when localized files are handled better
@@ -59,9 +72,11 @@ def test_all_tests_make_test_items(test_name, shared_run_dir):
5972

6073
if isinstance(test, PromptResponseTest):
6174
test_data_path = os.path.join(shared_run_dir, test.__class__.__name__)
75+
dependencies = ensure_public_dependencies(test.get_dependencies())
76+
6277
dependency_helper = FromSourceDependencyHelper(
6378
test_data_path,
64-
test.get_dependencies(),
79+
dependencies,
6580
required_versions={},
6681
)
6782

0 commit comments

Comments
 (0)