-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathgarak_remote_eval.py
More file actions
402 lines (332 loc) · 18.4 KB
/
garak_remote_eval.py
File metadata and controls
402 lines (332 loc) · 18.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
from ..compat import (
EvaluateResponse,
BenchmarkConfig,
RunEvalRequest,
JobStatusRequest,
JobCancelRequest,
JobResultRequest,
ProviderSpec,
Api,
Job,
JobStatus,
OpenAIFilePurpose,
ListFilesRequest,
RetrieveFileContentRequest,
)
import os
import logging
import json
import asyncio
from ..config import GarakRemoteConfig
from ..base_eval import GarakEvalBase
from llama_stack_provider_trustyai_garak import shield_scan
from ..errors import GarakError, GarakConfigError, GarakValidationError
from ..utils import as_bool
from dotenv import load_dotenv
load_dotenv()
logger = logging.getLogger(__name__)
JOB_ID_PREFIX = "garak-job-"
class GarakRemoteEvalAdapter(GarakEvalBase):
"""Remote Garak evaluation adapter for running scans on Kubeflow Pipelines."""
def __init__(self, config: GarakRemoteConfig, deps: dict[Api, ProviderSpec]):
super().__init__(config, deps)
self._config: GarakRemoteConfig = config
self.kfp_client = None
self._jobs_lock = asyncio.Lock() # Will be initialized in async initialize()
def _ensure_garak_installed(self) -> None:
"""Override: Skip garak check for remote provider - it runs in container."""
logger.debug("Skipping garak installation check for remote provider (runs in container)")
pass
def _get_all_probes(self) -> set[str]:
"""Override: Skip probe enumeration for remote provider - validation happens in pod.
Returns empty set to allow any probe names in benchmark metadata.
Remote validation will occur when the scan runs in the Kubernetes pod.
"""
logger.debug("Skipping probe enumeration for remote provider (validated in container)")
return set() # Allow any probes; validation happens in the pod
async def initialize(self) -> None:
"""Initialize the remote Garak provider."""
self._initialize()
self._initialized = True
logger.info("Initialized Garak remote provider.")
def _ensure_kfp_client(self):
if not self.kfp_client:
self._create_kfp_client()
def _create_kfp_client(self):
try:
from kfp import Client
from kfp_server_api.exceptions import ApiException
ssl_cert = None
if isinstance(self._verify_ssl, str):
ssl_cert = self._verify_ssl
verify_ssl = True
else:
verify_ssl = self._verify_ssl
# Use token from config if provided, otherwise get from kubeconfig
token = self._get_token()
self.kfp_client = Client(
host=self._config.kubeflow_config.pipelines_endpoint,
existing_token=token,
verify_ssl=verify_ssl,
ssl_ca_cert=ssl_cert,
)
except ImportError:
raise GarakError("Kubeflow Pipelines SDK not available. Install with: pip install -e .[remote]")
except ApiException as e:
raise GarakError(
"Unable to connect to Kubeflow Pipelines. Please check if you are logged in to the correct cluster. "
"If you are logged in, please check if the token & pipeline endpoint is valid. "
"If you are not logged in, please run `oc login` command first."
) from e
except Exception as e:
raise GarakError("Unable to connect to Kubeflow Pipelines.") from e
def _get_token(self) -> str:
"""Get authentication token from environment variable or kubernetes config."""
if self._config.kubeflow_config.pipelines_api_token:
logger.info("Using KUBEFLOW_PIPELINES_TOKEN from config")
return self._config.kubeflow_config.pipelines_api_token.get_secret_value()
else:
logger.info("Using authentication token from kubernetes config")
try:
from .kfp_utils.utils import _load_kube_config
from kubernetes.client.exceptions import ApiException
from kubernetes.config.config_exception import ConfigException
config = _load_kube_config()
token = str(config.api_key["authorization"].split(" ")[-1])
except (KeyError, ConfigException) as e:
raise ApiException(401, "Unauthorized, try running command like `oc login` first") from e
except ImportError as e:
raise GarakError("Kubernetes client is not installed. Install with: pip install kubernetes") from e
except Exception as e:
raise GarakError(
f"Unable to get authentication token from kubernetes config: {e}. Please run `oc login` and try again."
) from e
return token
async def run_eval(self, request: RunEvalRequest) -> Job:
"""Run an evaluation for a specific benchmark and configuration.
Args:
request: Run eval request containing benchmark_id and benchmark_config
Raises:
GarakValidationError: If benchmark_id or benchmark_config are invalid
GarakConfigError: If configuration is invalid
GarakError: If KFP pipeline creation fails
"""
if not self._initialized:
await self.initialize()
benchmark_id = request.benchmark_id
benchmark_config: BenchmarkConfig = request.benchmark_config
if not benchmark_id or not isinstance(benchmark_id, str):
raise GarakValidationError("benchmark_id must be a non-empty string")
if not benchmark_config:
raise GarakValidationError("benchmark_config cannot be None")
garak_config, provider_params = await self._validate_run_eval_request(benchmark_id, benchmark_config)
job_id = self._get_job_id(prefix=JOB_ID_PREFIX)
job = Job(job_id=job_id, status=JobStatus.scheduled)
async with self._jobs_lock:
self._jobs[job_id] = job
self._job_metadata[job_id] = {} # Initialize metadata dict
try:
timeout = provider_params.get("timeout", self._config.timeout)
cmd_config: dict = await self._build_command(benchmark_config, garak_config, provider_params)
# Validate config before creating pipeline
if not self._config.kubeflow_config.namespace:
raise GarakConfigError("kubeflow_config.namespace is not configured")
if not self._config.llama_stack_url:
raise GarakConfigError("llama_stack_url is not configured")
garak_base_image = self._config.kubeflow_config.garak_base_image
if garak_base_image and garak_base_image.strip() and not os.environ.get("KUBEFLOW_GARAK_BASE_IMAGE"):
os.environ["KUBEFLOW_GARAK_BASE_IMAGE"] = garak_base_image
logger.info(f"KUBEFLOW_GARAK_BASE_IMAGE set to {garak_base_image}")
experiment_name = self._config.kubeflow_config.experiment_name or "trustyai-garak"
llama_stack_url = self._config.llama_stack_url.strip().rstrip("/")
if not llama_stack_url:
raise GarakConfigError("llama_stack_url cannot be empty after normalization")
from .kfp_utils.pipeline import garak_scan_pipeline
from ..core.pipeline_steps import redact_api_keys
self._ensure_kfp_client()
sanitised_config = redact_api_keys(cmd_config)
disable_cache = as_bool(provider_params.get("disable_cache", False))
run = self.kfp_client.create_run_from_pipeline_func(
garak_scan_pipeline,
arguments={
"command": json.dumps(sanitised_config),
"llama_stack_url": llama_stack_url,
"job_id": job_id,
"eval_threshold": float(garak_config.run.eval_threshold),
"timeout_seconds": int(timeout),
"verify_ssl": str(self._verify_ssl),
"art_intents": provider_params.get("art_intents", False),
"policy_file_id": provider_params.get("policy_file_id", ""),
"policy_format": provider_params.get("policy_format", "csv"),
"intents_file_id": provider_params.get("intents_file_id", ""),
"intents_format": provider_params.get("intents_format", "csv"),
"sdg_model": provider_params.get("sdg_model", ""),
"sdg_api_base": provider_params.get("sdg_api_base", ""),
"sdg_flow_id": provider_params.get("sdg_flow_id", ""),
},
run_name=f"garak-{benchmark_id.split('::')[-1]}-{job_id.removeprefix(JOB_ID_PREFIX)}",
namespace=self._config.kubeflow_config.namespace,
experiment_name=experiment_name,
enable_caching=not disable_cache,
)
async with self._jobs_lock:
self._job_metadata[job_id] = {
"created_at": self._convert_datetime_to_str(run.run_info.created_at),
"kfp_run_id": run.run_id,
}
# Return Job object with metadata (Job model patched in compat.py to allow extra fields)
return Job(job_id=job_id, status=job.status, metadata=self._job_metadata.get(job_id, {}))
except Exception as e:
logger.error(f"Error running eval for {benchmark_id}: {e}")
async with self._jobs_lock:
job.status = JobStatus.failed
self._job_metadata[job.job_id]["error"] = str(e)
raise e
def _map_kfp_run_state_to_job_status(self, run_state) -> JobStatus:
"""Map the KFP run state to the job status."""
from kfp_server_api.models import V2beta1RuntimeState
if run_state in [V2beta1RuntimeState.RUNTIME_STATE_UNSPECIFIED, V2beta1RuntimeState.PENDING]:
return JobStatus.scheduled
elif run_state in [V2beta1RuntimeState.RUNNING, V2beta1RuntimeState.CANCELING, V2beta1RuntimeState.PAUSED]:
return JobStatus.in_progress
elif run_state in [V2beta1RuntimeState.SUCCEEDED, V2beta1RuntimeState.SKIPPED]:
return JobStatus.completed
elif run_state == V2beta1RuntimeState.FAILED:
return JobStatus.failed
elif run_state == V2beta1RuntimeState.CANCELED:
return JobStatus.cancelled
else:
logger.warning(f"KFP run has an unknown status: {run_state}, mapping to scheduled")
return JobStatus.scheduled
async def job_status(self, request: JobStatusRequest) -> Job:
"""Get the status of a job.
Args:
request: Job status request containing benchmark_id and job_id
"""
from kfp_server_api.models import V2beta1Run
benchmark_id = request.benchmark_id
job_id = request.job_id
async with self._jobs_lock:
job = self._jobs.get(job_id)
if not job:
logger.warning(f"Job {job_id} not found")
return Job(job_id=job_id, status=JobStatus.failed, metadata={"error": "Job not found"})
metadata: dict = self._job_metadata.get(job_id, {})
if "kfp_run_id" not in metadata:
logger.warning(f"Job {job_id} has no kfp run id")
return Job(job_id=job_id, status=JobStatus.failed, metadata={"error": "No KFP run ID found"})
kfp_run_id = metadata["kfp_run_id"]
try:
self._ensure_kfp_client()
run: V2beta1Run = self.kfp_client.get_run(kfp_run_id)
new_status = self._map_kfp_run_state_to_job_status(run.state)
async with self._jobs_lock:
job.status = new_status
if job.status in [JobStatus.completed, JobStatus.failed, JobStatus.cancelled]:
self._job_metadata[job_id]["finished_at"] = self._convert_datetime_to_str(run.finished_at)
if job.status == JobStatus.completed:
# Retrieve file_id_mapping from Files API using predictable filename
# Two-phase mapping: parse_results uploads enriched {job_id}_mapping.json,
# garak_scan uploads raw {job_id}_mapping_raw.json as fallback
try:
mapping_filename = f"{job_id}_mapping.json"
raw_mapping_filename = f"{job_id}_mapping_raw.json"
logger.debug(f"Searching for mapping file: {mapping_filename}")
if "mapping_file_id" not in self._job_metadata[job_id]:
# List files and find the mapping file by name
files_list = await self.file_api.openai_list_files(
ListFilesRequest(purpose=OpenAIFilePurpose.BATCH)
)
for file_obj in files_list.data:
if not hasattr(file_obj, "filename"):
continue
if file_obj.filename == mapping_filename:
self._job_metadata[job_id]["mapping_file_id"] = file_obj.id
logger.debug(f"Found enriched mapping: {mapping_filename} (ID: {file_obj.id})")
break
elif file_obj.filename == raw_mapping_filename:
self._job_metadata[job_id]["mapping_file_id"] = file_obj.id
logger.debug(
f"Found raw mapping (fallback): {raw_mapping_filename} (ID: {file_obj.id})"
)
if mapping_file_id := self._job_metadata[job_id].get("mapping_file_id"):
# Retrieve the mapping file via Files API
mapping_content = await self.file_api.openai_retrieve_file_content(
RetrieveFileContentRequest(file_id=mapping_file_id)
)
if mapping_content:
file_id_mapping: dict = json.loads(mapping_content.body.decode("utf-8"))
if not isinstance(file_id_mapping, dict):
raise ValueError(f"Invalid file mapping format: {type(file_id_mapping)}")
# Store the file IDs in job metadata
for key, value in file_id_mapping.items():
self._job_metadata[job_id][key] = value
logger.debug(
f"Successfully retrieved {len(file_id_mapping)} file IDs from Files API"
)
else:
logger.warning(f"Empty mapping file content for file ID: {mapping_file_id}")
else:
logger.warning(
f"Could not find mapping file '{mapping_filename}' or '{raw_mapping_filename}' in Files API. "
f"This might be expected if the pipeline is still running or failed."
)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON from mapping file: {e}")
except Exception as e:
logger.warning(f"Could not retrieve mapping from Files API for job {job_id}: {e}")
return_metadata = self._job_metadata.get(job_id, {})
return_status = job.status
except Exception as e:
logger.error(f"Error getting KFP run {kfp_run_id}: {e}")
return Job(job_id=job_id, status=JobStatus.failed, metadata={"error": f"Error getting KFP run: {str(e)}"})
# Return Job object with metadata (Job model patched in compat.py to allow extra fields)
return Job(job_id=job_id, status=return_status, metadata=return_metadata)
async def job_result(self, request: JobResultRequest) -> EvaluateResponse:
"""Get the result of a job (remote-specific: updates job state from KFP first).
Args:
request: Job result request containing benchmark_id and job_id
"""
# Update job status from KFP before getting results
await self.job_status(JobStatusRequest(benchmark_id=request.benchmark_id, job_id=request.job_id))
return await super().job_result(request, prefix=f"{request.job_id}_")
async def job_cancel(self, request: JobCancelRequest) -> None:
"""Cancel a job and kill the process.
Args:
request: Job cancel request containing benchmark_id and job_id
"""
benchmark_id = request.benchmark_id
job_id = request.job_id
# check/update the current status of the job
current_status = await self.job_status(JobStatusRequest(benchmark_id=benchmark_id, job_id=job_id))
if current_status.status in [JobStatus.completed, JobStatus.failed, JobStatus.cancelled]:
logger.warning(f"Job {job_id} is not running. Can't cancel.")
return
else:
async with self._jobs_lock:
kfp_run_id = self._job_metadata[job_id].get("kfp_run_id")
if kfp_run_id:
try:
self._ensure_kfp_client()
self.kfp_client.terminate_run(kfp_run_id)
except Exception as e:
logger.error(f"Error cancelling KFP run {kfp_run_id}: {e}")
async def shutdown(self) -> None:
"""Clean up resources when shutting down."""
logger.info("Shutting down Garak provider")
# Get snapshot of jobs to cancel
async with self._jobs_lock:
jobs_to_cancel = [
(job_id, job)
for job_id, job in self._jobs.items()
if job.status in [JobStatus.in_progress, JobStatus.scheduled]
]
# Kill all running jobs
for job_id, job in jobs_to_cancel:
await self.job_cancel(JobCancelRequest(benchmark_id="placeholder", job_id=job_id))
# # Clear all running tasks, jobs and job metadata
async with self._jobs_lock:
self._jobs.clear()
self._job_metadata.clear()
# Close the shield scanning HTTP client
shield_scan.simple_shield_orchestrator.close()