|
18 | 18 |
|
19 | 19 | from .modules import DECODER_OUTPUTS_TYPE, ENCODER_OUTPUTS_TYPE, PAST_KEY_VALUES_TYPE, T5Decoder, T5Encoder |
20 | 20 |
|
21 | | -# logging library is not automatically supported by Torchscript |
22 | | -import warnings |
23 | | - |
24 | 21 |
|
25 | | -@dataclass(frozen=True) |
| 22 | +@dataclass |
26 | 23 | class T5Conf: |
27 | 24 | encoder_only: bool = False |
28 | 25 | linear_head: bool = False |
@@ -215,7 +212,6 @@ def prepare_inputs_for_generation( |
215 | 212 | "return_past_key_values": return_past_key_values, |
216 | 213 | } |
217 | 214 |
|
218 | | - @torch.jit.export |
219 | 215 | def get_encoder(self) -> T5Encoder: |
220 | 216 | return self.encoder |
221 | 217 |
|
@@ -292,8 +288,6 @@ def forward( |
292 | 288 |
|
293 | 289 | # decoder_tokens is None means at start of inference, in which case decoder sequence should begin with padding idx. |
294 | 290 | if decoder_tokens is None: |
295 | | - batch_size = encoder_output.size()[0] |
296 | | - encoder_output_device = encoder_output.device |
297 | 291 | decoder_tokens = ( |
298 | 292 | torch.ones((batch_size, 1), device=encoder_output_device, dtype=torch.long) * self.padding_idx |
299 | 293 | ) |
@@ -323,7 +317,7 @@ def forward( |
323 | 317 | # Rescale output before projecting on vocab. This happens when the encoder and decoder share the |
324 | 318 | # same word embeddings, which is always the case in our t5 implementation. |
325 | 319 | # See https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/t5/modeling_t5.py#L1661 |
326 | | - decoder_output = decoder_output * (self.embedding_dim**-0.5) |
| 320 | + decoder_output = decoder_output * (self.embedding_dim ** -0.5) |
327 | 321 | decoder_output = self.lm_head(decoder_output) |
328 | 322 | decoder_outputs["decoder_output"] = decoder_output |
329 | 323 |
|
|
0 commit comments