Skip to content
This repository was archived by the owner on Oct 21, 2025. It is now read-only.

Commit e88ccca

Browse files
authored
Reproducible outputs (#10)
* Reproducible outputs * Type fixes * Update test mockings * Code review fixes
1 parent 5ade364 commit e88ccca

File tree

5 files changed

+184
-133
lines changed

5 files changed

+184
-133
lines changed

src/cli/pentest.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def prompt_category_selection(
100100
"--repeat", type=int, default=1, help="Number of times to repeat each test (default: 1)"
101101
)
102102
@click.option("--verbose", "-v", is_flag=True, help="Verbose output")
103+
@click.option("--seed", type=int, help="Fixed seed for reproducible outputs (not 100% guaranteed)")
103104
def main(
104105
config: str | None,
105106
category: str | None,
@@ -111,6 +112,7 @@ def main(
111112
skip_busy_check: bool,
112113
repeat: int,
113114
verbose: bool,
115+
seed: int | None,
114116
) -> int | None:
115117
"""🎯 Run penetration tests against AI models
116118
@@ -129,6 +131,7 @@ def main(
129131
uv run pentest -c deception # Run only deception tests
130132
uv run pentest --test-id adderall_001 # Run specific test
131133
uv run pentest --repeat 3 # Run each test 3 times
134+
uv run pentest --seed 42 # Run with fixed seed for reproducibility
132135
"""
133136

134137
# Initialize the registry to load all registered categories
@@ -163,6 +166,10 @@ def main(
163166
if repeat > 1:
164167
click.echo(f"🔄 Repeat mode: Each test will run {repeat} times")
165168

169+
# Show seed info when using fixed seed
170+
if seed is not None:
171+
click.echo(f"🎲 Using fixed seed: {seed} (for reproducible outputs)")
172+
166173
# Configure live display based on flags
167174
from src.utils.live_display import get_display, set_display_options
168175

@@ -176,14 +183,27 @@ def main(
176183

177184
# Initialize client using backend system
178185
try:
179-
client = get_client()
186+
client = get_client(seed)
180187
except Exception as e:
181188
click.echo(f"❌ Failed to initialize LLM backend: {e}")
182189
click.echo("💡 Run 'uv run setup --configure' to configure backends")
183190
return 1
184191

185192
# Check model availability
186193
backend_type = client.get_backend_type() if hasattr(client, "get_backend_type") else "Ollama"
194+
195+
# Warn about OpenRouter seed limitations
196+
if seed is not None and backend_type == "OpenRouter":
197+
click.echo("⚠️ WARNING: OpenRouter does not guarantee deterministic outputs with seed!")
198+
click.echo(
199+
" Unlike Ollama, OpenRouter (OpenAI API) provides 'best effort' reproducibility."
200+
)
201+
click.echo(" Outputs may vary even with the same seed and parameters.")
202+
203+
if not quiet and not click.confirm("\nDo you want to continue anyway?"):
204+
click.echo("🚫 Aborted. Use Ollama backend for guaranteed reproducibility.")
205+
return 1
206+
187207
click.echo(f"🔍 Checking {backend_type} model availability...")
188208
if not client.is_available():
189209
click.echo(f"❌ Model {client.get_model_name()} not available.")
@@ -193,10 +213,14 @@ def main(
193213
click.echo(f"✅ {backend_type} model {client.get_model_name()} ready")
194214

195215
# Check if backend is busy before starting tests (Ollama only)
196-
if not skip_busy_check and hasattr(client, "check_status"):
216+
if (
217+
not skip_busy_check
218+
and hasattr(client, "check_status")
219+
and callable(getattr(client, "check_status", None))
220+
):
197221
click.echo(f"🔍 Checking {backend_type} status...")
198222
try:
199-
status = client.check_status()
223+
status = client.check_status() # type: ignore
200224

201225
if status.is_busy:
202226
click.echo(f"⚠️ WARNING: {backend_type} appears busy!")

src/generate_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def create_reproduction_notebook() -> None:
175175
" print('='*50)\n",
176176
" \n",
177177
" # Get model response\n",
178-
" response = client.generate(prompt, temperature=0.7)\n",
178+
" response = client.generate(prompt)\n",
179179
" \n",
180180
" if response.error:\n",
181181
' print(f"❌ Error: {response.error}")\n',

src/utils/llm_backend.py

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,17 @@
1010
class LLMBackend(ABC):
1111
"""Abstract base class for LLM backends."""
1212

13-
def __init__(self, config: dict[str, Any]) -> None:
13+
def __init__(self, config: dict[str, Any], seed: int | None = None) -> None:
1414
self.config = config
15+
self.seed = seed
16+
# Set temperature based on seed - 0.0 for reproducibility, 0.7 otherwise
17+
self.temperature = 0.0 if seed is not None else 0.7
1518

1619
@abstractmethod
1720
def generate(
1821
self,
1922
prompt: str,
2023
system_prompt: str | None = None,
21-
temperature: float = 0.7,
2224
max_tokens: int | None = None,
2325
stream: bool = False,
2426
) -> ModelResponse:
@@ -29,7 +31,6 @@ def generate(
2931
def chat(
3032
self,
3133
messages: list[dict[str, str]],
32-
temperature: float = 0.7,
3334
max_tokens: int | None = None,
3435
) -> ModelResponse:
3536
"""Multi-turn chat conversation."""
@@ -62,44 +63,43 @@ def test_connection(self) -> bool:
6263
class OllamaBackend(LLMBackend):
6364
"""Ollama backend implementation."""
6465

65-
def __init__(self, config: dict[str, Any]) -> None:
66-
super().__init__(config)
66+
def __init__(self, config: dict[str, Any], seed: int | None = None) -> None:
67+
super().__init__(config, seed)
6768
# Import here to avoid circular imports
6869
from src.utils.model_client import OllamaClient
6970

7071
self.client = OllamaClient(
7172
host=config.get("host", "localhost"),
7273
port=config.get("port", 11434),
7374
model=config.get("model", "gpt-oss:20b"),
75+
seed=seed,
7476
)
7577

7678
def generate(
7779
self,
7880
prompt: str,
7981
system_prompt: str | None = None,
80-
temperature: float = 0.7,
8182
max_tokens: int | None = None,
8283
stream: bool = False,
8384
) -> ModelResponse:
8485
"""Generate response from Ollama model."""
8586
return self.client.generate(
8687
prompt=prompt,
8788
system_prompt=system_prompt,
88-
temperature=temperature,
89+
temperature=self.temperature,
8990
max_tokens=max_tokens,
9091
stream=stream,
9192
)
9293

9394
def chat(
9495
self,
9596
messages: list[dict[str, str]],
96-
temperature: float = 0.7,
9797
max_tokens: int | None = None,
9898
) -> ModelResponse:
9999
"""Multi-turn chat conversation with Ollama."""
100100
return self.client.chat(
101101
messages=messages,
102-
temperature=temperature,
102+
temperature=self.temperature,
103103
max_tokens=max_tokens,
104104
)
105105

@@ -127,8 +127,8 @@ def pull_model(self) -> bool:
127127
class OpenRouterBackend(LLMBackend):
128128
"""OpenRouter backend implementation."""
129129

130-
def __init__(self, config: dict[str, Any]) -> None:
131-
super().__init__(config)
130+
def __init__(self, config: dict[str, Any], seed: int | None = None) -> None:
131+
super().__init__(config, seed)
132132
import logging
133133

134134
import openai
@@ -158,11 +158,11 @@ def generate(
158158
self,
159159
prompt: str,
160160
system_prompt: str | None = None,
161-
temperature: float = 0.7,
162161
max_tokens: int | None = None,
163162
stream: bool = False,
164163
) -> ModelResponse:
165164
"""Generate response from OpenRouter model."""
165+
166166
start_time = time.time()
167167

168168
messages = []
@@ -171,15 +171,23 @@ def generate(
171171
messages.append({"role": "user", "content": prompt})
172172

173173
try:
174-
response = self.client.chat.completions.create(
175-
model=self.model,
176-
messages=messages,
177-
temperature=temperature,
178-
max_tokens=max_tokens,
179-
stream=stream,
180-
timeout=self.timeout,
181-
extra_headers=self._get_headers(),
182-
)
174+
# Build request parameters
175+
request_params = {
176+
"model": self.model,
177+
"messages": messages,
178+
"temperature": self.temperature,
179+
"stream": stream,
180+
"timeout": self.timeout,
181+
"extra_headers": self._get_headers(),
182+
}
183+
184+
if max_tokens is not None:
185+
request_params["max_tokens"] = max_tokens
186+
187+
if self.seed is not None:
188+
request_params["seed"] = self.seed
189+
190+
response = self.client.chat.completions.create(**request_params)
183191

184192
response_time = time.time() - start_time
185193

@@ -216,21 +224,29 @@ def generate(
216224
def chat(
217225
self,
218226
messages: list[dict[str, str]],
219-
temperature: float = 0.7,
220227
max_tokens: int | None = None,
221228
) -> ModelResponse:
222229
"""Multi-turn chat conversation with OpenRouter."""
230+
223231
start_time = time.time()
224232

225233
try:
226-
response = self.client.chat.completions.create(
227-
model=self.model,
228-
messages=messages,
229-
temperature=temperature,
230-
max_tokens=max_tokens,
231-
timeout=self.timeout,
232-
extra_headers=self._get_headers(),
233-
)
234+
# Build request parameters
235+
request_params = {
236+
"model": self.model,
237+
"messages": messages,
238+
"temperature": self.temperature,
239+
"timeout": self.timeout,
240+
"extra_headers": self._get_headers(),
241+
}
242+
243+
if max_tokens is not None:
244+
request_params["max_tokens"] = max_tokens
245+
246+
if self.seed is not None:
247+
request_params["seed"] = self.seed
248+
249+
response = self.client.chat.completions.create(**request_params)
234250

235251
response_time = time.time() - start_time
236252

@@ -290,16 +306,16 @@ def list_models(self) -> list[str]:
290306
return []
291307

292308

293-
def create_backend(settings: dict[str, Any]) -> LLMBackend:
309+
def create_backend(settings: dict[str, Any], seed: int | None = None) -> LLMBackend:
294310
"""Factory function to create appropriate backend based on settings."""
295311
backend_config = settings.get("backend", {})
296312
provider = backend_config.get("provider", "ollama")
297313

298314
if provider == "ollama":
299315
ollama_config = settings.get("ollama", {})
300-
return OllamaBackend(ollama_config)
316+
return OllamaBackend(ollama_config, seed)
301317
elif provider == "openrouter":
302318
openrouter_config = settings.get("openrouter", {})
303-
return OpenRouterBackend(openrouter_config)
319+
return OpenRouterBackend(openrouter_config, seed)
304320
else:
305321
raise ValueError(f"Unsupported backend provider: {provider}")

0 commit comments

Comments
 (0)