13
13
import uuid
14
14
from collections import Counter
15
15
from datetime import datetime
16
+ from itertools import islice
16
17
from multiprocessing .pool import ThreadPool
17
18
from typing import (
18
19
Any ,
55
56
logger = get_logger ()
56
57
57
58
59
+ def batched (lst , n ):
60
+ it = iter (lst )
61
+ while batch := list (islice (it , n )):
62
+ yield batch
63
+
58
64
class StandardAPIParamsMixin (Artifact ):
59
65
model : str
60
66
frequency_penalty : Optional [float ] = None
@@ -227,12 +233,8 @@ def infer(
227
233
result = self ._mock_infer (dataset )
228
234
else :
229
235
if self .use_cache :
230
- if isinstance (dataset , Dataset ):
231
- dataset = dataset .to_list ()
232
- dataset_batches = [dataset [i :i + self .cache_batch_size ]
233
- for i in range (0 , len (dataset ), self .cache_batch_size )]
234
236
result = []
235
- for batch_num , batch in enumerate (dataset_batches ):
237
+ for batch_num , batch in enumerate (batched ( dataset , self . cache_batch_size ) ):
236
238
cached_results = []
237
239
missing_examples = []
238
240
for i , item in enumerate (batch ):
@@ -243,16 +245,19 @@ def infer(
243
245
else :
244
246
missing_examples .append ((i , item )) # each element is index in batch and example
245
247
# infare on missing examples only, without indices
246
- logger .info (f"Inferring batch { batch_num } / { len (dataset_batches )} " )
247
- inferred_results = self ._infer ([e [1 ] for e in missing_examples ], return_meta_data )
248
- # recombined to index and value
249
- inferred_results = list (zip ([e [0 ] for e in missing_examples ], inferred_results ))
250
- # Add missing examples to cache
251
- for (_ , item ), (_ , prediction ) in zip (missing_examples , inferred_results ):
252
- if prediction is None :
253
- continue
254
- cache_key = self ._get_cache_key (item )
255
- self ._cache [cache_key ] = prediction
248
+ logger .info (f"Inferring batch { batch_num } / { len (dataset ) // self .cache_batch_size } " )
249
+ if len (missing_examples ) > 0 :
250
+ inferred_results = self ._infer ([e [1 ] for e in missing_examples ], return_meta_data )
251
+ # recombined to index and value
252
+ inferred_results = list (zip ([e [0 ] for e in missing_examples ], inferred_results ))
253
+ # Add missing examples to cache
254
+ for (_ , item ), (_ , prediction ) in zip (missing_examples , inferred_results ):
255
+ if prediction is None :
256
+ continue
257
+ cache_key = self ._get_cache_key (item )
258
+ self ._cache [cache_key ] = prediction
259
+ else :
260
+ inferred_results = []
256
261
257
262
# Combine cached and inferred results in original order
258
263
batch_predictions = [p [1 ] for p in sorted (cached_results + inferred_results )]
@@ -1798,6 +1803,10 @@ class RITSInferenceEngine(
1798
1803
label : str = "rits"
1799
1804
data_classification_policy = ["public" , "proprietary" ]
1800
1805
1806
+ model_names_dict = {
1807
+ "microsoft/phi-4" : "microsoft-phi-4"
1808
+ }
1809
+
1801
1810
def get_default_headers (self ):
1802
1811
return {"RITS_API_KEY" : self .credentials ["api_key" ]}
1803
1812
@@ -1818,8 +1827,10 @@ def get_base_url_from_model_name(model_name: str):
1818
1827
RITSInferenceEngine ._get_model_name_for_endpoint (model_name )
1819
1828
)
1820
1829
1821
- @staticmethod
1822
- def _get_model_name_for_endpoint (model_name : str ):
1830
+ @classmethod
1831
+ def _get_model_name_for_endpoint (cls , model_name : str ):
1832
+ if model_name in cls .model_names_dict :
1833
+ return cls .model_names_dict [model_name ]
1823
1834
return (
1824
1835
model_name .split ("/" )[- 1 ]
1825
1836
.lower ()
@@ -2959,15 +2970,12 @@ def prepare_engine(self):
2959
2970
capacity = self .max_requests_per_second ,
2960
2971
)
2961
2972
self .inference_type = "litellm"
2962
- import litellm
2963
2973
from litellm import acompletion
2964
- from litellm .caching .caching import Cache
2965
2974
2966
- litellm .cache = Cache (type = "disk" )
2967
2975
2968
2976
self ._completion = acompletion
2969
2977
# Initialize a semaphore to limit concurrency
2970
- self ._semaphore = asyncio .Semaphore (self .max_requests_per_second )
2978
+ self ._semaphore = asyncio .Semaphore (round ( self .max_requests_per_second ) )
2971
2979
2972
2980
async def _infer_instance (
2973
2981
self , index : int , instance : Dict [str , Any ]
0 commit comments