Skip to content

Commit 668c93b

Browse files
authored
SUT Options defined at test level (#897)
* class method sut_options * trim test wrapper * put back test * sut_options is an instance method * First try: SUTOptions is separate from TestItem * finish refactor * Remove another sut_options from testitem * Fix plugins
1 parent 8a78e5c commit 668c93b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+285
-320
lines changed

demo_plugin/modelgauge/suts/demo_01_yes_no_sut.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from modelgauge.prompt import ChatPrompt, TextPrompt
22
from modelgauge.prompt_formatting import format_chat
3-
from modelgauge.sut import PromptResponseSUT, SUTResponse
3+
from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse
44
from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt
55
from modelgauge.sut_decorator import modelgauge_sut
66
from modelgauge.sut_registry import SUTS
@@ -24,10 +24,10 @@ class DemoYesNoResponse(BaseModel):
2424
class DemoYesNoSUT(PromptResponseSUT[DemoYesNoRequest, DemoYesNoResponse]):
2525
"""This SUT demonstrates the bare minimum behavior of a SUT: Use the input Prompt to determine the response."""
2626

27-
def translate_text_prompt(self, prompt: TextPrompt) -> DemoYesNoRequest:
27+
def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> DemoYesNoRequest:
2828
return DemoYesNoRequest(text=prompt.text)
2929

30-
def translate_chat_prompt(self, prompt: ChatPrompt) -> DemoYesNoRequest:
30+
def translate_chat_prompt(self, prompt: ChatPrompt, options: SUTOptions) -> DemoYesNoRequest:
3131
return DemoYesNoRequest(text=format_chat(prompt))
3232

3333
def evaluate(self, request: DemoYesNoRequest) -> DemoYesNoResponse:

demo_plugin/modelgauge/suts/demo_02_secrets_and_options_sut.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import random
2-
from modelgauge.prompt import ChatPrompt, SUTOptions, TextPrompt
2+
from modelgauge.prompt import ChatPrompt, TextPrompt
33
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
4-
from modelgauge.sut import PromptResponseSUT, SUTResponse
4+
from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse
55
from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt
66
from modelgauge.sut_decorator import modelgauge_sut
77
from modelgauge.sut_registry import SUTS
@@ -46,12 +46,12 @@ def __init__(self, uid: str, api_key: DemoApiKey):
4646
def _load_client(self) -> "RandomWordsClient":
4747
return RandomWordsClient(api_key=self.api_key)
4848

49-
def translate_text_prompt(self, prompt: TextPrompt) -> DemoRandomWordsRequest:
50-
return self._translate(prompt.text, prompt.options)
49+
def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> DemoRandomWordsRequest:
50+
return self._translate(prompt.text, options)
5151

52-
def translate_chat_prompt(self, prompt: ChatPrompt) -> DemoRandomWordsRequest:
52+
def translate_chat_prompt(self, prompt: ChatPrompt, options: SUTOptions) -> DemoRandomWordsRequest:
5353
# All we care about are the words in the chat history, not who said them.
54-
return self._translate(_words_in_chat(prompt), prompt.options)
54+
return self._translate(_words_in_chat(prompt), options)
5555

5656
def _translate(self, text, options: SUTOptions) -> DemoRandomWordsRequest:
5757
return DemoRandomWordsRequest(

demo_plugin/modelgauge/suts/demo_03_sut_with_args.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from modelgauge.prompt import ChatPrompt, TextPrompt
2-
from modelgauge.sut import PromptResponseSUT, SUTResponse
2+
from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse
33
from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt
44
from modelgauge.sut_decorator import modelgauge_sut
55
from modelgauge.sut_registry import SUTS
@@ -26,10 +26,10 @@ def __init__(self, uid: str, response_text: str):
2626
super().__init__(uid)
2727
self.response_text = response_text
2828

29-
def translate_text_prompt(self, prompt: TextPrompt) -> DemoConstantRequest:
29+
def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> DemoConstantRequest:
3030
return DemoConstantRequest(configured_response=self.response_text)
3131

32-
def translate_chat_prompt(self, prompt: ChatPrompt) -> DemoConstantRequest:
32+
def translate_chat_prompt(self, prompt: ChatPrompt, options: SUTOptions) -> DemoConstantRequest:
3333
return DemoConstantRequest(configured_response=self.response_text)
3434

3535
def evaluate(self, request: DemoConstantRequest) -> DemoConstantResponse:

plugins/amazon/modelgauge/suts/aws_bedrock_client.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from modelgauge.prompt import TextPrompt
1111
from modelgauge.retry_decorator import retry
1212
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
13-
from modelgauge.sut import PromptResponseSUT, SUTResponse
13+
from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse
1414
from modelgauge.sut_capabilities import AcceptsTextPrompt
1515
from modelgauge.sut_decorator import modelgauge_sut
1616
from modelgauge.sut_registry import SUTS
@@ -122,12 +122,12 @@ def _load_client(self):
122122
aws_secret_access_key=self.secret_access_key,
123123
)
124124

125-
def translate_text_prompt(self, prompt: TextPrompt) -> BedrockRequest:
125+
def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> BedrockRequest:
126126
inference_config = BedrockRequest.InferenceConfig(
127-
maxTokens=prompt.options.max_tokens,
128-
temperature=prompt.options.temperature,
129-
topP=prompt.options.top_p,
130-
stopSequences=prompt.options.stop_sequences,
127+
maxTokens=options.max_tokens,
128+
temperature=options.temperature,
129+
topP=options.top_p,
130+
stopSequences=options.stop_sequences,
131131
)
132132

133133
return BedrockRequest(

plugins/amazon/tests/test_aws_bedrock_client.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import pytest
22
from unittest.mock import patch
33

4-
from modelgauge.prompt import SUTOptions, TextPrompt
5-
from modelgauge.sut import SUTResponse
4+
from modelgauge.prompt import TextPrompt
5+
from modelgauge.sut import SUTOptions, SUTResponse
66
from modelgauge.typed_data import is_typeable
77

88
from modelgauge.suts.aws_bedrock_client import (
@@ -44,8 +44,8 @@ def _make_response(response_text):
4444

4545
def test_translate_text_prompt(fake_sut):
4646
default_options = SUTOptions()
47-
prompt = TextPrompt(text="some-text", options=default_options)
48-
request = fake_sut.translate_text_prompt(prompt)
47+
prompt = TextPrompt(text="some-text")
48+
request = fake_sut.translate_text_prompt(prompt, default_options)
4949

5050
assert isinstance(request, BedrockRequest)
5151
assert request.modelId == FAKE_MODEL_ID

plugins/anthropic/modelgauge/suts/anthropic_api.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from modelgauge.general import APIException
1212
from modelgauge.prompt import ChatRole, TextPrompt
1313
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
14-
from modelgauge.sut import PromptResponseSUT, SUTResponse
14+
from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse
1515
from modelgauge.sut_capabilities import AcceptsTextPrompt
1616
from modelgauge.sut_decorator import modelgauge_sut
1717
from modelgauge.sut_registry import SUTS
@@ -54,16 +54,16 @@ def _load_client(self) -> Anthropic:
5454
max_retries=7,
5555
)
5656

57-
def translate_text_prompt(self, prompt: TextPrompt) -> AnthropicRequest:
57+
def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> AnthropicRequest:
5858
messages = [OpenAIChatMessage(content=prompt.text, role=_ROLE_MAP[ChatRole.user])]
5959
return AnthropicRequest(
6060
model=self.model,
6161
messages=messages,
62-
max_tokens=prompt.options.max_tokens,
63-
stop_sequences=prompt.options.stop_sequences,
64-
temperature=prompt.options.temperature,
65-
top_k=prompt.options.top_k_per_token,
66-
top_p=prompt.options.top_p,
62+
max_tokens=options.max_tokens,
63+
stop_sequences=options.stop_sequences,
64+
temperature=options.temperature,
65+
top_k=options.top_k_per_token,
66+
top_p=options.top_p,
6767
)
6868

6969
def evaluate(self, request: AnthropicRequest) -> AnthropicMessage:

plugins/anthropic/tests/test_anthropic_api.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from unittest.mock import patch
44

55
from modelgauge.general import APIException
6-
from modelgauge.prompt import SUTOptions, TextPrompt
7-
from modelgauge.sut import SUTResponse
6+
from modelgauge.prompt import TextPrompt
7+
from modelgauge.sut import SUTOptions, SUTResponse
88

99
from modelgauge.suts.anthropic_api import AnthropicRequest, AnthropicApiKey, AnthropicSUT
1010
from modelgauge.suts.openai_client import OpenAIChatMessage
@@ -24,7 +24,7 @@ def simple_anthropic_request():
2424
def test_anthropic_api_translate_request_default_sut_options(fake_sut):
2525
prompt = TextPrompt(text="some-text")
2626

27-
request = fake_sut.translate_text_prompt(prompt)
27+
request = fake_sut.translate_text_prompt(prompt, SUTOptions())
2828

2929
assert isinstance(request, AnthropicRequest)
3030
assert request.model == "fake-model"
@@ -47,9 +47,9 @@ def test_anthropic_api_translate_request_non_default_sut_options(fake_sut):
4747
stop_sequences=["stop"],
4848
top_p=0.5,
4949
)
50-
prompt = TextPrompt(text="some-text", options=options)
50+
prompt = TextPrompt(text="some-text")
5151

52-
request = fake_sut.translate_text_prompt(prompt)
52+
request = fake_sut.translate_text_prompt(prompt, options)
5353

5454
assert request.max_tokens == 200
5555
assert request.temperature == 0.5

plugins/azure/modelgauge/suts/azure_client.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from modelgauge.general import APIException
99
from modelgauge.prompt import TextPrompt
1010
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
11-
from modelgauge.sut import PromptResponseSUT, SUTResponse
11+
from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse
1212
from modelgauge.sut_capabilities import AcceptsTextPrompt
1313
from modelgauge.sut_decorator import modelgauge_sut
1414
from modelgauge.sut_registry import SUTS
@@ -105,16 +105,16 @@ def __init__(self, uid: str, endpoint_url: str, api_key: AzureApiKey):
105105
self.endpoint_url = endpoint_url
106106
self.api_key = api_key.value
107107

108-
def translate_text_prompt(self, prompt: TextPrompt) -> AzureChatRequest:
108+
def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> AzureChatRequest:
109109
messages = [AzureChatRequest.Message(content=prompt.text, role="user")]
110110
return AzureChatRequest(
111111
messages=messages,
112-
max_tokens=prompt.options.max_tokens,
113-
stop=prompt.options.stop_sequences,
114-
temperature=prompt.options.temperature,
115-
top_p=prompt.options.top_p,
116-
frequency_penalty=prompt.options.frequency_penalty,
117-
presence_penalty=prompt.options.presence_penalty,
112+
max_tokens=options.max_tokens,
113+
stop=options.stop_sequences,
114+
temperature=options.temperature,
115+
top_p=options.top_p,
116+
frequency_penalty=options.frequency_penalty,
117+
presence_penalty=options.presence_penalty,
118118
)
119119

120120
def evaluate(self, request: AzureChatRequest) -> AzureChatResponse:

plugins/google/modelgauge/suts/google_genai_client.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from modelgauge.prompt import TextPrompt
1515
from modelgauge.retry_decorator import retry
1616
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
17-
from modelgauge.sut import REFUSAL_RESPONSE, PromptResponseSUT, SUTResponse # usort: skip
17+
from modelgauge.sut import REFUSAL_RESPONSE, PromptResponseSUT, SUTOptions, SUTResponse # usort: skip
1818
from modelgauge.sut_capabilities import AcceptsTextPrompt
1919
from modelgauge.sut_decorator import modelgauge_sut
2020
from modelgauge.sut_registry import SUTS
@@ -101,15 +101,15 @@ def safety_settings(self) -> Optional[Dict[HarmCategory, HarmBlockThreshold]]:
101101
def _load_client(self) -> genai.GenerativeModel:
102102
return genai.GenerativeModel(self.model_name)
103103

104-
def translate_text_prompt(self, prompt: TextPrompt) -> GoogleGenAiRequest:
104+
def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> GoogleGenAiRequest:
105105
generation_config = GoogleGenAiConfig(
106-
stop_sequences=prompt.options.stop_sequences,
107-
max_output_tokens=prompt.options.max_tokens,
108-
temperature=prompt.options.temperature,
109-
top_p=prompt.options.top_p,
110-
top_k=prompt.options.top_k_per_token,
111-
presence_penalty=prompt.options.presence_penalty,
112-
frequency_penalty=prompt.options.frequency_penalty,
106+
stop_sequences=options.stop_sequences,
107+
max_output_tokens=options.max_tokens,
108+
temperature=options.temperature,
109+
top_p=options.top_p,
110+
top_k=options.top_k_per_token,
111+
presence_penalty=options.presence_penalty,
112+
frequency_penalty=options.frequency_penalty,
113113
)
114114
return GoogleGenAiRequest(
115115
contents=prompt.text, generation_config=generation_config, safety_settings=self.safety_settings

plugins/google/tests/test_google_genai_client.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from google.generativeai.types import HarmCategory, HarmBlockThreshold, generation_types # type: ignore
77

88
from modelgauge.general import APIException
9-
from modelgauge.prompt import SUTOptions, TextPrompt
10-
from modelgauge.sut import REFUSAL_RESPONSE, SUTResponse
9+
from modelgauge.prompt import TextPrompt
10+
from modelgauge.sut import REFUSAL_RESPONSE, SUTOptions, SUTResponse
1111
from modelgauge.suts.google_genai_client import ( # type: ignore
1212
GEMINI_HARM_CATEGORIES,
1313
GoogleAiApiKey,
@@ -108,7 +108,7 @@ def mock_model(mock_model_patch, fake_raw_response):
108108

109109
def test_google_genai_translate_request_default_options(google_default_sut):
110110
prompt = TextPrompt(text="some-text")
111-
request = google_default_sut.translate_text_prompt(prompt)
111+
request = google_default_sut.translate_text_prompt(prompt, SUTOptions())
112112
assert request == GoogleGenAiRequest(
113113
contents="some-text",
114114
generation_config=GoogleGenAiConfig(
@@ -129,7 +129,7 @@ def test_google_genai_translate_request_default_options_disabled_safety(google_d
129129
for harm in GEMINI_HARM_CATEGORIES:
130130
safety_settings[harm] = HarmBlockThreshold.BLOCK_NONE
131131

132-
request = google_disabled_safety_sut.translate_text_prompt(prompt)
132+
request = google_disabled_safety_sut.translate_text_prompt(prompt, SUTOptions())
133133

134134
assert request == GoogleGenAiRequest(
135135
contents="some-text",
@@ -147,13 +147,11 @@ def test_google_genai_translate_request_default_options_disabled_safety(google_d
147147

148148

149149
def test_google_genai_translate_request_generation_options(google_default_sut):
150-
prompt = TextPrompt(
151-
text="some-text",
152-
options=SUTOptions(
153-
stop_sequences=["stop"], max_tokens=200, temperature=0.5, top_k_per_token=5, frequency_penalty=0.5
154-
),
150+
prompt = TextPrompt(text="some-text")
151+
options = SUTOptions(
152+
stop_sequences=["stop"], max_tokens=200, temperature=0.5, top_k_per_token=5, frequency_penalty=0.5
155153
)
156-
request = google_default_sut.translate_text_prompt(prompt)
154+
request = google_default_sut.translate_text_prompt(prompt, options)
157155
assert request == GoogleGenAiRequest(
158156
contents="some-text",
159157
generation_config=GoogleGenAiConfig(

plugins/huggingface/modelgauge/suts/huggingface_api.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
77
from modelgauge.prompt import TextPrompt
88
from modelgauge.secret_values import InjectSecret
9-
from modelgauge.sut import PromptResponseSUT, SUTResponse
9+
from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse
1010
from modelgauge.sut_capabilities import AcceptsTextPrompt
1111
from modelgauge.sut_decorator import modelgauge_sut
1212
from modelgauge.sut_registry import SUTS
@@ -37,12 +37,10 @@ def __init__(self, uid: str, api_url: str, token: HuggingFaceInferenceToken):
3737
self.token = token.value
3838
self.api_url = api_url
3939

40-
def translate_text_prompt(self, prompt: TextPrompt) -> HuggingFaceChatRequest:
40+
def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> HuggingFaceChatRequest:
4141
return HuggingFaceChatRequest(
4242
inputs=prompt.text,
43-
parameters=HuggingFaceChatParams(
44-
max_new_tokens=prompt.options.max_tokens, temperature=prompt.options.temperature
45-
),
43+
parameters=HuggingFaceChatParams(max_new_tokens=options.max_tokens, temperature=options.temperature),
4644
)
4745

4846
@tenacity.retry(stop=stop_after_attempt(7), wait=wait_random_exponential())

plugins/huggingface/modelgauge/suts/huggingface_chat_completion.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
99
from modelgauge.prompt import TextPrompt
1010
from modelgauge.secret_values import InjectSecret
11-
from modelgauge.sut import PromptResponseSUT, SUTResponse, TokenProbability, TopTokens
11+
from modelgauge.sut import PromptResponseSUT, SUTOptions, SUTResponse, TokenProbability, TopTokens
1212
from modelgauge.sut_capabilities import AcceptsTextPrompt, ProducesPerTokenLogProbabilities
1313
from modelgauge.sut_decorator import modelgauge_sut
1414
from modelgauge.sut_registry import SUTS
@@ -76,14 +76,14 @@ def _create_client(self):
7676

7777
self.client = InferenceClient(base_url=endpoint.url, token=self.token.value)
7878

79-
def translate_text_prompt(self, prompt: TextPrompt) -> HuggingFaceChatCompletionRequest:
79+
def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> HuggingFaceChatCompletionRequest:
8080
logprobs = False
81-
if prompt.options.top_logprobs is not None:
81+
if options.top_logprobs is not None:
8282
logprobs = True
8383
return HuggingFaceChatCompletionRequest(
8484
messages=[ChatMessage(role="user", content=prompt.text)],
8585
logprobs=logprobs,
86-
**prompt.options.model_dump(),
86+
**options.model_dump(),
8787
)
8888

8989
def evaluate(self, request: HuggingFaceChatCompletionRequest) -> HuggingFaceChatCompletionOutput:

plugins/huggingface/tests/test_huggingface_api.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from unittest.mock import ANY, patch
33

44
from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
5-
from modelgauge.prompt import SUTOptions, TextPrompt
6-
from modelgauge.sut import SUTResponse
5+
from modelgauge.prompt import TextPrompt
6+
from modelgauge.sut import SUTOptions, SUTResponse
77
from modelgauge.suts.huggingface_api import (
88
HuggingFaceChatParams,
99
HuggingFaceChatRequest,
@@ -17,22 +17,16 @@ def fake_sut():
1717
return HuggingFaceSUT("fake_uid", "https://fake_url.com", HuggingFaceInferenceToken("fake_token"))
1818

1919

20-
def _make_prompt(text="some text prompt", sut_options=None):
21-
if sut_options is None:
22-
sut_options = SUTOptions()
23-
return TextPrompt(text=text, options=sut_options)
24-
25-
2620
def _make_sut_request(text, **params):
2721
return HuggingFaceChatRequest(inputs=text, parameters=HuggingFaceChatParams(**params))
2822

2923

3024
def test_huggingface_api_translate_text_prompt_request(fake_sut):
3125
prompt_text = "some text prompt"
3226
sut_options = SUTOptions(max_tokens=5, temperature=1.0, random="should be ignored")
33-
prompt = _make_prompt(prompt_text, sut_options)
27+
prompt = TextPrompt(text=prompt_text)
3428

35-
request = fake_sut.translate_text_prompt(prompt)
29+
request = fake_sut.translate_text_prompt(prompt, sut_options)
3630

3731
assert isinstance(request, HuggingFaceChatRequest)
3832
assert request.inputs == prompt_text

0 commit comments

Comments
 (0)