-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathslopts_model.py
More file actions
358 lines (306 loc) · 19.9 KB
/
slopts_model.py
File metadata and controls
358 lines (306 loc) · 19.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
import torch
import torch.nn.functional as F
from transformers import WhisperForConditionalGeneration
from transformers.models.whisper.modeling_whisper import WhisperDecoder, shift_tokens_right
from copy import deepcopy
from typing import Dict
def prepend_sos_token_last_dim(input_ids: torch.Tensor, sos_token: int):
return F.pad(input_ids, (1, 0), value=sos_token)
class SloptsDecoderPredictor(torch.nn.Module):
"""
Separated class for the slot decoder+predictor to be able to save it separately from the Whisper part.
"""
def __init__(self, config, out_slot_predictor_size):
super(SloptsDecoderPredictor, self).__init__()
self.config = config
self.slot_decoder = WhisperDecoder(config)
# Create the slot predictor to feed with values from the slot decoder concat with the values from whisper decoder
self.out_slot_predictor = torch.nn.Linear(in_features=out_slot_predictor_size, out_features=self.slot_decoder.config.vocab_size, bias=False)
class SloptsModel(torch.nn.Module):
def __init__(self, whisper_model_path: str, num_output_slots: int, slot_decoder_layers: int, trie, slot_sos_token: int = 10, slot_pad_token: int = 9):
super(SloptsModel, self).__init__()
# Save the trie
self.trie = trie
# Save number of possible prev_sys_da and output slots
self.num_output_slots = num_output_slots
# Create the Whisper model
self.whisper = WhisperForConditionalGeneration.from_pretrained(whisper_model_path)
# Fix the wrong config of Czech whisper :(
self.whisper.config.decoder_start_token_id = 50258
self.whisper.config.eos_token_id = 50257
# Set the whisper model to eval
self.whisper.eval()
# Freeze the whisper model
for param in self.whisper.parameters():
param.requires_grad = False
# Create a second WhisperDecoder next to the original one
my_config = deepcopy(self.whisper.config)
del my_config._name_or_path
del my_config.architectures
my_config.vocab_size = num_output_slots
my_config.decoder_layers = slot_decoder_layers
my_config.pad_token_id = slot_pad_token
my_config.decoder_start_token_id = slot_sos_token
# Create the slot decoder and final fully connected layer
out_slot_predictor_size = my_config.d_model
self.slopts_decoder_predictor = SloptsDecoderPredictor(my_config, out_slot_predictor_size)
self.slot_decoder = self.slopts_decoder_predictor.slot_decoder
self.out_slot_predictor = self.slopts_decoder_predictor.out_slot_predictor
# Save the slot SOS and PAD tokens
self.slot_sos_token = slot_sos_token
self.slot_pad_token = slot_pad_token
# Prepare masks for slot probabilities
self.slot_conts_mask = torch.zeros((self.num_output_slots), dtype=torch.bool)
self.slot_conts_mask[1:5] = True
def forward(self, input_features: Dict[str, torch.Tensor], decoder_input_ids: torch.Tensor = None, slot_input_ids: torch.Tensor = None, labels: Dict[str, torch.Tensor] = None):
if decoder_input_ids is not None or slot_input_ids is not None:
try:
if not (decoder_input_ids.shape == slot_input_ids.shape):
raise ValueError("Decoder and slot input_ids must have the same shape.")
except AttributeError:
raise ValueError("Decoder and slot input_ids must be provided together and must contain a sequence of the same length.")
audio_input, prev_sys_da = input_features["audio"], input_features["prev_sys_da"]
# If no input_ids were provided, expect training and use labels
if labels is not None:
labels_tokens, labels_slots = labels["tokens"], labels["slots"]
if decoder_input_ids is None:
decoder_input_ids = labels_tokens
slot_input_ids = labels_slots
decoder_input_ids = shift_tokens_right(
decoder_input_ids, self.whisper.config.pad_token_id, self.whisper.config.decoder_start_token_id)
slot_input_ids = shift_tokens_right(slot_input_ids, self.slot_pad_token, self.slot_sos_token)
# Prepend prev_sys_da as prompt tokens for the slot decoder
slot_input_ids = torch.cat((prev_sys_da, slot_input_ids), dim=1)
# Get the whisper encoder output (not training it)
with torch.no_grad():
whisper_output = self.whisper(audio_input, labels=decoder_input_ids, output_hidden_states=True)
whisper_encoder_output = whisper_output.encoder_hidden_states[-1]
whisper_decoder_output = whisper_output.decoder_hidden_states[-1]
# Pad the decoder_input_ids to match the slot_input_ids
padding = torch.empty((decoder_input_ids.shape[0], prev_sys_da.shape[1]), dtype=torch.long, device=decoder_input_ids.device).fill_(self.whisper.config.pad_token_id)
decoder_input_ids = torch.cat((padding, decoder_input_ids), dim=1)
# Get token embeddings
token_embeds = self.whisper.model.decoder.embed_tokens(decoder_input_ids)
# Get slot embeddings
slot_embeds = self.slot_decoder.embed_tokens(slot_input_ids)
# Sum slot embeddings with the shifted whisper embeddings
slot_embeds = slot_embeds + token_embeds
# Pass through slots decoder
slots_decoder_output = self.slot_decoder(inputs_embeds=slot_embeds, encoder_hidden_states=whisper_encoder_output).last_hidden_state
# Cut out the part corresponding to the prompt tokens
slots_decoder_output = slots_decoder_output[:, prev_sys_da.shape[1]:, :]
# Concat the whisper decoder output with the slot decoder output
# slots_decoder_output = torch.cat((whisper_decoder_output, slots_decoder_output), dim=-1)
# Pass through the final fully connected layer
slot_predictions = self.out_slot_predictor(slots_decoder_output)
# NLL loss expects (batch_size, num_classes, sequence_len) shape and softmaxed logits
slot_predictions = slot_predictions.permute((0, 2, 1))
slot_predictions = F.log_softmax(slot_predictions, dim=1)
return slot_predictions, {"whisper_encoder": whisper_encoder_output, "whisper_decoder": whisper_output.decoder_hidden_states[-1], "slot_decoder": slots_decoder_output}
def slot_decoder_forward(self, decoder_input_ids: torch.Tensor, slot_input_ids: torch.Tensor, whisper_encoder_output: torch.Tensor, whisper_decoder_output: torch.Tensor, prev_sys_da: torch.Tensor, beam_width, device):
# Prepend the SOS token to the slot input
slot_input_ids = prepend_sos_token_last_dim(slot_input_ids, self.slot_sos_token)
# Prepend prev_sys_da as prompt tokens for the slot decoder
slot_input_ids = torch.cat((prev_sys_da.repeat(beam_width, 1), slot_input_ids), dim=1)
# Pad the decoder_input_ids to match the slot_input_ids
padding = torch.empty((decoder_input_ids.shape[0], prev_sys_da.shape[1]), dtype=torch.long, device=device).fill_(self.whisper.config.pad_token_id)
decoder_input_ids = torch.cat((padding, decoder_input_ids), dim=1)
# Get token embeddings
token_embeds = self.whisper.model.decoder.embed_tokens(decoder_input_ids)
# Get slot embeddings
slot_embeds = self.slot_decoder.embed_tokens(slot_input_ids)
# Sum slot embeddings with the shifted whisper embeddings
slot_embeds = slot_embeds + token_embeds
# Pass through slots decoder
slots_decoder_output = self.slot_decoder(inputs_embeds=slot_embeds, encoder_hidden_states=whisper_encoder_output).last_hidden_state
# Cut out the part corresponding to prev_sys_da
slots_decoder_output = slots_decoder_output[:, prev_sys_da.shape[1]:, :]
# Concat the last whisper decoder output with the last slot decoder output
# final_input = torch.cat((whisper_decoder_output[:, -1, :], slots_decoder_output[:, -1, :]), dim=-1)
final_input = slots_decoder_output[:, -1, :]
# Pass the concated decoder outputs for next slot through the final fully connected layer
next_slot_probs = self.out_slot_predictor(final_input)
next_slot_probs = F.log_softmax(next_slot_probs, dim=-1)
return next_slot_probs
def generate(self, input_features: Dict[str, torch.Tensor], decoder_input_ids: torch.Tensor, max_length: int = 448, beam_width: int = 1):
audio_input, prev_sys_da = input_features["audio"], input_features["prev_sys_da"]
device = audio_input.device
assert len(audio_input.shape) == 3 and (len(prev_sys_da.shape) == len(decoder_input_ids.shape) == 2), "Provide a batch of size 1 as input for generation."
assert audio_input.shape[0] == prev_sys_da.shape[0] == 1, "Only batch_size=1 is supported for generation."
assert decoder_input_ids.shape[1] > 0, "Provide at least one prompt token for generation (<|transcribe|>, etc.)."
# The prompt tokens (except the SOS token) are expected in output
tokens_out = decoder_input_ids.detach().clone().repeat(beam_width, 1)
# We need to pad the slot decoder inputs to match the asr decoder input (SOS token is prepended later)
slots_out = torch.empty((decoder_input_ids.shape), dtype=torch.long, device=device).fill_(self.slot_pad_token).repeat(beam_width, 1)
# Get the whisper encoder output (not training it)
with torch.no_grad():
whisper_encoder_output = self.whisper.model.encoder(audio_input).last_hidden_state
slot_conts_mask = self.slot_conts_mask
# Create the first prediction to init the beam search
slot_input_ids = slots_out[0].unsqueeze(0)
# Prepend <|startoftranscript|> token to the decoder input
decoder_input_ids = prepend_sos_token_last_dim(decoder_input_ids, self.whisper.config.decoder_start_token_id)
# Get the whisper decoder output (not training it)
with torch.no_grad():
whisper_decoder_output = self.whisper.model.decoder(decoder_input_ids, encoder_hidden_states=whisper_encoder_output).last_hidden_state
# Run the slot decoder
next_slot_probs = self.slot_decoder_forward(decoder_input_ids, slot_input_ids, whisper_encoder_output, whisper_decoder_output, prev_sys_da, beam_width=1, device=device)
# We can't continue in a slot, so we mask the slot continuation tokens
next_slot_probs[0, slot_conts_mask] = -float("inf")
next_slot = torch.argmax(next_slot_probs, dim=-1)[0]
# Get next token probability distribution
next_token_probs = self.whisper.proj_out(whisper_decoder_output[:, -1, :])
next_token_probs = F.log_softmax(next_token_probs, dim=-1)[0]
if next_slot.item() in {5, 6, 7, 8}:
possible_next_tokens = self.trie.root.get_possible_next_nodes()
mask = torch.zeros(next_token_probs.shape, dtype=torch.bool, device=device)
mask[possible_next_tokens] = True
next_token_probs[~mask] = -float("inf")
scores, top_indices = next_token_probs.topk(beam_width, 0, True, True)
trie_pointers = []
for token_idx in top_indices:
if next_slot.item() in {5, 6, 7, 8}:
pointer = self.trie.root.try_move_to_node(token_idx.item())
if pointer is not None:
trie_pointers.append(pointer)
continue
trie_pointers.append(None)
# Update tokens_out and slots_out
tokens_out = torch.cat([tokens_out, top_indices.unsqueeze(1)], dim=1)
slots_out = torch.cat([slots_out, next_slot.unsqueeze(0).repeat(beam_width, 1)], dim=1)
# Initialize beam search variables
whisper_encoder_output = whisper_encoder_output.repeat(beam_width, 1, 1)
vocab_size = self.whisper.config.vocab_size
finished_seqs = []
while tokens_out.shape[1] < max_length:
# Prepare slot_input_ids and decoder_input_ids for next iteration
slot_input_ids, decoder_input_ids = slots_out, tokens_out
# Prepend <|startoftranscript|> token to the decoder input
decoder_input_ids = prepend_sos_token_last_dim(decoder_input_ids, self.whisper.config.decoder_start_token_id)
# Get the whisper decoder output (not training it)
with torch.no_grad():
whisper_decoder_output = self.whisper.model.decoder(decoder_input_ids, encoder_hidden_states=whisper_encoder_output).last_hidden_state
# Run the slot decoder
next_slot_probs = self.slot_decoder_forward(decoder_input_ids, slot_input_ids, whisper_encoder_output, whisper_decoder_output, prev_sys_da, beam_width, device)
# Update next slot probs based on the trie
for i in range(beam_width):
if trie_pointers[i] is None:
continue
if not trie_pointers[i].get_possible_next_nodes():
# We are in a slot and we can't continue
next_slot_probs[i, slot_conts_mask] = -float("inf")
elif not trie_pointers[i].is_end_of_word:
# We are in a slot, can continue and haven't finished a word
next_slot_probs[i, ~slot_conts_mask] = -float("inf")
next_slots = torch.argmax(next_slot_probs, dim=-1)
# Get last token probability distribution
next_token_probs = self.whisper.proj_out(whisper_decoder_output[:, -1, :])
next_token_probs = F.log_softmax(next_token_probs, dim=-1)
# For each beam, constrain next token probs based on the next slot
for i in range(beam_width):
next_slot = next_slots[i].item()
if next_slot in {5, 6, 7, 8}:
possible_next_tokens = self.trie.root.get_possible_next_nodes()
mask = torch.zeros(next_token_probs.shape[1], dtype=torch.bool, device=device)
mask[possible_next_tokens] = True
next_token_probs[i, ~mask] = -float("inf")
elif (next_slot in {1, 2, 3, 4}) and (trie_pointers[i] is not None):
possible_next_tokens = trie_pointers[i].get_possible_next_nodes()
mask = torch.zeros(next_token_probs.shape[1], dtype=torch.bool, device=device)
mask[possible_next_tokens] = True
next_token_probs[i, ~mask] = -float("inf")
# Calculate scores for all possible next tokens
next_scores = scores.unsqueeze(1) + next_token_probs
# Get top k next tokens and their scores
top_scores, top_indices = next_scores.view(-1).topk(beam_width, 0, True, True)
# Convert top_indices to next_beam_indices and next_token_indices
# .view(-1) flattened the indices before, now we can use // and % to recover stuff
next_beam_indices = (top_indices // vocab_size)
next_token_indices = top_indices % vocab_size
# Take the slots that correspond to beam indices
next_slots = next_slots[next_beam_indices]
# Update Trie pointers
new_trie_pointers = []
for (beam_idx, token_idx, next_slot) in zip(next_beam_indices, next_token_indices, next_slots):
beam_idx, token_idx, next_slot = beam_idx.item(), token_idx.item(), next_slot.item()
if next_slot in {5, 6, 7, 8}:
new_trie_pointers.append(self.trie.root.try_move_to_node(token_idx))
elif (next_slot in {1, 2, 3, 4}) and (trie_pointers[beam_idx] is not None):
new_trie_pointers.append(trie_pointers[beam_idx].try_move_to_node(token_idx))
else:
new_trie_pointers.append(None)
trie_pointers = new_trie_pointers
# Update tokens_out and slots_out
tokens_out = torch.cat([tokens_out[next_beam_indices], next_token_indices.unsqueeze(1)], dim=1)
slots_out = torch.cat([slots_out[next_beam_indices], next_slots.unsqueeze(1)], dim=1)
# Update scores
scores = top_scores
# Check for finished beams
eos_mask = next_token_indices == self.whisper.config.eos_token_id
if eos_mask.any():
# Move completed sequences to a separate list
for idx in torch.where(eos_mask)[0]:
finished_seqs.append((
tokens_out[idx],
slots_out[idx],
scores[idx].item()
))
# Keep only incomplete sequences
incomplete_mask = ~eos_mask
tokens_out = tokens_out[incomplete_mask]
slots_out = slots_out[incomplete_mask]
scores = scores[incomplete_mask]
trie_pointers = [tp for i, tp in enumerate(trie_pointers) if incomplete_mask[i]]
# If we have fewer incomplete sequences than beam_width, reduce beam_width to match
beam_width = min(beam_width, len(scores))
if beam_width == 0:
break # All sequences completed
# Combine incomplete and completed sequences
all_seqs = finished_seqs + [(tokens_out[i], slots_out[i], scores[i].item())
for i in range(len(scores))]
# Sort by score and select the best one
best_sequence = max(all_seqs, key=lambda x: x[2])
best_tokens, best_slots, _ = best_sequence
return {"tokens": best_tokens.unsqueeze(0), "slots": best_slots.unsqueeze(0)}
if __name__ == "__main__":
# Serves as a test before sending to cluster computation
import os
from datasets import load_dataset
dataset = load_dataset(os.path.normpath(os.path.join("..", "slopts_dataset", "slopts_dataset.py")), split="train", streaming=False)
print("Dataset loaded.")
from transformers import WhisperFeatureExtractor, WhisperTokenizerFast
model_path = "mikr/whisper-large-v3-czech-cv13"
ft_extractor = WhisperFeatureExtractor.from_pretrained(model_path)
tokenizer = WhisperTokenizerFast.from_pretrained(model_path)
from train import build_trie
# Build the trie
trie = build_trie(tokenizer)
print("Trie built.")
model = SloptsModel(whisper_model_path=model_path, num_output_slots=11, slot_decoder_layers=1, trie=trie, slot_sos_token=10, slot_pad_token=9)
print("Model loaded.")
# Forward (training)
# ------------------
# print("ASR num of params:", sum(p.numel() for p in model.whisper.parameters()))
# print("Slot predictor num of params:", sum(p.numel() for p in model.slot_decoder.parameters())+sum(p.numel() for p in model.out_slot_predictor.parameters()))
# audio = {"array": [sample["array"] for sample in dataset[:2]["audio"]], "sampling_rate": 16000}
# prev_sys_da = torch.tensor([[0,0,0,0] for _ in range(2)])
# audio = ft_extractor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
# labels = {"tokens": torch.tensor([[258 for _ in range(20)] for _ in range(2)]), "slots": torch.tensor([[0 for _ in range(20)] for _ in range(2)])}
# preds, _ = model({"audio": audio, "prev_sys_da": prev_sys_da}, labels=labels)
# print(torch.argmax(preds, dim=1))
# print(labels["slots"])
# criterion = torch.nn.NLLLoss(ignore_index=9)
# loss = criterion(preds, labels["slots"])
# loss.backward()
# print(loss)
# Generate (inference)
# --------------------
# print(dataset[3]["normalized"])
# model.eval()
# audio = {"array": [sample["array"] for sample in [dataset[3]["audio"]]], "sampling_rate": 16000}
# audio = ft_extractor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
# prev_sys_da = torch.tensor([[3] for _ in range(1)])
# input_ids = torch.tensor([[50283, 50360, 50364] for _ in range(1)])
# outputs = model.generate({"audio": audio, "prev_sys_da": prev_sys_da}, input_ids, max_length=20)
# print(outputs)
# print(tokenizer.batch_decode(outputs["tokens"]))