Skip to content

Commit 5a9ba52

Browse files
authored
Fix type error when the visual encoder is not CLIP (#496)
* Fix type error when the visual encoder is not CLIP. * update
1 parent adcbf27 commit 5a9ba52

File tree

5 files changed

+10
-5
lines changed

5 files changed

+10
-5
lines changed

xtuner/dataset/llava.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ def __getitem__(self, index):
9898
image, return_tensors='pt')['pixel_values'][0]
9999
data_dict['pixel_values'] = image
100100
else:
101-
crop_size = self.image_processor.crop_size
101+
if hasattr(self.image_processor, 'crop_size'):
102+
crop_size = self.image_processor.crop_size
103+
else:
104+
crop_size = self.image_processor.size
102105
data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
103106
crop_size['width'])
104107
return data_dict

xtuner/engine/hooks/evaluate_chat_hook.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ def _eval_images(self,
129129
input_ids.append(IMAGE_TOKEN_INDEX)
130130
input_ids = torch.tensor(input_ids).to(device)
131131
visual_outputs = model.visual_encoder(
132-
image.unsqueeze(0), output_hidden_states=True)
132+
image.unsqueeze(0).to(model.visual_encoder.dtype),
133+
output_hidden_states=True)
133134
pixel_values = model.projector(
134135
visual_outputs.hidden_states[model.visual_select_layer][:, 1:])
135136

xtuner/model/llava.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ def _build_from_cfg_or_module(self, cfg_or_mod):
169169
def forward(self, data, data_samples=None, mode='loss'):
170170
if 'pixel_values' in data:
171171
visual_outputs = self.visual_encoder(
172-
data['pixel_values'], output_hidden_states=True)
172+
data['pixel_values'].to(self.visual_encoder.dtype),
173+
output_hidden_states=True)
173174
pixel_values = self.projector(
174175
visual_outputs.hidden_states[self.visual_select_layer][:, 1:])
175176
data['pixel_values'] = pixel_values

xtuner/tools/chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def main():
306306
image, tuple(int(x * 255) for x in image_processor.image_mean))
307307
image = image_processor.preprocess(
308308
image, return_tensors='pt')['pixel_values'][0]
309-
image = image.cuda().unsqueeze(0)
309+
image = image.cuda().unsqueeze(0).to(visual_encoder.dtype)
310310
visual_outputs = visual_encoder(image, output_hidden_states=True)
311311
pixel_values = projector(
312312
visual_outputs.hidden_states[args.visual_select_layer][:, 1:])

xtuner/tools/mmbench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ def main():
445445
image, tuple(int(x * 255) for x in image_processor.image_mean))
446446
image = image_processor.preprocess(
447447
image, return_tensors='pt')['pixel_values'][0]
448-
image = image.cuda().unsqueeze(0)
448+
image = image.cuda().unsqueeze(0).to(visual_encoder.dtype)
449449
visual_outputs = visual_encoder(image, output_hidden_states=True)
450450
pixel_values = projector(
451451
visual_outputs.hidden_states[args.visual_select_layer][:, 1:])

0 commit comments

Comments
 (0)