Skip to content

Commit 912dc2a

Browse files
Add granite documents format (#1566)
* Add granite documents format Signed-off-by: elronbandel <[email protected]> * Update process method to use Optional for stream_name parameter Signed-off-by: elronbandel <[email protected]> --------- Signed-off-by: elronbandel <[email protected]> Co-authored-by: OfirArviv <[email protected]>
1 parent 121c268 commit 912dc2a

File tree

4 files changed

+199
-2
lines changed

4 files changed

+199
-2
lines changed

prepare/formats/models/granite.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from unitxt.catalog import add_to_catalog
2+
from unitxt.formats import GraniteDocumentsFormat
3+
4+
format = GraniteDocumentsFormat(model="ibm-granite/granite-3.1-8b-instruct")
5+
6+
add_to_catalog(format, "formats.models.granite_3_1_documents", overwrite=True)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"__type__": "granite_documents_format",
3+
"model": "ibm-granite/granite-3.1-8b-instruct"
4+
}

src/unitxt/formats.py

+50
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from .dataclass import OptionalField
1515
from .dict_utils import dict_get
16+
from .error_utils import UnitxtError
1617
from .image_operators import image_to_data_url
1718
from .operator import InstanceOperator
1819
from .settings_utils import get_constants
@@ -25,6 +26,55 @@ class Format(InstanceOperator):
2526
pass
2627

2728

