66from datetime import timedelta
77from enum import StrEnum , auto
88from pathlib import Path
9- from typing import Annotated
9+ from typing import Annotated , Literal
1010
1111import mlperf_loadgen as lg
1212from loguru import logger
13- from pydantic import BaseModel , DirectoryPath , Field , field_validator
13+ from pydantic import BaseModel , DirectoryPath , Field , FilePath , field_validator
1414from pydantic_typer import Typer
1515from typer import Option
1616
17- from .evaluation import Evaluator
17+ from .evaluation import run_evaluation
1818from .task import ShopifyGlobalCatalogue
1919
2020app = Typer ()
@@ -179,7 +179,9 @@ class TestSettings(BaseModel):
179179 int ,
180180 Field (
181181 description = """The minimum testing query count.
182- The benchmark runs until this value has been met.""" ,
182+ The benchmark runs until this value has been met.
183+ if min_query_count is less than the total number of samples in the dataset,
184+ only the first min_query_count samples will be used during testing.""" ,
183185 ),
184186 ] = 100
185187
@@ -348,22 +350,6 @@ def to_lgtype(self) -> tuple[lg.TestSettings, lg.LogSettings]:
348350 log_settings = self .logging .to_lgtype ()
349351 return (test_settings , log_settings )
350352
351- class Evaluation (BaseModel ):
352- """Evaluate the results of the accuracy scenario."""
353- enable_evaluation : Annotated [
354- bool ,
355- Field (
356- description = "Evaluate the results of the accuracy scenario." ,
357- ),
358- ] = False
359-
360- filename : Annotated [
361- Path ,
362- Field (
363- description = "Location of the accuracy file." ,
364- ),
365- ] = Path ("./output/mlperf_log_accuracy.json" )
366-
367353
368354class Model (BaseModel ):
369355 """Specifies the model to use for the VL2L benchmark."""
@@ -391,6 +377,11 @@ class Dataset(BaseModel):
391377 ),
392378 ] = None
393379
380+ split : Annotated [
381+ Literal ["train" ,"test" ],
382+ Field (description = "choose between train or test split" ),
383+ ] = "train"
384+
394385
395386class Verbosity (StrEnum ):
396387 """The verbosity level of the logger."""
@@ -422,15 +413,28 @@ class Endpoint(BaseModel):
422413 Field (description = "The API key to authenticate the inference requests." ),
423414 ] = ""
424415
416+ @app .command ()
417+ def evaluate (
418+ filename : Annotated [
419+ FilePath ,
420+ Option (
421+ help = "Location of the accuracy file." ,
422+ ),
423+ ],
424+ dataset : Dataset ,
425+ ) -> None :
426+ """Evaluate the accuracy of the VLM responses."""
427+ logger .info ("Evaluating the accuracy file" )
428+ run_evaluation (filename = filename , dataset = dataset )
429+
425430
426431@app .command ()
427- def main (
432+ def benchmark (
428433 * ,
429434 settings : Settings ,
430435 model : Model ,
431436 dataset : Dataset ,
432437 endpoint : Endpoint ,
433- evaluation : Evaluation ,
434438 random_seed : Annotated [
435439 int ,
436440 Option (help = "The seed for the random number generator used by the benchmark." ),
@@ -441,33 +445,27 @@ def main(
441445 ] = Verbosity .INFO ,
442446) -> None :
443447 """Main CLI for running the VL2L benchmark."""
444- if evaluation .enable_evaluation :
445- logger .info ("Evaluating the accuracy file" )
446- evaluator = Evaluator (filename = evaluation .filename , dataset_cli = dataset )
447- evaluator .run_evaluation ()
448- else :
449- logger .remove ()
450- logger .add (sys .stdout , level = verbosity .value .upper ())
451- logger .info ("Running VL2L benchmark with settings: {}" , settings )
452- logger .info ("Running VL2L benchmark with model: {}" , model )
453- logger .info ("Running VL2L benchmark with dataset: {}" , dataset )
454- logger .info (
455- "Running VL2L benchmark with OpenAI API endpoint: {}" ,
456- endpoint )
457- logger .info ("Running VL2L benchmark with random seed: {}" , random_seed )
458- test_settings , log_settings = settings .to_lgtype ()
459- task = ShopifyGlobalCatalogue (
460- dataset_cli = dataset ,
461- model_cli = model ,
462- endpoint_cli = endpoint ,
463- scenario = settings .test .scenario ,
464- min_query_count = settings .test .min_query_count ,
465- random_seed = random_seed ,
466- )
467- sut = task .construct_sut ()
468- qsl = task .construct_qsl ()
469- logger .info ("Starting the VL2L benchmark with LoadGen..." )
470- lg .StartTestWithLogSettings (sut , qsl , test_settings , log_settings )
471- logger .info ("The VL2L benchmark with LoadGen completed." )
472- lg .DestroyQSL (qsl )
473- lg .DestroySUT (sut )
448+ logger .remove ()
449+ logger .add (sys .stdout , level = verbosity .value .upper ())
450+ logger .info ("Running VL2L benchmark with settings: {}" , settings )
451+ logger .info ("Running VL2L benchmark with model: {}" , model )
452+ logger .info ("Running VL2L benchmark with dataset: {}" , dataset )
453+ logger .info (
454+ "Running VL2L benchmark with OpenAI API endpoint: {}" ,
455+ endpoint )
456+ logger .info ("Running VL2L benchmark with random seed: {}" , random_seed )
457+ test_settings , log_settings = settings .to_lgtype ()
458+ task = ShopifyGlobalCatalogue (
459+ dataset_cli = dataset ,
460+ model_cli = model ,
461+ endpoint_cli = endpoint ,
462+ settings = settings .test ,
463+ random_seed = random_seed ,
464+ )
465+ sut = task .construct_sut ()
466+ qsl = task .construct_qsl ()
467+ logger .info ("Starting the VL2L benchmark with LoadGen..." )
468+ lg .StartTestWithLogSettings (sut , qsl , test_settings , log_settings )
469+ logger .info ("The VL2L benchmark with LoadGen completed." )
470+ lg .DestroyQSL (qsl )
471+ lg .DestroySUT (sut )
0 commit comments