Skip to content

Commit dead85a

Browse files
authored
fix inputs_embeds for hunyuanOCR (modelscope#7803)
* fix inputs_embeds for hunyuanOCR * fix typos
1 parent baaa18d commit dead85a

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

swift/template/templates/tencent.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Copyright (c) ModelScope Contributors. All rights reserved.
12
from dataclasses import dataclass, field
23
from typing import Any, Dict, List, Literal, Optional
34

@@ -86,6 +87,25 @@ def _get_new_tokens(i):
8687
encoded['attention_mask'] = attention_mask
8788
return encoded
8889

90+
def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
91+
if not self.is_training:
92+
return inputs
93+
94+
input_ids = inputs['input_ids']
95+
pixel_values = inputs.get('pixel_values')
96+
image_grid_thw = inputs.get('image_grid_thw')
97+
base_model = self.get_base_model(model)
98+
inputs_embeds = base_model.model.embed_tokens(input_ids)
99+
100+
if pixel_values is not None:
101+
pixel_values = pixel_values.to(base_model.vit.dtype)
102+
image_embeds = base_model.vit(pixel_values, image_grid_thw)
103+
image_embeds = image_embeds.to(input_ids.device, non_blocking=True)
104+
image_mask, _ = base_model.get_placeholder_mask(
105+
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds)
106+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
107+
return {'inputs_embeds': inputs_embeds}
108+
89109
def _pad_3d_position_ids(self,
90110
position_ids: List[torch.Tensor],
91111
padding_value: float = 0.,

0 commit comments

Comments
 (0)