Skip to content

Commit 02451cd

Browse files
patrickvonplatenLysandreJik
authored andcommitted
Deprecate Wav2Vec2ForMaskedLM and add Wav2Vec2ForCTC (#10089)
* add wav2vec2CTC and deprecate for maskedlm * remove from docs
1 parent 800f385 commit 02451cd

File tree

8 files changed

+100
-10
lines changed

8 files changed

+100
-10
lines changed

docs/source/model_doc/wav2vec2.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ Wav2Vec2Model
5858
:members: forward
5959

6060

61-
Wav2Vec2ForMaskedLM
61+
Wav2Vec2ForCTC
6262
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6363

64-
.. autoclass:: transformers.Wav2Vec2ForMaskedLM
64+
.. autoclass:: transformers.Wav2Vec2ForCTC
6565
:members: forward

src/transformers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@
367367
_import_structure["models.wav2vec2"].extend(
368368
[
369369
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
370+
"Wav2Vec2ForCTC",
370371
"Wav2Vec2ForMaskedLM",
371372
"Wav2Vec2Model",
372373
"Wav2Vec2PreTrainedModel",
@@ -1813,6 +1814,7 @@
18131814
)
18141815
from .models.wav2vec2 import (
18151816
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
1817+
Wav2Vec2ForCTC,
18161818
Wav2Vec2ForMaskedLM,
18171819
Wav2Vec2Model,
18181820
Wav2Vec2PreTrainedModel,

src/transformers/models/wav2vec2/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
_import_structure["modeling_wav2vec2"] = [
3030
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
3131
"Wav2Vec2ForMaskedLM",
32+
"Wav2Vec2ForCTC",
3233
"Wav2Vec2Model",
3334
"Wav2Vec2PreTrainedModel",
3435
]
@@ -41,6 +42,7 @@
4142
if is_torch_available():
4243
from .modeling_wav2vec2 import (
4344
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
45+
Wav2Vec2ForCTC,
4446
Wav2Vec2ForMaskedLM,
4547
Wav2Vec2Model,
4648
Wav2Vec2PreTrainedModel,

src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import fairseq
2121
import torch
2222

23-
from transformers import Wav2Vec2Config, Wav2Vec2ForMaskedLM, logging
23+
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, logging
2424

2525

2626
logging.set_verbosity_info()
@@ -141,7 +141,7 @@ def convert_wav2vec2_checkpoint(checkpoint_path, pytorch_dump_folder_path, dict_
141141
"""
142142
Copy/paste/tweak model's weights to transformers design.
143143
"""
144-
hf_wav2vec = Wav2Vec2ForMaskedLM(Wav2Vec2Config())
144+
hf_wav2vec = Wav2Vec2ForCTC(Wav2Vec2Config())
145145

146146
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
147147
[checkpoint_path], arg_overrides={"data": dict_path}

src/transformers/models/wav2vec2/modeling_wav2vec2.py

+80-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
""" PyTorch Wav2Vec2 model. """
1616

1717

18+
import warnings
1819
from typing import Optional, Tuple
1920

2021
import torch
@@ -24,7 +25,7 @@
2425

2526
from ...activations import ACT2FN
2627
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
27-
from ...modeling_outputs import BaseModelOutput, MaskedLMOutput
28+
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput
2829
from ...modeling_utils import PreTrainedModel
2930
from ...utils import logging
3031
from .configuration_wav2vec2 import Wav2Vec2Config
@@ -665,6 +666,10 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
665666
def __init__(self, config):
666667
super().__init__(config)
667668

669+
warnings.warn(
670+
"The class `Wav2Vec2ForMaskedLM` is deprecated. Please use `Wav2Vec2ForCTC` instead.", FutureWarning
671+
)
672+
668673
self.wav2vec2 = Wav2Vec2Model(config)
669674
self.dropout = nn.Dropout(config.hidden_dropout_prob)
670675
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
@@ -729,3 +734,77 @@ def forward(
729734
return output
730735

731736
return MaskedLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
737+
738+
739+
@add_start_docstrings(
740+
"""Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """,
741+
WAV_2_VEC_2_START_DOCSTRING,
742+
)
743+
class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
744+
def __init__(self, config):
745+
super().__init__(config)
746+
747+
self.wav2vec2 = Wav2Vec2Model(config)
748+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
749+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
750+
751+
self.init_weights()
752+
753+
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
754+
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
755+
def forward(
756+
self,
757+
input_values,
758+
output_attentions=None,
759+
output_hidden_states=None,
760+
return_dict=None,
761+
labels=None,
762+
):
763+
r"""
764+
labels (:obj:`Float.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
765+
TODO(PVP): Fill out when adding training
766+
767+
Returns:
768+
769+
Example::
770+
771+
>>> from transformers import Wav2Vec2Tokenizer, Wav2Vec2Model
772+
>>> from datasets import load_dataset
773+
>>> import soundfile as sf
774+
775+
>>> tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
776+
>>> model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
777+
778+
>>> def map_to_array(batch):
779+
>>> speech, _ = sf.read(batch["file"])
780+
>>> batch["speech"] = speech
781+
>>> return batch
782+
783+
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
784+
>>> ds = ds.map(map_to_array)
785+
786+
>>> input_values = tokenizer(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
787+
>>> logits = model(input_values).logits
788+
789+
>>> predicted_ids = torch.argmax(logits, dim=-1)
790+
>>> transcription = tokenizer.decode(predicted_ids[0])
791+
"""
792+
793+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
794+
795+
outputs = self.wav2vec2(
796+
input_values,
797+
output_attentions=output_attentions,
798+
output_hidden_states=output_hidden_states,
799+
return_dict=return_dict,
800+
)
801+
802+
hidden_states = outputs[0]
803+
hidden_states = self.dropout(hidden_states)
804+
logits = self.lm_head(hidden_states)
805+
806+
if not return_dict:
807+
output = (logits,) + outputs[1:]
808+
return output
809+
810+
return CausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)

src/transformers/utils/dummy_pt_objects.py

+5
Original file line numberDiff line numberDiff line change
@@ -2229,6 +2229,11 @@ def load_tf_weights_in_transfo_xl(*args, **kwargs):
22292229
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
22302230

22312231

2232+
class Wav2Vec2ForCTC:
2233+
def __init__(self, *args, **kwargs):
2234+
requires_pytorch(self)
2235+
2236+
22322237
class Wav2Vec2ForMaskedLM:
22332238
def __init__(self, *args, **kwargs):
22342239
requires_pytorch(self)

tests/test_modeling_wav2vec2.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
if is_torch_available():
3030
import torch
3131

32-
from transformers import Wav2Vec2Config, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer
32+
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer
3333

3434

3535
class Wav2Vec2ModelTester:
@@ -204,7 +204,7 @@ def test_model_from_pretrained(self):
204204

205205
@require_torch
206206
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
207-
all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else ()
207+
all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForCTC) if is_torch_available() else ()
208208
test_pruning = False
209209
test_headmasking = False
210210
test_torchscript = False
@@ -289,7 +289,7 @@ def map_to_array(batch):
289289
return ds["speech"][:num_samples]
290290

291291
def test_inference_masked_lm_normal(self):
292-
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
292+
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
293293
model.to(torch_device)
294294
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
295295

@@ -307,7 +307,7 @@ def test_inference_masked_lm_normal(self):
307307
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
308308

309309
def test_inference_masked_lm_normal_batched(self):
310-
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
310+
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
311311
model.to(torch_device)
312312
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
313313

@@ -330,7 +330,7 @@ def test_inference_masked_lm_normal_batched(self):
330330
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
331331

332332
def test_inference_masked_lm_robust_batched(self):
333-
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
333+
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
334334
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
335335

336336
input_speech = self._load_datasamples(4)

utils/check_repo.py

+2
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@
118118
"TFMT5EncoderModel",
119119
"TFOpenAIGPTDoubleHeadsModel",
120120
"TFT5EncoderModel",
121+
"Wav2Vec2ForCTC",
121122
"XLMForQuestionAnswering",
122123
"XLMProphetNetDecoder",
123124
"XLMProphetNetEncoder",
@@ -370,6 +371,7 @@ def find_all_documented_objects():
370371
"TFBartPretrainedModel",
371372
"TextDataset",
372373
"TextDatasetForNextSentencePrediction",
374+
"Wav2Vec2ForMaskedLM",
373375
"glue_compute_metrics",
374376
"glue_convert_examples_to_features",
375377
"glue_output_modes",

0 commit comments

Comments
 (0)