Skip to content

Commit b9936c3

Browse files
updated the rest of the core and examples to use pydantic and added doc strings
1 parent 1b51ce4 commit b9936c3

27 files changed

+863
-143
lines changed

examples/binary_classify_list_example.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
import asyncio
2-
from core.binary_classify_list_agent import BinaryClassifyListAgent
2+
from core.binary_classify_list_agent import BinaryClassifyListAgent, BinaryClassifyListInput
33

44
async def run_binary_classify_list_example():
5-
items_to_classify = ['Apple', 'Chocolate', 'Carrot']
6-
criteria = 'Classify each item as either healthy (true) or unhealthy (false)'
7-
agent = BinaryClassifyListAgent(list_to_classify=items_to_classify, criteria=criteria)
5+
input_data = BinaryClassifyListInput(
6+
list_to_classify=['Apple', 'Chocolate', 'Carrot'],
7+
criteria='Classify each item as either healthy (true) or unhealthy (false)',
8+
max_tokens=1000,
9+
temperature=0.0
10+
)
11+
12+
agent = BinaryClassifyListAgent(input_data)
813
classified_items = await agent.classify_list()
914

10-
print("Original list:", items_to_classify)
15+
print("Original list:", input_data.list_to_classify)
1116
print("Binary classified results:", classified_items)
1217

1318
if __name__ == "__main__":

examples/chain_of_thought_example.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import asyncio
2-
from core.chain_of_thought_agent import ChainOfThoughtAgent
2+
from core.chain_of_thought_agent import ChainOfThoughtAgent, ChainOfThoughtInput
33

44
async def run_chain_of_thought_example():
5-
question = 'What is the square root of 144?'
6-
agent = ChainOfThoughtAgent(question=question)
5+
input_data = ChainOfThoughtInput(
6+
question='What is the square root of 144?',
7+
max_tokens=1000,
8+
temperature=0.0
9+
)
10+
11+
agent = ChainOfThoughtAgent(input_data)
712
result = await agent.chain_of_thought()
813

9-
print("Question:", question)
14+
print("Question:", input_data.question)
1015
print("Chain of Thought Reasoning:", result)
1116

1217
if __name__ == "__main__":

examples/classify_list_example.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import asyncio
2-
from core.classify_list_agent import ClassifyListAgent
2+
from core.classify_list_agent import ClassifyListAgent, ClassifyListInput
33

44
async def run_classify_list_example():
5-
items_to_classify = ['Apple', 'Chocolate', 'Carrot']
6-
classification_criteria = 'Classify each item as healthy or unhealthy snack'
7-
agent = ClassifyListAgent(list_to_classify=items_to_classify, classification_criteria=classification_criteria)
8-
classified_items = await agent.classify_list()
5+
input_data = ClassifyListInput(
6+
list_to_classify=["Apple", "Banana", "Carrot"],
7+
classification_criteria="Classify each item as a fruit or vegetable.",
8+
max_tokens=1000
9+
)
10+
11+
agent = ClassifyListAgent(input_data)
12+
classifications = await agent.classify_list()
913

10-
print("Original list:", items_to_classify)
11-
print("Classified results:", classified_items)
14+
print("Original list:", input_data.list_to_classify)
15+
print("Classified results:", classifications)
1216

1317
if __name__ == "__main__":
14-
asyncio.run(run_classify_list_example())
18+
asyncio.run(run_classify_list_example())

examples/filter_list_example.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
11
import asyncio
2-
from core.filter_list_agent import FilterListAgent
2+
from core.filter_list_agent import FilterListAgent, FilterListInput
33

44
async def run_filter_list_example():
5-
goal = "Remove items that are unhealthy snacks."
6-
items_to_filter = [
7-
"Apple",
8-
"Chocolate bar",
9-
"Carrot",
10-
"Chips",
11-
"Orange"
12-
]
13-
14-
agent = FilterListAgent(goal=goal, items_to_filter=items_to_filter)
5+
input_data = FilterListInput(
6+
goal="Remove items that are unhealthy snacks.",
7+
items_to_filter=[
8+
"Apple",
9+
"Chocolate bar",
10+
"Carrot",
11+
"Chips",
12+
"Orange"
13+
],
14+
max_tokens=500,
15+
temperature=0.0
16+
)
17+
18+
agent = FilterListAgent(input_data)
1519
filtered_results = await agent.filter()
1620

