Skip to content

Commit 61d3e10

Browse files
authored
Merge pull request #94 from trustyai-explainability/main
[pull] main from trustyai-explainability:main
2 parents 0ef3d76 + 6356648 commit 61d3e10

File tree

14 files changed

+469
-24
lines changed

14 files changed

+469
-24
lines changed

.gitignore

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@ __pycache__/
1010
dist/
1111
.env
1212
_providers.d/
13-
13+
meta/
14+
scan_out/
15+
**.env
16+
*.csv
17+
*.ipynb
18+
*.json
19+
*.cpu
1420
# Hermeto outputs (generated during testing)
1521
hermeto-output*/
1622
hermeto*.env

.pre-commit-config.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,11 @@ repos:
1313
- id: ruff
1414
args: [--fix, --exit-non-zero-on-fix]
1515
- id: ruff-format
16+
- repo: local
17+
hooks:
18+
- id: mypy
19+
name: mypy type check
20+
entry: mypy src/
21+
language: system
22+
pass_filenames: false
23+
types: [python]

src/llama_stack_provider_trustyai_garak/constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,8 @@
2121

2222
# SDG variables
2323
DEFAULT_SDG_FLOW_ID = "major-sage-742"
24+
DEFAULT_SDG_MAX_CONCURRENCY = 10
25+
DEFAULT_SDG_NUM_SAMPLES = 0
26+
DEFAULT_SDG_NUM_SAMPLES_BLOCK_NAME = "replicate_rows"
27+
DEFAULT_SDG_MAX_TOKENS = 0
28+
DEFAULT_SDG_MAX_TOKENS_BLOCK_NAME = "generate_adversarial_prompt"

