1+ from pydantic import BaseModel , Field
12import asyncio
23from typing import List , Dict
34from .openai_api import OpenAIClient
45from .logging import Logger # Using correct logging abstraction
56
7+ class BinaryClassifyListInput (BaseModel ):
8+ list_to_classify : List [str ] = Field (..., description = "The list of items to classify" )
9+ criteria : str = Field (..., description = "The criteria for binary classification" )
10+ max_tokens : int = Field (1000 , description = "The maximum number of tokens to generate" )
11+ temperature : float = Field (0.0 , description = "Sampling temperature for the OpenAI model" )
12+
613class BinaryClassifyListAgent :
7- def __init__ (self , list_to_classify : List [str ], criteria : str , max_tokens : int = 1000 , temperature : float = 0.0 ):
8- self .list_to_classify = list_to_classify
9- self .criteria = criteria
10- self .max_tokens = max_tokens
11- self .temperature = temperature
14+ """
15+ A class to classify items in a list based on binary criteria using the OpenAI API.
16+
17+ Attributes:
18+ list_to_classify (List[str]): The list of items to classify.
19+ criteria (str): The criteria for binary classification.
20+ max_tokens (int): The maximum number of tokens to generate.
21+ temperature (float): Sampling temperature for the OpenAI model.
22+ openai_client (OpenAIClient): An instance of OpenAIClient to interact with the API.
23+ logger (Logger): An instance of Logger to log classification requests and responses.
24+
25+ Methods:
26+ classify_list(): Classifies the entire list of items.
27+ classify_item(user_prompt): Classifies a single item based on the criteria.
28+ """
29+
30+ def __init__ (self , data : BinaryClassifyListInput ):
31+ """
32+ Constructs all the necessary attributes for the BinaryClassifyListAgent object.
33+
34+ Args:
35+ data (BinaryClassifyListInput): An instance of BinaryClassifyListInput containing
36+ the list of items, criteria, max_tokens, and temperature.
37+ """
38+ self .list_to_classify = data .list_to_classify
39+ self .criteria = data .criteria
40+ self .max_tokens = data .max_tokens
41+ self .temperature = data .temperature
1242 self .openai_client = OpenAIClient ()
1343 self .logger = Logger ()
1444
1545 async def classify_list (self ) -> List [Dict ]:
46+ """
47+ Classifies the entire list based on the provided items and criteria.
48+
49+ Returns:
50+ List[Dict]: A list of dictionaries with the classification results.
51+ """
1652 tasks = []
1753 for item in self .list_to_classify :
1854 user_prompt = f"Based on the following criteria '{ self .criteria } ', classify the item '{ item } ' as true or false."
@@ -22,6 +58,15 @@ async def classify_list(self) -> List[Dict]:
2258 return results
2359
2460 async def classify_item (self , user_prompt : str ) -> Dict :
61+ """
62+ Classifies a single item based on the criteria.
63+
64+ Args:
65+ user_prompt (str): The prompt describing the classification criteria and item.
66+
67+ Returns:
68+ Dict: A dictionary with the classification result.
69+ """
2570 system_prompt = "You are an assistant tasked with binary classification of items."
2671
2772 self .logger .info (f"Classifying item: { user_prompt } " ) # Logging the classification request
0 commit comments