Skip to content

Commit 35e7684

Browse files
authored
Merge pull request #62 from enoch3712/61-extract-batch-feature
extract batch feature
2 parents b237386 + 5768f3e commit 35e7684

File tree

7 files changed

+455
-21
lines changed

7 files changed

+455
-21
lines changed

extract_thinker/batch_job.py

+212
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import asyncio
2+
from typing import Any, List, Type, Iterator, Optional
3+
from pydantic import BaseModel
4+
from openai import OpenAI
5+
from instructor.batch import BatchJob as InstructorBatchJob
6+
import json
7+
import os
8+
9+
SLEEP_TIME = 60
10+
11+
class BatchJob:
12+
def __init__(
13+
self,
14+
messages_batch: Iterator[List[dict]],
15+
model: str,
16+
response_model: Type[BaseModel],
17+
file_path: str,
18+
output_path: str,
19+
api_key: str = os.getenv("OPENAI_API_KEY")
20+
):
21+
self.response_model = response_model
22+
self.output_path = output_path
23+
self.file_path = file_path
24+
self.model = model
25+
self.client = OpenAI(api_key=api_key)
26+
self.batch_id = None
27+
self.file_id = None
28+
29+
# Create the batch job input file (.jsonl)
30+
InstructorBatchJob.create_from_messages(
31+
messages_batch=messages_batch,
32+
model=model,
33+
file_path=file_path,
34+
response_model=response_model
35+
)
36+
37+
self._add_method_to_file()
38+
39+
# Upload file and create batch job
40+
self.file_id = self._upload_file()
41+
if not self.file_id:
42+
raise ValueError("Failed to upload file")
43+
44+
self.batch_id = self._create_batch_job()
45+
if not self.batch_id:
46+
raise ValueError("Failed to create batch job")
47+
48+
def _add_method_to_file(self) -> None:
49+
"""Transform the JSONL file to match OpenAI's batch request format."""
50+
with open(self.file_path, 'r') as file:
51+
lines = file.readlines()
52+
53+
with open(self.file_path, 'w') as file:
54+
for line in lines:
55+
data = json.loads(line)
56+
57+
new_data = {
58+
"custom_id": data["custom_id"],
59+
"method": "POST",
60+
"url": "/v1/chat/completions",
61+
"body": {
62+
"model": data["params"]["model"],
63+
"messages": data["params"]["messages"],
64+
"max_tokens": data["params"]["max_tokens"],
65+
"temperature": data["params"]["temperature"],
66+
"tools": data["params"]["tools"],
67+
"tool_choice": data["params"]["tool_choice"]
68+
}
69+
}
70+
file.write(json.dumps(new_data) + '\n')
71+
72+
def _upload_file(self) -> Optional[str]:
73+
"""Upload the JSONL file to OpenAI."""
74+
try:
75+
with open(self.file_path, "rb") as file:
76+
response = self.client.files.create(
77+
file=file,
78+
purpose="batch"
79+
)
80+
return response.id
81+
except Exception as e:
82+
print(f"Error uploading file: {e}")
83+
return None
84+
85+
def _create_batch_job(self) -> Optional[str]:
86+
"""Create a batch job via OpenAI API."""
87+
try:
88+
batch = self.client.batches.create(
89+
input_file_id=self.file_id,
90+
endpoint="/v1/chat/completions",
91+
completion_window="24h"
92+
)
93+
return batch.id
94+
except Exception as e:
95+
print(f"Error creating batch job: {e}")
96+
return None
97+
98+
async def get_status(self) -> str:
99+
"""
100+
Get the current status of the batch job.
101+
Returns: queued, processing, completed, or failed
102+
"""
103+
try:
104+
batch = await asyncio.to_thread(
105+
self.client.batches.retrieve,
106+
self.batch_id
107+
)
108+
return self._map_status(batch.status)
109+
except Exception as e:
110+
print(f"Error getting batch status: {e}")
111+
return "failed"
112+
113+
def _map_status(self, api_status: str) -> str:
114+
"""Maps OpenAI API status to simplified status."""
115+
status_mapping = {
116+
'validating': 'queued',
117+
'in_progress': 'processing',
118+
'finalizing': 'processing',
119+
'completed': 'completed',
120+
'failed': 'failed',
121+
'expired': 'failed',
122+
'cancelling': 'processing',
123+
'cancelled': 'failed'
124+
}
125+
return status_mapping.get(api_status, 'failed')
126+
127+
async def get_result(self) -> BaseModel:
128+
"""
129+
Wait for job completion and return parsed results using Instructor.
130+
Returns a tuple of (parsed_results, unparsed_results).
131+
132+
parsed_results: List of successfully parsed objects matching response_model
133+
unparsed_results: List of results that failed to parse
134+
"""
135+
try:
136+
# Wait until the batch is complete
137+
while True:
138+
status = await self.get_status()
139+
if status == 'completed':
140+
break
141+
elif status == 'failed':
142+
raise ValueError("Batch job failed")
143+
await asyncio.sleep(SLEEP_TIME)
144+
145+
# Get batch details
146+
batch = await asyncio.to_thread(
147+
self.client.batches.retrieve,
148+
self.batch_id
149+
)
150+
151+
if not batch.output_file_id:
152+
raise ValueError("No output file ID found")
153+
154+
# Download the output file
155+
response = await asyncio.to_thread(
156+
self.client.files.content,
157+
batch.output_file_id
158+
)
159+
160+
# Save the output file
161+
with open(self.output_path, 'w') as f:
162+
f.write(response.text)
163+
164+
# Use Instructor to parse the results
165+
parsed, unparsed = InstructorBatchJob.parse_from_file(
166+
file_path=self.output_path,
167+
response_model=self.response_model
168+
)
169+
170+
return parsed[0]
171+
172+
except Exception as e:
173+
raise ValueError(f"Failed to process output file: {e}")
174+
finally:
175+
self._cleanup_files()
176+
177+
async def cancel(self) -> bool:
178+
"""Cancel the current batch job and confirm cancellation."""
179+
if not self.batch_id:
180+
print("No batch job to cancel.")
181+
return False
182+
183+
try:
184+
await asyncio.to_thread(
185+
self.client.batches.cancel,
186+
self.batch_id
187+
)
188+
print("Batch job canceled successfully.")
189+
self._cleanup_files()
190+
return True
191+
except Exception as e:
192+
print(f"Error cancelling batch: {e}")
193+
return False
194+
195+
def _cleanup_files(self):
196+
"""Remove temporary files and batch directory if empty"""
197+
try:
198+
if os.path.exists(self.file_path):
199+
os.remove(self.file_path)
200+
if os.path.exists(self.output_path):
201+
os.remove(self.output_path)
202+
203+
# Try to remove parent directory if empty
204+
batch_dir = os.path.dirname(self.file_path)
205+
if os.path.exists(batch_dir) and not os.listdir(batch_dir):
206+
os.rmdir(batch_dir)
207+
except Exception as e:
208+
print(f"Warning: Failed to cleanup batch files: {e}")
209+
210+
def __del__(self):
211+
"""Cleanup files when object is destroyed"""
212+
self._cleanup_files()

0 commit comments

Comments
 (0)