17-
print("Original list:", items_to_filter)
21+
print("Original list:", input_data.items_to_filter)
1822
print("Filtered results:")
1923
for result in filtered_results:
2024
print(result)
2125

2226
if __name__ == "__main__":
23-
asyncio.run(run_filter_list_example())
27+
asyncio.run(run_filter_list_example())
Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
import asyncio
2-
from core.grounded_answer_agent import GroundedAnswerAgent
2+
from core.grounded_answer_agent import GroundedAnswerAgent, GroundedAnswerInput
33

44
async def run_grounded_answer_example():
5-
question = "What is the capital of France?"
6-
context = "France is a country in Western Europe. Paris is its capital and largest city."
7-
instructions = "Ensure the answer is grounded only in the provided context."
8-
agent = GroundedAnswerAgent(question=question, context=context, instructions=instructions)
9-
result = await agent.answer()
5+
input_data = GroundedAnswerInput(
6+
question="What is the capital of France?",
7+
context="France is a country in Western Europe known for its wine and cuisine. The capital is a major global center for art, fashion, and culture.",
8+
instructions="",
9+
max_tokens=1000
10+
)
11+
12+
agent = GroundedAnswerAgent(input_data)
13+
answer = await agent.answer()
1014

11-
print("Question:", question)
12-
print("Result:", result)
15+
print("Question:", input_data.question)
16+
print("Answer:", answer)
1317

1418
if __name__ == "__main__":
15-
asyncio.run(run_grounded_answer_example())
19+
asyncio.run(run_grounded_answer_example())

examples/map_list_example.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
import asyncio
2-
from core.map_list_agent import MapListAgent
2+
from core.map_list_agent import MapListAgent, MapListInput
33

44
async def run_map_list_example():
5-
items_to_map = ['Apple', 'Banana', 'Carrot']
6-
transformation = 'Convert all items to uppercase'
7-
agent = MapListAgent(list_to_map=items_to_map, transformation=transformation)
5+
input_data = MapListInput(
6+
list_to_map=['Apple', 'Banana', 'Carrot'],
7+
transformation='Convert all items to uppercase',
8+
max_tokens=1000
9+
)
10+
11+
agent = MapListAgent(input_data)
812
transformed_items = await agent.map_list()
913

10-
print("Original list:", items_to_map)
14+
print("Original list:", input_data.list_to_map)
1115
print("Transformed list:", transformed_items)
1216

1317
if __name__ == "__main__":

examples/project_list_example.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
import asyncio
2-
from core.project_list_agent import ProjectListAgent
2+
from core.project_list_agent import ProjectListAgent, ProjectListInput
33

44
async def run_project_list_example():
5-
items_to_project = ['Apple', 'Banana', 'Carrot']
6-
projection_rule = 'Project these items as their vitamin content'
7-
agent = ProjectListAgent(list_to_project=items_to_project, projection_rule=projection_rule)
5+
input_data = ProjectListInput(
6+
list_to_project=['Apple', 'Banana', 'Carrot'],
7+
projection_rule='Project these items as their vitamin content',
8+
max_tokens=1000
9+
)
10+
11+
agent = ProjectListAgent(input_data)
812
projected_items = await agent.project_list()
913

10-
print("Original list:", items_to_project)
14+
print("Original list:", input_data.list_to_project)
1115
print("Projected results:", projected_items)
1216

1317
if __name__ == "__main__":

examples/reduce_list_example.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import asyncio
2-
from core.reduce_list_agent import ReduceListAgent
2+
from core.reduce_list_agent import ReduceListAgent, ReduceListInput
33

44
async def run_reduce_list_example():
5-
items_to_reduce = ['Banana', 'Apple', 'Carrot']
6-
reduction_goal = 'Reduce these items to a single word representing their nutritional value'
7-
agent = ReduceListAgent(list_to_reduce=items_to_reduce, reduction_goal=reduction_goal)
5+
input_data = ReduceListInput(
6+
list_to_reduce=["Apple", "Banana", "Carrot"],
7+
reduction_goal="Reduce each item to its first letter.",
8+
max_tokens=1000
9+
)
10+
11+
agent = ReduceListAgent(input_data)
812
reduced_items = await agent.reduce_list()
913

10-
print("Original list:", items_to_reduce)
14+
print("Original list:", input_data.list_to_reduce)
1115
print("Reduced results:", reduced_items)
1216

