-
Notifications
You must be signed in to change notification settings - Fork 592
Expand file tree
/
Copy pathaudio_flamingo_3.py
More file actions
259 lines (211 loc) · 9.48 KB
/
audio_flamingo_3.py
File metadata and controls
259 lines (211 loc) · 9.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import os
import tempfile
from typing import List, Optional, Tuple, Union
import numpy as np
import soundfile as sf
import torch
import transformers
from accelerate import Accelerator, DistributedType
from loguru import logger as eval_logger
from tqdm import tqdm
from transformers import AutoProcessor
try:
from transformers import AudioFlamingo3ForConditionalGeneration
except ImportError:
AudioFlamingo3ForConditionalGeneration = None
from lmms_eval import utils
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
@register_model("audio_flamingo_3")
class AudioFlamingo3(lmms):
"""
Audio-Flamingo-3 Model
https://github.com/NVIDIA/audio-flamingo
"""
def __init__(
self,
pretrained: str = "nvidia/audio-flamingo-3-hf",
device: Optional[str] = "cuda",
device_map: Optional[str] = "cuda",
batch_size: Optional[Union[int, str]] = 1,
use_cache: bool = True,
**kwargs,
) -> None:
super().__init__()
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"
accelerator = Accelerator()
if accelerator.num_processes > 1:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"
elif accelerator.num_processes == 1 and device_map == "auto":
self._device = torch.device(device)
self.device_map = device_map
else:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"
if AudioFlamingo3ForConditionalGeneration is None:
raise ImportError("AudioFlamingo3ForConditionalGeneration is not available in transformers " f"{transformers.__version__}. Please upgrade transformers/accelerate in this env, e.g. " "`pip install -U transformers accelerate`.")
self._model = AudioFlamingo3ForConditionalGeneration.from_pretrained(
pretrained,
torch_dtype="auto",
device_map=self.device_map,
).eval()
self.processor = AutoProcessor.from_pretrained(pretrained)
self.processor.tokenizer.padding_side = "left"
self._tokenizer = self.processor.tokenizer
self._config = self.model.config
self.batch_size_per_gpu = int(batch_size)
self.use_cache = use_cache
if accelerator.num_processes > 1:
assert accelerator.distributed_type in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
], "Unsupported distributed type provided."
if accelerator.distributed_type == DistributedType.FSDP:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
else:
self.model.to(self._device)
self._rank = 0
self._world_size = 1
@property
def config(self):
return self._config
@property
def tokenizer(self):
return self._tokenizer
@property
def model(self):
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model
@property
def eot_token_id(self):
return self.tokenizer.eos_token_id
@property
def batch_size(self):
return self.batch_size_per_gpu
@property
def device(self):
return self._device
@property
def rank(self):
return self._rank
@property
def world_size(self):
return self._world_size
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
raise NotImplementedError("Loglikelihood is not implemented for AudioFlamingo3")
def _save_audio_to_temp(self, audio_array: np.ndarray, sampling_rate: int) -> str:
"""Save audio array to a temporary file and return the path."""
temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
sf.write(temp_file.name, audio_array, sampling_rate)
return temp_file.name
def generate_until(self, requests: List[Instance]) -> List[str]:
res = []
def _collate(x):
toks = self.tokenizer.encode(x[0])
return -len(toks), x[0]
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
for chunk in chunks:
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
task = task[0]
split = split[0]
# Get audio data from task
batched_audios = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
gen_kwargs = all_gen_kwargs[0]
until = [self.tokenizer.decode([self.eot_token_id])]
if "until" in gen_kwargs:
until = gen_kwargs.pop("until")
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] " f"but got {type(until)}")
# Build conversations for each item in the batch
conversations = []
temp_files = []
for batch_idx, (context, audios) in enumerate(zip(contexts, batched_audios)):
conv = [{"role": "user", "content": []}]
# Add text prompt first (as per official example)
if context and context.strip():
conv[0]["content"].append({"type": "text", "text": context})
# Add audio content after text
for audio in audios:
audio_array = audio["array"]
sampling_rate = audio["sampling_rate"]
# Save audio to temp file (processor can handle file paths)
temp_path = self._save_audio_to_temp(audio_array, sampling_rate)
temp_files.append(temp_path)
conv[0]["content"].append({"type": "audio", "path": temp_path})
conversations.append(conv)
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 256
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
if "top_p" not in gen_kwargs:
gen_kwargs["top_p"] = None
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1
try:
# Process each conversation individually to avoid batching issues
answers = []
for conv in conversations:
inputs = self.processor.apply_chat_template(
conv,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
).to(self.device if self.device_map != "auto" else "cuda")
cont = self.model.generate(
**inputs,
do_sample=True if gen_kwargs["temperature"] > 0 else False,
temperature=gen_kwargs["temperature"],
top_p=gen_kwargs["top_p"],
num_beams=gen_kwargs["num_beams"],
max_new_tokens=gen_kwargs["max_new_tokens"],
min_new_tokens=1,
use_cache=self.use_cache,
)
# Trim input tokens from output
generated_ids_trimmed = cont[:, inputs.input_ids.shape[1] :]
answer = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
answers.append(answer)
# Apply until tokens
for i, ans in enumerate(answers):
for term in until:
if len(term) > 0:
ans = ans.split(term)[0]
answers[i] = ans
except Exception as e:
eval_logger.debug(f"Error while generating: {e}. Contexts: {contexts}")
answers = [""] * len(contexts)
# Clean up temp files
for temp_path in temp_files:
try:
os.unlink(temp_path)
except Exception:
pass
for ans, context in zip(answers, contexts):
res.append(ans)
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), ans)
pbar.update(1)
# Reorder results back to original order
res = re_ords.get_original(res)
pbar.close()
return res
def generate_until_multi_round(self, requests) -> List[str]:
raise NotImplementedError("Multi-round generation is not implemented")