|
14 | 14 | from modelgauge.prompt import TextPrompt
|
15 | 15 | from modelgauge.retry_decorator import retry
|
16 | 16 | from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
|
17 |
| -from modelgauge.sut import REFUSAL_RESPONSE, PromptResponseSUT, SUTResponse # usort: skip |
| 17 | +from modelgauge.sut import REFUSAL_RESPONSE, PromptResponseSUT, SUTOptions, SUTResponse # usort: skip |
18 | 18 | from modelgauge.sut_capabilities import AcceptsTextPrompt
|
19 | 19 | from modelgauge.sut_decorator import modelgauge_sut
|
20 | 20 | from modelgauge.sut_registry import SUTS
|
@@ -101,15 +101,15 @@ def safety_settings(self) -> Optional[Dict[HarmCategory, HarmBlockThreshold]]:
|
101 | 101 | def _load_client(self) -> genai.GenerativeModel:
|
102 | 102 | return genai.GenerativeModel(self.model_name)
|
103 | 103 |
|
104 |
| - def translate_text_prompt(self, prompt: TextPrompt) -> GoogleGenAiRequest: |
| 104 | + def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> GoogleGenAiRequest: |
105 | 105 | generation_config = GoogleGenAiConfig(
|
106 |
| - stop_sequences=prompt.options.stop_sequences, |
107 |
| - max_output_tokens=prompt.options.max_tokens, |
108 |
| - temperature=prompt.options.temperature, |
109 |
| - top_p=prompt.options.top_p, |
110 |
| - top_k=prompt.options.top_k_per_token, |
111 |
| - presence_penalty=prompt.options.presence_penalty, |
112 |
| - frequency_penalty=prompt.options.frequency_penalty, |
| 106 | + stop_sequences=options.stop_sequences, |
| 107 | + max_output_tokens=options.max_tokens, |
| 108 | + temperature=options.temperature, |
| 109 | + top_p=options.top_p, |
| 110 | + top_k=options.top_k_per_token, |
| 111 | + presence_penalty=options.presence_penalty, |
| 112 | + frequency_penalty=options.frequency_penalty, |
113 | 113 | )
|
114 | 114 | return GoogleGenAiRequest(
|
115 | 115 | contents=prompt.text, generation_config=generation_config, safety_settings=self.safety_settings
|
|
0 commit comments