|
| 1 | +"""NanoVLM evaluation model for lmms-eval. |
| 2 | +
|
| 3 | +NanoVLM (SigLIP2 + MLP projector + Qwen3-0.6B) is a lightweight VLM |
| 4 | +trained with lmms-engine. This wrapper supports async multi-GPU inference: |
| 5 | +it loads model replicas on N GPUs and dispatches work via a job queue so |
| 6 | +workers run independently without synchronization overhead. |
| 7 | +
|
| 8 | +Single-GPU fallback is automatic when only one device is visible. |
| 9 | +Use ``worker_gpus`` or ``worker_count`` model_args to control GPU selection. |
| 10 | +""" |
| 11 | + |
| 12 | +import os |
| 13 | +import queue |
| 14 | +import threading |
| 15 | +import time |
| 16 | +from dataclasses import dataclass |
| 17 | +from typing import Dict, List, Optional, Tuple, Union |
| 18 | + |
| 19 | +# Register NanoVLM with transformers Auto classes |
| 20 | +import lmms_engine.models.nanovlm # noqa: F401 |
| 21 | +import torch |
| 22 | +from loguru import logger as eval_logger |
| 23 | +from PIL import Image |
| 24 | +from tqdm import tqdm |
| 25 | +from transformers import AutoImageProcessor, AutoModelForImageTextToText, AutoTokenizer |
| 26 | + |
| 27 | +from lmms_eval.api.instance import Instance |
| 28 | +from lmms_eval.api.model import lmms |
| 29 | +from lmms_eval.api.registry import register_model |
| 30 | +from lmms_eval.protocol import ChatMessages |
| 31 | + |
| 32 | + |
| 33 | +@dataclass |
| 34 | +class _NanoVLMWorker: |
| 35 | + model: AutoModelForImageTextToText |
| 36 | + tokenizer: AutoTokenizer |
| 37 | + image_processor: AutoImageProcessor |
| 38 | + device: torch.device |
| 39 | + image_token_count: int |
| 40 | + image_token_id: int |
| 41 | + |
| 42 | + |
| 43 | +@register_model("nanovlm") |
| 44 | +class NanoVLM(lmms): |
| 45 | + is_simple = False |
| 46 | + |
| 47 | + def __init__( |
| 48 | + self, |
| 49 | + pretrained: str = "LMMs-Lab-Speedrun/NanoVLM_Init", |
| 50 | + device: Optional[str] = "cuda", |
| 51 | + batch_size: Optional[Union[int, str]] = 1, |
| 52 | + attn_implementation: Optional[str] = None, |
| 53 | + system_prompt: Optional[str] = "You are a helpful assistant.", |
| 54 | + use_cache: bool = False, |
| 55 | + worker_gpus: Optional[str] = None, |
| 56 | + worker_count: Optional[int] = None, |
| 57 | + **kwargs, |
| 58 | + ) -> None: |
| 59 | + super().__init__() |
| 60 | + |
| 61 | + if int(os.environ.get("WORLD_SIZE", "1")) > 1: |
| 62 | + raise ValueError("NanoVLM manages multi-GPU dispatch internally. Please run without accelerate/torchrun multi-process launch.") |
| 63 | + |
| 64 | + if kwargs: |
| 65 | + eval_logger.warning(f"Ignoring unsupported kwargs for nanovlm: {sorted(kwargs.keys())}") |
| 66 | + |
| 67 | + self.pretrained = pretrained |
| 68 | + self.system_prompt = system_prompt |
| 69 | + self.use_cache = use_cache |
| 70 | + self._attn_implementation = attn_implementation |
| 71 | + |
| 72 | + worker_devices = self._resolve_worker_devices(device=device, worker_gpus=worker_gpus, worker_count=worker_count) |
| 73 | + self._workers: List[_NanoVLMWorker] = [self._load_worker(dev) for dev in worker_devices] |
| 74 | + |
| 75 | + # Public attributes expected by the lmms-eval framework |
| 76 | + self.model = self._workers[0].model |
| 77 | + self.tokenizer = self._workers[0].tokenizer |
| 78 | + self.config = self._workers[0].model.config |
| 79 | + self.device = self._workers[0].device |
| 80 | + self.batch_size = int(batch_size) |
| 81 | + self.max_length = 4096 |
| 82 | + self.eot_token_id = self._workers[0].tokenizer.eos_token_id |
| 83 | + |
| 84 | + eval_logger.info(f"NanoVLM loaded: {len(self._workers)} worker(s) on {worker_devices}, " f"image_token_count={self._workers[0].image_token_count}, use_cache={self.use_cache}") |
| 85 | + |
| 86 | + # ------------------------------------------------------------------ |
| 87 | + # Initialization helpers |
| 88 | + # ------------------------------------------------------------------ |
| 89 | + |
| 90 | + def _resolve_worker_devices(self, device: Optional[str], worker_gpus: Optional[str], worker_count: Optional[int]) -> List[str]: |
| 91 | + if device == "cpu": |
| 92 | + return ["cpu"] |
| 93 | + if worker_gpus: |
| 94 | + selected = [gpu.strip() for gpu in worker_gpus.split(",") if gpu.strip()] |
| 95 | + return [f"cuda:{gpu}" if not gpu.startswith("cuda:") else gpu for gpu in selected] |
| 96 | + if not torch.cuda.is_available(): |
| 97 | + return ["cpu"] |
| 98 | + available = [f"cuda:{i}" for i in range(torch.cuda.device_count())] |
| 99 | + if worker_count is None: |
| 100 | + return available |
| 101 | + return available[: min(worker_count, len(available))] |
| 102 | + |
| 103 | + def _load_worker(self, device_name: str) -> _NanoVLMWorker: |
| 104 | + model_kwargs: Dict[str, object] = {"torch_dtype": torch.bfloat16, "device_map": device_name} |
| 105 | + if self._attn_implementation: |
| 106 | + model_kwargs["attn_implementation"] = self._attn_implementation |
| 107 | + |
| 108 | + eval_logger.info(f"Loading NanoVLM worker on {device_name}") |
| 109 | + model = AutoModelForImageTextToText.from_pretrained(self.pretrained, **model_kwargs).eval() |
| 110 | + tokenizer = AutoTokenizer.from_pretrained(self.pretrained) |
| 111 | + image_processor = AutoImageProcessor.from_pretrained(self.pretrained) |
| 112 | + |
| 113 | + config = model.config |
| 114 | + image_token_count = getattr(config, "image_token_count", 256) |
| 115 | + image_token_id = getattr(config, "image_token_id", tokenizer.convert_tokens_to_ids("<|image_pad|>")) |
| 116 | + |
| 117 | + return _NanoVLMWorker( |
| 118 | + model=model, |
| 119 | + tokenizer=tokenizer, |
| 120 | + image_processor=image_processor, |
| 121 | + device=torch.device(device_name), |
| 122 | + image_token_count=image_token_count, |
| 123 | + image_token_id=image_token_id, |
| 124 | + ) |
| 125 | + |
| 126 | + # ------------------------------------------------------------------ |
| 127 | + # Inference internals |
| 128 | + # ------------------------------------------------------------------ |
| 129 | + |
| 130 | + def _expand_image_tokens(self, input_ids: List[int], image_token_id: int, image_token_count: int) -> List[int]: |
| 131 | + """Expand each single image_token_id to image_token_count copies.""" |
| 132 | + expanded = [] |
| 133 | + for token_id in input_ids: |
| 134 | + if token_id == image_token_id: |
| 135 | + expanded.extend([image_token_id] * image_token_count) |
| 136 | + else: |
| 137 | + expanded.append(token_id) |
| 138 | + return expanded |
| 139 | + |
| 140 | + def _process_single(self, worker: _NanoVLMWorker, hf_messages: List[dict], images: List) -> Tuple[torch.Tensor, dict]: |
| 141 | + """Tokenize with chat template, expand image tokens, and process images.""" |
| 142 | + token_ids = worker.tokenizer.apply_chat_template(hf_messages, tokenize=True, add_generation_prompt=True) |
| 143 | + token_ids = self._expand_image_tokens(token_ids, worker.image_token_id, worker.image_token_count) |
| 144 | + input_ids = torch.tensor([token_ids], dtype=torch.long) |
| 145 | + |
| 146 | + image_inputs = {} |
| 147 | + if images: |
| 148 | + pil_images = [] |
| 149 | + for img in images: |
| 150 | + if isinstance(img, Image.Image): |
| 151 | + pil_images.append(img.convert("RGB")) |
| 152 | + elif isinstance(img, str): |
| 153 | + pil_images.append(Image.open(img).convert("RGB")) |
| 154 | + else: |
| 155 | + pil_images.append(img) |
| 156 | + processed = worker.image_processor(images=pil_images, return_tensors="pt") |
| 157 | + for k, v in processed.items(): |
| 158 | + image_inputs[k] = v |
| 159 | + |
| 160 | + return input_ids, image_inputs |
| 161 | + |
| 162 | + def _run_single_request(self, worker: _NanoVLMWorker, request: Instance) -> Tuple[str, float, int]: |
| 163 | + """Run inference for a single request on a specific worker. Returns (answer, elapsed, n_tokens).""" |
| 164 | + context, doc_to_messages, gen_kwargs, doc_id, task, split = request.args |
| 165 | + chat_messages = doc_to_messages(self.task_dict[task][split][doc_id]) |
| 166 | + chat_messages = ChatMessages(messages=chat_messages) |
| 167 | + |
| 168 | + images, videos, audios = chat_messages.extract_media() |
| 169 | + hf_messages = chat_messages.to_hf_messages() |
| 170 | + |
| 171 | + if not hf_messages or hf_messages[0]["role"] != "system": |
| 172 | + hf_messages.insert(0, {"role": "system", "content": [{"type": "text", "text": self.system_prompt}]}) |
| 173 | + |
| 174 | + input_ids, image_inputs = self._process_single(worker, hf_messages, images) |
| 175 | + |
| 176 | + input_ids = input_ids.to(worker.device) |
| 177 | + for k, v in image_inputs.items(): |
| 178 | + if isinstance(v, torch.Tensor): |
| 179 | + image_inputs[k] = v.to(worker.device) |
| 180 | + |
| 181 | + max_new_tokens = gen_kwargs.get("max_new_tokens", 16) |
| 182 | + temperature = gen_kwargs.get("temperature", 0) |
| 183 | + do_sample = temperature > 0 |
| 184 | + |
| 185 | + gen_kwargs_call = { |
| 186 | + "max_new_tokens": max_new_tokens, |
| 187 | + "do_sample": do_sample, |
| 188 | + "eos_token_id": worker.tokenizer.eos_token_id, |
| 189 | + "pad_token_id": worker.tokenizer.pad_token_id or worker.tokenizer.eos_token_id, |
| 190 | + "use_cache": self.use_cache, |
| 191 | + } |
| 192 | + if do_sample: |
| 193 | + gen_kwargs_call["temperature"] = temperature |
| 194 | + gen_kwargs_call["top_p"] = gen_kwargs.get("top_p", 1.0) |
| 195 | + |
| 196 | + start_time = time.time() |
| 197 | + with torch.inference_mode(): |
| 198 | + output_ids = worker.model.generate(input_ids=input_ids, **image_inputs, **gen_kwargs_call) |
| 199 | + elapsed = time.time() - start_time |
| 200 | + |
| 201 | + generated_ids = output_ids[0][input_ids.shape[1] :] |
| 202 | + answer = worker.tokenizer.decode(generated_ids, skip_special_tokens=True).strip() |
| 203 | + return answer, elapsed, len(generated_ids) |
| 204 | + |
| 205 | + # ------------------------------------------------------------------ |
| 206 | + # lmms-eval interface (abstract method implementations) |
| 207 | + # ------------------------------------------------------------------ |
| 208 | + |
| 209 | + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: |
| 210 | + # Required by abc.abstractmethod in base class; NanoVLM is generate-only. |
| 211 | + raise NotImplementedError("NanoVLM does not support loglikelihood scoring") |
| 212 | + |
| 213 | + def generate_until(self, requests: List[Instance]) -> List[str]: |
| 214 | + """Generate answers for all requests using async multi-GPU dispatch. |
| 215 | +
|
| 216 | + Each worker (one per GPU) pulls jobs from a shared queue and runs |
| 217 | + inference independently. With a single GPU this reduces to standard |
| 218 | + sequential processing. |
| 219 | + """ |
| 220 | + results: List[Optional[str]] = [None] * len(requests) |
| 221 | + job_queue: "queue.Queue[Tuple[int, Instance]]" = queue.Queue() |
| 222 | + for idx, request in enumerate(requests): |
| 223 | + job_queue.put((idx, request)) |
| 224 | + |
| 225 | + pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="NanoVLM Responding") |
| 226 | + lock = threading.Lock() |
| 227 | + errors: List[Exception] = [] |
| 228 | + total_elapsed = 0.0 |
| 229 | + total_tokens = 0 |
| 230 | + |
| 231 | + def worker_loop(worker: _NanoVLMWorker) -> None: |
| 232 | + nonlocal total_elapsed, total_tokens |
| 233 | + while True: |
| 234 | + if errors: |
| 235 | + return |
| 236 | + try: |
| 237 | + idx, request = job_queue.get_nowait() |
| 238 | + except queue.Empty: |
| 239 | + return |
| 240 | + try: |
| 241 | + answer, elapsed, n_tokens = self._run_single_request(worker, request) |
| 242 | + with lock: |
| 243 | + results[idx] = answer |
| 244 | + total_elapsed += elapsed |
| 245 | + total_tokens += n_tokens |
| 246 | + pbar.update(1) |
| 247 | + except Exception as exc: |
| 248 | + eval_logger.error(f"Worker on {worker.device} failed: {exc}") |
| 249 | + with lock: |
| 250 | + errors.append(exc) |
| 251 | + return |
| 252 | + |
| 253 | + threads = [threading.Thread(target=worker_loop, args=(w,), daemon=True) for w in self._workers] |
| 254 | + for t in threads: |
| 255 | + t.start() |
| 256 | + for t in threads: |
| 257 | + t.join() |
| 258 | + pbar.close() |
| 259 | + |
| 260 | + if errors: |
| 261 | + raise errors[0] |
| 262 | + |
| 263 | + if any(r is None for r in results): |
| 264 | + raise RuntimeError(f"NanoVLM completed {sum(1 for r in results if r is not None)} / {len(requests)} requests") |
| 265 | + |
| 266 | + if total_elapsed > 0: |
| 267 | + eval_logger.info(f"NanoVLM inference: {total_tokens} tokens in {total_elapsed:.1f}s ({total_tokens / total_elapsed:.1f} tok/s)") |
| 268 | + |
| 269 | + return results |
| 270 | + |
| 271 | + def generate_until_multi_round(self, requests) -> List[str]: |
| 272 | + # Required by abc.abstractmethod in base class; not needed for current benchmarks. |
| 273 | + raise NotImplementedError("NanoVLM does not support multi-round generation") |
0 commit comments