Skip to content

Commit 31da1f6

Browse files
committed
support lmdeploy
1 parent 5c8c265 commit 31da1f6

17 files changed

+1148
-6
lines changed

xtuner/chat/__init__.py

Whitespace-only changes.

xtuner/chat/backend/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .encoder import VisionEncoderForDeploy
2+
from .huggingface import HFBackend
3+
from .lmdeploy import LMDeployBackend
4+
5+
__all__ = ['VisionEncoderForDeploy', 'HFBackend', 'LMDeployBackend']

xtuner/chat/backend/base.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from abc import abstractmethod
2+
3+
from xtuner.types import HybridChatTemplate
4+
5+
6+
class BaseBackend():
7+
8+
@property
9+
def chat_template(self) -> HybridChatTemplate:
10+
pass
11+
12+
@abstractmethod
13+
def create_streamer(self, iterable=False):
14+
pass
15+
16+
@abstractmethod
17+
def chat(self, messages, streamer=None, generation_config=None):
18+
pass
19+
20+
# @abstractmethod
21+
# def response_with_function_call(self, response: str):
22+
# pass
23+
24+
# @abstractmethod
25+
# def response_with_code_interpreter(self, response: str):
26+
# pass

xtuner/chat/backend/encoder.py

+308
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
import base64
2+
import os
3+
from io import BytesIO
4+
from typing import List, Literal, Optional, Union
5+
6+
import requests
7+
import torch
8+
from peft import PeftModel
9+
from PIL import Image
10+
from torch import nn
11+
from transformers import AutoModel, CLIPImageProcessor, CLIPVisionModel
12+
13+
from xtuner.dataset.utils import expand2square
14+
15+
16+
def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
17+
"""load image from base64 format."""
18+
return Image.open(BytesIO(base64.b64decode(image)))
19+
20+
21+
def load_image(image_url: str) -> Image.Image:
22+
"""load image from url, local path or openai GPT4V."""
23+
24+
headers = {
25+
'User-Agent':
26+
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 '
27+
'(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
28+
}
29+
if image_url.startswith('http'):
30+
response = requests.get(image_url, headers=headers)
31+
response.raise_for_status()
32+
33+
# Open the image using PIL
34+
img = Image.open(BytesIO(response.content))
35+
elif image_url.startswith('data:image'):
36+
img = load_image_from_base64(image_url.split(',')[1])
37+
else:
38+
img = Image.open(image_url)
39+
40+
return img
41+
42+
43+
ModelHub = Literal['huggingface', 'modelscope']
44+
45+
46+
class VisionEncoderForDeploy(nn.Module):
47+
48+
def __init__(self,
49+
model_name_or_path: str,
50+
projector_name_or_path: str,
51+
adapter_name_or_path: str = None,
52+
select_layer: int = -2,
53+
hub: ModelHub = 'huggingface',
54+
device='cuda'):
55+
56+
super().__init__()
57+
58+
# model_path = self._parse_model_path(xtuner_model_name_or_path, hub)
59+
# visual_encoder_path = self._parse_visual_encoder_path(
60+
# model_path, visual_encoder_name_or_path, hub
61+
# )
62+
# projector_path = self._parse_projector_path(model_path)
63+
64+
# # parse visual encoder adapter path.
65+
# vis_enc_adapter_path = self._parse_vis_enc_adapter_path(model_path)
66+
67+
self.select_layer = select_layer
68+
self.image_processor = CLIPImageProcessor.from_pretrained(
69+
model_name_or_path)
70+
print(f'Load Image Processor From {model_name_or_path}')
71+
72+
visual_encoder = CLIPVisionModel.from_pretrained(
73+
model_name_or_path, torch_dtype=torch.float16)
74+
print(f'Load Visual Encoder From {model_name_or_path}')
75+
76+
# when path is None, means without visual encoder adapter
77+
if adapter_name_or_path:
78+
self.visual_encoder = PeftModel.from_pretrained(
79+
visual_encoder, adapter_name_or_path)
80+
print(f'Load Visual Encoder Adapter From {adapter_name_or_path}')
81+
else:
82+
self.visual_encoder = visual_encoder
83+
84+
self.projector = AutoModel.from_pretrained(
85+
projector_name_or_path,
86+
torch_dtype=torch.float16,
87+
trust_remote_code=True)
88+
print(f'Load Projector from {projector_name_or_path}')
89+
90+
self.dtype = torch.float16
91+
self.device = device
92+
self.to(self.device)
93+
self.to(self.dtype)
94+
95+
def process_img(self, image: Image.Image) -> List[torch.Tensor]:
96+
"""Preprocess the input image, including expanding to square and
97+
normalization.
98+
99+
Args:
100+
image (Image.Image): The input image need to be preprocessed.
101+
102+
Returns:
103+
torch.Tensor: The preprocessed image tensor.
104+
"""
105+
106+
if isinstance(image, str):
107+
image = load_image(image)
108+
109+
if not isinstance(image, Image.Image):
110+
raise TypeError(f"Don't support {type(image).__name__}, "
111+
'the image type must be `PIL.Image`.')
112+
113+
processor = self.image_processor
114+
image_mean = processor.image_mean
115+
116+
background_color = tuple(int(x * 255) for x in image_mean)
117+
squared_img = expand2square(image, background_color)
118+
119+
processed = processor.preprocess(squared_img, return_tensors='pt')
120+
img_tensor = processed['pixel_values'][0] # shape: 3, h, w
121+
122+
# before this line, `img_tensor` is on cpu.
123+
img_tensor = img_tensor.to(self.device).to(self.dtype)
124+
return img_tensor
125+
126+
@torch.no_grad()
127+
def forward(self, images: List[Union[str,
128+
Image.Image]]) -> List[torch.Tensor]:
129+
"""Obtain the corresponding embeddings based on the images.
130+
131+
Args:
132+
images (List[Image.Image]): The input images. The data layout
133+
for each image is (c, h, w).
134+
135+
Returns:
136+
List[torch.Tensor]: The list of extracted features from images.
137+
The data layout for each tensor should be (tokens, dims).
138+
"""
139+
140+
num_imgs = len(images)
141+
142+
img_tensors = [self.process_img(img) for img in images]
143+
144+
# Determine if all image sizes are consistent.
145+
# TODO (pppppM): Confirm when the image size will be inconsistent
146+
shape_consistant = all(x.shape == img_tensors[0].shape
147+
for x in img_tensors)
148+
149+
from transformers.modeling_outputs import BaseModelOutputWithPooling
150+
151+
if shape_consistant:
152+
# Batch inference when all image sizes are consistent.
153+
# img_tensors[0] shape: (3, h, w)
154+
# tensor shape: (num_imgs, 3, h, w)
155+
tensor = torch.stack(img_tensors, dim=0)
156+
157+
enc_out = self.visual_encoder(tensor, output_hidden_states=True)
158+
enc_out: BaseModelOutputWithPooling
159+
160+
# feat shape: (num_imgs, tokens, dims)
161+
feat = self.projector(enc_out.hidden_states[self.select_layer][:,
162+
1:])
163+
164+
# Split along the batch dimension
165+
# The feature of each image corresponds to a tensor.
166+
# len(features): num_imgs, features[0] shape:(1, tokens, dims)
167+
features = torch.chunk(feat, num_imgs, dim=0)
168+
169+
# per image feature's layout should be (tokens, dims)
170+
features = [x.flatten(0, 1) for x in features]
171+
172+
else:
173+
features = []
174+
for tensor in img_tensors:
175+
tensor: torch.Tensor
176+
# The visual encoder requires a data layout of (bs, c, h, w).
177+
# tensor shape: (3, h, w) batch_tensor shape: (1, 3, h, w)
178+
batch_tensor = tensor.unsqueeze(0)
179+
enc_out = self.visual_encoder(
180+
batch_tensor, output_hidden_states=True)
181+
enc_out: BaseModelOutputWithPooling
182+
# feat shape: (1, tokens, dims)
183+
feat = self.projector(
184+
enc_out.hidden_states[self.select_layer][:, 1:])
185+
features.append(feat)
186+
187+
return features
188+
189+
def _parse_model_path(self, name_or_path: str, hub: ModelHub) -> str:
190+
"""Parse and get the directory path of the model. It supports load
191+
model from local directory or download from the hub.
192+
193+
Args:
194+
name_or_path (str): The directory path or name of the model.
195+
hub (str): The hub to download models from.
196+
197+
Returns:
198+
str: The local directory path of the model.
199+
200+
Raises:
201+
NotImplementedError: If the input hub is not supported currently.
202+
"""
203+
204+
if os.path.isdir(name_or_path):
205+
model_path = name_or_path
206+
else:
207+
if hub == 'huggingface':
208+
from huggingface_hub import snapshot_download
209+
model_path = snapshot_download(repo_id=name_or_path)
210+
elif hub == 'modelscope':
211+
from modelscope import snapshot_download
212+
model_path = snapshot_download(model_id=name_or_path)
213+
else:
214+
raise NotImplementedError(
215+
'Only supports downloading models from `Huggingface` or '
216+
'`Modelscope`.')
217+
218+
return model_path
219+
220+
def _parse_visual_encoder_path(self, model_path: str,
221+
visual_encoder_name_or_path: str,
222+
hub: ModelHub) -> str:
223+
"""Parse and get the directory path of the visual encoder. It supports
224+
load visual encoder from local directory, download from the hub, or
225+
find it in the XTuner model directory.
226+
227+
Args:
228+
model_path (str): The directory path of the model.
229+
visual_encoder_name_or_path (Optional[str]): The directory path or
230+
name of the visual encoder.
231+
hub (str): The hub to download models from.
232+
233+
Returns:
234+
str: The local directory path of the visual encoder.
235+
236+
Raises:
237+
NotImplementedError: If the input hub is not supported currently.
238+
"""
239+
240+
if 'visual_encoder' in os.listdir(model_path):
241+
assert visual_encoder_name_or_path is None
242+
visual_encoder_path = os.path.join(model_path, 'visual_encoder')
243+
elif os.path.isdir(visual_encoder_name_or_path):
244+
visual_encoder_path = visual_encoder_name_or_path
245+
else:
246+
if hub == 'huggingface':
247+
from huggingface_hub import snapshot_download
248+
visual_encoder_path = snapshot_download(
249+
repo_id=visual_encoder_name_or_path)
250+
elif hub == 'modelscope':
251+
from modelscope import snapshot_download
252+
visual_encoder_path = snapshot_download(
253+
model_id=visual_encoder_name_or_path)
254+
else:
255+
raise NotImplementedError(
256+
'Only supports downloading models from `Huggingface` or '
257+
'`Modelscope`.')
258+
259+
return visual_encoder_path
260+
261+
def _parse_projector_path(self, model_path: str) -> Optional[str]:
262+
"""Parse the path of the `projector` model according to the model path.
263+
264+
Args:
265+
model_path (str): The path to the model directory.
266+
267+
Raises:
268+
ValueError: If the 'projector' directory is not found in the
269+
`model_path`.
270+
271+
Returns:
272+
Optional[str]: The full path of 'projector' directory if exists,
273+
else raises ValueError.
274+
"""
275+
if 'projector' in os.listdir(model_path):
276+
projector_path = os.path.join(model_path, 'projector')
277+
else:
278+
# Raises exception if 'projector' directory/folder not found
279+
raise ValueError('Projector directory not found in given path')
280+
return projector_path
281+
282+
def _parse_vis_enc_adapter_path(self, model_path: str) -> Optional[str]:
283+
"""Parses the model path and returns the path to
284+
'visual_encoder_adapter' directory.
285+
286+
Args:
287+
model_path (str): The path to the model directory.
288+
289+
Returns:
290+
Optional[str]: The full path of 'visual_encoder_adapter' directory if exists,
291+
else returns None.
292+
"""
293+
if 'visual_encoder_adapter' in os.listdir(model_path):
294+
adapter_path = os.path.join(model_path, 'visual_encoder_adapter')
295+
else:
296+
# Returns None if 'visual_encoder_adapter' directory/folder not found
297+
adapter_path = None
298+
return adapter_path
299+
300+
301+
if __name__ == '__main__':
302+
img = load_image('llava.jpeg')
303+
model = VisionEncoderForDeploy('xtuner/llava-internlm-7b',
304+
'openai/clip-vit-large-patch14-336')
305+
306+
model.cuda()
307+
model.eval()
308+
outputs = model([img])

0 commit comments

Comments
 (0)