Skip to content

Commit e954d5c

Browse files
committed
fix forward error
1 parent b36cbb5 commit e954d5c

File tree

5 files changed

+164
-109
lines changed

5 files changed

+164
-109
lines changed

xtuner/dataset/hybrid/collate.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def hybrid_collate_fn(instances: Sequence[Dict],
1616
pixel_values = []
1717
cumulative_len = []
1818
image_ranges = []
19-
image_belong = []
19+
image_belongs = []
2020
position_ids = []
21-
21+
2222
for i, data in enumerate(instances):
2323
input_ids.append(torch.LongTensor(data['input_ids']))
2424
labels.append(torch.LongTensor(data['labels']))
@@ -27,28 +27,33 @@ def hybrid_collate_fn(instances: Sequence[Dict],
2727
if 'cumulative_len' in data:
2828
cumulative_len.append(torch.IntTensor(data['cumulative_len']))
2929

30-
image_belong.append(i)
31-
pixel_values.extend(data['pixel_values'])
32-
image_ranges.extend(torch.IntTensor(data['image_ranges']))
33-
30+
31+
_values = data['pixel_values']
32+
_ranges = data['image_ranges']
33+
34+
assert len(_values) == len(_ranges)
35+
for v, rng in zip(_values, _ranges):
36+
pixel_values.append(v)
37+
image_ranges.append(rng)
38+
image_belongs.append(i)
39+
3440
if len(pixel_values) > 0:
3541
assert len(image_ranges) > 0
36-
assert len(image_belong) > 0
42+
assert len(image_belongs) > 0
3743

3844
pixel_values = torch.stack(pixel_values)
39-
image_ranges = torch.stack(image_ranges)
40-
image_belong = torch.IntTensor(image_belong)
45+
# image_belongs = torch.IntTensor(image_belongs)
4146
else:
4247
pixel_values = None
4348
image_ranges = None
44-
image_belong = None
49+
image_belongs = None
4550

4651
if len(instances) > 1:
4752
input_ids = pad_sequence(
4853
input_ids, batch_first=True, padding_value=pad_index)
4954
labels = pad_sequence(
5055
labels, batch_first=True, padding_value=IGNORE_INDEX)
51-
position_ids = pad_sequence(labels, batch_first=True, padding_value=0)
56+
position_ids = pad_sequence(position_ids, batch_first=True, padding_value=0)
5257
else:
5358
input_ids = torch.stack(input_ids)
5459
labels = torch.stack(labels)
@@ -57,6 +62,7 @@ def hybrid_collate_fn(instances: Sequence[Dict],
5762
if len(cumulative_len) == 0:
5863
cumulative_len = None
5964

65+
# breakpoint()
6066
data_dict = {
6167
'input_ids': input_ids,
6268
'position_ids': position_ids,
@@ -65,8 +71,9 @@ def hybrid_collate_fn(instances: Sequence[Dict],
6571
'pixel_values': pixel_values,
6672
'cumulative_len': cumulative_len,
6773
'image_ranges': image_ranges,
68-
'image_belong': image_belong
74+
'image_belongs': image_belongs
6975
}
76+
7077

7178
if return_hf_format:
7279
return data_dict

xtuner/dataset/hybrid/hybrid.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

xtuner/model/hybrid.py

Lines changed: 100 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,18 @@
44
import torch
55
from mmengine.model import BaseModel
66
from peft import LoraConfig
7+
from mmengine import print_log
78
from torch import nn
8-
9+
import math
910
from xtuner.registry import BUILDER
1011
from xtuner.utils.config import build_from_cfg_or_obj
1112
from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
1213
from .utils import (LoadWoInit, enable_hf_model_gradient_checkpointing,
1314
get_peft_model_state_dict, prepare_for_llm_lora,
1415
prepare_for_vision_lora,
1516
smart_tokenizer_and_embedding_resize)
16-
17-
17+
import torch.distributed as dist
18+
from mmengine import runner
1819
class HybridFinetune(BaseModel):
1920

2021
def __init__(
@@ -106,54 +107,125 @@ def forward(self, data, data_samples=None, mode='loss'):
106107
"""Overload parent class method, only support training."""
107108

108109
if mode == 'loss':
109-
return self.compute_loss(data, data_samples)
110+
return self.compute_loss(data)
110111
else:
111112
raise NotImplementedError(
112113
f"{type(self)}'s forward is only supported for use during "
113114
'training. If you want to get predictions or chat, please '
114115
"directly use `llm`'s forward.")
115116

116-
def compute_loss(self, data, data_samples=None):
117-
117+
118+
119+
def _get_vision_embeds_and_ranges(self, data):
120+
118121
input_ids = data['input_ids']
119-
labels = data['labels']
120-
position_ids = data['position_ids']
121-
attention_mask = data['attention_mask']
122122
pixel_values = data['pixel_values']
123123
img_rngs = data['image_ranges']
124-
img_belong = data['image_belong']
125-
126-
input_embeds = self.llm.get_input_embeddings()(input_ids)
124+
img_belongs = data['image_belongs']
125+
126+
bs, tokens = input_ids.shape
127+
128+
img_embeds = []
129+
ranges_in_flat_batch = []
127130

128131
if pixel_values is not None:
132+
assert isinstance(pixel_values, torch.Tensor)
133+
assert len(img_rngs) == len(img_belongs) == pixel_values.size(0)
134+
135+
batch_total_imgs = len(img_rngs)
136+
129137
visual_outputs = self.visual_encoder(
130138
pixel_values, output_hidden_states=True)
131-
img_embeds = self.projector(
139+
features = self.projector(
132140
visual_outputs.hidden_states[self.visual_select_layer][:, 1:])
133-
134-
empty_embs = torch.zeros_like(input_embeds)
135-
for emb, rng, b_id in zip(img_embeds, img_rngs, img_belong):
136-
left, right = rng
137-
if emb.size(0) == right - left:
138-
empty_embs[b_id, left:right, :] = emb
139-
elif not emb.size(0) == right - left and left == 0:
140-
empty_embs[b_id, left:right, :] = emb[-right:]
141-
elif not emb.size(
142-
0) == right - left and right == empty_embs.size(1):
143-
empty_embs[b_id, left:right, :] = emb[:right - left]
141+
batch_total_imgs, actual_img_tokens, _ = features.shape
142+
143+
144+
for i in range(batch_total_imgs):
145+
img_start, img_end = img_rngs[i]
146+
expect_img_tokens = img_end - img_start
147+
img_emb = features[i]
148+
img_bs_ind = img_belongs[i]
149+
150+
if actual_img_tokens == expect_img_tokens:
151+
img_embeds.append(img_emb)
152+
elif not actual_img_tokens == expect_img_tokens and img_start == 0:
153+
img_embeds.append(img_emb[actual_img_tokens-img_end:])
154+
elif not actual_img_tokens == expect_img_tokens and img_end == tokens:
155+
img_embeds.append(img_emb[:expect_img_tokens])
144156
else:
145-
breakpoint()
157+
raise RuntimeError
158+
159+
flat_offset = tokens * img_bs_ind
160+
161+
left = flat_offset + img_start
162+
right = flat_offset + img_end
163+
ranges_in_flat_batch.append((left, right))
164+
165+
return img_embeds, ranges_in_flat_batch
166+
167+
168+
def _insert_mm_embeddings(self, flat_embeds, mm_embeds, ranges):
169+
170+
assert len(mm_embeds) == len(ranges)
171+
if len(mm_embeds) == 0:
172+
return flat_embeds
173+
174+
chunk_embeds = []
175+
chunk_sizes = []
176+
mm_chunk_ids = []
177+
178+
cursor = 0
179+
_empty_embeds = torch.zeros_like(flat_embeds)
180+
for (start, end), emb in zip(ranges, mm_embeds):
181+
_empty_embeds[start: end] += emb
182+
# if start - cursor > 0:
183+
# chunk_sizes.append(start - cursor)
184+
# cursor = start
185+
186+
# mm_chunk_ids.append(len(chunk_sizes))
187+
188+
189+
# chunk_embeds.append(emb)
190+
# chunk_sizes.append(end - start)
191+
# cursor = end
192+
193+
# tokens = flat_embeds.size(0)
194+
# if sum(chunk_sizes) < tokens :
195+
# chunk_sizes.append(tokens - sum(chunk_sizes))
196+
197+
# chunk_embs = list(torch.split(flat_embeds, chunk_sizes))
198+
# for ind, mm_emb in zip(mm_chunk_ids, mm_embeds) :
199+
# chunk_embs[ind] = mm_emb
200+
201+
# flat_embeds = torch.cat(chunk_embs, dim=0)
202+
flat_embeds = flat_embeds * (_empty_embeds == 0)
203+
204+
return flat_embeds + _empty_embeds
205+
206+
def compute_loss(self, data):
146207

147-
non_img_mask = (empty_embs == 0)
148-
input_embeds = input_embeds * non_img_mask + empty_embs
208+
input_ids = data['input_ids']
209+
labels = data['labels']
210+
position_ids = data['position_ids']
211+
attention_mask = data['attention_mask']
212+
213+
input_embeds = self.llm.get_input_embeddings()(input_ids)
214+
215+
bs, tokens, dim = input_embeds.shape
216+
flat_embeds = input_embeds.flatten(0,1)
217+
218+
img_embs, flat_bs_img_rngs = self._get_vision_embeds_and_ranges(data)
219+
flat_embeds = self._insert_mm_embeddings(flat_embeds, img_embs, flat_bs_img_rngs)
220+
input_embeds = flat_embeds.reshape(bs, tokens, dim)
149221

150222
outputs = self.llm(
151223
input_ids=None,
152224
position_ids=position_ids,
153225
attention_mask=attention_mask,
154226
inputs_embeds=input_embeds,
155227
labels=labels)
156-
228+
157229
loss_dict = {'loss': outputs.loss}
158230
return loss_dict
159231

xtuner/types/chat.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,30 @@ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
9292
return chat_template.decorate_function_result(self.content)
9393

9494

95+
class CodeInterpreterCallMsg(BaseModel):
96+
97+
role: Literal['assistant']
98+
content: str
99+
conde_interpreter_call: Union[str, Dict]
100+
101+
def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
102+
103+
return chat_template.decorate_code_interpreter_call(
104+
self.content, self.conde_interpreter_call)
105+
106+
107+
108+
class CodeInterpreterResultMsg(BaseModel):
109+
role: Literal['function']
110+
name: str
111+
content: Union[str, Dict]
112+
113+
def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
114+
return chat_template.decorate_code_internpreter_result(self.content)
115+
116+
117+
118+
95119
class Functions(BaseModel):
96120

97121
# class Parameters(BaseModel):
@@ -108,6 +132,26 @@ class Functions(BaseModel):
108132
name: str
109133
description: Union[str, Dict]
110134
parameters: Union[str, Dict]
135+
136+
137+
138+
class CodeInterpreter(BaseModel):
139+
140+
# class Parameters(BaseModel):
141+
142+
# class Property(BaseModel):
143+
# type: str
144+
# description: str
145+
# enum: Optional[List] = None
146+
147+
# type: Literal['object']
148+
# properties: Dict[str, Property]
149+
# required: List[str]
150+
151+
name: str
152+
description: Union[str, Dict]
153+
154+
111155

112156

113157
HybridChatMsgType = Union[ChatMsg, FunctionCallMsg, FunctionResultMsg]

xtuner/utils/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def build_from_cfg_or_obj(cfg_or_obj: Union[dict, OBJ_T],
124124
raise TypeError(
125125
f'Expect an object of {accept}, but there is an object of '
126126
f'{type(obj)}.')
127-
return BUILDER.build(cfg_or_obj)
127+
return obj
128128

129129
else:
130130
raise TypeError(f'cfg_or_obj must be a dict, or {accept}, but got '

0 commit comments

Comments
 (0)