Skip to content

Commit cf74509

Browse files
Merge pull request #103 from opendatahub-io/main
sync: main to incubation
2 parents f2c22b0 + b2f5ade commit cf74509

7 files changed

Lines changed: 189 additions & 13 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ dependencies = [
1919
"kfp-server-api>=2.14.6",
2020
"boto3>=1.35.88",
2121
# eval-hub integration
22-
"eval-hub-sdk[adapter]>=0.1.4",
22+
"eval-hub-sdk[adapter]==0.1.4",
2323
"pandas>=2.3.3",
2424
"Jinja2>=3.1.6",
2525
]

src/llama_stack_provider_trustyai_garak/evalhub/garak_adapter.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,11 +416,18 @@ def _run_simple(
416416
)
417417
)
418418

419+
env: dict[str, str] = {}
420+
hf_cache = (config.parameters or {}).get("hf_cache_path", "")
421+
if hf_cache:
422+
env["HF_HUB_CACHE"] = hf_cache
423+
logger.info("Using HF cache from mounted path: %s", hf_cache)
424+
419425
result = run_garak_scan(
420426
config_file=config_file,
421427
timeout_seconds=timeout,
422428
log_file=log_file,
423429
report_prefix=report_prefix,
430+
env=env if env else None,
424431
)
425432

