Skip to content

Commit f5995c3

Browse files
committed
address review comments and test hiclass implementation
1 parent 35c9704 commit f5995c3

File tree

4 files changed

+247
-233
lines changed

4 files changed

+247
-233
lines changed

multimodal/vl2l/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies = [
2121
"typer",
2222
"scikit-learn",
2323
"tabulate",
24+
"hiclass",
2425
]
2526
dynamic = ["version"]
2627

multimodal/vl2l/src/mlperf_inference_multimodal_vl2l/cli.py

Lines changed: 50 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
from datetime import timedelta
77
from enum import StrEnum, auto
88
from pathlib import Path
9-
from typing import Annotated
9+
from typing import Annotated, Literal
1010

1111
import mlperf_loadgen as lg
1212
from loguru import logger
13-
from pydantic import BaseModel, DirectoryPath, Field, field_validator
13+
from pydantic import BaseModel, DirectoryPath, Field, FilePath, field_validator
1414
from pydantic_typer import Typer
1515
from typer import Option
1616

17-
from .evaluation import Evaluator
17+
from .evaluation import run_evaluation
1818
from .task import ShopifyGlobalCatalogue
1919

2020
app = Typer()
@@ -179,7 +179,9 @@ class TestSettings(BaseModel):
179179
int,
180180
Field(
181181
description="""The minimum testing query count.
182-
The benchmark runs until this value has been met.""",
182+
The benchmark runs until this value has been met.
183+
if min_query_count is less than the total number of samples in the dataset,
184+
only the first min_query_count samples will be used during testing.""",
183185
),
184186
] = 100
185187

@@ -348,22 +350,6 @@ def to_lgtype(self) -> tuple[lg.TestSettings, lg.LogSettings]:
348350
log_settings = self.logging.to_lgtype()
349351
return (test_settings, log_settings)
350352

351-
class Evaluation(BaseModel):
352-
"""Evaluate the results of the accuracy scenario."""
353-
enable_evaluation: Annotated[
354-
bool,
355-
Field(
356-
description="Evaluate the results of the accuracy scenario.",
357-
),
358-
] = False
359-
360-
filename: Annotated[
361-
Path,
362-
Field(
363-
description="Location of the accuracy file.",
364-
),
365-
] = Path("./output/mlperf_log_accuracy.json")
366-
367353

368354
class Model(BaseModel):
369355
"""Specifies the model to use for the VL2L benchmark."""
@@ -391,6 +377,11 @@ class Dataset(BaseModel):
391377
),
392378
] = None
393379

380+
split: Annotated[
381+
Literal["train","test"],
382+
Field(description="choose between train or test split"),
383+
] = "train"
384+
394385

395386
class Verbosity(StrEnum):
396387
"""The verbosity level of the logger."""
@@ -422,15 +413,28 @@ class Endpoint(BaseModel):
422413
Field(description="The API key to authenticate the inference requests."),
423414
] = ""
424415

416+
@app.command()
417+
def evaluate(
418+
filename: Annotated[
419+
FilePath,
420+
Option(
421+
help="Location of the accuracy file.",
422+
),
423+
],
424+
dataset: Dataset,
425+
) -> None:
426+
"""Evaluate the accuracy of the VLM responses."""
427+
logger.info("Evaluating the accuracy file")
428+
run_evaluation(filename=filename, dataset=dataset)
429+
425430

426431
@app.command()
427-
def main(
432+
def benchmark(
428433
*,
429434
settings: Settings,
430435
model: Model,
431436
dataset: Dataset,
432437
endpoint: Endpoint,
433-
evaluation: Evaluation,
434438
random_seed: Annotated[
435439
int,
436440
Option(help="The seed for the random number generator used by the benchmark."),
@@ -441,33 +445,27 @@ def main(
441445
] = Verbosity.INFO,
442446
) -> None:
443447
"""Main CLI for running the VL2L benchmark."""
444-
if evaluation.enable_evaluation:
445-
logger.info("Evaluating the accuracy file")
446-
evaluator = Evaluator(filename=evaluation.filename, dataset_cli=dataset)
447-
evaluator.run_evaluation()
448-
else:
449-
logger.remove()
450-
logger.add(sys.stdout, level=verbosity.value.upper())
451-
logger.info("Running VL2L benchmark with settings: {}", settings)
452-
logger.info("Running VL2L benchmark with model: {}", model)
453-
logger.info("Running VL2L benchmark with dataset: {}", dataset)
454-
logger.info(
455-
"Running VL2L benchmark with OpenAI API endpoint: {}",
456-
endpoint)
457-
logger.info("Running VL2L benchmark with random seed: {}", random_seed)
458-
test_settings, log_settings = settings.to_lgtype()
459-
task = ShopifyGlobalCatalogue(
460-
dataset_cli=dataset,
461-
model_cli=model,
462-
endpoint_cli=endpoint,
463-
scenario=settings.test.scenario,
464-
min_query_count=settings.test.min_query_count,
465-
random_seed=random_seed,
466-
)
467-
sut = task.construct_sut()
468-
qsl = task.construct_qsl()
469-
logger.info("Starting the VL2L benchmark with LoadGen...")
470-
lg.StartTestWithLogSettings(sut, qsl, test_settings, log_settings)
471-
logger.info("The VL2L benchmark with LoadGen completed.")
472-
lg.DestroyQSL(qsl)
473-
lg.DestroySUT(sut)
448+
logger.remove()
449+
logger.add(sys.stdout, level=verbosity.value.upper())
450+
logger.info("Running VL2L benchmark with settings: {}", settings)
451+
logger.info("Running VL2L benchmark with model: {}", model)
452+
logger.info("Running VL2L benchmark with dataset: {}", dataset)
453+
logger.info(
454+
"Running VL2L benchmark with OpenAI API endpoint: {}",
455+
endpoint)
456+
logger.info("Running VL2L benchmark with random seed: {}", random_seed)
457+
test_settings, log_settings = settings.to_lgtype()
458+
task = ShopifyGlobalCatalogue(
459+
dataset_cli=dataset,
460+
model_cli=model,
461+
endpoint_cli=endpoint,
462+
settings = settings.test,
463+
random_seed=random_seed,
464+
)
465+
sut = task.construct_sut()
466+
qsl = task.construct_qsl()
467+
logger.info("Starting the VL2L benchmark with LoadGen...")
468+
lg.StartTestWithLogSettings(sut, qsl, test_settings, log_settings)
469+
logger.info("The VL2L benchmark with LoadGen completed.")
470+
lg.DestroyQSL(qsl)
471+
lg.DestroySUT(sut)

0 commit comments

Comments
 (0)