src/llama_stack_provider_trustyai_garak/core/pipeline_steps.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,9 @@ def run_sdg_generation(
269269
sdg_model: str,
270270
sdg_api_base: str,
271271
sdg_flow_id: str = "",
272+
sdg_max_concurrency: int = 0,
273+
sdg_num_samples: int = 0,
274+
sdg_max_tokens: int = 0,
272275
) -> pd.DataFrame:
273276
"""Run Synthetic Data Generation on a taxonomy. Returns the raw DataFrame.
274277
@@ -299,6 +302,9 @@ def run_sdg_generation(
299302
flow_id=effective_flow_id,
300303
api_key=effective_key,
301304
taxonomy=taxonomy_df,
305+
max_concurrency=sdg_max_concurrency,
306+
num_samples=sdg_num_samples,
307+
max_tokens=sdg_max_tokens,
302308
)
303309
logger.info(
304310
"SDG produced %d raw rows across %d categories",

src/llama_stack_provider_trustyai_garak/evalhub/garak_adapter.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,17 @@
5959
parse_digest_from_report_content,
6060
parse_generations_from_report_content,
6161
)
62-
from ..utils import get_scan_base_dir, as_bool
62+
from ..utils import get_scan_base_dir, as_bool, safe_int
6363
from ..constants import (
6464
DEFAULT_TIMEOUT,
6565
DEFAULT_MODEL_TYPE,
6666
DEFAULT_EVAL_THRESHOLD,
6767
EXECUTION_MODE_SIMPLE,
6868
EXECUTION_MODE_KFP,
6969
DEFAULT_SDG_FLOW_ID,
70+
DEFAULT_SDG_MAX_CONCURRENCY,
71+
DEFAULT_SDG_NUM_SAMPLES,
72+
DEFAULT_SDG_MAX_TOKENS,
7073
)
7174

7275
logger = logging.getLogger(__name__)
@@ -546,6 +549,9 @@ def _run_via_kfp(
546549
"sdg_model": ip.get("sdg_model", ""),
547550
"sdg_api_base": ip.get("sdg_api_base", ""),
548551
"sdg_flow_id": ip.get("sdg_flow_id", DEFAULT_SDG_FLOW_ID),
552+
"sdg_max_concurrency": ip.get("sdg_max_concurrency", DEFAULT_SDG_MAX_CONCURRENCY),
553+
"sdg_num_samples": ip.get("sdg_num_samples", DEFAULT_SDG_NUM_SAMPLES),
554+
"sdg_max_tokens": ip.get("sdg_max_tokens", DEFAULT_SDG_MAX_TOKENS),
549555
}
550556
if model_auth_secret:
551557
pipeline_args["model_auth_secret_name"] = model_auth_secret
@@ -933,6 +939,20 @@ def _build_config_from_spec(
933939
"intents_s3_key": benchmark_config.get("intents_s3_key", profile.get("intents_s3_key", "")),
934940
"intents_format": benchmark_config.get("intents_format", profile.get("intents_format", "csv")),
935941
"sdg_flow_id": benchmark_config.get("sdg_flow_id", profile.get("sdg_flow_id", DEFAULT_SDG_FLOW_ID)),
942+
"sdg_max_concurrency": safe_int(
943+
benchmark_config.get(
944+
"sdg_max_concurrency", profile.get("sdg_max_concurrency", DEFAULT_SDG_MAX_CONCURRENCY)
945+
),
946+
DEFAULT_SDG_MAX_CONCURRENCY,
947+
),
948+
"sdg_num_samples": safe_int(
949+
benchmark_config.get("sdg_num_samples", profile.get("sdg_num_samples", DEFAULT_SDG_NUM_SAMPLES)),
950+
DEFAULT_SDG_NUM_SAMPLES,
951+
),
952+
"sdg_max_tokens": safe_int(
953+
benchmark_config.get("sdg_max_tokens", profile.get("sdg_max_tokens", DEFAULT_SDG_MAX_TOKENS)),
954+
DEFAULT_SDG_MAX_TOKENS,
955+
),
936956
"disable_cache": as_bool(benchmark_config.get("disable_cache", False)),
937957
}
938958

src/llama_stack_provider_trustyai_garak/evalhub/kfp_pipeline.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@
3131

3232
from kfp import dsl, kubernetes
3333

34-
from ..constants import DEFAULT_SDG_FLOW_ID
34+
from ..constants import (
35+
DEFAULT_SDG_FLOW_ID,
36+
DEFAULT_SDG_MAX_CONCURRENCY,
37+
DEFAULT_SDG_NUM_SAMPLES,
38+
DEFAULT_SDG_MAX_TOKENS,
39+
)
3540
from ..core.pipeline_steps import MODEL_AUTH_MOUNT_PATH
3641

3742
logger = logging.getLogger(__name__)
@@ -254,6 +259,9 @@ def sdg_generate(
254259
sdg_model: str,
255260
sdg_api_base: str,
256261
sdg_flow_id: str,
262+
sdg_max_concurrency: int,
263+
sdg_num_samples: int,
264+
sdg_max_tokens: int,
257265
taxonomy_dataset: dsl.Input[dsl.Dataset],
258266
sdg_dataset: dsl.Output[dsl.Dataset],
259267
):
@@ -306,6 +314,9 @@ def sdg_generate(
306314
sdg_model=sdg_model,
307315
sdg_api_base=sdg_api_base,
308316
sdg_flow_id=sdg_flow_id,
317+
sdg_max_concurrency=sdg_max_concurrency,
318+
sdg_num_samples=sdg_num_samples,
319+
sdg_max_tokens=sdg_max_tokens,
309320
)
310321
raw_df.to_csv(sdg_dataset.path, index=False)
311322
log.info("Wrote %d raw SDG rows to artifact", len(raw_df))
@@ -625,6 +636,9 @@ def evalhub_garak_pipeline(
625636
sdg_model: str = "",
626637
sdg_api_base: str = "",
627638
sdg_flow_id: str = DEFAULT_SDG_FLOW_ID,
639+
sdg_max_concurrency: int = DEFAULT_SDG_MAX_CONCURRENCY,
640+
sdg_num_samples: int = DEFAULT_SDG_NUM_SAMPLES,
641+
sdg_max_tokens: int = DEFAULT_SDG_MAX_TOKENS,
628642
):
629643
"""Six-step pipeline: validate, resolve taxonomy, SDG, prepare prompts, scan, write outputs.
630644
@@ -682,6 +696,9 @@ def evalhub_garak_pipeline(
682696
sdg_model=sdg_model,
683697
sdg_api_base=sdg_api_base,
684698
sdg_flow_id=sdg_flow_id,
699+
sdg_max_concurrency=sdg_max_concurrency,
700+
sdg_num_samples=sdg_num_samples,
701+
sdg_max_tokens=sdg_max_tokens,
685702
taxonomy_dataset=taxonomy_task.outputs["taxonomy_dataset"],
686703
)
687704
sdg_task.set_caching_options(True)

src/llama_stack_provider_trustyai_garak/remote/garak_remote_eval.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
from ..base_eval import GarakEvalBase
2222
from llama_stack_provider_trustyai_garak import shield_scan
2323
from ..errors import GarakError, GarakConfigError, GarakValidationError
24-
from ..utils import as_bool
24+
from ..constants import DEFAULT_SDG_MAX_CONCURRENCY, DEFAULT_SDG_NUM_SAMPLES, DEFAULT_SDG_MAX_TOKENS
25+
from ..utils import as_bool, safe_int
2526
from dotenv import load_dotenv
2627

