Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start implementing assesment for unitxt assitant #1625

Merged
merged 2 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ repos:
language: system
# Adjust the files pattern to match your needs
files: ^src/.*\.py$
exclude: .*/(metric|dataset|hf_utils)\.py$
exclude: .*/(metric|dataset|hf_utils|unitxt/assistant/.*)\.py$
# Optional: Specify types or exclude files
types: [python]

Expand Down
Empty file.
58 changes: 58 additions & 0 deletions src/unitxt/assistant/assessment/assistant_inference_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import csv
import logging
import os
from typing import Any, Dict, List, Union

from unitxt import create_dataset, evaluate
from unitxt.assistant.app import Assistant
from unitxt.assistant.assessment.my_pretty_table import print_generic_table
from unitxt.dataset import Dataset
from unitxt.inference import InferenceEngine, TextGenerationInferenceOutput

logger = logging.getLogger("assistance-inference-engine")


class AssistantInferenceEngine(InferenceEngine):
def prepare_engine(self):
self.assistant = Assistant()

def _infer(self, dataset: Union[List[Dict[str, Any]], Dataset], return_meta_data: bool = False) -> Union[List[str], List[TextGenerationInferenceOutput]]:
sources = [x if isinstance(x, str) else x["source"] for x in dataset]
messages = [[{"role": "user", "content": s}] for s in sources]
generators = [self.assistant.generate_response(m) for m in messages]
return ["".join(g) for g in generators]


if __name__ == "__main__":
dataset_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "unitxt_assistant_qa_dataset.csv")
with open(dataset_file_path, encoding="utf-8") as file:
reader = csv.DictReader(file)
dataset = list(reader)

dataset = [{k: [v] if k == "answers" else v for k, v in line.items()} for line in dataset]

criteria = "metrics.llm_as_judge.direct.criteria.answer_completeness"
metrics = [
f"metrics.llm_as_judge.direct.rits.llama3_1_70b[criteria={criteria}, context_fields=[answers]]"
]

dataset = create_dataset(
task="tasks.qa.open",
test_set=dataset,
metrics=metrics,
)
dataset = dataset["test"]
model = AssistantInferenceEngine()
predictions = model(dataset)

results = evaluate(predictions=predictions, data=dataset)

res_dict = [{"score": r["score"]["instance"]["score"], "source": r["source"], "prediction": r["prediction"],
"target": r["target"],
"judgement" :results[0]["score"]["instance"]["answer_completeness_positional_bias_assessment"]}
for r in results]
col_width = {"score": 5, "source": 25, "prediction": 50, "target": 25, "judgement": 80}
print_generic_table(headers=col_width.keys(), data=res_dict, col_widths=col_width)

logger.info("Global Scores:")
logger.info(results.global_scores.summary)
53 changes: 53 additions & 0 deletions src/unitxt/assistant/assessment/my_pretty_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import logging
import textwrap
from typing import List, Optional

# Setup logger
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)



def print_generic_table(headers: list, data: List, col_widths: Optional[dict] = None):
"""Prints a table with the given headers, column widths (with a default uniform width), and data.

Args:
headers (list): A list of column headers.
col_widths (dict, optional): A dictionary with column names as keys and their respective widths as values.
If not provided, all columns will have a uniform width of 20.
data (list): A list of dictionaries, where each dictionary represents a row with column names as keys.
"""
# Set default uniform width if col_widths is not provided
if col_widths is None:
col_widths = {header: 20 for header in headers}

# Calculate the total table width based on column widths
table_width = sum(col_widths.values()) + len(col_widths) * 3 + 2 # Adjust for separators and | at left and right

# Print separator before the table
logger.info("=" * table_width)

# Create the header row
header_row = " | ".join([f"{header:<{col_widths[header]}}" for header in headers])

# logger.info the header
logger.info(f"| {header_row} |")
logger.info("=" * table_width) # Separator line after the header

# Loop through the data and print each row
for row in data:
# Wrap text to fit within column widths and prepare wrapped rows
wrapped_columns = {col: textwrap.fill(str(row[col]), width=col_widths[col]) for col in headers}

# Split wrapped columns into multiple lines
wrapped_lines = {col: wrapped_columns[col].split("\n") for col in headers}

# Find the maximum number of lines across all columns
max_lines = max(len(wrapped_lines[col]) for col in headers)

# Print each line
for i in range(max_lines):
line = "| " + " | ".join([f"{wrapped_lines[col][i] if i < len(wrapped_lines[col]) else '':<{col_widths[col]}}"
for col in headers]) + " |"
logger.info(line)
logger.info("=" * table_width) # Separator line after each row
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
id,question,answers
1,"Which operator allows users to define operations in Python?","Using LiteralEval, ExecuteExpression or FilterByExpression"
2,"How can I specify a custom directory as the catalog root when adding a new artifact?","Use the 'catalog_path' parameter in the 'add_to_catalog' function."
3,"Can I configure Unitxt to search for the catalog in a specific directory using an environment variable?","Yes, set the 'UNITXT_CATALOGS' environment variable."
Loading