|
4 | 4 | import torch
|
5 | 5 | from mmengine.model import BaseModel
|
6 | 6 | from peft import LoraConfig
|
| 7 | +from mmengine import print_log |
7 | 8 | from torch import nn
|
8 |
| - |
| 9 | +import math |
9 | 10 | from xtuner.registry import BUILDER
|
10 | 11 | from xtuner.utils.config import build_from_cfg_or_obj
|
11 | 12 | from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
|
12 | 13 | from .utils import (LoadWoInit, enable_hf_model_gradient_checkpointing,
|
13 | 14 | get_peft_model_state_dict, prepare_for_llm_lora,
|
14 | 15 | prepare_for_vision_lora,
|
15 | 16 | smart_tokenizer_and_embedding_resize)
|
16 |
| - |
17 |
| - |
| 17 | +import torch.distributed as dist |
| 18 | +from mmengine import runner |
18 | 19 | class HybridFinetune(BaseModel):
|
19 | 20 |
|
20 | 21 | def __init__(
|
@@ -106,54 +107,125 @@ def forward(self, data, data_samples=None, mode='loss'):
|
106 | 107 | """Overload parent class method, only support training."""
|
107 | 108 |
|
108 | 109 | if mode == 'loss':
|
109 |
| - return self.compute_loss(data, data_samples) |
| 110 | + return self.compute_loss(data) |
110 | 111 | else:
|
111 | 112 | raise NotImplementedError(
|
112 | 113 | f"{type(self)}'s forward is only supported for use during "
|
113 | 114 | 'training. If you want to get predictions or chat, please '
|
114 | 115 | "directly use `llm`'s forward.")
|
115 | 116 |
|
116 |
| - def compute_loss(self, data, data_samples=None): |
117 |
| - |
| 117 | + |
| 118 | + |
| 119 | + def _get_vision_embeds_and_ranges(self, data): |
| 120 | + |
118 | 121 | input_ids = data['input_ids']
|
119 |
| - labels = data['labels'] |
120 |
| - position_ids = data['position_ids'] |
121 |
| - attention_mask = data['attention_mask'] |
122 | 122 | pixel_values = data['pixel_values']
|
123 | 123 | img_rngs = data['image_ranges']
|
124 |
| - img_belong = data['image_belong'] |
125 |
| - |
126 |
| - input_embeds = self.llm.get_input_embeddings()(input_ids) |
| 124 | + img_belongs = data['image_belongs'] |
| 125 | + |
| 126 | + bs, tokens = input_ids.shape |
| 127 | + |
| 128 | + img_embeds = [] |
| 129 | + ranges_in_flat_batch = [] |
127 | 130 |
|
128 | 131 | if pixel_values is not None:
|
| 132 | + assert isinstance(pixel_values, torch.Tensor) |
| 133 | + assert len(img_rngs) == len(img_belongs) == pixel_values.size(0) |
| 134 | + |
| 135 | + batch_total_imgs = len(img_rngs) |
| 136 | + |
129 | 137 | visual_outputs = self.visual_encoder(
|
130 | 138 | pixel_values, output_hidden_states=True)
|
131 |
| - img_embeds = self.projector( |
| 139 | + features = self.projector( |
132 | 140 | visual_outputs.hidden_states[self.visual_select_layer][:, 1:])
|
133 |
| - |
134 |
| - empty_embs = torch.zeros_like(input_embeds) |
135 |
| - for emb, rng, b_id in zip(img_embeds, img_rngs, img_belong): |
136 |
| - left, right = rng |
137 |
| - if emb.size(0) == right - left: |
138 |
| - empty_embs[b_id, left:right, :] = emb |
139 |
| - elif not emb.size(0) == right - left and left == 0: |
140 |
| - empty_embs[b_id, left:right, :] = emb[-right:] |
141 |
| - elif not emb.size( |
142 |
| - 0) == right - left and right == empty_embs.size(1): |
143 |
| - empty_embs[b_id, left:right, :] = emb[:right - left] |
| 141 | + batch_total_imgs, actual_img_tokens, _ = features.shape |
| 142 | + |
| 143 | + |
| 144 | + for i in range(batch_total_imgs): |
| 145 | + img_start, img_end = img_rngs[i] |
| 146 | + expect_img_tokens = img_end - img_start |
| 147 | + img_emb = features[i] |
| 148 | + img_bs_ind = img_belongs[i] |
| 149 | + |
| 150 | + if actual_img_tokens == expect_img_tokens: |
| 151 | + img_embeds.append(img_emb) |
| 152 | + elif not actual_img_tokens == expect_img_tokens and img_start == 0: |
| 153 | + img_embeds.append(img_emb[actual_img_tokens-img_end:]) |
| 154 | + elif not actual_img_tokens == expect_img_tokens and img_end == tokens: |
| 155 | + img_embeds.append(img_emb[:expect_img_tokens]) |
144 | 156 | else:
|
145 |
| - breakpoint() |
| 157 | + raise RuntimeError |
| 158 | + |
| 159 | + flat_offset = tokens * img_bs_ind |
| 160 | + |
| 161 | + left = flat_offset + img_start |
| 162 | + right = flat_offset + img_end |
| 163 | + ranges_in_flat_batch.append((left, right)) |
| 164 | + |
| 165 | + return img_embeds, ranges_in_flat_batch |
| 166 | + |
| 167 | + |
| 168 | + def _insert_mm_embeddings(self, flat_embeds, mm_embeds, ranges): |
| 169 | + |
| 170 | + assert len(mm_embeds) == len(ranges) |
| 171 | + if len(mm_embeds) == 0: |
| 172 | + return flat_embeds |
| 173 | + |
| 174 | + chunk_embeds = [] |
| 175 | + chunk_sizes = [] |
| 176 | + mm_chunk_ids = [] |
| 177 | + |
| 178 | + cursor = 0 |
| 179 | + _empty_embeds = torch.zeros_like(flat_embeds) |
| 180 | + for (start, end), emb in zip(ranges, mm_embeds): |
| 181 | + _empty_embeds[start: end] += emb |
| 182 | + # if start - cursor > 0: |
| 183 | + # chunk_sizes.append(start - cursor) |
| 184 | + # cursor = start |
| 185 | + |
| 186 | + # mm_chunk_ids.append(len(chunk_sizes)) |
| 187 | + |
| 188 | + |
| 189 | + # chunk_embeds.append(emb) |
| 190 | + # chunk_sizes.append(end - start) |
| 191 | + # cursor = end |
| 192 | + |
| 193 | + # tokens = flat_embeds.size(0) |
| 194 | + # if sum(chunk_sizes) < tokens : |
| 195 | + # chunk_sizes.append(tokens - sum(chunk_sizes)) |
| 196 | + |
| 197 | + # chunk_embs = list(torch.split(flat_embeds, chunk_sizes)) |
| 198 | + # for ind, mm_emb in zip(mm_chunk_ids, mm_embeds) : |
| 199 | + # chunk_embs[ind] = mm_emb |
| 200 | + |
| 201 | + # flat_embeds = torch.cat(chunk_embs, dim=0) |
| 202 | + flat_embeds = flat_embeds * (_empty_embeds == 0) |
| 203 | + |
| 204 | + return flat_embeds + _empty_embeds |
| 205 | + |
| 206 | + def compute_loss(self, data): |
146 | 207 |
|
147 |
| - non_img_mask = (empty_embs == 0) |
148 |
| - input_embeds = input_embeds * non_img_mask + empty_embs |
| 208 | + input_ids = data['input_ids'] |
| 209 | + labels = data['labels'] |
| 210 | + position_ids = data['position_ids'] |
| 211 | + attention_mask = data['attention_mask'] |
| 212 | + |
| 213 | + input_embeds = self.llm.get_input_embeddings()(input_ids) |
| 214 | + |
| 215 | + bs, tokens, dim = input_embeds.shape |
| 216 | + flat_embeds = input_embeds.flatten(0,1) |
| 217 | + |
| 218 | + img_embs, flat_bs_img_rngs = self._get_vision_embeds_and_ranges(data) |
| 219 | + flat_embeds = self._insert_mm_embeddings(flat_embeds, img_embs, flat_bs_img_rngs) |
| 220 | + input_embeds = flat_embeds.reshape(bs, tokens, dim) |
149 | 221 |
|
150 | 222 | outputs = self.llm(
|
151 | 223 | input_ids=None,
|
152 | 224 | position_ids=position_ids,
|
153 | 225 | attention_mask=attention_mask,
|
154 | 226 | inputs_embeds=input_embeds,
|
155 | 227 | labels=labels)
|
156 |
| - |
| 228 | + |
157 | 229 | loss_dict = {'loss': outputs.loss}
|
158 | 230 | return loss_dict
|
159 | 231 |
|
|
0 commit comments