-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathdataset.py
130 lines (120 loc) · 5.18 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright (c) 2025 Binbin Zhang([email protected])
import math
import json
from dataclasses import dataclass, field
from typing import Dict
from torch.utils.data import Dataset
from transformers.trainer_pt_utils import LabelSmoother
import torch
import torch.nn.functional as F
import torchaudio
import transformers
import whisper
@dataclass
class DataArguments:
data_path: str = field(default=None,
metadata={"help": "Path to the training data."})
eval_data_path: str = field(
default=None, metadata={"help": "Path to the evaluation data."})
test_data_path: str = field(default=None,
metadata={"help": "Path to the test data."})
class SpeechDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self,
data_path,
tokenizer: transformers.PreTrainedTokenizer,
config, # model config
inference: bool = False,
):
super(SpeechDataset, self).__init__()
print("Formatting inputs...")
self.tokenizer = tokenizer
self.config = config
self.inference = inference
self.raw_data = []
with open(data_path, "r") as f:
for line in f:
self.raw_data.append(json.loads(line))
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
msg = self.raw_data[i]
audio, sample_rate = torchaudio.load(msg['wav'])
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(sample_rate, 16000)(audio)
if self.config.encoder_type == 'whisper':
mel_len = math.ceil(
float(audio.size(1)) / 16000 * self.config.frames_per_second)
audio = whisper.pad_or_trim(audio[0])
mel = whisper.log_mel_spectrogram(audio)
else:
# Note: We use 16-bit quantization by default in WeNet.
audio = audio * (1 << 15)
mel = torchaudio.compliance.kaldi.fbank(audio,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
energy_floor=0.0,
sample_frequency=16000)
mel = mel.transpose(0, 1) # (80, T)
if mel.size(1) < self.config.max_mel_size:
mel_len = mel.size(1)
mel = F.pad(mel, (0, self.config.max_mel_size - mel.size(1)),
value=0.0)
else: # hard truncation
mel_len = self.config.max_mel_size
mel = mel[:, :self.config.max_mel_size]
ids_audio = [0] * self.config.max_speech_token_size
tgt_audio = [IGNORE_TOKEN_ID] * len(ids_audio)
if 'instruction' in msg:
instruction = msg['instruction']
elif self.inference and self.config.decode_instruction != '':
instruction = self.config.decode_instruction
else:
instruction = 'Transcribe the speech'
chat = [{"role": "user", "content": instruction}]
# `content`: the anwser acorrding to the audio and instruction
# `txt`: the transcription of the audio
# If there is no content, the default `content` is the same as `txt`.
content = msg['content'] if 'content' in msg else msg['txt']
if self.inference:
kwargs = {'add_generation_prompt': True}
else:
chat.append({"role": "assistant", "content": content})
kwargs = {
'padding': 'max_length',
'max_length': self.config.model_max_length -
self.config.max_speech_token_size,
'truncation': True,
'add_generation_prompt': False,
}
ids_text = self.tokenizer.apply_chat_template(chat,
tokenize=True,
**kwargs)
ids = ids_audio + ids_text
tgt = tgt_audio + ids_text
input_ids = torch.tensor(ids, dtype=torch.int)
target_ids = torch.tensor(tgt, dtype=torch.int)
target_ids[target_ids == self.tokenizer.pad_token_id] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
ctc_tokens = self.tokenizer(msg['txt'],
padding='max_length',
max_length=100,
truncation=True,
return_tensors='pt')
ctc_ids = ctc_tokens['input_ids'][0]
ctc_ids_len = ctc_tokens['attention_mask'].sum().item()
ret = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'mel': mel,
'mel_len': mel_len,
}
if not self.inference:
ret['labels'] = target_ids
ret['ctc_ids'] = ctc_ids
ret['ctc_ids_len'] = ctc_ids_len
return ret