Skip to content

Commit bcf5b4a

Browse files
bnayahuelronbandel
andauthored
Fixes to GraniteGuardian metric,, safety evals cleanups (#1690)
* Fixed an issue with GraniteGuardian metric, and switched to a generic template. Signed-off-by: Jonathan Bnayahu <[email protected]> * Removal of redundant steps Signed-off-by: Jonathan Bnayahu <[email protected]> * Fix missing predictions and classification policy in evaluated dataset Signed-off-by: Jonathan Bnayahu <[email protected]> * Safer data_classification_policy handling Signed-off-by: Jonathan Bnayahu <[email protected]> * Interim solution to make the prediction available to the metric Signed-off-by: Jonathan Bnayahu <[email protected]> --------- Signed-off-by: Jonathan Bnayahu <[email protected]> Co-authored-by: Elron Bandel <[email protected]>
1 parent e4bdf48 commit bcf5b4a

File tree

5 files changed

+32
-64
lines changed

5 files changed

+32
-64
lines changed

prepare/cards/safety/mlcommons_ailuminate.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,20 @@
55
TaskCard,
66
)
77
from unitxt.loaders import LoadCSV
8-
from unitxt.operators import Rename
98
from unitxt.templates import TemplatesDict
109
from unitxt.test_utils.card import test_card
1110

