Skip to content

Commit b55532f

Browse files
yoavkatzpawelkneselronbandel
authored
Add option to store template instruction in user role and not system role and added granite thinking example (#1667)
* support for asynchronous requests in wml chat Signed-off-by: Paweł Knes <pawel.knes@ibm.com> * Test branch to evaluate impact of different format * Avoid removal of needed import Signed-off-by: Yoav Katz <katz@il.ibm.com> * Made inference_using_ibm_watsonx_ai work with env variables out of the box Signed-off-by: Yoav Katz <katz@il.ibm.com> * Renamed repeat_instruction_per_turn to place_instruction_in_user_turns to highlight where instructions are placed. Signed-off-by: Yoav Katz <katz@il.ibm.com> * Added example of granite thinking Signed-off-by: Yoav Katz <katz@il.ibm.com> * Added documentation for 'place_instruction_in_user_turns' Removed 'add_target_prefix' which is not related. Improved example. Signed-off-by: Yoav Katz <katz@il.ibm.com> * Added granite thinking with MMLU Signed-off-by: Yoav Katz <katz@il.ibm.com> * Updated prints Signed-off-by: Yoav Katz <katz@il.ibm.com> * Added example of inference with cross provider without load_dataset Signed-off-by: Yoav Katz <katz@il.ibm.com> * Improved example. Signed-off-by: Yoav Katz <katz@il.ibm.com> * Updated multi format example Signed-off-by: Yoav Katz <katz@il.ibm.com> * Simplfied doc Signed-off-by: Yoav Katz <katz@il.ibm.com> --------- Signed-off-by: Paweł Knes <pawel.knes@ibm.com> Signed-off-by: Yoav Katz <katz@il.ibm.com> Co-authored-by: Paweł Knes <pawel.knes@ibm.com> Co-authored-by: Elron Bandel <elronbandel@gmail.com>
1 parent 9895a3d commit b55532f

File tree

6 files changed

+345
-69
lines changed

6 files changed

+345
-69
lines changed

docs/docs/examples.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ Related documentation: :ref:`Templates tutorial <adding_template>`, :ref:`Format
134134
Evaluate the impact of different formats and system prompts
135135
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
136136

137-
This example demonstrates how different formats and system prompts affect the input provided to a llama3 chat model and evaluate their impact on the obtained scores.
137+
This example demonstrates how different formats and system prompts affect the input provided to a granite chat model and evaluate their impact on the obtained scores.
138138

139139
`Example code <https://github.com/IBM/unitxt/blob/main/examples/evaluate_different_formats.py>`__
140140

Lines changed: 159 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,165 @@
1+
import json
2+
import time
3+
14
import pandas as pd
25
from unitxt.api import evaluate, load_dataset
3-
from unitxt.inference import CrossProviderInferenceEngine
6+
from unitxt.inference import (
7+
CrossProviderInferenceEngine,
8+
WMLInferenceEngineChat,
9+
WMLInferenceEngineGeneration,
10+
)
11+
12+
print("Creating cross_provider_rits ...")
13+
cross_provider_rits = CrossProviderInferenceEngine(
14+
model="granite-3-8b-instruct", max_tokens=32, provider="rits", temperature=0
15+
)
416

5-
model = CrossProviderInferenceEngine(
6-
model="llama-3-8b-instruct", max_tokens=32, provider="bam"
17+
print("Creating cross_provider_watsonx ...")
18+
cross_provider_watsonx = CrossProviderInferenceEngine(
19+
model="granite-3-8b-instruct", max_tokens=32, provider="watsonx", temperature=0
20+
)
21+
print("Creating wml_gen ...")
22+
wml_gen = WMLInferenceEngineGeneration(
23+
model_name="ibm/granite-3-8b-instruct", max_new_tokens=32, temperature=0
24+
)
25+
print("Creating wml_chat ...")
26+
wml_chat = WMLInferenceEngineChat(
27+
model_name="ibm/granite-3-8b-instruct", max_tokens=32, temperature=0
728
)
8-
"""
9-
We are using a CrossProviderInferenceEngine inference engine that supply api access to provider such as:
10-
watsonx, bam, openai, azure, aws and more.
11-
12-
For the arguments these inference engines can receive, please refer to the classes documentation or read
13-
about the the open ai api arguments the CrossProviderInferenceEngine follows.
14-
"""
15-
16-
card = "cards.boolq.classification"
17-
template = "templates.classification.multi_class.relation.default"
18-
19-
df = pd.DataFrame(columns=["format", "system_prompt", "f1_micro", "ci_low", "ci_high"])
20-
21-
for format in [
22-
"formats.llama3_instruct",
23-
"formats.empty",
24-
"formats.llama3_instruct_all_demos_in_one_turn",
25-
]:
26-
for system_prompt in [
27-
"system_prompts.models.llama2",
28-
"system_prompts.empty",
29+
30+
df = pd.DataFrame(
31+
columns=[
32+
"model",
33+
"format",
34+
"system_prompt",
35+
"f1_micro",
36+
"ci_low",
37+
"ci_high",
38+
"duration",
39+
"num_instances",
40+
"type_of_input",
41+
]
42+
)
43+
44+
model_list = [
45+
(cross_provider_watsonx, "cross-provider-watsonx"),
46+
(wml_chat, "wml-chat"),
47+
(wml_gen, "wml-gen"),
48+
]
49+
50+
# This example compares the impact of different formats on a classification dataset
51+
#
52+
# formats.chat_api - creates a list of OpenAI messages, where the instruction appears in the system prompt.
53+
#
54+
# [
55+
# {
56+
# "role": "system",
57+
# "content": "Classify the contractual clauses of the following text to one of these options: Records, Warranties... "
58+
# },
59+
# {
60+
# "role": "user",
61+
# "content": "text: Each Credit Party shall maintain..."
62+
# },
63+
# {
64+
# "role": "assistant",
65+
# "content": "The contractual clauses is Records"
66+
# },
67+
# {
68+
# "role": "user",
69+
# "content": "text: Executive agrees to be employed with the Company...."
70+
# }
71+
# ]
72+
#
73+
# formats.chat_api[place_instruction_in_user_turns=True] - creates a list of OpenAI messages, where the instruction appears in each user turn prompt.
74+
#
75+
# [
76+
# {
77+
# "role": "user",
78+
# "content": "Classify the contractual clauses of the following text to one of these options: ...
79+
# text: Each Credit Party shall maintain...."
80+
# },
81+
# {
82+
# "role": "assistant",
83+
# "content": "The contractual clauses is Records"
84+
# },
85+
# {
86+
# "role": "user",
87+
# "content": "Classify the contractual clauses of the following text to one of these options: ...
88+
# text: Executive agrees to be employed with the Company...
89+
# }
90+
# ]
91+
#
92+
# formats.empty - pass inputs as a single string
93+
#
94+
# "Classify the contractual clauses of the following text to one of these options: Records, Warranties,.
95+
# text: Each Credit Party shall maintain...
96+
# The contractual clauses is Records
97+
#
98+
# text: Executive agrees to be employed with the Company,...
99+
# The contractual clauses is "
100+
101+
for model, model_name in model_list:
102+
print(model_name)
103+
card = "cards.ledgar"
104+
template = "templates.classification.multi_class.instruction"
105+
for format in [
106+
"formats.chat_api[place_instruction_in_user_turns=True]",
107+
"formats.chat_api",
108+
"formats.empty",
29109
]:
30-
dataset = load_dataset(
31-
card=card,
32-
template=template,
33-
format=format,
34-
system_prompt=system_prompt,
35-
num_demos=2,
36-
demos_pool_size=50,
37-
loader_limit=300,
38-
max_test_instances=100,
39-
split="test",
40-
)
41-
42-
predictions = model(dataset)
43-
results = evaluate(predictions=predictions, data=dataset)
44-
45-
print(
46-
f"Sample input and output for format '{format}' and system prompt '{system_prompt}':"
47-
)
48-
49-
print(
50-
results.instance_scores.to_df(
51-
columns=[
52-
"source",
53-
"prediction",
54-
]
110+
for system_prompt in [
111+
"system_prompts.empty",
112+
]:
113+
if model_name == "wml-gen" and "formats.chat_api" in format:
114+
continue
115+
if model_name == "wml-chat" and "formats.chat_api" not in format:
116+
continue
117+
dataset = load_dataset(
118+
card=card,
119+
format=format,
120+
system_prompt=system_prompt,
121+
template=template,
122+
num_demos=5,
123+
demos_pool_size=100,
124+
loader_limit=1000,
125+
max_test_instances=128,
126+
split="test",
55127
)
56-
)
57-
58-
global_scores = results.global_scores
59-
df.loc[len(df)] = [
60-
format,
61-
system_prompt,
62-
global_scores["score"],
63-
global_scores["score_ci_low"],
64-
global_scores["score_ci_high"],
65-
]
66-
67-
df = df.round(decimals=2)
68-
print(df.to_markdown())
128+
type_of_input = type(dataset[0]["source"])
129+
130+
print("Starting inference...")
131+
start = time.perf_counter()
132+
predictions = model(dataset)
133+
end = time.perf_counter()
134+
duration = end - start
135+
print("End of inference...")
136+
137+
results = evaluate(predictions=predictions, data=dataset)
138+
139+
print(
140+
f"Sample input and output for format '{format}' and system prompt '{system_prompt}':"
141+
)
142+
143+
print("Example prompt:")
144+
145+
print(json.dumps(results.instance_scores[0]["source"], indent=4))
146+
147+
print("Example prediction:")
148+
149+
print(json.dumps(results.instance_scores[0]["prediction"], indent=4))
150+
151+
global_scores = results.global_scores
152+
df.loc[len(df)] = [
153+
model_name,
154+
format,
155+
system_prompt,
156+
global_scores["score"],
157+
global_scores["score_ci_low"],
158+
global_scores["score_ci_high"],
159+
duration,
160+
len(predictions),
161+
type_of_input,
162+
]
163+
164+
df = df.round(decimals=2)
165+
print(df.to_markdown())
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from unitxt import get_logger
2+
from unitxt.api import create_dataset, evaluate
3+
from unitxt.formats import HFSystemFormat
4+
from unitxt.inference import CrossProviderInferenceEngine
5+
from unitxt.processors import ExtractWithRegex, PostProcess
6+
from unitxt.task import Task
7+
from unitxt.templates import InputOutputTemplate
8+
9+
logger = get_logger()
10+
11+
# Set up question answer pairs in a dictionary
12+
test_set = [
13+
{
14+
"question": "If I had 32 apples, I lost 5 apples, and gain twice more as many as I have. How many do I have at the end",
15+
"answer": "81",
16+
},
17+
]
18+
19+
20+
# define the QA task
21+
task = Task(
22+
input_fields={"question": str},
23+
reference_fields={"answer": str},
24+
prediction_type=str,
25+
metrics=["metrics.accuracy"],
26+
)
27+
28+
29+
# Create a simple template that formats the input.
30+
# Add lowercase normalization as a post processor.
31+
32+
33+
for thinking in [True, False]:
34+
postprocessors = ["processors.lower_case"]
35+
if thinking:
36+
postprocessors.append(
37+
PostProcess(
38+
ExtractWithRegex(regex="<response>(.*)</response"),
39+
process_references=False,
40+
)
41+
)
42+
43+
template = InputOutputTemplate(
44+
instruction="Answer the following question with the single numeric answer. Do not answer in complete sentences. Just return the answer.",
45+
input_format="{question}",
46+
output_format="{answer}",
47+
postprocessors=postprocessors,
48+
)
49+
dataset = create_dataset(
50+
task=task,
51+
test_set=test_set,
52+
template=template,
53+
split="test",
54+
format=HFSystemFormat(
55+
model_name="ibm-granite/granite-3.3-8b-instruct",
56+
chat_kwargs_dict={"thinking": thinking},
57+
place_instruction_in_user_turns=True,
58+
),
59+
)
60+
61+
model = CrossProviderInferenceEngine(
62+
model="granite-3-3-8b-instruct", provider="rits", use_cache=False
63+
)
64+
65+
predictions = model(dataset)
66+
67+
results = evaluate(predictions=predictions, data=dataset)
68+
69+
print("Instance Results when Thinking=", thinking)
70+
print(results.instance_scores)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from unitxt.api import evaluate, load_dataset
2+
from unitxt.formats import HFSystemFormat
3+
from unitxt.inference import CrossProviderInferenceEngine
4+
from unitxt.processors import ExtractWithRegex, PostProcess
5+
from unitxt.templates import MultipleChoiceTemplate
6+
7+
for thinking in [True, False]:
8+
postprocessors = ["processors.first_character"]
9+
if thinking:
10+
postprocessors = [
11+
PostProcess(
12+
ExtractWithRegex(regex="<response>(.*)</response"),
13+
process_references=False,
14+
),
15+
"processors.first_character",
16+
]
17+
18+
template = MultipleChoiceTemplate(
19+
input_format="""The following are multiple choice questions (with answers) about {topic}.
20+
{question}
21+
Answers:
22+
{choices}
23+
The response should be returned as a single letter: A, B, C, or D. Do not answer in sentences. Only return the single letter answer.""",
24+
target_field="answer",
25+
choices_separator="\n",
26+
postprocessors=postprocessors,
27+
)
28+
dataset = load_dataset(
29+
card="cards.mmlu.abstract_algebra",
30+
template=template,
31+
split="test",
32+
format=HFSystemFormat(
33+
model_name="ibm-granite/granite-3.3-8b-instruct",
34+
chat_kwargs_dict={"thinking": thinking},
35+
place_instruction_in_user_turns=True,
36+
),
37+
loader_limit=100,
38+
)
39+
40+
model = CrossProviderInferenceEngine(
41+
model="granite-3-3-8b-instruct", provider="rits", temperature=0
42+
)
43+
44+
predictions = model(dataset)
45+
46+
results = evaluate(predictions=predictions, data=dataset)
47+
48+
print("Instance Results when Thinking=", thinking)
49+
50+
for instance in results.instance_scores:
51+
if instance["processed_prediction"] not in ["A", "B", "C", "D"]:
52+
print(
53+
"Problematic prediction (could not be parsed to a acceptable single letter answer):"
54+
)
55+
print(instance["prediction"])
56+
57+
print("Global Results when Thinking=", thinking)
58+
print(results.global_scores.summary)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from unitxt.inference import CrossProviderInferenceEngine
2+
from unitxt.text_utils import print_dict
3+
4+
if __name__ == "__main__":
5+
for provider in ["watsonx", "rits", "watsonx-sdk"]:
6+
print()
7+
print("------------------------------------------------ ")
8+
print("PROVIDER:", provider)
9+
model = CrossProviderInferenceEngine(
10+
model="granite-3-3-8b-instruct", provider=provider, temperature=0
11+
)
12+
13+
# Loading dataset:
14+
test_data = [
15+
{
16+
"source": [{"content": "Hello, how are you?", "role": "user"}],
17+
"data_classification_policy": ["public"],
18+
}
19+
]
20+
21+
# Performing inference:
22+
predictions = model(test_data)
23+
for inp, prediction in zip(test_data, predictions):
24+
result = {**inp, "prediction": prediction}
25+
26+
print_dict(result, keys_to_print=["source", "prediction"])

0 commit comments

Comments
 (0)