Skip to content

Commit b5d3440

Browse files
committed
Add sglang runtime
1 parent 57fe347 commit b5d3440

3 files changed

Lines changed: 181 additions & 8 deletions

File tree

lmms_eval/models/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,7 @@
7878
"vora": "VoRA",
7979
}
8080

81-
AVAILABLE_CHAT_TEMPLATE_MODELS = {
82-
"llava_hf": "LlavaHf",
83-
"qwen2_5_vl": "Qwen2_5_VL",
84-
"openai_compatible": "OpenAICompatible",
85-
"vllm": "VLLM",
86-
}
81+
AVAILABLE_CHAT_TEMPLATE_MODELS = {"llava_hf": "LlavaHf", "qwen2_5_vl": "Qwen2_5_VL", "openai_compatible": "OpenAICompatible", "vllm": "VLLM", "sglang": "Sglang"}
8782

8883

8984
def get_model(model_name, force_simple: bool = False):

lmms_eval/models/chat/sglang.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import json
2+
import warnings
3+
from typing import List, Optional, Tuple, Union
4+
5+
import PIL
6+
from accelerate import Accelerator, DistributedType
7+
from sglang import Engine
8+
from tqdm import tqdm
9+
from transformers import AutoProcessor
10+
11+
from lmms_eval import utils
12+
from lmms_eval.api.instance import Instance
13+
from lmms_eval.api.model import lmms
14+
from lmms_eval.api.registry import register_model
15+
from lmms_eval.models.model_utils.load_video import load_video_decord
16+
from lmms_eval.protocol import ChatMessages
17+
18+
warnings.filterwarnings("ignore")
19+
20+
from loguru import logger as eval_logger
21+
22+
23+
@register_model("sglang_runtime")
24+
class Sglang(lmms):
25+
is_simple = False
26+
27+
def __init__(
28+
self,
29+
model_version: str = "Qwen/Qwen2.5-VL-3B-Instruct",
30+
tensor_parallel_size: int = 1,
31+
gpu_memory_utilization: float = 0.8,
32+
batch_size: int = 1,
33+
max_frame_num: int = 32,
34+
threads: int = 16, # Threads to use for decoding visuals
35+
trust_remote_code: Optional[bool] = True,
36+
chat_template: Optional[str] = None,
37+
**kwargs,
38+
) -> None:
39+
super().__init__()
40+
# Manually set a image token for GPT4V so that we can search for it
41+
# and split the text and image
42+
# Here we just use the same token as llava for convenient
43+
self.model_version = model_version
44+
self.max_frame_num = max_frame_num
45+
self.threads = threads
46+
self.chat_template = chat_template
47+
48+
# Convert any string arguments that start with { and end with } to dictionaries
49+
for key, value in kwargs.items():
50+
if isinstance(value, str) and value.strip().startswith("{") and value.strip().endswith("}"):
51+
try:
52+
kwargs[key] = json.loads(value)
53+
except json.JSONDecodeError:
54+
eval_logger.warning(f"Failed to parse JSON-like string for argument '{key}': {value}")
55+
56+
# Set up vllm client
57+
self.client = Engine(model_path=model_version, tensor_parallel_size=tensor_parallel_size, mem_fraction_static=gpu_memory_utilization, **kwargs)
58+
self.processor = AutoProcessor.from_pretrained(model_version)
59+
60+
accelerator = Accelerator()
61+
if accelerator.num_processes > 1:
62+
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
63+
self.accelerator = accelerator
64+
if self.accelerator.is_local_main_process:
65+
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
66+
self._rank = self.accelerator.local_process_index
67+
self._world_size = self.accelerator.num_processes
68+
else:
69+
self.accelerator = accelerator
70+
self._rank = self.accelerator.local_process_index
71+
self._world_size = self.accelerator.num_processes
72+
73+
self.device = self.accelerator.device
74+
self.batch_size_per_gpu = int(batch_size)
75+
76+
@property
77+
def config(self):
78+
# return the associated transformers.AutoConfig for the given pretrained model.
79+
return self._config
80+
81+
@property
82+
def tokenizer(self):
83+
return self._tokenizer
84+
85+
@property
86+
def model(self):
87+
# returns the model, unwrapping it if using Accelerate
88+
return self.client
89+
90+
@property
91+
def batch_size(self):
92+
return self.batch_size_per_gpu
93+
94+
@property
95+
def rank(self):
96+
return self._rank
97+
98+
@property
99+
def world_size(self):
100+
return self._world_size
101+
102+
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]:
103+
""" """
104+
add_special_tokens = False if add_special_tokens is None else add_special_tokens
105+
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
106+
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
107+
if left_truncate_len:
108+
encoding = encoding[-left_truncate_len:]
109+
return encoding
110+
111+
def tok_decode(self, tokens):
112+
return self.tokenizer.decode(tokens)
113+
114+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
115+
assert False, "TODO, not implemented"
116+
117+
def generate_until(self, requests) -> List[str]:
118+
res = []
119+
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
120+
121+
batch_size = self.batch_size_per_gpu
122+
batched_requests = [requests[i : i + batch_size] for i in range(0, len(requests), batch_size)]
123+
for batch_requests in batched_requests:
124+
batched_messages = []
125+
image_data = []
126+
for idx in range(len(batch_requests)):
127+
doc_to_messages, gen_kwargs, doc_id, task, split = batch_requests[idx].arguments
128+
chat_messages = doc_to_messages(self.task_dict[task][split][doc_id])
129+
chat_messages: ChatMessages = ChatMessages(**{"messages": chat_messages})
130+
if "max_new_tokens" not in gen_kwargs:
131+
gen_kwargs["max_new_tokens"] = 1024
132+
if gen_kwargs["max_new_tokens"] > 4096:
133+
gen_kwargs["max_new_tokens"] = 4096
134+
if "temperature" not in gen_kwargs:
135+
gen_kwargs["temperature"] = 0
136+
if "top_p" not in gen_kwargs:
137+
gen_kwargs["top_p"] = 0.95
138+
139+
params = {
140+
"temperature": gen_kwargs["temperature"],
141+
"max_tokens": gen_kwargs["max_new_tokens"],
142+
"top_p": gen_kwargs["top_p"],
143+
}
144+
video_kwargs = {"enforce_image": True, "num_frames": self.max_frame_num}
145+
messages = chat_messages.to_hf_messages(video_kwargs)
146+
147+
images, videos, audio = chat_messages.extract_media()
148+
video_data = []
149+
for video in videos:
150+
video_data.extend(load_video_decord(video, max_frames_num=self.max_frame_num))
151+
image_data.append(images)
152+
image_data.append(video_data)
153+
154+
batched_messages.append(messages)
155+
156+
texts = self.processor.apply_chat_template(batched_messages)
157+
outputs = self.client.generate(texts, params)
158+
159+
response_text = [o["text"] for o in outputs]
160+
161+
assert len(response_text) == len(batch_requests)
162+
res.extend(response_text)
163+
pbar.update(len(batch_requests))
164+
165+
pbar.close()
166+
return res
167+
168+
def generate_until_multi_round(self, requests) -> List[str]:
169+
raise NotImplementedError("TODO: Implement multi-round generation for LLaVAHF")

lmms_eval/protocol.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@ def extract_media(self):
5555

5656
return images, videos, audios
5757

58-
def to_hf_messages(self):
58+
def to_hf_messages(self, video_kwargs: Dict[str, str] = None):
59+
if video_kwargs is None:
60+
video_kwargs = {}
61+
enforce_images = video_kwargs.pop("enforce_images", False)
62+
num_frames = video_kwargs.pop("num_frames", 32)
5963
hf_messages = []
6064
for message in self.messages:
6165
hf_message = {"role": message.role, "content": []}
@@ -65,7 +69,12 @@ def to_hf_messages(self):
6569
elif content.type == "image":
6670
hf_message["content"].append({"type": "image", "image": content.url})
6771
elif content.type == "video":
68-
hf_message["content"].append({"type": "video", "video": content.url})
72+
# Note this is a hacky way if you want to do video in multi-images way
73+
if enforce_images:
74+
for f in num_frames:
75+
hf_message["content"].append({"type": "image"})
76+
else:
77+
hf_message["content"].append({"type": "video", "video": content.url})
6978
elif content.type == "audio":
7079
hf_message["content"].append({"type": "audio", "audio": content.url})
7180
hf_messages.append(hf_message)

0 commit comments

Comments
 (0)