1
+ import asyncio
2
+ from typing import Any , List , Type , Iterator , Optional
3
+ from pydantic import BaseModel
4
+ from openai import OpenAI
5
+ from instructor .batch import BatchJob as InstructorBatchJob
6
+ import json
7
+ import os
8
+
9
+ SLEEP_TIME = 60
10
+
11
+ class BatchJob :
12
+ def __init__ (
13
+ self ,
14
+ messages_batch : Iterator [List [dict ]],
15
+ model : str ,
16
+ response_model : Type [BaseModel ],
17
+ file_path : str ,
18
+ output_path : str ,
19
+ api_key : str = os .getenv ("OPENAI_API_KEY" )
20
+ ):
21
+ self .response_model = response_model
22
+ self .output_path = output_path
23
+ self .file_path = file_path
24
+ self .model = model
25
+ self .client = OpenAI (api_key = api_key )
26
+ self .batch_id = None
27
+ self .file_id = None
28
+
29
+ # Create the batch job input file (.jsonl)
30
+ InstructorBatchJob .create_from_messages (
31
+ messages_batch = messages_batch ,
32
+ model = model ,
33
+ file_path = file_path ,
34
+ response_model = response_model
35
+ )
36
+
37
+ self ._add_method_to_file ()
38
+
39
+ # Upload file and create batch job
40
+ self .file_id = self ._upload_file ()
41
+ if not self .file_id :
42
+ raise ValueError ("Failed to upload file" )
43
+
44
+ self .batch_id = self ._create_batch_job ()
45
+ if not self .batch_id :
46
+ raise ValueError ("Failed to create batch job" )
47
+
48
+ def _add_method_to_file (self ) -> None :
49
+ """Transform the JSONL file to match OpenAI's batch request format."""
50
+ with open (self .file_path , 'r' ) as file :
51
+ lines = file .readlines ()
52
+
53
+ with open (self .file_path , 'w' ) as file :
54
+ for line in lines :
55
+ data = json .loads (line )
56
+
57
+ new_data = {
58
+ "custom_id" : data ["custom_id" ],
59
+ "method" : "POST" ,
60
+ "url" : "/v1/chat/completions" ,
61
+ "body" : {
62
+ "model" : data ["params" ]["model" ],
63
+ "messages" : data ["params" ]["messages" ],
64
+ "max_tokens" : data ["params" ]["max_tokens" ],
65
+ "temperature" : data ["params" ]["temperature" ],
66
+ "tools" : data ["params" ]["tools" ],
67
+ "tool_choice" : data ["params" ]["tool_choice" ]
68
+ }
69
+ }
70
+ file .write (json .dumps (new_data ) + '\n ' )
71
+
72
+ def _upload_file (self ) -> Optional [str ]:
73
+ """Upload the JSONL file to OpenAI."""
74
+ try :
75
+ with open (self .file_path , "rb" ) as file :
76
+ response = self .client .files .create (
77
+ file = file ,
78
+ purpose = "batch"
79
+ )
80
+ return response .id
81
+ except Exception as e :
82
+ print (f"Error uploading file: { e } " )
83
+ return None
84
+
85
+ def _create_batch_job (self ) -> Optional [str ]:
86
+ """Create a batch job via OpenAI API."""
87
+ try :
88
+ batch = self .client .batches .create (
89
+ input_file_id = self .file_id ,
90
+ endpoint = "/v1/chat/completions" ,
91
+ completion_window = "24h"
92
+ )
93
+ return batch .id
94
+ except Exception as e :
95
+ print (f"Error creating batch job: { e } " )
96
+ return None
97
+
98
+ async def get_status (self ) -> str :
99
+ """
100
+ Get the current status of the batch job.
101
+ Returns: queued, processing, completed, or failed
102
+ """
103
+ try :
104
+ batch = await asyncio .to_thread (
105
+ self .client .batches .retrieve ,
106
+ self .batch_id
107
+ )
108
+ return self ._map_status (batch .status )
109
+ except Exception as e :
110
+ print (f"Error getting batch status: { e } " )
111
+ return "failed"
112
+
113
+ def _map_status (self , api_status : str ) -> str :
114
+ """Maps OpenAI API status to simplified status."""
115
+ status_mapping = {
116
+ 'validating' : 'queued' ,
117
+ 'in_progress' : 'processing' ,
118
+ 'finalizing' : 'processing' ,
119
+ 'completed' : 'completed' ,
120
+ 'failed' : 'failed' ,
121
+ 'expired' : 'failed' ,
122
+ 'cancelling' : 'processing' ,
123
+ 'cancelled' : 'failed'
124
+ }
125
+ return status_mapping .get (api_status , 'failed' )
126
+
127
+ async def get_result (self ) -> BaseModel :
128
+ """
129
+ Wait for job completion and return parsed results using Instructor.
130
+ Returns a tuple of (parsed_results, unparsed_results).
131
+
132
+ parsed_results: List of successfully parsed objects matching response_model
133
+ unparsed_results: List of results that failed to parse
134
+ """
135
+ try :
136
+ # Wait until the batch is complete
137
+ while True :
138
+ status = await self .get_status ()
139
+ if status == 'completed' :
140
+ break
141
+ elif status == 'failed' :
142
+ raise ValueError ("Batch job failed" )
143
+ await asyncio .sleep (SLEEP_TIME )
144
+
145
+ # Get batch details
146
+ batch = await asyncio .to_thread (
147
+ self .client .batches .retrieve ,
148
+ self .batch_id
149
+ )
150
+
151
+ if not batch .output_file_id :
152
+ raise ValueError ("No output file ID found" )
153
+
154
+ # Download the output file
155
+ response = await asyncio .to_thread (
156
+ self .client .files .content ,
157
+ batch .output_file_id
158
+ )
159
+
160
+ # Save the output file
161
+ with open (self .output_path , 'w' ) as f :
162
+ f .write (response .text )
163
+
164
+ # Use Instructor to parse the results
165
+ parsed , unparsed = InstructorBatchJob .parse_from_file (
166
+ file_path = self .output_path ,
167
+ response_model = self .response_model
168
+ )
169
+
170
+ return parsed [0 ]
171
+
172
+ except Exception as e :
173
+ raise ValueError (f"Failed to process output file: { e } " )
174
+ finally :
175
+ self ._cleanup_files ()
176
+
177
+ async def cancel (self ) -> bool :
178
+ """Cancel the current batch job and confirm cancellation."""
179
+ if not self .batch_id :
180
+ print ("No batch job to cancel." )
181
+ return False
182
+
183
+ try :
184
+ await asyncio .to_thread (
185
+ self .client .batches .cancel ,
186
+ self .batch_id
187
+ )
188
+ print ("Batch job canceled successfully." )
189
+ self ._cleanup_files ()
190
+ return True
191
+ except Exception as e :
192
+ print (f"Error cancelling batch: { e } " )
193
+ return False
194
+
195
+ def _cleanup_files (self ):
196
+ """Remove temporary files and batch directory if empty"""
197
+ try :
198
+ if os .path .exists (self .file_path ):
199
+ os .remove (self .file_path )
200
+ if os .path .exists (self .output_path ):
201
+ os .remove (self .output_path )
202
+
203
+ # Try to remove parent directory if empty
204
+ batch_dir = os .path .dirname (self .file_path )
205
+ if os .path .exists (batch_dir ) and not os .listdir (batch_dir ):
206
+ os .rmdir (batch_dir )
207
+ except Exception as e :
208
+ print (f"Warning: Failed to cleanup batch files: { e } " )
209
+
210
+ def __del__ (self ):
211
+ """Cleanup files when object is destroyed"""
212
+ self ._cleanup_files ()
0 commit comments