29+
class GraniteDocumentsFormat(Format):
30+
model: str = "ibm-granite/granite-3.1-8b-instruct"
31+
citations: bool = True
32+
length: str = "long"
33+
34+
_requirements_list = ["transformers"]
35+
36+
def prepare(self):
37+
super().prepare()
38+
from transformers import AutoTokenizer
39+
40+
self.tokenizer = AutoTokenizer.from_pretrained(self.model)
41+
42+
def process(
43+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
44+
) -> Dict[str, Any]:
45+
inputs = instance["input_fields"]
46+
if "question" not in inputs:
47+
raise UnitxtError(
48+
"GraniteRAGFormat works only for tasks with field: 'question'"
49+
)
50+
if "context" not in inputs and "contexts" not in inputs:
51+
raise UnitxtError(
52+
"GraniteRAGFormat works only for tasks with field: 'context' or 'contexts"
53+
)
54+
55+
if "context" in inputs:
56+
texts = [inputs["context"]]
57+
if "contexts" in inputs:
58+
texts = inputs["contexts"]
59+
60+
documents = []
61+
for text in texts:
62+
documents.append({"title": "", "text": text})
63+
64+
question = inputs["question"]
65+
66+
instance["source"] = self.tokenizer.apply_chat_template(
67+
[
68+
{"role": "user", "content": question},
69+
],
70+
documents=documents,
71+
controls={"citations": self.citations, "length": self.length},
72+
add_generation_prompt=True,
73+
tokenize=False,
74+
)
75+
return instance
76+
77+
2878
def apply_capital_new_line_notation(text: str) -> str:
2979
r"""Transforms a given string by applying the Capital New Line Notation.
3080

tests/library/test_formats.py

+139-2
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,29 @@
1+
from datetime import datetime
2+
3+
from unitxt.api import load_dataset
14
from unitxt.card import TaskCard
2-
from unitxt.formats import ChatAPIFormat, HFSystemFormat, SystemFormat
5+
from unitxt.collections_operators import Wrap
6+
from unitxt.formats import (
7+
ChatAPIFormat,
8+
GraniteDocumentsFormat,
9+
HFSystemFormat,
10+
SystemFormat,
11+
)
312
from unitxt.loaders import LoadFromDictionary
13+
from unitxt.operators import Rename, Set
414
from unitxt.settings_utils import get_constants
515
from unitxt.standard import DatasetRecipe
616
from unitxt.system_prompts import TextualSystemPrompt
717
from unitxt.task import Task
8-
from unitxt.templates import InputOutputTemplate
18+
from unitxt.templates import InputOutputTemplate, MultiReferenceTemplate, TemplatesDict
919
from unitxt.test_utils.operators import (
1020
check_operator,
1121
)
1222

1323
from tests.library.test_image_operators import create_random_jpeg_image
1424
from tests.utils import UnitxtTestCase
1525

26+
# Assume
1627
constants = get_constants()
1728

1829

@@ -327,6 +338,132 @@ def test_hf_system_format(self):
327338
tester=self,
328339
)
329340

341+
def test_granite_documents_format(self):
342+
inputs = [
343+
{
344+
"input_fields": {
345+
"question": "what is love?",
346+
"contexts": ["love is love"],
347+
},
348+
},
349+
{
350+
"input_fields": {
351+
"question": "what is love?",
352+
"context": "love is love",
353+
},
354+
},
355+
]
356+
357+
system_format = GraniteDocumentsFormat()
358+
359+
today = datetime.today().strftime("%B %d, %Y")
360+
targets = [
361+
{
362+
"input_fields": {
363+
"question": "what is love?",
364+
"contexts": ["love is love"],
365+
},
366+
"source": "<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: "
367+
+ today
368+
+ '.\nYou are Granite, developed by IBM. Write the response to the user\'s input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data.\n\nIn your response, use the symbols <co> and </co> to indicate when a fact comes from a document in the search result, e.g <co>0</co> for a fact from document 0. Afterwards, list all the citations with their corresponding documents in an ordered list.<|end_of_text|>\n<|start_of_role|>documents<|end_of_role|>Document 0\nlove is love<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>what is love?<|end_of_text|>\n<|start_of_role|>assistant {"citations": true, "length": "long"}<|end_of_role|>',
369+
},
370+
{
371+
"input_fields": {
372+
"question": "what is love?",
373+
"context": "love is love",
374+
},
375+
"source": "<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: "
376+
+ today
377+
+ '.\nYou are Granite, developed by IBM. Write the response to the user\'s input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data.\n\nIn your response, use the symbols <co> and </co> to indicate when a fact comes from a document in the search result, e.g <co>0</co> for a fact from document 0. Afterwards, list all the citations with their corresponding documents in an ordered list.<|end_of_text|>\n<|start_of_role|>documents<|end_of_role|>Document 0\nlove is love<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>what is love?<|end_of_text|>\n<|start_of_role|>assistant {"citations": true, "length": "long"}<|end_of_role|>',
378+
},
379+
]
380+
381+
check_operator(
382+
operator=system_format,
383+
inputs=inputs,
384+
targets=targets,
385+
tester=self,
386+
)
387+
388+
data = {
389+
"test": [
390+
{
391+
"query": "What city is the largest in Texas?",
392+
"extracted_chunks": "Austin is the capital of Texas.\nHouston is the the largest city in Texas but not the capital of it. ",
393+
"expected_answer": "Houston",
394+
},
395+
{
396+
"query": "What city is the capital of Texas?",
397+
"extracted_chunks": "Houston is the the largest city in Texas but not the capital of it. ",
398+
"expected_answer": "Austin",
399+
},
400+
]
401+
}
402+
403+
card = TaskCard(
404+
# Assumes this csv, contains 3 fields
405+
# question (string), extracted_chunks (string), expected_answer (string)
406+
loader=LoadFromDictionary(data=data),
407+
# Map these fields to the fields of the task.rag.response_generation task.
408+
# See https://www.unitxt.ai/en/latest/catalog/catalog.tasks.rag.response_generation.html
409+
preprocess_steps=[
410+
Rename(field_to_field={"query": "question"}),
411+
Wrap(field="extracted_chunks", inside="list", to_field="contexts"),
412+
Wrap(
413+
field="expected_answer", inside="list", to_field="reference_answers"
414+
),
415+
Set(
416+
fields={
417+
"contexts_ids": [],
418+
}
419+
),
420+
],
421+
# Specify the task and the desired metrics (note that these are part of the default
422+
# metrics for the task, so the metrics selection can be omitted).
423+
task="tasks.rag.response_generation",
424+
# Specify a default template
425+
templates=TemplatesDict(
426+
{
427+
"simple": MultiReferenceTemplate(
428+
instruction="Answer the question based on the information provided in the document given below.\n\n",
429+
input_format="Document: {contexts}\nQuestion: {question}",
430+
references_field="reference_answers",
431+
),
432+
}
433+
),
434+
)
435+
436+
# select recommended metrics according to your available resources.
437+
metrics = [
438+
"metrics.rag.response_generation.recommended.cpu_only.all",
439+
# "metrics.rag.response_generation.recommended.small_llm.all",
440+
# "metrics.rag.response_generation.recommended.llmaj_watsonx.all",
441+
# "metrics.rag.response_generation.recommended.llmaj_rits.all"
442+
# "metrics.rag.response_generation.recommended.llmaj_azure.all"
443+
]
444+
445+
# Verbalize the dataset using the template
446+
dataset = load_dataset(
447+
card=card,
448+
template_card_index="simple",
449+
format=GraniteDocumentsFormat(),
450+
split="test",
451+
max_test_instances=10,
452+
metrics=metrics,
453+
)
454+
455+
self.assertListEqual(
456+
dataset["source"],
457+
[
458+
"<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: "
459+
+ today
460+
+ '.\nYou are Granite, developed by IBM. Write the response to the user\'s input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data.\n\nIn your response, use the symbols <co> and </co> to indicate when a fact comes from a document in the search result, e.g <co>0</co> for a fact from document 0. Afterwards, list all the citations with their corresponding documents in an ordered list.<|end_of_text|>\n<|start_of_role|>documents<|end_of_role|>Document 0\nAustin is the capital of Texas.\nHouston is the the largest city in Texas but not the capital of it. <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>What city is the largest in Texas?<|end_of_text|>\n<|start_of_role|>assistant {"citations": true, "length": "long"}<|end_of_role|>',
461+
"<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: "
462+
+ today
463+
+ '.\nYou are Granite, developed by IBM. Write the response to the user\'s input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data.\n\nIn your response, use the symbols <co> and </co> to indicate when a fact comes from a document in the search result, e.g <co>0</co> for a fact from document 0. Afterwards, list all the citations with their corresponding documents in an ordered list.<|end_of_text|>\n<|start_of_role|>documents<|end_of_role|>Document 0\nHouston is the the largest city in Texas but not the capital of it. <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>What city is the capital of Texas?<|end_of_text|>\n<|start_of_role|>assistant {"citations": true, "length": "long"}<|end_of_role|>',
464+
],
465+
)
466+
330467
def test_system_format(self):
331468
instruction = "solve the math exercises"
332469

0 commit comments

Comments
 (0)