|
15 | 15 | """ PyTorch Wav2Vec2 model. """
|
16 | 16 |
|
17 | 17 |
|
| 18 | +import warnings |
18 | 19 | from typing import Optional, Tuple
|
19 | 20 |
|
20 | 21 | import torch
|
|
24 | 25 |
|
25 | 26 | from ...activations import ACT2FN
|
26 | 27 | from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
27 |
| -from ...modeling_outputs import BaseModelOutput, MaskedLMOutput |
| 28 | +from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput |
28 | 29 | from ...modeling_utils import PreTrainedModel
|
29 | 30 | from ...utils import logging
|
30 | 31 | from .configuration_wav2vec2 import Wav2Vec2Config
|
@@ -665,6 +666,10 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
|
665 | 666 | def __init__(self, config):
|
666 | 667 | super().__init__(config)
|
667 | 668 |
|
| 669 | + warnings.warn( |
| 670 | + "The class `Wav2Vec2ForMaskedLM` is deprecated. Please use `Wav2Vec2ForCTC` instead.", FutureWarning |
| 671 | + ) |
| 672 | + |
668 | 673 | self.wav2vec2 = Wav2Vec2Model(config)
|
669 | 674 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
670 | 675 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
@@ -729,3 +734,77 @@ def forward(
|
729 | 734 | return output
|
730 | 735 |
|
731 | 736 | return MaskedLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
| 737 | + |
| 738 | + |
| 739 | +@add_start_docstrings( |
| 740 | + """Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """, |
| 741 | + WAV_2_VEC_2_START_DOCSTRING, |
| 742 | +) |
| 743 | +class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): |
| 744 | + def __init__(self, config): |
| 745 | + super().__init__(config) |
| 746 | + |
| 747 | + self.wav2vec2 = Wav2Vec2Model(config) |
| 748 | + self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| 749 | + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) |
| 750 | + |
| 751 | + self.init_weights() |
| 752 | + |
| 753 | + @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) |
| 754 | + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) |
| 755 | + def forward( |
| 756 | + self, |
| 757 | + input_values, |
| 758 | + output_attentions=None, |
| 759 | + output_hidden_states=None, |
| 760 | + return_dict=None, |
| 761 | + labels=None, |
| 762 | + ): |
| 763 | + r""" |
| 764 | + labels (:obj:`Float.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): |
| 765 | + TODO(PVP): Fill out when adding training |
| 766 | +
|
| 767 | + Returns: |
| 768 | +
|
| 769 | + Example:: |
| 770 | +
|
| 771 | + >>> from transformers import Wav2Vec2Tokenizer, Wav2Vec2Model |
| 772 | + >>> from datasets import load_dataset |
| 773 | + >>> import soundfile as sf |
| 774 | +
|
| 775 | + >>> tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") |
| 776 | + >>> model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") |
| 777 | +
|
| 778 | + >>> def map_to_array(batch): |
| 779 | + >>> speech, _ = sf.read(batch["file"]) |
| 780 | + >>> batch["speech"] = speech |
| 781 | + >>> return batch |
| 782 | +
|
| 783 | + >>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") |
| 784 | + >>> ds = ds.map(map_to_array) |
| 785 | +
|
| 786 | + >>> input_values = tokenizer(ds["speech"][0], return_tensors="pt").input_values # Batch size 1 |
| 787 | + >>> logits = model(input_values).logits |
| 788 | +
|
| 789 | + >>> predicted_ids = torch.argmax(logits, dim=-1) |
| 790 | + >>> transcription = tokenizer.decode(predicted_ids[0]) |
| 791 | + """ |
| 792 | + |
| 793 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 794 | + |
| 795 | + outputs = self.wav2vec2( |
| 796 | + input_values, |
| 797 | + output_attentions=output_attentions, |
| 798 | + output_hidden_states=output_hidden_states, |
| 799 | + return_dict=return_dict, |
| 800 | + ) |
| 801 | + |
| 802 | + hidden_states = outputs[0] |
| 803 | + hidden_states = self.dropout(hidden_states) |
| 804 | + logits = self.lm_head(hidden_states) |
| 805 | + |
| 806 | + if not return_dict: |
| 807 | + output = (logits,) + outputs[1:] |
| 808 | + return output |
| 809 | + |
| 810 | + return CausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) |
0 commit comments