1317
if __name__ == "__main__":
14-
asyncio.run(run_reduce_list_example())
18+
asyncio.run(run_reduce_list_example())

examples/summarize_list_example.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
import asyncio
2-
from core.summarize_list_agent import SummarizeListAgent
2+
from core.summarize_list_agent import SummarizeListAgent, SummarizeListInput
33

44
async def run_summarize_list_example():
5-
items_to_summarize = ['The quick brown fox jumps over the lazy dog.', 'Python is a popular programming language.']
6-
agent = SummarizeListAgent(list_to_summarize=items_to_summarize)
5+
input_data = SummarizeListInput(
6+
list_to_summarize=[
7+
'The quick brown fox jumps over the lazy dog.',
8+
'Python is a popular programming language.'
9+
],
10+
max_tokens=1000
11+
)
12+
13+
agent = SummarizeListAgent(input_data)
714
summaries = await agent.summarize_list()
815

9-
print("Original list:", items_to_summarize)
16+
print("Original list:", input_data.list_to_summarize)
1017
print("Summarized results:", summaries)
1118

1219
if __name__ == "__main__":

src/core/binary_classify_list_agent.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,54 @@
1+
from pydantic import BaseModel, Field
12
import asyncio
23
from typing import List, Dict
34
from .openai_api import OpenAIClient
45
from .logging import Logger # Using correct logging abstraction
56

7+
class BinaryClassifyListInput(BaseModel):
8+
list_to_classify: List[str] = Field(..., description="The list of items to classify")
9+
criteria: str = Field(..., description="The criteria for binary classification")
10+
max_tokens: int = Field(1000, description="The maximum number of tokens to generate")
11+
temperature: float = Field(0.0, description="Sampling temperature for the OpenAI model")
12+
613
class BinaryClassifyListAgent:
7-
def __init__(self, list_to_classify: List[str], criteria: str, max_tokens: int = 1000, temperature: float = 0.0):
8-
self.list_to_classify = list_to_classify
9-
self.criteria = criteria
10-
self.max_tokens = max_tokens
11-
self.temperature = temperature
14+
"""
15+
A class to classify items in a list based on binary criteria using the OpenAI API.
16+
17+
Attributes:
18+
list_to_classify (List[str]): The list of items to classify.
19+
criteria (str): The criteria for binary classification.
20+
max_tokens (int): The maximum number of tokens to generate.
21+
temperature (float): Sampling temperature for the OpenAI model.
22+
openai_client (OpenAIClient): An instance of OpenAIClient to interact with the API.
23+
logger (Logger): An instance of Logger to log classification requests and responses.
24+
25+
Methods:
26+
classify_list(): Classifies the entire list of items.
27+
classify_item(user_prompt): Classifies a single item based on the criteria.
28+
"""
29+
30+
def __init__(self, data: BinaryClassifyListInput):
31+
"""
32+
Constructs all the necessary attributes for the BinaryClassifyListAgent object.
33+
34+
Args:
35+
data (BinaryClassifyListInput): An instance of BinaryClassifyListInput containing
36+
the list of items, criteria, max_tokens, and temperature.
37+
"""
38+
self.list_to_classify = data.list_to_classify
39+
self.criteria = data.criteria
40+
self.max_tokens = data.max_tokens
41+
self.temperature = data.temperature
1242
self.openai_client = OpenAIClient()
1343
self.logger = Logger()
1444

1545
async def classify_list(self) -> List[Dict]:
46+
"""
47+
Classifies the entire list based on the provided items and criteria.
48+
49+
Returns:
50+
List[Dict]: A list of dictionaries with the classification results.
51+
"""
1652
tasks = []
1753
for item in self.list_to_classify:
1854
user_prompt = f"Based on the following criteria '{self.criteria}', classify the item '{item}' as true or false."
@@ -22,6 +58,15 @@ async def classify_list(self) -> List[Dict]:
2258
return results
2359

2460
async def classify_item(self, user_prompt: str) -> Dict:
61+
"""
62+
Classifies a single item based on the criteria.
63+
64+
Args:
65+
user_prompt (str): The prompt describing the classification criteria and item.
66+
67+
Returns:
68+
Dict: A dictionary with the classification result.
69+
"""
2570
system_prompt = "You are an assistant tasked with binary classification of items."
2671

2772
self.logger.info(f"Classifying item: {user_prompt}") # Logging the classification request

0 commit comments

Comments
 (0)