@@ -22,43 +22,59 @@ def main() -> None:
2222 from tinker_cookbook .tokenizer_utils import get_tokenizer
2323
2424 parser = argparse .ArgumentParser ()
25- parser .add_argument ("--path" , required = True )
25+ parser .add_argument ("--path" , required = True , action = "append" , help = "One or more URI paths to evaluate concurrently" )
2626 parser .add_argument ("--base-model" , default = "Qwen/Qwen2.5-0.5B" )
2727 parser .add_argument ("--base-url" , default = os .getenv ("TINKER_BASE_URL" , os .getenv ("BASE_URL" , "http://127.0.0.1:8000" )))
2828 parser .add_argument ("--data" , default = "gsm8k_test.json" )
2929 parser .add_argument ("--gpu-memory-utilization" , type = float , default = 0.85 )
30+ parser .add_argument ("--microbatch-size" , type = int , default = 10 , help = "Number of evaluation problems to dispatch per micro-batch" )
3031 parser .add_argument ("--min-accuracy" , type = float , default = 0.0 , help = "exit nonzero if accuracy falls below this fraction" )
3132 args = parser .parse_args ()
3233
3334 with open (args .data ) as f :
3435 data = json .load (f )
3536
37+ paths = args .path if isinstance (args .path , list ) else [args .path ]
3638 client = ServiceClient (api_key = os .getenv ("TINKER_API_KEY" , "tml-dummy-key" ), base_url = args .base_url )
37- sampler = client .create_sampling_client (args . path )
39+ samplers = [ client .create_sampling_client (p ) for p in paths ]
3840 tokenizer = get_tokenizer (args .base_model )
3941
4042 sampling_params = types .SamplingParams (temperature = 0.0 , max_tokens = 256 )
4143 start = time .time ()
42-
43- outputs = []
44- for datum in data :
45- prompt_tokens = tokenizer .encode (datum ["prompt" ], add_special_tokens = False )
46- seqs = sampler .sample (
47- prompt = types .ModelInput .from_ints (tokens = prompt_tokens ),
48- num_samples = 1 ,
49- sampling_params = sampling_params ,
50- ).result ().sequences
51- outputs .append (tokenizer .decode (seqs [0 ].tokens ) if seqs else "" )
5244
53- elapsed = time .time () - start
54- correct = sum (int (extract (text ) == datum ["gold" ]) for datum , text in zip (data , outputs , strict = True ))
55- accuracy = correct / len (data )
45+ import asyncio
46+
47+ async def run_evals ():
48+ outputs_by_sampler = [[] for _ in paths ]
49+ batch_size = args .microbatch_size
50+ for i in range (0 , len (data ), batch_size ):
51+ chunk = data [i : i + batch_size ]
52+ for s_idx , sampler in enumerate (samplers ):
53+ tasks = [
54+ sampler .sample_async (
55+ prompt = types .ModelInput .from_ints (tokens = tokenizer .encode (datum ["prompt" ], add_special_tokens = False )),
56+ num_samples = 1 ,
57+ sampling_params = sampling_params ,
58+ )
59+ for datum in chunk
60+ ]
61+ res_list = await asyncio .gather (* tasks )
62+ for res in res_list :
63+ seqs = res .sequences
64+ outputs_by_sampler [s_idx ].append (tokenizer .decode (seqs [0 ].tokens ) if seqs else "" )
65+ return outputs_by_sampler
5666
57- print ("***************************************************************" )
58- print (f"[SAMPLER] { args .path } 0-shot GSM8K acc = { accuracy :.1%} on { len (data )} problems in { elapsed :.1f} s" )
59- print ("***************************************************************" )
60- if accuracy < args .min_accuracy :
61- raise SystemExit (f"GSM8K accuracy { accuracy :.1%} is below the required { args .min_accuracy :.1%} " )
67+ outputs_by_sampler = asyncio .run (run_evals ())
68+
69+ elapsed = time .time () - start
70+ for path , outputs in zip (paths , outputs_by_sampler , strict = True ):
71+ correct = sum (int (extract (text ) == datum ["gold" ]) for datum , text in zip (data , outputs , strict = True ))
72+ accuracy = correct / len (data )
73+ print ("***************************************************************" )
74+ print (f"[SAMPLER] { path } 0-shot GSM8K acc = { accuracy :.1%} on { len (data )} problems in { elapsed :.1f} s" )
75+ print ("***************************************************************" )
76+ if accuracy < args .min_accuracy :
77+ raise SystemExit (f"GSM8K accuracy { accuracy :.1%} for { path } is below the required { args .min_accuracy :.1%} " )
6278
6379
6480if __name__ == "__main__" :
0 commit comments