Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 134 additions & 167 deletions notebooks/openvino/vision_language_quantization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@
"metadata": {},
"outputs": [],
"source": [
"! pip install \"optimum-intel[openvino]\" datasets num2words\n",
"! pip install torchvision"
"! pip install \"optimum-intel[openvino]\" datasets num2words torchvision\n",
"! pip install git+https://github.com/huggingface/optimum-benchmark.git"
]
},
{
Expand Down Expand Up @@ -200,7 +200,7 @@
" },\n",
" dataset=dataset,\n",
" num_samples=num_samples,\n",
")\n"
")"
]
},
{
Expand Down Expand Up @@ -321,189 +321,156 @@
"id": "c7ef9297",
"metadata": {},
"source": [
"### 5c: Compare performance on different Intel Hardware platforms"
"### Step 5c: Compare performance on different Intel Hardware platforms"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "94f18ce3",
"id": "23406fcb",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import time\n",
"from optimum.intel import OVModelForVisualCausalLM\n",
"import matplotlib\n",
"matplotlib.use(\"Agg\")\n",
"import matplotlib.pyplot as plt\n",
"from huggingface_hub import create_repo, upload_file\n",
"from optimum_benchmark import (\n",
" Benchmark,\n",
" BenchmarkConfig,\n",
" BenchmarkReport,\n",
" InferenceConfig,\n",
" OpenVINOConfig,\n",
" ProcessConfig,\n",
" PyTorchConfig,\n",
")\n",
"from optimum_benchmark.logging_utils import setup_logging\n",
"from openvino.runtime import Core\n",
"\n",
"class InferRequestWrapper:\n",
" \"\"\"\n",
" A helper class to track pipeline components' inference time.\n",
" \"\"\"\n",
" def __init__(self, request, infer_time_values):\n",
" self.request = request\n",
" self.infer_time_values = infer_time_values\n",
" self._start_async_time = None\n",
"\n",
" def reset_state(self):\n",
" self.request.reset_state()\n",
"\n",
" def get_tensor(self, name):\n",
" return self.request.get_tensor(name)\n",
"\n",
" def __call__(self, *args, **kwargs):\n",
" start_time = time.perf_counter()\n",
" result = self.request(*args, **kwargs)\n",
" end_time = time.perf_counter()\n",
" self.infer_time_values.append(end_time - start_time)\n",
" return result\n",
"\n",
" def start_async(self, *args, **kwargs):\n",
" assert self._start_async_time is None, \"start_async is already in progress\"\n",
" self._start_async_time = time.perf_counter()\n",
" return self.request.start_async(*args, **kwargs)\n",
"\n",
" def wait(self):\n",
" assert self._start_async_time is not None, \"start_async must be called before wait\"\n",
" result = self.request.wait()\n",
" self.infer_time_values.append(time.perf_counter() - self._start_async_time)\n",
" self._start_async_time = None\n",
" return result\n",
"\n",
"\n",
"def benchmark(model, inputs, model_dir: str, nb_pass=10, warmup=4,max_tokens=50):\n",
" \"\"\"\n",
" Benchmark an OV visual causal LM model.\n",
"\n",
" Returns a dict with:\n",
" - avg_latency_sec\n",
" - image_throughput\n",
" - first_token_throughput\n",
" - second_token_throughput\n",
" - model_size_mb\n",
" \"\"\"\n",
"\n",
" # --- Patch OpenVINO InferRequest objects to track inference time ---\n",
" model.compile()\n",
" lm_model_time_values = []\n",
" vision_embed_time_values = []\n",
" first_token_latencies = []\n",
" model.language_model.request = InferRequestWrapper(model.language_model.request, lm_model_time_values)\n",
" model.vision_embeddings.request = InferRequestWrapper(model.vision_embeddings.request, vision_embed_time_values)\n",
" \n",
" # --- Warmup ---\n",
" for _ in range(warmup):\n",
" _ = model.generate(**inputs)\n",
"\n",
" lm_model_time_values.clear()\n",
" vision_embed_time_values.clear()\n",
"\n",
" # --- Timed inference ---\n",
" start = time.perf_counter()\n",
" for _ in range(nb_pass):\n",
" last_infer_count = len(lm_model_time_values)\n",
" outputs = model.generate(**inputs,max_new_tokens=max_tokens)\n",
" first_token_latencies.append(lm_model_time_values[last_infer_count])\n",
" end = time.perf_counter()\n",
"\n",
" # --- Unpatch InferRequest objects ---\n",
" model.language_model.request = model.language_model.request.request\n",
" model.vision_embeddings.request = model.vision_embeddings.request.request\n",
"\n",
" # --- Throughput calculations ---\n",
" avg_latency = (end - start) / nb_pass\n",
" \n",
" avg_vision_embed_time = sum(vision_embed_time_values) / len(vision_embed_time_values)\n",
" avg_first_token_latency = sum(first_token_latencies) / len(first_token_latencies)\n",
" avg_second_token_latency = (sum(lm_model_time_values) - sum(first_token_latencies)) / \\\n",
" (len(lm_model_time_values) - len(first_token_latencies))\n",
"\n",
" batch_size = inputs[\"pixel_values\"].shape[0] if \"pixel_values\" in inputs else 1\n",
" image_throughput = batch_size / avg_vision_embed_time\n",
"\n",
" # --- Model size ---\n",
" model_size_bytes = sum(\n",
" os.path.getsize(os.path.join(model_dir, f))\n",
" for f in os.listdir(model_dir)\n",
" if f.startswith(\"openvino_\")\n",
"setup_logging(level=\"INFO\", prefix=\"MAIN-PROCESS\")\n",
"\n",
"launcher_config = ProcessConfig()\n",
"scenario_config = InferenceConfig(\n",
" memory=True,\n",
" latency=True,\n",
" generate_kwargs={\"max_new_tokens\": 16, \"min_new_tokens\": 16},\n",
" input_shapes={\"batch_size\": 1, \"sequence_length\": 16, \"num_images\": 1},\n",
")\n",
"\n",
"configs = {\n",
" \"pytorch\": PyTorchConfig(device=\"cpu\", model=model_id, no_weights=True),\n",
" \"openvino\": OpenVINOConfig(device=\"cpu\", model=model_id, no_weights=True),\n",
" \"openvino-8bit-woq\": OpenVINOConfig(\n",
" device=\"cpu\",\n",
" model=model_id,\n",
" no_weights=True,\n",
" quantization_config={\"bits\": 8, \"num_samples\": 1, \"weight_only\": True},\n",
" ),\n",
"}\n",
"\n",
"for config_name, backend_config in configs.items():\n",
" benchmark_config = BenchmarkConfig(\n",
" name=f\"{config_name}\",\n",
" launcher=launcher_config,\n",
" scenario=scenario_config,\n",
" backend=backend_config,\n",
" )\n",
" model_size_mb = model_size_bytes / (1024**2)\n",
"\n",
" return {\n",
" \"avg_latency_sec\": avg_latency,\n",
" \"image_throughput\": image_throughput,\n",
" \"first_token_throughput\": 1 / avg_first_token_latency,\n",
" \"second_token_throughput\": 1 / avg_second_token_latency,\n",
" \"model_size_mb\": model_size_mb,\n",
" }\n"
]
},
{
"cell_type": "markdown",
"id": "f111381d",
"metadata": {},
"source": [
"#### Run benchmark"
" benchmark_report = Benchmark.launch(benchmark_config)\n",
" benchmark_report.save_json(f\"{config_name}_report.json\")\n",
" benchmark_config.save_json(f\"{config_name}_config.json\")\n",
"\n",
"reports = {}\n",
"for config_name in configs.keys():\n",
" reports[config_name] = BenchmarkReport.from_json(f\"{config_name}_report.json\")\n",
"\n",
"# Plotting results\n",
"_, ax = plt.subplots()\n",
"ax.boxplot(\n",
" [reports[config_name].prefill.latency.values for config_name in reports.keys()],\n",
" tick_labels=reports.keys(),\n",
" showfliers=False,\n",
")\n",
"plt.xticks(rotation=10)\n",
"ax.set_ylabel(\"Latency (s)\")\n",
"ax.set_xlabel(\"Configurations\")\n",
"ax.set_title(\"Prefill Latencies\")\n",
"plt.savefig(\"prefill_latencies_boxplot.png\")\n",
"\n",
"_, ax = plt.subplots()\n",
"ax.bar(\n",
" list(reports.keys()),\n",
" [reports[config_name].decode.throughput.value for config_name in reports.keys()],\n",
" color=[\"C0\", \"C1\", \"C2\", \"C3\", \"C4\", \"C5\"],\n",
")\n",
"plt.xticks(rotation=10)\n",
"ax.set_xlabel(\"Configurations\")\n",
"ax.set_title(\"Decoding Throughput\")\n",
"ax.set_ylabel(\"Throughput (tokens/s)\")\n",
"plt.savefig(\"decode_throughput_barplot.png\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "967e558d",
"id": "1f1720e5",
"metadata": {},
"outputs": [],
"source": [
"#Check for available hardware platforms\n",
"# Print results\n",
"import json\n",
"import pandas as pd\n",
"\n",
"from openvino.runtime import Core\n",
"# List of config names\n",
"config_names = list(configs.keys())\n",
"\n",
"core = Core()\n",
"devices = core.available_devices\n",
"device_list = []\n",
"\n",
"for device in devices:\n",
" try:\n",
" # Use FULL_DEVICE_NAME if available, else fallback to device ID\n",
" name = core.get_property(device, \"FULL_DEVICE_NAME\")\n",
" except:\n",
" name = device\n",
" device_list.append(device) # keep the device ID for model loading\n",
" print(f\"{device}: {name}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d5b8341",
"metadata": {},
"outputs": [],
"source": [
"# --- Local models ---\n",
"models = {\n",
" \"SmolVLM2-256M (full)\": fp32_model_path,\n",
" \"SmolVLM2-256M-int8\": int8_model_path\n",
"}\n",
"# Stages we want to include in the table\n",
"stages = [\"load_model\", \"first_generate\", \"prefill\", \"generate\", \"decode\"]\n",
"\n",
"table_rows = []\n",
"\n",
"# --- Run benchmark ---\n",
"for model_name, model_dir in models.items():\n",
" for device in device_list:\n",
" print(f\"\\nBenchmarking {model_name} on {device}...\")\n",
"\n",
" # Load model for the specific device\n",
" model_ov = OVModelForVisualCausalLM.from_pretrained(\n",
" model_dir, export=False, device=device\n",
" )\n",
"\n",
" # Run benchmark\n",
" results = benchmark(model_ov, inputs, model_dir=model_dir)\n",
"\n",
" # Print results\n",
" print(\n",
" f\"Latency: {results['avg_latency_sec']:.4f}s | \"\n",
" f\"Image throughput: {results['image_throughput']:.2f} im/s | \"\n",
" f\"First token throughput: {results['first_token_throughput']:.2f} t/s | \"\n",
" f\"Second token throughput: {results['second_token_throughput']:.2f} t/s | \"\n",
" f\"Model size: {results['model_size_mb']:.2f} MB\"\n",
" )\n"
"for config_name in config_names:\n",
" report_file = f\"{config_name}_report.json\"\n",
" with open(report_file, \"r\") as f:\n",
" report_data = json.load(f)\n",
" \n",
" row = {\"Configuration\": config_name}\n",
" \n",
" for stage in stages:\n",
" stage_data = report_data.get(stage, {})\n",
"\n",
" # Latency (mean)\n",
" latency_mean = stage_data.get(\"latency\", {}).get(\"mean\")\n",
" row[f\"{stage} Latency (s)\"] = round(latency_mean, 3) if latency_mean is not None else \"N/A\"\n",
" \n",
" # Throughput (value + unit)\n",
" throughput_data = stage_data.get(\"throughput\")\n",
" if throughput_data:\n",
" throughput_value = throughput_data.get(\"value\")\n",
" throughput_unit = throughput_data.get(\"unit\", \"\")\n",
" row[f\"{stage} Throughput\"] = f\"{throughput_value:.3f} {throughput_unit}\" if throughput_value else \"N/A\"\n",
" else:\n",
" row[f\"{stage} Throughput\"] = \"N/A\"\n",
" \n",
" # Max RAM\n",
" memory_max = stage_data.get(\"memory\", {}).get(\"max_ram\")\n",
" row[f\"{stage} Memory (MB)\"] = round(memory_max, 2) if memory_max is not None else \"N/A\"\n",
" \n",
" table_rows.append(row)\n",
"\n",
"# Build the DataFrame\n",
"df = pd.DataFrame(table_rows)\n",
"\n",
"# Optional: reorder columns for readability\n",
"columns_order = [\"Configuration\"]\n",
"for stage in stages:\n",
" columns_order += [\n",
" f\"{stage} Latency (s)\",\n",
" f\"{stage} Throughput\",\n",
" f\"{stage} Memory (MB)\"\n",
" ]\n",
"df = df[columns_order]\n",
"\n",
"df"
]
},
{
Expand All @@ -529,7 +496,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -543,7 +510,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.3"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down