|
| 1 | +# Copyright (c) ModelScope Contributors. All rights reserved. |
1 | 2 | from dataclasses import dataclass, field |
2 | 3 | from typing import Any, Dict, List, Literal, Optional |
3 | 4 |
|
@@ -86,6 +87,25 @@ def _get_new_tokens(i): |
86 | 87 | encoded['attention_mask'] = attention_mask |
87 | 88 | return encoded |
88 | 89 |
|
| 90 | + def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: |
| 91 | + if not self.is_training: |
| 92 | + return inputs |
| 93 | + |
| 94 | + input_ids = inputs['input_ids'] |
| 95 | + pixel_values = inputs.get('pixel_values') |
| 96 | + image_grid_thw = inputs.get('image_grid_thw') |
| 97 | + base_model = self.get_base_model(model) |
| 98 | + inputs_embeds = base_model.model.embed_tokens(input_ids) |
| 99 | + |
| 100 | + if pixel_values is not None: |
| 101 | + pixel_values = pixel_values.to(base_model.vit.dtype) |
| 102 | + image_embeds = base_model.vit(pixel_values, image_grid_thw) |
| 103 | + image_embeds = image_embeds.to(input_ids.device, non_blocking=True) |
| 104 | + image_mask, _ = base_model.get_placeholder_mask( |
| 105 | + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds) |
| 106 | + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
| 107 | + return {'inputs_embeds': inputs_embeds} |
| 108 | + |
89 | 109 | def _pad_3d_position_ids(self, |
90 | 110 | position_ids: List[torch.Tensor], |
91 | 111 | padding_value: float = 0., |
|
0 commit comments