426433
# AVID conversion
@@ -554,6 +561,7 @@ def _run_via_kfp(
554561
"sdg_max_concurrency": ip.get("sdg_max_concurrency", DEFAULT_SDG_MAX_CONCURRENCY),
555562
"sdg_num_samples": ip.get("sdg_num_samples", DEFAULT_SDG_NUM_SAMPLES),
556563
"sdg_max_tokens": ip.get("sdg_max_tokens", DEFAULT_SDG_MAX_TOKENS),
564+
"hf_cache_path": benchmark_config.get("hf_cache_path", ""),
557565
}
558566
if model_auth_secret:
559567
pipeline_args["model_auth_secret_name"] = model_auth_secret
@@ -595,7 +603,6 @@ def _run_via_kfp(
595603
),
596604
)
597605
)
598-
s3_bucket = kfp_config.s3_bucket or os.getenv("AWS_S3_BUCKET", "")
599606
creds = (
600607
self._read_s3_credentials_from_secret(
601608
kfp_config.s3_secret_name,
@@ -616,11 +623,13 @@ def _run_via_kfp(
616623
kfp_config.s3_secret_name,
617624
kfp_config.namespace,
618625
)
626+
s3_bucket = kfp_config.s3_bucket or creds.pop("bucket", "") or os.getenv("AWS_S3_BUCKET", "")
627+
s3_endpoint = kfp_config.s3_endpoint or creds.pop("endpoint_url", "") or None
619628
self._download_results_from_s3(
620629
s3_bucket,
621630
s3_prefix,
622631
scan_dir,
623-
endpoint_url=kfp_config.s3_endpoint or None,
632+
endpoint_url=s3_endpoint,
624633
**creds,
625634
)
626635

@@ -770,6 +779,8 @@ def _decode(key: str) -> str:
770779
"access_key": _decode("AWS_ACCESS_KEY_ID"),
771780
"secret_key": _decode("AWS_SECRET_ACCESS_KEY"),
772781
"region": _decode("AWS_DEFAULT_REGION"),
782+
"bucket": _decode("AWS_S3_BUCKET"),
783+
"endpoint_url": _decode("AWS_S3_ENDPOINT"),
773784
}
774785
except Exception as exc:
775786
logger.warning("Could not read S3 credentials from secret %s/%s: %s", namespace, secret_name, exc)
@@ -1289,6 +1300,22 @@ def _parse_results(
12891300
)
12901301
overall_summary = combined.get("scores", {}).get("_overall", {}).get("aggregated_results", {})
12911302

1303+
overall_asr = overall_summary.get("attack_success_rate")
1304+
if overall_asr is not None:
1305+
try:
1306+
overall_asr = float(overall_asr)
1307+
except (TypeError, ValueError):
1308+
overall_asr = None
1309+
if overall_asr is not None:
1310+
metrics.append(
1311+
EvaluationResult(
1312+
metric_name="attack_success_rate",
1313+
metric_value=overall_asr,
1314+
metric_type="percentage",
1315+
num_samples=overall_summary.get("total_attempts"),
1316+
)
1317+
)
1318+
12921319
# Convert to EvaluationResult format (one per probe)
12931320
for probe_name, score_data in combined["scores"].items():
12941321
if probe_name == "_overall":

src/llama_stack_provider_trustyai_garak/evalhub/kfp_pipeline.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,7 @@ def garak_scan(
441441
config_json: str,
442442
s3_prefix: str,
443443
timeout_seconds: int,
444+
hf_cache_path: str,
444445
prompts_dataset: dsl.Input[dsl.Dataset],
445446
) -> NamedTuple("Outputs", [("success", bool), ("return_code", int)]):
446447
"""Run a Garak scan and upload output to S3 via Data Connection credentials.
@@ -469,6 +470,54 @@ def garak_scan(
469470
)
470471
from llama_stack_provider_trustyai_garak.errors import GarakError
471472

473+
if hf_cache_path and hf_cache_path.strip():
474+
from llama_stack_provider_trustyai_garak.evalhub.s3_utils import create_s3_client
475+
476+
if hf_cache_path.startswith("s3://"):
477+
parts = hf_cache_path[len("s3://") :].split("/", 1)
478+
bucket = parts[0]
479+
prefix = parts[1] if len(parts) > 1 else ""
480+
else:
481+
bucket = os.environ.get("AWS_S3_BUCKET", "")
482+
prefix = hf_cache_path.lstrip("/")
483+
484+
if not bucket:
485+
raise GarakError(
486+
"Cannot determine S3 bucket for HF cache. "
487+
"Provide a full s3://bucket/prefix URI in hf_cache_path, "
488+
"or set AWS_S3_BUCKET."
489+
)
490+
491+
if not prefix:
492+
log.warning(
493+
"hf_cache_path has no sub-prefix; downloading all objects from bucket '%s'.",
494+
bucket,
495+
)
496+
497+
hf_cache_dir = Path(tempfile.mkdtemp(prefix="hf-cache-"))
498+
s3 = create_s3_client()
499+
downloaded = 0
500+
501+
paginator = s3.get_paginator("list_objects_v2")
502+
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
503+
for obj in page.get("Contents", []):
504+
rel = obj["Key"][len(prefix) :].lstrip("/")
505+
if not rel:
506+
continue
507+
dest = hf_cache_dir / rel
508+
dest.parent.mkdir(parents=True, exist_ok=True)
509+
s3.download_file(bucket, obj["Key"], str(dest))
510+
downloaded += 1
511+
512+
os.environ["HF_HUB_CACHE"] = str(hf_cache_dir)
513+
log.info(
514+
"Populated HF cache from s3://%s/%s -> %s (%d files)",
515+
bucket,
516+
prefix,
517+
hf_cache_dir,
518+
downloaded,
519+
)
520+
472521
scan_dir = Path(tempfile.mkdtemp(prefix="garak-scan-"))
473522

474523
prompts_path = Path(prompts_dataset.path)
@@ -639,6 +688,7 @@ def evalhub_garak_pipeline(
639688
sdg_max_concurrency: int = DEFAULT_SDG_MAX_CONCURRENCY,
640689
sdg_num_samples: int = DEFAULT_SDG_NUM_SAMPLES,
641690
sdg_max_tokens: int = DEFAULT_SDG_MAX_TOKENS,
691+
hf_cache_path: str = "",
642692
):
643693
"""Six-step pipeline: validate, resolve taxonomy, SDG, prepare prompts, scan, write outputs.
644694
@@ -736,6 +786,7 @@ def evalhub_garak_pipeline(
736786
config_json=config_json,
737787
s3_prefix=s3_prefix,
738788
timeout_seconds=timeout_seconds,
789+
hf_cache_path=hf_cache_path,
739790
prompts_dataset=prep_task.outputs["prompts_dataset"],
740791
)
741792
scan_task.set_caching_options(False)

src/llama_stack_provider_trustyai_garak/remote/garak_remote_eval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ async def run_eval(self, request: RunEvalRequest) -> Job:
211211
provider_params.get("sdg_max_tokens", DEFAULT_SDG_MAX_TOKENS),
212212
DEFAULT_SDG_MAX_TOKENS,
213213
),
214+
"hf_cache_path": provider_params.get("hf_cache_path", ""),
214215
},
215216
run_name=f"garak-{benchmark_id.split('::')[-1]}-{job_id.removeprefix(JOB_ID_PREFIX)}",
216217
namespace=self._config.kubeflow_config.namespace,