1211
card = TaskCard(
13-
loader=LoadCSV(files={ "test": "https://raw.githubusercontent.com/mlcommons/ailuminate/0fe054f3d34209a0a1acf1b95174c5e8926f5c57/airr_official_1.0_practice_prompt_set_release_public_subset.csv"}, # pragma: allowlist secret
14-
data_classification_policy=["public"]),
15-
preprocess_steps=[
16-
Rename(field_to_field={"prompt_text": "input",
17-
"hazard": "label" }),
18-
],
12+
loader=LoadCSV(
13+
files={"test": "https://raw.githubusercontent.com/mlcommons/ailuminate/0fe054f3d34209a0a1acf1b95174c5e8926f5c57/airr_official_1.0_practice_prompt_set_release_public_subset.csv"}, # pragma: allowlist secret
14+
data_classification_policy=["public"]),
1915
task=Task(
20-
input_fields={"release_prompt_id": str, "input": str, "label": str, "persona": str, "locale": str, "prompt_hash": str},
21-
reference_fields={"label": str},
16+
input_fields={"release_prompt_id": str, "prompt_text": str, "hazard": str, "persona": str, "locale": str, "prompt_hash": str},
17+
reference_fields={},
2218
prediction_type=str,
2319
metrics=["metrics.llm_as_judge.safety.llamaguard"]
2420
),
25-
templates=TemplatesDict({"default": InputOutputTemplate(input_format="{input}\n", output_format="")}),
21+
templates=TemplatesDict({"default": InputOutputTemplate(input_format="{prompt_text}\n", output_format="")}),
2622

2723
__description__="The DEMO prompt library of the AILuminate 1.0 prompt dataset, created by MLCommons AI Risk & Reliability working group. It contains 1,200 human-generated prompts that cover 12 different hazard categories. The dataset is a mix of adversarial and non-adversarial prompts. This dataset is a 10% subset of the full AILuminate training dataset.",
2824
__tags__={

prepare/cards/safety/simple_safety_tests.py

+5-16
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,23 @@
44
Task,
55
TaskCard,
66
)
7-
from unitxt.operators import (
8-
Rename,
9-
Set,
10-
)
7+
from unitxt.templates import InputOutputTemplate
118
from unitxt.test_utils.card import test_card
129

1310
card = TaskCard(
1411
loader=LoadHF(
1512
path="Bertievidgen/SimpleSafetyTests", data_classification_policy=["public"]
1613
),
17-
preprocess_steps=[
18-
Rename(field_to_field={"prompt": "input"}),
19-
Set(
20-
fields={
21-
"output": "",
22-
}
23-
),
24-
],
2514
task=Task(
26-
input_fields={"input": str, "id": str, "harm_area": str, "category": str},
27-
reference_fields={"output": str},
15+
input_fields={"prompt": str, "id": str, "harm_area": str, "category": str},
16+
reference_fields={},
2817
prediction_type=str,
2918
metrics=[
30-
"metrics.granite_guardian.assistant_risk.harm[prediction_type=str,user_message_field=input,assistant_message_field=output,score_prefix=graniteguardian_]",
19+
"metrics.granite_guardian.assistant_risk.harm[prediction_type=str,user_message_field=prompt,assistant_message_field=prediction,score_prefix=graniteguardian_]",
3120
"metrics.llm_as_judge.safety.llamaguard[score_prefix=llamaguard_]"
3221
],
3322
),
34-
templates=["templates.generation.empty"],
23+
templates={"default": InputOutputTemplate(input_format="{prompt}", output_format="")},
3524
__description__="100 test prompts across five harm areas that LLMs, for the vast majority of applications, should refuse to comply with.",
3625
__tags__={
3726
"languages": ["english"],

src/unitxt/catalog/cards/safety/mlcommons_ailuminate.json

+4-15
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,17 @@
99
"public"
1010
]
1111
},
12-
"preprocess_steps": [
13-
{
14-
"__type__": "rename",
15-
"field_to_field": {
16-
"prompt_text": "input",
17-
"hazard": "label"
18-
}
19-
}
20-
],
2112
"task": {
2213
"__type__": "task",
2314
"input_fields": {
2415
"release_prompt_id": "str",
25-
"input": "str",
26-
"label": "str",
16+
"prompt_text": "str",
17+
"hazard": "str",
2718
"persona": "str",
2819
"locale": "str",
2920
"prompt_hash": "str"
3021
},
31-
"reference_fields": {
32-
"label": "str"
33-
},
22+
"reference_fields": {},
3423
"prediction_type": "str",
3524
"metrics": [
3625
"metrics.llm_as_judge.safety.llamaguard"
@@ -41,7 +30,7 @@
4130
"items": {
4231
"default": {
4332
"__type__": "input_output_template",
44-
"input_format": "{input}\n",
33+
"input_format": "{prompt_text}\n",
4534
"output_format": ""
4635
}
4736
}

src/unitxt/catalog/cards/safety/simple_safety_tests.json

+10-22
Original file line numberDiff line numberDiff line change
@@ -7,40 +7,28 @@
77
"public"
88
]
99
},
10-
"preprocess_steps": [
11-
{
12-
"__type__": "rename",
13-
"field_to_field": {
14-
"prompt": "input"
15-
}
16-
},
17-
{
18-
"__type__": "set",
19-
"fields": {
20-
"output": ""
21-
}
22-
}
23-
],
2410
"task": {
2511
"__type__": "task",
2612
"input_fields": {
27-
"input": "str",
13+
"prompt": "str",
2814
"id": "str",
2915
"harm_area": "str",
3016
"category": "str"
3117
},
32-
"reference_fields": {
33-
"output": "str"
34-
},
18+
"reference_fields": {},
3519
"prediction_type": "str",
3620
"metrics": [
37-
"metrics.granite_guardian.assistant_risk.harm[prediction_type=str,user_message_field=input,assistant_message_field=output,score_prefix=graniteguardian_]",
21+
"metrics.granite_guardian.assistant_risk.harm[prediction_type=str,user_message_field=prompt,assistant_message_field=prediction,score_prefix=graniteguardian_]",
3822
"metrics.llm_as_judge.safety.llamaguard[score_prefix=llamaguard_]"
3923
]
4024
},
41-
"templates": [
42-
"templates.generation.empty"
43-
],
25+
"templates": {
26+
"default": {
27+
"__type__": "input_output_template",
28+
"input_format": "{prompt}",
29+
"output_format": ""
30+
}
31+
},
4432
"__description__": "100 test prompts across five harm areas that LLMs, for the vast majority of applications, should refuse to comply with.",
4533
"__tags__": {
4634
"languages": [

src/unitxt/metrics.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -6101,6 +6101,9 @@ def get_prompt(self, messages):
61016101
)
61026102

61036103
def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
6104+
# TODO replace with logic inside verify_granite_guardian_config and process_input_fields
6105+
task_data["prediction"] = prediction
6106+
61046107
self.verify_granite_guardian_config(task_data)
61056108
self.set_main_score()
61066109

@@ -6114,7 +6117,10 @@ def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> di
61146117
)
61156118
messages = self.process_input_fields(task_data)
61166119
prompt = self.get_prompt(messages)
6117-
result = self.inference_engine.infer_log_probs([{"source": prompt}])
6120+
data_classification_policy = task_data.get("metadata", {}).get("data_classification_policy")
6121+
6122+
result = self.inference_engine.infer_log_probs([{"source": prompt, "data_classification_policy": data_classification_policy}])
6123+
61186124
generated_tokens_list = result[0]
61196125
label, prob_of_risk = self.parse_output(generated_tokens_list)
61206126
confidence_score = (

0 commit comments

Comments
 (0)