Skip to content
This repository was archived by the owner on Oct 21, 2025. It is now read-only.

Commit 71b10ca

Browse files
committed
Code review fixes
1 parent df68171 commit 71b10ca

File tree

3 files changed

+11
-35
lines changed

3 files changed

+11
-35
lines changed

src/generate_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def create_reproduction_notebook() -> None:
175175
" print('='*50)\n",
176176
" \n",
177177
" # Get model response\n",
178-
" response = client.generate(prompt, temperature=0.7)\n",
178+
" response = client.generate(prompt)\n",
179179
" \n",
180180
" if response.error:\n",
181181
' print(f"❌ Error: {response.error}")\n',

src/utils/llm_backend.py

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@ class LLMBackend(ABC):
1313
def __init__(self, config: dict[str, Any], seed: int | None = None) -> None:
1414
self.config = config
1515
self.seed = seed
16+
# Set temperature based on seed - 0.0 for reproducibility, 0.7 otherwise
17+
self.temperature = 0.0 if seed is not None else 0.7
1618

1719
@abstractmethod
1820
def generate(
1921
self,
2022
prompt: str,
2123
system_prompt: str | None = None,
22-
temperature: float = 0.7,
2324
max_tokens: int | None = None,
2425
stream: bool = False,
2526
) -> ModelResponse:
@@ -30,7 +31,6 @@ def generate(
3031
def chat(
3132
self,
3233
messages: list[dict[str, str]],
33-
temperature: float = 0.7,
3434
max_tokens: int | None = None,
3535
) -> ModelResponse:
3636
"""Multi-turn chat conversation."""
@@ -79,37 +79,27 @@ def generate(
7979
self,
8080
prompt: str,
8181
system_prompt: str | None = None,
82-
temperature: float = 0.7,
8382
max_tokens: int | None = None,
8483
stream: bool = False,
8584
) -> ModelResponse:
8685
"""Generate response from Ollama model."""
87-
# For reproducibility, use temperature=0 when seed is set
88-
if self.seed is not None:
89-
temperature = 0.0
90-
9186
return self.client.generate(
9287
prompt=prompt,
9388
system_prompt=system_prompt,
94-
temperature=temperature,
89+
temperature=self.temperature,
9590
max_tokens=max_tokens,
9691
stream=stream,
9792
)
9893

9994
def chat(
10095
self,
10196
messages: list[dict[str, str]],
102-
temperature: float = 0.7,
10397
max_tokens: int | None = None,
10498
) -> ModelResponse:
10599
"""Multi-turn chat conversation with Ollama."""
106-
# For reproducibility, use temperature=0 when seed is set
107-
if self.seed is not None:
108-
temperature = 0.0
109-
110100
return self.client.chat(
111101
messages=messages,
112-
temperature=temperature,
102+
temperature=self.temperature,
113103
max_tokens=max_tokens,
114104
)
115105

@@ -168,14 +158,10 @@ def generate(
168158
self,
169159
prompt: str,
170160
system_prompt: str | None = None,
171-
temperature: float = 0.7,
172161
max_tokens: int | None = None,
173162
stream: bool = False,
174163
) -> ModelResponse:
175164
"""Generate response from OpenRouter model."""
176-
# For reproducibility, use temperature=0 when seed is set
177-
if self.seed is not None:
178-
temperature = 0.0
179165

180166
start_time = time.time()
181167

@@ -189,7 +175,7 @@ def generate(
189175
request_params = {
190176
"model": self.model,
191177
"messages": messages,
192-
"temperature": temperature,
178+
"temperature": self.temperature,
193179
"stream": stream,
194180
"timeout": self.timeout,
195181
"extra_headers": self._get_headers(),
@@ -238,13 +224,9 @@ def generate(
238224
def chat(
239225
self,
240226
messages: list[dict[str, str]],
241-
temperature: float = 0.7,
242227
max_tokens: int | None = None,
243228
) -> ModelResponse:
244229
"""Multi-turn chat conversation with OpenRouter."""
245-
# For reproducibility, use temperature=0 when seed is set
246-
if self.seed is not None:
247-
temperature = 0.0
248230

249231
start_time = time.time()
250232

@@ -253,7 +235,7 @@ def chat(
253235
request_params = {
254236
"model": self.model,
255237
"messages": messages,
256-
"temperature": temperature,
238+
"temperature": self.temperature,
257239
"timeout": self.timeout,
258240
"extra_headers": self._get_headers(),
259241
}

src/utils/model_client.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def __init__(
3838
self.base_url = f"http://{host}:{port}"
3939
self.model = model
4040
self.seed = seed
41+
# Set temperature based on seed - 0.0 for reproducibility, 0.7 otherwise
42+
self.temperature = 0.0 if seed is not None else 0.7
4143

4244
def _make_request(
4345
self, endpoint: str, data: dict[str, Any] | None = None, method: str = "POST"
@@ -205,14 +207,10 @@ def generate(
205207
self,
206208
prompt: str,
207209
system_prompt: str | None = None,
208-
temperature: float = 0.7,
209210
max_tokens: int | None = None,
210211
stream: bool = False,
211212
) -> ModelResponse:
212213
"""Generate response from model"""
213-
# For reproducibility, use temperature=0 when seed is set
214-
if self.seed is not None:
215-
temperature = 0.0
216214

217215
start_time = time.time()
218216

@@ -221,7 +219,7 @@ def generate(
221219
"prompt": prompt,
222220
"stream": stream,
223221
"options": {
224-
"temperature": temperature,
222+
"temperature": self.temperature,
225223
},
226224
}
227225

@@ -262,13 +260,9 @@ def generate(
262260
def chat(
263261
self,
264262
messages: list[dict[str, str]],
265-
temperature: float = 0.7,
266263
max_tokens: int | None = None,
267264
) -> ModelResponse:
268265
"""Multi-turn chat conversation"""
269-
# For reproducibility, use temperature=0 when seed is set
270-
if self.seed is not None:
271-
temperature = 0.0
272266

273267
start_time = time.time()
274268

@@ -277,7 +271,7 @@ def chat(
277271
"messages": messages,
278272
"stream": False,
279273
"options": {
280-
"temperature": temperature,
274+
"temperature": self.temperature,
281275
},
282276
}
283277

0 commit comments

Comments
 (0)