|
1 | 1 | import re |
2 | 2 |
|
3 | 3 | import torch |
| 4 | +import torch.nn.functional as F |
4 | 5 |
|
5 | 6 | from xtuner.v1.data_proto import SequenceContext |
6 | | -from xtuner.v1.loss import CELossContext |
| 7 | +from xtuner.v1.loss import BaseLossContext |
7 | 8 | from xtuner.v1.model.base import ModelOutputs |
8 | 9 |
|
9 | 10 | from .qwen3 import Qwen3Dense, Qwen3Dense4BConfig, Qwen3Dense8BConfig |
@@ -34,10 +35,10 @@ def _deepstack_process( |
34 | 35 | hidden_states[visual_pos_masks, :] = local_this |
35 | 36 | return hidden_states |
36 | 37 |
|
37 | | - def forward( |
| 38 | + def forward( # type: ignore[override] |
38 | 39 | self, |
39 | 40 | seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch |
40 | | - loss_ctx: CELossContext, |
| 41 | + loss_ctx: dict[str, BaseLossContext | list[BaseLossContext]] | None = None, |
41 | 42 | ) -> ModelOutputs: |
42 | 43 | input_ids = seq_ctx.input_ids |
43 | 44 | position_ids = seq_ctx.position_ids |
@@ -78,11 +79,18 @@ def forward( |
78 | 79 |
|
79 | 80 | hidden_states = self.norm(hidden_states) |
80 | 81 |
|
81 | | - loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx) |
82 | | - output["loss"] = loss |
83 | | - output["logits"] = logits |
84 | | - output["extra_info"] = extra_info |
85 | | - return ModelOutputs(**output) # type: ignore[typeddict-item] |
| 82 | + if loss_ctx is None: |
| 83 | + # Inference mode |
| 84 | + logits = F.linear(hidden_states, self.lm_head.weight, self.lm_head.bias) |
| 85 | + output["logits"] = logits |
| 86 | + else: |
| 87 | + # Training mode |
| 88 | + loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx["lm"]) # type: ignore[call-overload] |
| 89 | + output["loss"] = loss |
| 90 | + output["logits"] = logits |
| 91 | + output["extra_info"] = extra_info |
| 92 | + |
| 93 | + return ModelOutputs(**output) |
86 | 94 |
|
87 | 95 |
|
88 | 96 | class Qwen3VLTextDense4BConfig(Qwen3Dense4BConfig): |
|
0 commit comments