src/llama_stack_provider_trustyai_garak/remote/kfp_utils/components.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ def garak_scan(
397397
job_id: str,
398398
timeout_seconds: int,
399399
verify_ssl: str,
400+
hf_cache_path: str,
400401
prompts_dataset: dsl.Input[dsl.Dataset],
401402
) -> NamedTuple(
402403
"Outputs",
@@ -419,13 +420,19 @@ def garak_scan(
419420
)
420421
log = logging.getLogger("garak_scan")
421422

423+
import os
424+
422425
from llama_stack_client import LlamaStackClient
423426
from llama_stack_provider_trustyai_garak.utils import get_http_client_with_tls
424427
from llama_stack_provider_trustyai_garak.core.pipeline_steps import (
425428
setup_and_run_garak,
426429
redact_api_keys,
427430
)
428431

432+
if hf_cache_path and hf_cache_path.strip():
433+
os.environ["HF_HUB_CACHE"] = hf_cache_path
434+
log.info("Set HF_HUB_CACHE=%s for disconnected mode", hf_cache_path)
435+
429436
scan_dir = Path(tempfile.mkdtemp(prefix="garak-scan-"))
430437

431438
prompts_path = Path(prompts_dataset.path)

src/llama_stack_provider_trustyai_garak/remote/kfp_utils/pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def garak_scan_pipeline(
5353
sdg_max_concurrency: int = DEFAULT_SDG_MAX_CONCURRENCY,
5454
sdg_num_samples: int = DEFAULT_SDG_NUM_SAMPLES,
5555
sdg_max_tokens: int = DEFAULT_SDG_MAX_TOKENS,
56+
hf_cache_path: str = "",
5657
):
5758
"""Six-step pipeline: validate, resolve taxonomy, SDG, prepare prompts, scan, parse.
5859
@@ -120,6 +121,7 @@ def garak_scan_pipeline(
120121
job_id=job_id,
121122
timeout_seconds=timeout_seconds,
122123
verify_ssl=verify_ssl,
124+
hf_cache_path=hf_cache_path,
123125
prompts_dataset=prep_task.outputs["prompts_dataset"],
124126
)
125127
scan_task.set_caching_options(False)

tests/test_evalhub_adapter.py

Lines changed: 98 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,88 @@ def create_oci_artifact(self, _spec):
189189
assert captured["timeout_seconds"] == 42
190190

191191

192+
def test_simple_mode_passes_hf_cache_env(monkeypatch, tmp_path):
193+
"""When hf_cache_path is set, _run_simple passes HF_HUB_CACHE via env to run_garak_scan."""
194+
module = _load_evalhub_garak_adapter(monkeypatch)
195+
adapter = module.GarakAdapter()
196+
monkeypatch.setenv("GARAK_SCAN_DIR", str(tmp_path))
197+
198+
captured: dict[str, object] = {}
199+
200+
def _fake_run_garak_scan(config_file, timeout_seconds, report_prefix, env=None, log_file=None):
201+
captured["env"] = env
202+
report_prefix.with_suffix(".report.jsonl").write_text("{}", encoding="utf-8")
203+
return module.GarakScanResult(returncode=0, stdout="", stderr="", report_prefix=report_prefix)
204+
205+
monkeypatch.setattr(module, "run_garak_scan", _fake_run_garak_scan)
206+
monkeypatch.setattr(module, "convert_to_avid_report", lambda _path: True)
207+
monkeypatch.setattr(
208+
module.GarakAdapter,
209+
"_parse_results",
210+
lambda self, result, eval_threshold, art_intents=False: ([], None, 0, {"total_attempts": 0}),
211+
)
212+
213+
class _Callbacks:
214+
def report_status(self, _update):
215+
return None
216+
217+
def create_oci_artifact(self, _spec):
218+
return SimpleNamespace(reference="oci://ref", digest="sha256:test")
219+
220+
job = SimpleNamespace(
221+
id="hf-cache-job",
222+
benchmark_id="trustyai_garak::quick",
223+
benchmark_index=0,
224+
model=SimpleNamespace(url="http://localhost:8000", name="test-model"),
225+
parameters={"hf_cache_path": "/test_data/hf-cache"},
226+
exports=None,
227+
)
228+
229+
adapter.run_benchmark_job(job, _Callbacks())
230+
assert captured["env"] == {"HF_HUB_CACHE": "/test_data/hf-cache"}
231+
232+
233+
def test_simple_mode_no_hf_cache_passes_none_env(monkeypatch, tmp_path):
234+
"""When hf_cache_path is not set, env=None is passed (default behavior)."""
235+
module = _load_evalhub_garak_adapter(monkeypatch)
236+
adapter = module.GarakAdapter()
237+
monkeypatch.setenv("GARAK_SCAN_DIR", str(tmp_path))
238+
239+
captured: dict[str, object] = {}
240+
241+
def _fake_run_garak_scan(config_file, timeout_seconds, report_prefix, env=None, log_file=None):
242+
captured["env"] = env
243+
report_prefix.with_suffix(".report.jsonl").write_text("{}", encoding="utf-8")
244+
return module.GarakScanResult(returncode=0, stdout="", stderr="", report_prefix=report_prefix)
245+
246+
monkeypatch.setattr(module, "run_garak_scan", _fake_run_garak_scan)
247+
monkeypatch.setattr(module, "convert_to_avid_report", lambda _path: True)
248+
monkeypatch.setattr(
249+
module.GarakAdapter,
250+
"_parse_results",
251+
lambda self, result, eval_threshold, art_intents=False: ([], None, 0, {"total_attempts": 0}),
252+
)
253+
254+
class _Callbacks:
255+
def report_status(self, _update):
256+
return None
257+
258+
def create_oci_artifact(self, _spec):
259+
return SimpleNamespace(reference="oci://ref", digest="sha256:test")
260+
261+
job = SimpleNamespace(
262+
id="no-hf-cache-job",
263+
benchmark_id="trustyai_garak::quick",
264+
benchmark_index=0,
265+
model=SimpleNamespace(url="http://localhost:8000", name="test-model"),
266+
parameters={},
267+
exports=None,
268+
)
269+
270+
adapter.run_benchmark_job(job, _Callbacks())
271+
assert captured["env"] is None
272+
273+
192274
def test_parse_results_uses_overall_without_double_count(monkeypatch, tmp_path):
193275
module = _load_evalhub_garak_adapter(monkeypatch)
194276
adapter = module.GarakAdapter()
@@ -245,8 +327,11 @@ def test_parse_results_uses_overall_without_double_count(monkeypatch, tmp_path):
245327
)
246328
metrics, overall_score, num_examples, overall_summary = adapter._parse_results(result, 0.5)
247329

248-
assert len(metrics) == 1
249-
assert metrics[0].metric_name == "probe.alpha_asr"
330+
assert len(metrics) == 2
331+
assert metrics[0].metric_name == "attack_success_rate"
332+
assert metrics[0].metric_value == 30.0
333+
assert metrics[0].num_samples == 10
334+
assert metrics[1].metric_name == "probe.alpha_asr"
250335
assert overall_score == 30.0
251336
assert num_examples == 10
252337
assert overall_summary["tbsa"] == 4.1
@@ -1882,15 +1967,18 @@ def test_parse_results_with_art_intents(self, monkeypatch, tmp_path):
18821967
art_intents=True,
18831968
)
18841969

1885-
assert len(metrics) == 1
1886-
assert metrics[0].metric_name == "spo.SPOIntent_asr"
1970+
assert len(metrics) == 2
1971+
assert metrics[0].metric_name == "attack_success_rate"
18871972
assert metrics[0].metric_value == 30.0
1888-
assert metrics[0].num_samples is None
1889-
assert metrics[0].metadata["total_attempts"] == 20
1890-
assert metrics[0].metadata["unsafe_stubs"] == 3
1891-
assert metrics[0].metadata["safe_stubs"] == 7
1892-
assert "intent_breakdown" in metrics[0].metadata
1893-
assert metrics[0].metadata["intent_breakdown"]["S001"]["unsafe_stubs"] == 2
1973+
assert metrics[0].num_samples == 20
1974+
assert metrics[1].metric_name == "spo.SPOIntent_asr"
1975+
assert metrics[1].metric_value == 30.0
1976+
assert metrics[1].num_samples is None
1977+
assert metrics[1].metadata["total_attempts"] == 20
1978+
assert metrics[1].metadata["unsafe_stubs"] == 3
1979+
assert metrics[1].metadata["safe_stubs"] == 7
1980+
assert "intent_breakdown" in metrics[1].metadata
1981+
assert metrics[1].metadata["intent_breakdown"]["S001"]["unsafe_stubs"] == 2
18941982
assert overall_score == 30.0
18951983
assert num_examples == 20
18961984

0 commit comments

Comments
 (0)