Skip to content

Commit 69f1462

Browse files
authored
Merge pull request trustyai-explainability#136 from trustyai-explainability/fix-deep-merge
fix: deep-merge `garak_config` overrides for intents and make `intents_models` flexible
2 parents 840a966 + 248512b commit 69f1462

File tree

3 files changed

+547
-30
lines changed

3 files changed

+547
-30
lines changed

src/llama_stack_provider_trustyai_garak/evalhub/garak_adapter.py

Lines changed: 79 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -968,8 +968,11 @@ def _apply_intents_model_config(
968968
969969
- **1 provided**: the configured role is used for all three.
970970
- **3 provided**: each role uses its own config.
971-
- **0 provided**: raises ``ValueError`` — the target model cannot be
972-
used as judge/attacker/evaluator.
971+
- **0 provided, but models pre-configured in garak_config**: the
972+
override is skipped and models from garak_config are used as-is.
973+
SDG params are extracted from flat keys (``sdg_model``,
974+
``sdg_api_base``).
975+
- **0 provided, no garak_config models**: raises ``ValueError``.
973976
- **2 provided**: raises ``ValueError`` — ambiguous which should fill
974977
the missing role.
975978
@@ -1003,12 +1006,23 @@ def _apply_intents_model_config(
10031006
}
10041007

10051008
if len(provided) == 0:
1009+
if garak_config.plugins and self._models_preconfigured_in_garak_config(garak_config.plugins):
1010+
logger.info(
1011+
"No intents_models provided but models are already "
1012+
"configured in garak_config — skipping intents_models "
1013+
"override. API keys will be resolved by "
1014+
"_resolve_config_api_keys in the KFP pod."
1015+
)
1016+
return self._extract_sdg_params(sdg_cfg, benchmark_config, profile)
1017+
10061018
raise ValueError(
1007-
"Intents benchmark requires at least one of "
1008-
"intents_models.judge, intents_models.attacker, or "
1009-
"intents_models.evaluator with a 'url' and 'name'. "
1010-
"The target model (config.model) cannot be used as "
1011-
"judge/attacker/evaluator."
1019+
"Intents benchmark requires model configuration for "
1020+
"judge/attacker/evaluator roles. Either:\n"
1021+
" 1. Provide intents_models with at least one role "
1022+
"(url + name), or\n"
1023+
" 2. Configure models directly in garak_config "
1024+
"(plugins.detectors.judge with detector_model_name and "
1025+
"detector_model_config.uri)."
10121026
)
10131027
if len(provided) == 2:
10141028
missing = {"judge", "attacker", "evaluator"} - set(provided)
@@ -1043,33 +1057,70 @@ def _apply_intents_model_config(
10431057
plugins = garak_config.plugins
10441058

10451059
plugins.detectors = plugins.detectors or {}
1046-
plugins.detectors["judge"] = {
1047-
"detector_model_type": "openai.OpenAICompatible",
1048-
"detector_model_name": judge_name,
1049-
"detector_model_config": {
1050-
"uri": judge_url,
1051-
"api_key": _PLACEHOLDER,
1052-
},
1053-
}
1060+
existing_judge = plugins.detectors.get("judge", {})
1061+
existing_judge["detector_model_type"] = existing_judge.get("detector_model_type") or "openai.OpenAICompatible"
1062+
existing_judge["detector_model_name"] = existing_judge.get("detector_model_name") or judge_name
1063+
existing_det_cfg = existing_judge.get("detector_model_config", {})
1064+
existing_det_cfg["uri"] = existing_det_cfg.get("uri") or judge_url
1065+
existing_det_cfg["api_key"] = _PLACEHOLDER
1066+
existing_judge["detector_model_config"] = existing_det_cfg
1067+
plugins.detectors["judge"] = existing_judge
10541068

10551069
if plugins.probes and plugins.probes.get("tap"):
10561070
tap_cfg = plugins.probes["tap"].get("TAPIntent", {})
10571071
if isinstance(tap_cfg, dict):
1058-
tap_cfg["attack_model_name"] = attacker_name
1059-
tap_cfg["attack_model_config"] = {
1060-
"uri": attacker_url,
1061-
"api_key": _PLACEHOLDER,
1062-
"max_tokens": tap_cfg.get("attack_model_config", {}).get("max_tokens", 500),
1063-
}
1064-
tap_cfg["evaluator_model_name"] = evaluator_name
1065-
tap_cfg["evaluator_model_config"] = {
1066-
"uri": evaluator_url,
1067-
"api_key": _PLACEHOLDER,
1068-
"max_tokens": tap_cfg.get("evaluator_model_config", {}).get("max_tokens", 10),
1069-
"temperature": tap_cfg.get("evaluator_model_config", {}).get("temperature", 0.0),
1070-
}
1072+
tap_cfg["attack_model_name"] = tap_cfg.get("attack_model_name") or attacker_name
1073+
existing_attack_cfg = tap_cfg.get("attack_model_config", {})
1074+
existing_attack_cfg.setdefault("max_tokens", 500)
1075+
existing_attack_cfg["uri"] = existing_attack_cfg.get("uri") or attacker_url
1076+
existing_attack_cfg["api_key"] = _PLACEHOLDER
1077+
tap_cfg["attack_model_config"] = existing_attack_cfg
1078+
1079+
tap_cfg["evaluator_model_name"] = tap_cfg.get("evaluator_model_name") or evaluator_name
1080+
existing_eval_cfg = tap_cfg.get("evaluator_model_config", {})
1081+
existing_eval_cfg.setdefault("max_tokens", 10)
1082+
existing_eval_cfg.setdefault("temperature", 0.0)
1083+
existing_eval_cfg["uri"] = existing_eval_cfg.get("uri") or evaluator_url
1084+
existing_eval_cfg["api_key"] = _PLACEHOLDER
1085+
tap_cfg["evaluator_model_config"] = existing_eval_cfg
1086+
10711087
plugins.probes["tap"]["TAPIntent"] = tap_cfg
10721088

1089+
return self._extract_sdg_params(sdg_cfg, benchmark_config, profile)
1090+
1091+
@staticmethod
1092+
def _models_preconfigured_in_garak_config(plugins: Any) -> bool:
1093+
"""Check if intents models are already configured in garak_config.
1094+
1095+
Returns True when the judge detector has a non-empty
1096+
``detector_model_name`` and ``detector_model_config.uri``.
1097+
If TAPIntent is present in the probes, also requires non-empty
1098+
attack and evaluator model names and URIs.
1099+
"""
1100+
detectors = plugins.detectors or {}
1101+
judge = detectors.get("judge", {})
1102+
if not judge.get("detector_model_name") or not judge.get("detector_model_config", {}).get("uri"):
1103+
return False
1104+
1105+
probes = plugins.probes or {}
1106+
tap_cfg = probes.get("tap", {}).get("TAPIntent")
1107+
if tap_cfg and isinstance(tap_cfg, dict):
1108+
for name_key, cfg_key in [
1109+
("attack_model_name", "attack_model_config"),
1110+
("evaluator_model_name", "evaluator_model_config"),
1111+
]:
1112+
if not tap_cfg.get(name_key) or not tap_cfg.get(cfg_key, {}).get("uri"):
1113+
return False
1114+
1115+
return True
1116+
1117+
@staticmethod
1118+
def _extract_sdg_params(
1119+
sdg_cfg: dict,
1120+
benchmark_config: dict,
1121+
profile: dict,
1122+
) -> dict[str, str]:
1123+
"""Extract SDG model params from intents_models.sdg or flat keys."""
10731124
sdg_params: dict[str, str] = {
10741125
"sdg_model": "",
10751126
"sdg_api_base": "",
@@ -1083,7 +1134,6 @@ def _apply_intents_model_config(
10831134
if sdg_model and sdg_api_base:
10841135
sdg_params["sdg_model"] = sdg_model
10851136
sdg_params["sdg_api_base"] = sdg_api_base
1086-
10871137
return sdg_params
10881138

10891139
@staticmethod

tests/test_config.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
KubeflowConfig,
1111
GarakScanConfig,
1212
)
13+
from llama_stack_provider_trustyai_garak.core.config_resolution import deep_merge_dicts
1314

1415

1516
class TestGarakInlineConfig:
@@ -398,3 +399,112 @@ def test_framework_profile_structure(self, profile_key, expected_probe_tag, expe
398399

399400
# (2) Validate taxonomy wiring
400401
assert garak_config["reporting"]["taxonomy"] == expected_taxonomy
402+
403+
404+
class TestDeepMergeDicts:
405+
"""Verify deep_merge_dicts honours leaf-level overrides without clobbering siblings."""
406+
407+
def test_override_single_leaf_preserves_siblings(self):
408+
base = {
409+
"plugins": {
410+
"probes": {
411+
"tap": {
412+
"TAPIntent": {
413+
"depth": 10,
414+
"width": 10,
415+
"branching_factor": 4,
416+
"attack_model_config": {"uri": "", "max_tokens": 500},
417+
}
418+
}
419+
}
420+
}
421+
}
422+
override = {
423+
"plugins": {
424+
"probes": {
425+
"tap": {
426+
"TAPIntent": {
427+
"depth": 20,
428+
}
429+
}
430+
}
431+
}
432+
}
433+
result = deep_merge_dicts(base, override)
434+
tap = result["plugins"]["probes"]["tap"]["TAPIntent"]
435+
assert tap["depth"] == 20
436+
assert tap["width"] == 10
437+
assert tap["branching_factor"] == 4
438+
assert tap["attack_model_config"]["max_tokens"] == 500
439+
440+
def test_override_nested_dict_leaf_preserves_sibling_keys(self):
441+
base = {
442+
"plugins": {
443+
"detectors": {
444+
"judge": {
445+
"detector_model_config": {"uri": "http://old", "api_key": "k1", "max_tokens": 200},
446+
"MulticlassJudge": {
447+
"system_prompt": "Default prompt",
448+
"score_key": "complied",
449+
"confidence_cutoff": 70,
450+
}
451+
}
452+
}
453+
}
454+
}
455+
override = {
456+
"plugins": {
457+
"detectors": {
458+
"judge": {
459+
"MulticlassJudge": {
460+
"system_prompt": "New custom prompt",
461+
}
462+
}
463+
}
464+
}
465+
}
466+
result = deep_merge_dicts(base, override)
467+
judge = result["plugins"]["detectors"]["judge"]
468+
assert judge["detector_model_config"]["uri"] == "http://old"
469+
assert judge["detector_model_config"]["max_tokens"] == 200
470+
mcj = judge["MulticlassJudge"]
471+
assert mcj["system_prompt"] == "New custom prompt"
472+
assert mcj["score_key"] == "complied"
473+
assert mcj["confidence_cutoff"] == 70
474+
475+
def test_override_does_not_mutate_base(self):
476+
base = {"a": {"b": 1, "c": 2}}
477+
override = {"a": {"b": 99}}
478+
result = deep_merge_dicts(base, override)
479+
assert result["a"]["b"] == 99
480+
assert result["a"]["c"] == 2
481+
assert base["a"]["b"] == 1, "Original base must not be mutated"
482+
483+
def test_adding_new_key_at_deep_level(self):
484+
base = {"plugins": {"detectors": {"judge": {"detector_model_name": "m1"}}}}
485+
override = {
486+
"plugins": {
487+
"detectors": {
488+
"judge": {
489+
"MulticlassJudge": {"system_prompt": "Added later"}
490+
}
491+
}
492+
}
493+
}
494+
result = deep_merge_dicts(base, override)
495+
judge = result["plugins"]["detectors"]["judge"]
496+
assert judge["detector_model_name"] == "m1"
497+
assert judge["MulticlassJudge"]["system_prompt"] == "Added later"
498+
499+
def test_dict_override_merges_preserving_siblings(self):
500+
base = {"run": {"generations": 5, "eval_threshold": 0.5}}
501+
override = {"run": {"generations": 100}}
502+
result = deep_merge_dicts(base, override)
503+
assert result["run"]["generations"] == 100
504+
assert result["run"]["eval_threshold"] == 0.5
505+
506+
def test_non_dict_override_replaces_entirely(self):
507+
base = {"run": {"generations": 5, "eval_threshold": 0.5}}
508+
override = {"run": "disabled"}
509+
result = deep_merge_dicts(base, override)
510+
assert result["run"] == "disabled"

0 commit comments

Comments
 (0)