2728
load_dotenv()
@@ -198,6 +199,18 @@ async def run_eval(self, request: RunEvalRequest) -> Job:
198199
"sdg_model": provider_params.get("sdg_model", ""),
199200
"sdg_api_base": provider_params.get("sdg_api_base", ""),
200201
"sdg_flow_id": provider_params.get("sdg_flow_id", ""),
202+
"sdg_max_concurrency": safe_int(
203+
provider_params.get("sdg_max_concurrency", DEFAULT_SDG_MAX_CONCURRENCY),
204+
DEFAULT_SDG_MAX_CONCURRENCY,
205+
),
206+
"sdg_num_samples": safe_int(
207+
provider_params.get("sdg_num_samples", DEFAULT_SDG_NUM_SAMPLES),
208+
DEFAULT_SDG_NUM_SAMPLES,
209+
),
210+
"sdg_max_tokens": safe_int(
211+
provider_params.get("sdg_max_tokens", DEFAULT_SDG_MAX_TOKENS),
212+
DEFAULT_SDG_MAX_TOKENS,
213+
),
201214
},
202215
run_name=f"garak-{benchmark_id.split('::')[-1]}-{job_id.removeprefix(JOB_ID_PREFIX)}",
203216
namespace=self._config.kubeflow_config.namespace,

src/llama_stack_provider_trustyai_garak/remote/kfp_utils/components.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,9 @@ def sdg_generate(
228228
sdg_model: str,
229229
sdg_api_base: str,
230230
sdg_flow_id: str,
231+
sdg_max_concurrency: int,
232+
sdg_num_samples: int,
233+
sdg_max_tokens: int,
231234
taxonomy_dataset: dsl.Input[dsl.Dataset],
232235
sdg_dataset: dsl.Output[dsl.Dataset],
233236
):
@@ -269,6 +272,9 @@ def sdg_generate(
269272
sdg_model=sdg_model,
270273
sdg_api_base=sdg_api_base,
271274
sdg_flow_id=sdg_flow_id,
275+
sdg_max_concurrency=sdg_max_concurrency,
276+
sdg_num_samples=sdg_num_samples,
277+
sdg_max_tokens=sdg_max_tokens,
272278
)
273279
raw_df.to_csv(sdg_dataset.path, index=False)
274280
log.info("Wrote %d raw SDG rows to artifact", len(raw_df))

src/llama_stack_provider_trustyai_garak/remote/kfp_utils/pipeline.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from kfp import dsl
1414
from dotenv import load_dotenv
1515
import logging
16-
from ...constants import DEFAULT_SDG_FLOW_ID
16+
from ...constants import (
17+
DEFAULT_SDG_FLOW_ID,
18+
DEFAULT_SDG_MAX_CONCURRENCY,
19+
DEFAULT_SDG_NUM_SAMPLES,
20+
DEFAULT_SDG_MAX_TOKENS,
21+
)
1722

1823
from .components import (
1924
validate,
@@ -45,6 +50,9 @@ def garak_scan_pipeline(
4550
sdg_model: str = "",
4651
sdg_api_base: str = "",
4752
sdg_flow_id: str = DEFAULT_SDG_FLOW_ID,
53+
sdg_max_concurrency: int = DEFAULT_SDG_MAX_CONCURRENCY,
54+
sdg_num_samples: int = DEFAULT_SDG_NUM_SAMPLES,
55+
sdg_max_tokens: int = DEFAULT_SDG_MAX_TOKENS,
4856
):
4957
"""Six-step pipeline: validate, resolve taxonomy, SDG, prepare prompts, scan, parse.
5058
@@ -86,6 +94,9 @@ def garak_scan_pipeline(
8694
sdg_model=sdg_model,
8795
sdg_api_base=sdg_api_base,
8896
sdg_flow_id=sdg_flow_id,
97+
sdg_max_concurrency=sdg_max_concurrency,
98+
sdg_num_samples=sdg_num_samples,
99+
sdg_max_tokens=sdg_max_tokens,
89100
taxonomy_dataset=taxonomy_task.outputs["taxonomy_dataset"],
90101
)
91102
sdg_task.set_caching_options(True)

src/llama_stack_provider_trustyai_garak/resources/art_report.jinja2

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,8 @@
4242
<body>
4343
<div class="pf-v6-c-page" id="masthead-basic-example">
4444
<header class="pf-v6-c-masthead" id="masthead-basic-example-masthead">
45-
<div class="pf-v6-c-masthead__main">
46-
<div class="pf-v6-c-masthead__brand"></div>
47-
<div class="pf-v6-c-masthead__content">
48-
<h1 class="pf-v6-c-title pf-m-xl">Automated Red Teaming Report</h1>
49-
</div>
45+
<div class="pf-v6-c-masthead__content">
46+
<h1 class="pf-v6-c-title pf-m-xl">Automated Red Teaming Report</h1>
5047
</div>
5148
</header>
5249
<div class="pf-v6-c-page__sidebar">

0 commit comments

Comments
 (0)