Skip to content

Commit 18e16c1

Browse files
authored
use public prompt sets instead of prompt sets that require auth (#865)
* use public prompt sets instead of prompt sets that require auth * disable a noisy test until we fix it properly * find the public demo prompt set that corresponds to a non-public prompt set * less janky still way to turn a token-protected prompt set file into a publicly-accessible one * download public files instead of private or protected * fixed test to reflect new function behavior * appease mypy * remove test removed elsewhere; remove unneeded tmp_path fixture * remove unsupported model * clear cache and lock file to try to prevent errors like ' Hash for pydantic-core (2.27.2) from archive pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl not found in known hashes (was: sha256:ed4964723e97cdf8c70abebd1495001f511491b8eeac817b033db1af28a86bb5)' * refresh lock file with the right version of Poetry * restore gemini 1.0 pro, which is really supposed to be supported.
1 parent 18ae521 commit 18e16c1

File tree

6 files changed

+117
-36
lines changed

6 files changed

+117
-36
lines changed

.github/workflows/python-app.yml

+4-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ jobs:
7373
run: pipx install "poetry == 1.8.5"
7474

7575
- name: Install dependencies
76-
run: poetry lock && poetry install --no-interaction --with dev --extras all_plugins
76+
run: |
77+
poetry cache clear --all .
78+
rm -f poetry.lock
79+
poetry install --no-interaction --with dev --extras all_plugins
7780
7881
- name: Lint formatting
7982
run: poetry run black --check .

plugins/google/modelgauge/suts/google_genai_client.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,23 @@
22
from typing import Dict, List, Optional
33

44
import google.generativeai as genai # type: ignore
5-
from google.api_core.exceptions import InternalServerError, ResourceExhausted, RetryError, TooManyRequests
6-
from google.generativeai.types import HarmCategory, HarmBlockThreshold # type: ignore
7-
from pydantic import BaseModel
5+
from google.api_core.exceptions import (
6+
InternalServerError,
7+
ResourceExhausted,
8+
RetryError,
9+
TooManyRequests,
10+
)
11+
from google.generativeai.types import HarmBlockThreshold, HarmCategory # type: ignore
812

913
from modelgauge.general import APIException
1014
from modelgauge.prompt import TextPrompt
1115
from modelgauge.retry_decorator import retry
1216
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
13-
from modelgauge.sut import REFUSAL_RESPONSE, PromptResponseSUT, SUTResponse
17+
from modelgauge.sut import REFUSAL_RESPONSE, PromptResponseSUT, SUTResponse # usort: skip
1418
from modelgauge.sut_capabilities import AcceptsTextPrompt
1519
from modelgauge.sut_decorator import modelgauge_sut
1620
from modelgauge.sut_registry import SUTS
21+
from pydantic import BaseModel
1722

1823
FinishReason = genai.protos.Candidate.FinishReason
1924
GEMINI_HARM_CATEGORIES = [
@@ -191,7 +196,7 @@ def safety_settings(self) -> Optional[Dict[HarmCategory, HarmBlockThreshold]]:
191196
return {harm: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE for harm in GEMINI_HARM_CATEGORIES}
192197

193198

194-
gemini_models = ["gemini-1.5-flash", "gemini-1.0-pro", "gemini-1.5-pro"]
199+
gemini_models = ["gemini-1.0-pro", "gemini-1.5-flash", "gemini-1.5-pro"]
195200
for model in gemini_models:
196201
SUTS.register(GoogleGenAiDefaultSUT, model, model, InjectSecret(GoogleAiApiKey))
197202
SUTS.register(

plugins/validation_tests/test_object_creation.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import os
2+
import re
23

34
import pytest
45
from flaky import flaky # type: ignore
56
from modelgauge.base_test import PromptResponseTest
67
from modelgauge.caching import SqlDictCache
78
from modelgauge.config import load_secrets_from_config
89
from modelgauge.dependency_helper import FromSourceDependencyHelper
10+
from modelgauge.external_data import WebData
911
from modelgauge.load_plugins import load_plugins
1012
from modelgauge.locales import EN_US # see "workaround" below
1113
from modelgauge.prompt import SUTOptions, TextPrompt
14+
from modelgauge.prompt_sets import demo_prompt_set_url
1215
from modelgauge.record_init import InitializationRecord
1316
from modelgauge.sut import PromptResponseSUT, SUTResponse
1417
from modelgauge.sut_capabilities import AcceptsTextPrompt
@@ -22,15 +25,18 @@
2225

2326
# Ensure all the plugins are available during testing.
2427
load_plugins()
25-
# Some tests need to download a file from modellab, which requires a real auth token
28+
2629
_FAKE_SECRETS = fake_all_secrets()
2730

2831

29-
@pytest.mark.parametrize("test_name", [key for key, _ in TESTS.items()])
30-
def test_all_tests_construct_and_record_init(test_name):
31-
test = TESTS.make_instance(test_name, secrets=_FAKE_SECRETS)
32-
assert hasattr(test, "initialization_record"), "Test is probably missing @modelgauge_test() decorator."
33-
assert isinstance(test.initialization_record, InitializationRecord)
32+
def ensure_public_dependencies(dependencies):
33+
"""Some tests are defined with dependencies that require an auth token to download them.
34+
In this test context, we substitute public files instead."""
35+
for k, d in dependencies.items():
36+
if isinstance(d, WebData):
37+
new_dependency = WebData(source_url=demo_prompt_set_url(d.source_url), headers=None)
38+
dependencies[k] = new_dependency
39+
return dependencies
3440

3541

3642
@pytest.fixture(scope="session")
@@ -59,9 +65,10 @@ def test_all_tests_make_test_items(test_name, shared_run_dir):
5965

6066
if isinstance(test, PromptResponseTest):
6167
test_data_path = os.path.join(shared_run_dir, test.__class__.__name__)
68+
dependencies = ensure_public_dependencies(test.get_dependencies())
6269
dependency_helper = FromSourceDependencyHelper(
6370
test_data_path,
64-
test.get_dependencies(),
71+
dependencies,
6572
required_versions={},
6673
)
6774

poetry.lock

+23-23
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/modelgauge/prompt_sets.py

+37
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from pathlib import Path
12
from typing import Any, Optional
3+
from urllib.parse import urlparse
24

35
from modelgauge.locales import EN_US
46
from modelgauge.secret_values import OptionalSecret, SecretDescription
@@ -73,3 +75,38 @@ def validate_token_requirement(prompt_set: str, token=None) -> bool:
7375
if token:
7476
return True
7577
raise ValueError(f"Prompt set {prompt_set} requires a token from MLCommons.")
78+
79+
80+
def demo_prompt_set_from_private_prompt_set(prompt_set: str) -> str:
81+
"""In a test environment, we replace the practice or official prompt sets
82+
(which require auth) with matching demo prompt sets (which are public).
83+
This function returns the demo counterpart to a given practice or official prompt set."""
84+
found_locale = ""
85+
for prompt_set_type, prompt_sets in PROMPT_SETS.items():
86+
for locale, prompt_set_file_base_name in prompt_sets.items():
87+
print(f"target {prompt_set} looking at {prompt_set_file_base_name}")
88+
if prompt_set_file_base_name == prompt_set:
89+
found_locale = locale
90+
break
91+
92+
if found_locale:
93+
return PROMPT_SETS["demo"].get(found_locale, "")
94+
return prompt_set
95+
96+
97+
def prompt_set_from_url(source_url) -> str:
98+
"""Given the source_url from a WebData object, returns the bare prompt set name
99+
without an extension or hostname"""
100+
try:
101+
chunks = urlparse(source_url)
102+
filename = Path(chunks.path).stem
103+
return filename
104+
except Exception as exc:
105+
return source_url
106+
107+
108+
def demo_prompt_set_url(url: str) -> str:
109+
source_prompt_set = prompt_set_from_url(url)
110+
target_prompt_set = demo_prompt_set_from_private_prompt_set(source_prompt_set)
111+
target_url = url.replace(source_prompt_set, target_prompt_set)
112+
return target_url

tests/modelgauge_tests/test_prompt_sets.py

+29
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import pytest
22
from modelgauge.prompt_sets import (
33
PROMPT_SETS,
4+
demo_prompt_set_from_private_prompt_set,
5+
demo_prompt_set_url,
46
prompt_set_file_base_name,
7+
prompt_set_from_url,
58
validate_prompt_set,
69
) # usort: skip
710

@@ -33,3 +36,29 @@ def test_validate_prompt_set():
3336
assert validate_prompt_set(s, "en_us", PROMPT_SETS)
3437
with pytest.raises(ValueError):
3538
validate_prompt_set("should raise")
39+
40+
41+
def test_demo_prompt_set_from_private_prompt_set():
42+
assert demo_prompt_set_from_private_prompt_set(PROMPT_SETS["practice"]["en_us"]) == PROMPT_SETS["demo"]["en_us"]
43+
assert demo_prompt_set_from_private_prompt_set(PROMPT_SETS["practice"]["fr_fr"]) == PROMPT_SETS["demo"]["fr_fr"]
44+
assert demo_prompt_set_from_private_prompt_set(PROMPT_SETS["official"]["en_us"]) == PROMPT_SETS["demo"]["en_us"]
45+
assert demo_prompt_set_from_private_prompt_set(PROMPT_SETS["official"]["fr_fr"]) == PROMPT_SETS["demo"]["fr_fr"]
46+
assert demo_prompt_set_from_private_prompt_set(PROMPT_SETS["demo"]["en_us"]) == PROMPT_SETS["demo"]["en_us"]
47+
assert demo_prompt_set_from_private_prompt_set(PROMPT_SETS["demo"]["fr_fr"]) == PROMPT_SETS["demo"]["fr_fr"]
48+
assert demo_prompt_set_from_private_prompt_set("bogus") == "bogus"
49+
50+
51+
def test_prompt_set_from_url():
52+
assert prompt_set_from_url("https://www.example.com/path/to/file.csv") == "file"
53+
assert prompt_set_from_url("https://www.example.com/thing.css") == "thing"
54+
assert prompt_set_from_url("degenerate string") == "degenerate string"
55+
assert prompt_set_from_url("https://www.example.com") == ""
56+
assert prompt_set_from_url("https://www.example.com/") == ""
57+
58+
59+
def test_demo_prompt_set_url():
60+
base = "https://www.example.com/path/to/"
61+
for l in ("en_us", "fr_fr"):
62+
for t in ("practice", "official"):
63+
base_url = f"{base}{PROMPT_SETS[t][l]}.csv"
64+
assert demo_prompt_set_url(base_url) == f"{base}{PROMPT_SETS["demo"][l]}.csv"

0 commit comments

Comments
 (0)