|
| 1 | +import base64 |
| 2 | +import os |
| 3 | +from io import BytesIO |
| 4 | +from typing import List, Literal, Optional, Union |
| 5 | + |
| 6 | +import requests |
| 7 | +import torch |
| 8 | +from peft import PeftModel |
| 9 | +from PIL import Image |
| 10 | +from torch import nn |
| 11 | +from transformers import AutoModel, CLIPImageProcessor, CLIPVisionModel |
| 12 | + |
| 13 | +from xtuner.dataset.utils import expand2square |
| 14 | + |
| 15 | + |
| 16 | +def load_image_from_base64(image: Union[bytes, str]) -> Image.Image: |
| 17 | + """load image from base64 format.""" |
| 18 | + return Image.open(BytesIO(base64.b64decode(image))) |
| 19 | + |
| 20 | + |
| 21 | +def load_image(image_url: str) -> Image.Image: |
| 22 | + """load image from url, local path or openai GPT4V.""" |
| 23 | + |
| 24 | + headers = { |
| 25 | + 'User-Agent': |
| 26 | + 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ' |
| 27 | + '(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3' |
| 28 | + } |
| 29 | + if image_url.startswith('http'): |
| 30 | + response = requests.get(image_url, headers=headers) |
| 31 | + response.raise_for_status() |
| 32 | + |
| 33 | + # Open the image using PIL |
| 34 | + img = Image.open(BytesIO(response.content)) |
| 35 | + elif image_url.startswith('data:image'): |
| 36 | + img = load_image_from_base64(image_url.split(',')[1]) |
| 37 | + else: |
| 38 | + img = Image.open(image_url) |
| 39 | + |
| 40 | + return img |
| 41 | + |
| 42 | + |
| 43 | +ModelHub = Literal['huggingface', 'modelscope'] |
| 44 | + |
| 45 | + |
| 46 | +class VisionEncoderForDeploy(nn.Module): |
| 47 | + |
| 48 | + def __init__(self, |
| 49 | + model_name_or_path: str, |
| 50 | + projector_name_or_path: str, |
| 51 | + adapter_name_or_path: str = None, |
| 52 | + select_layer: int = -2, |
| 53 | + hub: ModelHub = 'huggingface', |
| 54 | + device='cuda'): |
| 55 | + |
| 56 | + super().__init__() |
| 57 | + |
| 58 | + # model_path = self._parse_model_path(xtuner_model_name_or_path, hub) |
| 59 | + # visual_encoder_path = self._parse_visual_encoder_path( |
| 60 | + # model_path, visual_encoder_name_or_path, hub |
| 61 | + # ) |
| 62 | + # projector_path = self._parse_projector_path(model_path) |
| 63 | + |
| 64 | + # # parse visual encoder adapter path. |
| 65 | + # vis_enc_adapter_path = self._parse_vis_enc_adapter_path(model_path) |
| 66 | + |
| 67 | + self.select_layer = select_layer |
| 68 | + self.image_processor = CLIPImageProcessor.from_pretrained( |
| 69 | + model_name_or_path) |
| 70 | + print(f'Load Image Processor From {model_name_or_path}') |
| 71 | + |
| 72 | + visual_encoder = CLIPVisionModel.from_pretrained( |
| 73 | + model_name_or_path, torch_dtype=torch.float16) |
| 74 | + print(f'Load Visual Encoder From {model_name_or_path}') |
| 75 | + |
| 76 | + # when path is None, means without visual encoder adapter |
| 77 | + if adapter_name_or_path: |
| 78 | + self.visual_encoder = PeftModel.from_pretrained( |
| 79 | + visual_encoder, adapter_name_or_path) |
| 80 | + print(f'Load Visual Encoder Adapter From {adapter_name_or_path}') |
| 81 | + else: |
| 82 | + self.visual_encoder = visual_encoder |
| 83 | + |
| 84 | + self.projector = AutoModel.from_pretrained( |
| 85 | + projector_name_or_path, |
| 86 | + torch_dtype=torch.float16, |
| 87 | + trust_remote_code=True) |
| 88 | + print(f'Load Projector from {projector_name_or_path}') |
| 89 | + |
| 90 | + self.dtype = torch.float16 |
| 91 | + self.device = device |
| 92 | + self.to(self.device) |
| 93 | + self.to(self.dtype) |
| 94 | + |
| 95 | + def process_img(self, image: Image.Image) -> List[torch.Tensor]: |
| 96 | + """Preprocess the input image, including expanding to square and |
| 97 | + normalization. |
| 98 | +
|
| 99 | + Args: |
| 100 | + image (Image.Image): The input image need to be preprocessed. |
| 101 | +
|
| 102 | + Returns: |
| 103 | + torch.Tensor: The preprocessed image tensor. |
| 104 | + """ |
| 105 | + |
| 106 | + if isinstance(image, str): |
| 107 | + image = load_image(image) |
| 108 | + |
| 109 | + if not isinstance(image, Image.Image): |
| 110 | + raise TypeError(f"Don't support {type(image).__name__}, " |
| 111 | + 'the image type must be `PIL.Image`.') |
| 112 | + |
| 113 | + processor = self.image_processor |
| 114 | + image_mean = processor.image_mean |
| 115 | + |
| 116 | + background_color = tuple(int(x * 255) for x in image_mean) |
| 117 | + squared_img = expand2square(image, background_color) |
| 118 | + |
| 119 | + processed = processor.preprocess(squared_img, return_tensors='pt') |
| 120 | + img_tensor = processed['pixel_values'][0] # shape: 3, h, w |
| 121 | + |
| 122 | + # before this line, `img_tensor` is on cpu. |
| 123 | + img_tensor = img_tensor.to(self.device).to(self.dtype) |
| 124 | + return img_tensor |
| 125 | + |
| 126 | + @torch.no_grad() |
| 127 | + def forward(self, images: List[Union[str, |
| 128 | + Image.Image]]) -> List[torch.Tensor]: |
| 129 | + """Obtain the corresponding embeddings based on the images. |
| 130 | +
|
| 131 | + Args: |
| 132 | + images (List[Image.Image]): The input images. The data layout |
| 133 | + for each image is (c, h, w). |
| 134 | +
|
| 135 | + Returns: |
| 136 | + List[torch.Tensor]: The list of extracted features from images. |
| 137 | + The data layout for each tensor should be (tokens, dims). |
| 138 | + """ |
| 139 | + |
| 140 | + num_imgs = len(images) |
| 141 | + |
| 142 | + img_tensors = [self.process_img(img) for img in images] |
| 143 | + |
| 144 | + # Determine if all image sizes are consistent. |
| 145 | + # TODO (pppppM): Confirm when the image size will be inconsistent |
| 146 | + shape_consistant = all(x.shape == img_tensors[0].shape |
| 147 | + for x in img_tensors) |
| 148 | + |
| 149 | + from transformers.modeling_outputs import BaseModelOutputWithPooling |
| 150 | + |
| 151 | + if shape_consistant: |
| 152 | + # Batch inference when all image sizes are consistent. |
| 153 | + # img_tensors[0] shape: (3, h, w) |
| 154 | + # tensor shape: (num_imgs, 3, h, w) |
| 155 | + tensor = torch.stack(img_tensors, dim=0) |
| 156 | + |
| 157 | + enc_out = self.visual_encoder(tensor, output_hidden_states=True) |
| 158 | + enc_out: BaseModelOutputWithPooling |
| 159 | + |
| 160 | + # feat shape: (num_imgs, tokens, dims) |
| 161 | + feat = self.projector(enc_out.hidden_states[self.select_layer][:, |
| 162 | + 1:]) |
| 163 | + |
| 164 | + # Split along the batch dimension |
| 165 | + # The feature of each image corresponds to a tensor. |
| 166 | + # len(features): num_imgs, features[0] shape:(1, tokens, dims) |
| 167 | + features = torch.chunk(feat, num_imgs, dim=0) |
| 168 | + |
| 169 | + # per image feature's layout should be (tokens, dims) |
| 170 | + features = [x.flatten(0, 1) for x in features] |
| 171 | + |
| 172 | + else: |
| 173 | + features = [] |
| 174 | + for tensor in img_tensors: |
| 175 | + tensor: torch.Tensor |
| 176 | + # The visual encoder requires a data layout of (bs, c, h, w). |
| 177 | + # tensor shape: (3, h, w) batch_tensor shape: (1, 3, h, w) |
| 178 | + batch_tensor = tensor.unsqueeze(0) |
| 179 | + enc_out = self.visual_encoder( |
| 180 | + batch_tensor, output_hidden_states=True) |
| 181 | + enc_out: BaseModelOutputWithPooling |
| 182 | + # feat shape: (1, tokens, dims) |
| 183 | + feat = self.projector( |
| 184 | + enc_out.hidden_states[self.select_layer][:, 1:]) |
| 185 | + features.append(feat) |
| 186 | + |
| 187 | + return features |
| 188 | + |
| 189 | + def _parse_model_path(self, name_or_path: str, hub: ModelHub) -> str: |
| 190 | + """Parse and get the directory path of the model. It supports load |
| 191 | + model from local directory or download from the hub. |
| 192 | +
|
| 193 | + Args: |
| 194 | + name_or_path (str): The directory path or name of the model. |
| 195 | + hub (str): The hub to download models from. |
| 196 | +
|
| 197 | + Returns: |
| 198 | + str: The local directory path of the model. |
| 199 | +
|
| 200 | + Raises: |
| 201 | + NotImplementedError: If the input hub is not supported currently. |
| 202 | + """ |
| 203 | + |
| 204 | + if os.path.isdir(name_or_path): |
| 205 | + model_path = name_or_path |
| 206 | + else: |
| 207 | + if hub == 'huggingface': |
| 208 | + from huggingface_hub import snapshot_download |
| 209 | + model_path = snapshot_download(repo_id=name_or_path) |
| 210 | + elif hub == 'modelscope': |
| 211 | + from modelscope import snapshot_download |
| 212 | + model_path = snapshot_download(model_id=name_or_path) |
| 213 | + else: |
| 214 | + raise NotImplementedError( |
| 215 | + 'Only supports downloading models from `Huggingface` or ' |
| 216 | + '`Modelscope`.') |
| 217 | + |
| 218 | + return model_path |
| 219 | + |
| 220 | + def _parse_visual_encoder_path(self, model_path: str, |
| 221 | + visual_encoder_name_or_path: str, |
| 222 | + hub: ModelHub) -> str: |
| 223 | + """Parse and get the directory path of the visual encoder. It supports |
| 224 | + load visual encoder from local directory, download from the hub, or |
| 225 | + find it in the XTuner model directory. |
| 226 | +
|
| 227 | + Args: |
| 228 | + model_path (str): The directory path of the model. |
| 229 | + visual_encoder_name_or_path (Optional[str]): The directory path or |
| 230 | + name of the visual encoder. |
| 231 | + hub (str): The hub to download models from. |
| 232 | +
|
| 233 | + Returns: |
| 234 | + str: The local directory path of the visual encoder. |
| 235 | +
|
| 236 | + Raises: |
| 237 | + NotImplementedError: If the input hub is not supported currently. |
| 238 | + """ |
| 239 | + |
| 240 | + if 'visual_encoder' in os.listdir(model_path): |
| 241 | + assert visual_encoder_name_or_path is None |
| 242 | + visual_encoder_path = os.path.join(model_path, 'visual_encoder') |
| 243 | + elif os.path.isdir(visual_encoder_name_or_path): |
| 244 | + visual_encoder_path = visual_encoder_name_or_path |
| 245 | + else: |
| 246 | + if hub == 'huggingface': |
| 247 | + from huggingface_hub import snapshot_download |
| 248 | + visual_encoder_path = snapshot_download( |
| 249 | + repo_id=visual_encoder_name_or_path) |
| 250 | + elif hub == 'modelscope': |
| 251 | + from modelscope import snapshot_download |
| 252 | + visual_encoder_path = snapshot_download( |
| 253 | + model_id=visual_encoder_name_or_path) |
| 254 | + else: |
| 255 | + raise NotImplementedError( |
| 256 | + 'Only supports downloading models from `Huggingface` or ' |
| 257 | + '`Modelscope`.') |
| 258 | + |
| 259 | + return visual_encoder_path |
| 260 | + |
| 261 | + def _parse_projector_path(self, model_path: str) -> Optional[str]: |
| 262 | + """Parse the path of the `projector` model according to the model path. |
| 263 | +
|
| 264 | + Args: |
| 265 | + model_path (str): The path to the model directory. |
| 266 | +
|
| 267 | + Raises: |
| 268 | + ValueError: If the 'projector' directory is not found in the |
| 269 | + `model_path`. |
| 270 | +
|
| 271 | + Returns: |
| 272 | + Optional[str]: The full path of 'projector' directory if exists, |
| 273 | + else raises ValueError. |
| 274 | + """ |
| 275 | + if 'projector' in os.listdir(model_path): |
| 276 | + projector_path = os.path.join(model_path, 'projector') |
| 277 | + else: |
| 278 | + # Raises exception if 'projector' directory/folder not found |
| 279 | + raise ValueError('Projector directory not found in given path') |
| 280 | + return projector_path |
| 281 | + |
| 282 | + def _parse_vis_enc_adapter_path(self, model_path: str) -> Optional[str]: |
| 283 | + """Parses the model path and returns the path to |
| 284 | + 'visual_encoder_adapter' directory. |
| 285 | +
|
| 286 | + Args: |
| 287 | + model_path (str): The path to the model directory. |
| 288 | +
|
| 289 | + Returns: |
| 290 | + Optional[str]: The full path of 'visual_encoder_adapter' directory if exists, |
| 291 | + else returns None. |
| 292 | + """ |
| 293 | + if 'visual_encoder_adapter' in os.listdir(model_path): |
| 294 | + adapter_path = os.path.join(model_path, 'visual_encoder_adapter') |
| 295 | + else: |
| 296 | + # Returns None if 'visual_encoder_adapter' directory/folder not found |
| 297 | + adapter_path = None |
| 298 | + return adapter_path |
| 299 | + |
| 300 | + |
| 301 | +if __name__ == '__main__': |
| 302 | + img = load_image('llava.jpeg') |
| 303 | + model = VisionEncoderForDeploy('xtuner/llava-internlm-7b', |
| 304 | + 'openai/clip-vit-large-patch14-336') |
| 305 | + |
| 306 | + model.cuda() |
| 307 | + model.eval() |
| 308 | + outputs = model([img]) |
0 commit comments