diff --git a/src/omnilingual_asr/__init__.py b/src/omnilingual_asr/__init__.py index 7109e48..ac6a340 100644 --- a/src/omnilingual_asr/__init__.py +++ b/src/omnilingual_asr/__init__.py @@ -38,7 +38,6 @@ def setup_fairseq2_extension(container: DependencyContainer) -> None: def _register_models(container: DependencyContainer) -> None: - # Only adding custom wav2vec2 archs for wav2vec2_ssl model in fs2 register_omnilingual_asr_wav2vec2_ssl_configs(container) diff --git a/src/omnilingual_asr/datasets/impl/manifest_asr_dataset.py b/src/omnilingual_asr/datasets/impl/manifest_asr_dataset.py index b38a1c2..b377c1a 100644 --- a/src/omnilingual_asr/datasets/impl/manifest_asr_dataset.py +++ b/src/omnilingual_asr/datasets/impl/manifest_asr_dataset.py @@ -60,7 +60,6 @@ def create_reader( storage_config: ManifestStorageConfig, task_config: AsrTaskConfig, ) -> DataReader[Seq2SeqBatch]: - storage = ManifestStorage( splits=self.splits, manifest_dir=self.manifest_dir, diff --git a/src/omnilingual_asr/datasets/storage/mixture_parquet_storage.py b/src/omnilingual_asr/datasets/storage/mixture_parquet_storage.py index f193031..215b0a1 100644 --- a/src/omnilingual_asr/datasets/storage/mixture_parquet_storage.py +++ b/src/omnilingual_asr/datasets/storage/mixture_parquet_storage.py @@ -205,7 +205,6 @@ def is_train_streaming(split: str, sync_mode: SyncMode) -> bool: @override def create_raw_data_pipeline(self, split: str, gangs: Gangs) -> DataPipelineBuilder: - config = self.config schema: LangASRSchema = LangASRSchema() diff --git a/src/omnilingual_asr/datasets/tasks/asr_task.py b/src/omnilingual_asr/datasets/tasks/asr_task.py index 8f59681..215ccd5 100644 --- a/src/omnilingual_asr/datasets/tasks/asr_task.py +++ b/src/omnilingual_asr/datasets/tasks/asr_task.py @@ -156,7 +156,6 @@ def apply_processing_pipeline( # type: ignore[override] tokenizer: Tokenizer, dtype: torch.dtype, ) -> DataPipelineBuilder: - config = self.config # Filtering audio to optimize before batching @@ -246,7 +245,6 @@ def add_tokenization_pipeline( text_selector: str, audio_length_selector: str, ) -> DataPipelineBuilder: - builder = filter_empty_text(builder, text_selector=text_selector) builder = filter_fast_speech( @@ -262,9 +260,13 @@ def add_tokenization_pipeline( text_selector=text_selector, ) - builder = filter_unknown_sequences( - builder, unk_idx=tokenizer.vocab_info.unk_idx, text_selector=text_selector # type: ignore - ) + unk_idx = tokenizer.vocab_info.unk_idx + if unk_idx is not None: + builder = filter_unknown_sequences( + builder, + unk_idx=unk_idx, + text_selector=text_selector, # type: ignore + ) if filter_long_text_threshold is not None: builder = filter_long_text( @@ -273,9 +275,10 @@ def add_tokenization_pipeline( text_selector=text_selector, ) - if remove_unknown: + if remove_unknown and unk_idx is not None: builder = filter_unknown_tokens( - builder, unk_idx=tokenizer.vocab_info.unk_idx # type: ignore + builder, + unk_idx=unk_idx, # type: ignore ) return builder @@ -293,7 +296,6 @@ def add_bucketing_pipeline( batch_size: int, no_padding: bool, ) -> DataPipelineBuilder: - if batching is BatchingStrategy.LENGTH: builder = add_length_batching( builder, @@ -358,7 +360,6 @@ def add_audio_processing_pipeline( npc: int, unified_audio_feature_keys: bool, ) -> DataPipelineBuilder: - builder = add_audio_decoding( builder, dtype=dtype, @@ -412,7 +413,6 @@ def add_postprocessing_pipeline( num_prefetch: int, no_padding: bool, ) -> DataPipelineBuilder: - if no_padding: log.warning( "Collating without padding is currently not supported, defaulting to padding." diff --git a/src/omnilingual_asr/datasets/tasks/ssl_task.py b/src/omnilingual_asr/datasets/tasks/ssl_task.py index 048ff70..187bdb3 100644 --- a/src/omnilingual_asr/datasets/tasks/ssl_task.py +++ b/src/omnilingual_asr/datasets/tasks/ssl_task.py @@ -130,7 +130,6 @@ def apply_processing_pipeline( # type: ignore[override] gangs: Gangs, dtype: torch.dtype, ) -> DataPipelineBuilder: - config = self.config # Shuffle individual samples @@ -218,7 +217,6 @@ def add_bucketing_pipeline( batch_size: int, no_padding: bool, ) -> DataPipelineBuilder: - if batching is BatchingStrategy.LENGTH: builder = add_length_batching( builder, @@ -260,7 +258,6 @@ def add_audio_processing_pipeline( no_padding: bool, seed: int, ) -> DataPipelineBuilder: - builder = add_audio_decoding( builder, dtype=dtype, @@ -320,7 +317,6 @@ def add_postprocessing_pipeline( num_prefetch: int, no_padding: bool, ) -> DataPipelineBuilder: - collater = Collater(pad_value=None if no_padding else 0) builder.map(collater, num_parallel_calls=npc) diff --git a/src/omnilingual_asr/models/inference/align.py b/src/omnilingual_asr/models/inference/align.py new file mode 100644 index 0000000..a64dfd3 --- /dev/null +++ b/src/omnilingual_asr/models/inference/align.py @@ -0,0 +1,485 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import re +import types +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, cast + +import numpy as np +import torch +import torch.nn.functional as F +from fairseq2.nn import BatchLayout +from torch import Tensor + +try: + from omnilingual_asr.models.inference.pipeline import ASRInferencePipeline + from omnilingual_asr.models.wav2vec2_llama.model import Wav2Vec2LlamaModel + from omnilingual_asr.models.wav2vec2_llama.syntax import lang_id_getter +except ImportError: + pass # Type checking only + + +# ============================================================================= +# SCRIPT DETECTION & UTILS +# ============================================================================= + +# Regex for scripts that typically do not use whitespace for word separation. +# Covers: CJK Unified Ideographs, Hiragana, Katakana, Thai, Lao, Khmer, Myanmar. +_NON_SPACED_SCRIPTS = re.compile( + r"[\u4E00-\u9FFF\u3400-\u4DBF\u3040-\u309F\u30A0-\u30FF" # CJK + Kana + r"\u0E00-\u0E7F\u0E80-\u0EFF" # Thai + Lao + r"\u1780-\u17FF\u1000-\u109F]" # Khmer + Myanmar +) + + +def detect_alignment_mode(text: str) -> str: + """ + Language-agnostic detection of alignment granularity. + Returns 'char' if the text contains characters from known non-spaced scripts + (CJK, Thai, etc.), otherwise returns 'word'. + """ + if _NON_SPACED_SCRIPTS.search(text): + return "char" + return "word" + + +def chunk_waveform( + waveform: Tensor, sample_rate: int, max_duration_sec: float +) -> List[Tuple[Tensor, float]]: + if waveform.ndim != 1: + raise ValueError("waveform must be 1D") + chunk_samples = int(max_duration_sec * sample_rate) + T = waveform.shape[0] + chunks: List[Tuple[Tensor, float]] = [] + for s in range(0, T, chunk_samples): + e = min(T, s + chunk_samples) + chunks.append((waveform[s:e], s / float(sample_rate))) + return chunks + + +def split_text_units(transcript: str, mode: str) -> List[Tuple[str, int, int]]: + """ + Split transcript into alignment units. + Returns list of (unit_text, start_char_idx, end_char_idx). + """ + units: List[Tuple[str, int, int]] = [] + if mode == "word": + # Split on whitespace + for m in re.finditer(r"\S+", transcript): + units.append((m.group(0), m.start(), m.end())) + elif mode == "char": + # Split into individual characters (ignoring whitespace) + for idx, ch in enumerate(transcript): + if not ch.isspace(): + units.append((ch, idx, idx + 1)) + else: + raise ValueError(f"Unknown mode: {mode}") + return units + + +# ============================================================================= +# CTC ALIGNMENT LOGIC +# ============================================================================= + + +def _make_seqs_layout(x_batched: Tensor, seq_lens: Sequence[int]) -> Any: + return BatchLayout.of(batch=x_batched, seq_lens=list(seq_lens)) + + +def _ensure_tensor(maybe_ret: Any) -> Tensor: + return maybe_ret[0] if isinstance(maybe_ret, (tuple, list)) else maybe_ret + + +def get_ctc_logits( + asr_model: Any, waveform: Tensor, sample_rate: int +) -> Tuple[Tensor, float]: + device = next(asr_model.parameters()).device + + if hasattr(asr_model, "encoder_frontend"): + frontend = asr_model.encoder_frontend + elif hasattr(asr_model, "frontend"): + frontend = asr_model.frontend + else: + raise RuntimeError("Model has no encoder_frontend or frontend.") + + try: + f_dtype = next(frontend.parameters()).dtype + except Exception: + f_dtype = next(asr_model.parameters()).dtype + + with torch.no_grad(): + wav = waveform.to(device=device, dtype=f_dtype) + x = wav.unsqueeze(0) + seqs_layout = _make_seqs_layout(x, [wav.shape[0]]) + + ret = frontend.extract_features(x, seqs_layout) + if isinstance(ret, (tuple, list)): + feats, layout = ret[0], ret[1] + else: + feats, layout = ret, seqs_layout + + try: + feats = _ensure_tensor(frontend.process_features(feats, layout, None)) + except Exception: + pass + + if hasattr(asr_model, "encoder"): + enc = _ensure_tensor(asr_model.encoder(feats, layout)) + elif hasattr(asr_model, "w2v2"): + enc = _ensure_tensor(asr_model.w2v2(feats, layout)) + else: + raise RuntimeError("CTC model missing encoder block") + + proj = None + for name in ("final_proj", "ctc_head", "output_proj", "proj", "decoder"): + if hasattr(asr_model, name): + proj = getattr(asr_model, name) + break + if proj is None: + raise RuntimeError("CTC model missing projection head.") + + logits = proj(enc) + logits = logits.squeeze(0) + + s_frames = feats.shape[1] + t_samples = waveform.shape[0] + stride_samples = max(1, round(t_samples / s_frames)) + stride_seconds = stride_samples / float(sample_rate) + + return logits, stride_seconds + + +def align_ctc( + model: Any, + waveform: Tensor, + sample_rate: int, + transcript: str, +) -> List[dict]: + """ + Aligns text to audio using CTC boundaries. Returns list of dicts. + """ + if not transcript.strip(): + return [] + + # 1. Get Logits & Boundaries + logits, stride_seconds = get_ctc_logits(model, waveform, sample_rate) + + path = torch.argmax(logits, dim=-1).tolist() + boundaries = [0] + prev_token = -1 + for frame_idx, token_id in enumerate(path): + if token_id != 0 and token_id != prev_token: + boundaries.append(frame_idx) + prev_token = token_id + elif token_id != 0: + prev_token = token_id + boundaries.append(len(path)) + + # 2. Detect Mode & Split + mode = detect_alignment_mode(transcript) + units = split_text_units(transcript, mode) + + if not boundaries or len(boundaries) < 2 or not units: + return [] + + # 3. Interpolate + num_frames = len(boundaries) - 1 + U = len(units) + results: List[dict] = [] + + # Key name based on mode + key_name = "word" if mode == "word" else "char" + + for u_idx, (tok, _s, _e) in enumerate(units): + start_bin = int(round((u_idx) * num_frames / U)) + end_bin = int(round((u_idx + 1) * num_frames / U)) + + start_bin = min(max(start_bin, 0), num_frames) + end_bin = min(max(end_bin, start_bin + 1), num_frames) + + start_frame = boundaries[start_bin] + end_frame = boundaries[end_bin] + + results.append( + { + key_name: tok, + "start": start_frame * stride_seconds, + "end": end_frame * stride_seconds, + } + ) + + return results + + +# ============================================================================= +# LLM ALIGNMENT LOGIC (DTW) +# ============================================================================= + + +class AttentionStore: + def __init__(self) -> None: + self.weights: Dict[int, torch.Tensor] = {} + + def add_weights(self, layer_idx: int, attn: torch.Tensor) -> None: + self.weights[layer_idx] = attn.detach().cpu() + + def clear(self) -> None: + self.weights = {} + + +_attention_store = AttentionStore() + + +def make_patched_sdpa_forward(layer_idx: int, original_forward_func): + def patched_sdpa_forward(self, *args, **kwargs): + context, attn_weights_orig = original_forward_func(*args, **kwargs) + try: + if len(args) >= 3: + query, key = args[0], args[2] + if isinstance(query, torch.Tensor) and isinstance(key, torch.Tensor): + q = query.to(torch.float32) + k = key.to(torch.float32) + if q.ndim == 4 and k.ndim == 4 and q.shape[0] == 1: + B, T, H, D = q.shape + q_flat = q.reshape(T, H * D) + k_flat = k.reshape(T, H * D) + dim = max(1, H * D) + sim = (q_flat @ k_flat.T) / (float(dim) ** 0.5) + attn = F.softmax(sim, dim=-1) + _attention_store.add_weights(layer_idx, attn) + except Exception: + pass + return context, attn_weights_orig + + return patched_sdpa_forward + + +def forced_alignment_dtw(similarity_matrix: np.ndarray) -> List[int]: + """Pure Numpy DTW for forced alignment.""" + N_text, N_audio = similarity_matrix.shape + score = np.full((N_text + 1, N_audio + 1), -1e9, dtype=np.float32) + score[0, 0] = 0.0 + + for i in range(1, N_text + 1): + for j in range(1, N_audio + 1): + s = similarity_matrix[i - 1, j - 1] + score_diag = score[i - 1, j - 1] + score_left = score[i, j - 1] + score[i, j] = max(score_diag, score_left) + s + + path = [] + i, j = N_text, int(np.argmax(score[N_text, :])) + + while i > 0 and j > 0: + path.append((i - 1, j - 1)) + s_diag = score[i - 1, j - 1] + s_left = score[i, j - 1] + if s_diag >= s_left: + i, j = i - 1, j - 1 + else: + j -= 1 + + path = path[::-1] + token_frames: List[List[int]] = [[] for _ in range(N_text)] + for t_idx, f_idx in path: + token_frames[t_idx].append(f_idx) + + aligned_frames = [] + last_frame = 0 + for tf in token_frames: + if tf: + avg = int(np.mean(tf)) + aligned_frames.append(avg) + last_frame = avg + else: + aligned_frames.append(last_frame) + + return aligned_frames + + +@torch.inference_mode() +def align_llm( + pipeline: "ASRInferencePipeline", + waveform: Tensor, + transcript: str, + lang: Optional[str] = None, +) -> List[dict]: + if not transcript.strip(): + return [] + + # Narrow type for MyPy + model = cast("Wav2Vec2LlamaModel", pipeline.model) + + lang = lang if lang else "eng_Latn" + + # 1. Prepare Inputs + try: + if waveform.ndim > 1: + waveform = waveform.mean(dim=0) + + audio_data = [{"waveform": waveform, "sample_rate": 16000}] + audio_candidate = next( + iter(pipeline._build_audio_wavform_pipeline(audio_data).and_return()) + ) + audio_batch = pipeline._create_batch_simple([(audio_candidate, lang)]) + + source_seq_lens = audio_batch.source_seq_lens + if isinstance(source_seq_lens, torch.Tensor): + seq_lens_list: List[int] = source_seq_lens.tolist() + else: + seq_lens_list = list(source_seq_lens) + + audio_features, _ = model.embed_audio( + audio_batch.source_seqs.to(dtype=pipeline.dtype), + seq_lens_list, + ) + + token_ids = pipeline.token_encoder(transcript) + text_embeddings = model.embed_text( + token_ids.unsqueeze(0).to(pipeline.device), pipeline.dtype + ) + + vocab_info = pipeline.tokenizer.vocab_info + bos_emb = model.embed_text( + torch.tensor([[vocab_info.bos_idx]], device=pipeline.device), pipeline.dtype + ) + sep_emb = model.embed_text( + torch.tensor([[vocab_info.size]], device=pipeline.device), pipeline.dtype + ) + + lang_emb = torch.zeros( + 1, 0, audio_features.shape[-1], device=pipeline.device, dtype=pipeline.dtype + ) + lang_mapping = getattr(model, "lang_mapping", None) + if model.lang_embeddings is not None and lang_mapping is not None: + lid = lang_id_getter(lang_mapping, lang) + lang_emb = model.lang_embeddings( + torch.tensor([lid], device=pipeline.device).unsqueeze(0) + ) + + full_input = torch.cat( + [audio_features, sep_emb, lang_emb, bos_emb, text_embeddings], dim=1 + ) + + except Exception as e: + print(f"Error preparing LLM alignment inputs: {e}") + return [] + + # 2. Run Decoder with Hook + _attention_store.clear() + original_forwards: Dict[int, Callable[..., Any]] = {} + patched_layers: List[int] = [] + + try: + for i, layer in enumerate(model.llama_decoder.layers): + sdpa = getattr(layer.self_attn, "sdpa", None) + sdpa_forward = getattr(sdpa, "forward", None) + if sdpa is None or sdpa_forward is None: + continue + + original_forwards[i] = cast(Callable[..., Any], sdpa_forward) + sdpa.forward = types.MethodType( + make_patched_sdpa_forward(i, original_forwards[i]), sdpa + ) + patched_layers.append(i) + + B, T_full, _ = full_input.shape + layout = BatchLayout( + shape=(B, T_full), seq_lens=[T_full], packed=False, device=full_input.device + ) + + model.llama_decoder(seqs=full_input, seqs_layout=layout, state_bag=None) + + except Exception as e: + print(f"Error running LLM alignment pass: {e}") + return [] + finally: + for i in patched_layers: + layer = model.llama_decoder.layers[i] + sdpa = getattr(layer.self_attn, "sdpa", None) + if sdpa is not None and i in original_forwards: + sdpa.forward = original_forwards[i] + + if not _attention_store.weights: + return [] + + # 3. DTW + L_audio = audio_features.shape[1] + L_pre = sep_emb.shape[1] + lang_emb.shape[1] + bos_emb.shape[1] + L_text = text_embeddings.shape[1] + + query_start = L_audio + L_pre + + sorted_layers = sorted(_attention_store.weights.keys()) + num_layers = len(sorted_layers) + start_layer = int(num_layers * 0.4) + end_layer = int(num_layers * 0.9) + selected = [l for l in sorted_layers if start_layer <= l < end_layer] + if not selected: + selected = sorted_layers + + avg_attn = torch.zeros_like(_attention_store.weights[selected[0]]) + for l in selected: + avg_attn += _attention_store.weights[l] + avg_attn /= len(selected) + + cross_attn = avg_attn[query_start : query_start + L_text, :L_audio].numpy() + + aligned_frames = forced_alignment_dtw(cross_attn) + + # 4. Decode to Words/Chars (Mode detection logic) + decoded_tokens = [ + pipeline.token_decoder(token_ids[i : i + 1]) for i in range(len(token_ids)) + ] + + mode = detect_alignment_mode(transcript) + key_name = "word" if mode == "word" else "char" + + # Group tokens + groups = [] + current_group: List[int] = [] + + if mode == "word": + # Group subwords into words (SentencePiece logic) + for i, tk in enumerate(decoded_tokens): + if (tk.startswith(" ") or tk in ".,!?") and current_group: + groups.append(current_group) + current_group = [] + current_group.append(i) + if current_group: + groups.append(current_group) + else: + # Char mode: We still have subword tokens. We must map tokens to text characters. + # This is complex for LLM tokens vs unicode chars. + # Simplified approach for LLM: One token usually contains multiple chars or one char. + # We will output token-level timestamps but label them "char" if they are single chars. + # Ideally, we would map token-frames to char-offsets, but that requires a char-to-token map. + # Fallback: Just report token-level as "char" units for now, grouping strictly by token index. + for i in range(len(decoded_tokens)): + groups.append([i]) + + results = [] + frame_dur = 0.02 # 20ms + + for indices in groups: + s_idx, e_idx = indices[0], indices[-1] + s_frame, e_frame = aligned_frames[s_idx], aligned_frames[e_idx] + + start = s_frame * frame_dur + end = e_frame * frame_dur + frame_dur + + text_fragment = "".join([decoded_tokens[i] for i in indices]) + if mode == "word": + text_fragment = text_fragment.replace(" ", "") + + if not text_fragment: + continue + + results.append({key_name: text_fragment, "start": start, "end": end}) + + return results diff --git a/src/omnilingual_asr/models/inference/pipeline.py b/src/omnilingual_asr/models/inference/pipeline.py index e0ba01c..f3bc660 100644 --- a/src/omnilingual_asr/models/inference/pipeline.py +++ b/src/omnilingual_asr/models/inference/pipeline.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Final, List, Tuple +from typing import Any, Dict, Final, List, Tuple, cast import numpy as np import torch @@ -33,6 +33,11 @@ from numpy.typing import NDArray from omnilingual_asr.datasets.utils.audio import add_waveform_processing +from omnilingual_asr.models.inference.align import ( + align_ctc, + align_llm, + chunk_waveform, +) from omnilingual_asr.models.wav2vec2_llama.beamsearch import ( Wav2Vec2LlamaBeamSearchSeq2SeqGenerator, ) @@ -141,12 +146,13 @@ def assert_max_length( current_sample_rate = audio_data["sample_rate"] waveform_len_s = len(waveform) / current_sample_rate if waveform_len_s > MAX_ALLOWED_AUDIO_SEC: - raise ValueError(f"Max audio length is capped at {MAX_ALLOWED_AUDIO_SEC}s") + raise ValueError( + f"Audio length {waveform_len_s:.2f}s exceeds max {MAX_ALLOWED_AUDIO_SEC}s. Use chunk_len parameter to process long files." + ) return audio_data class ASRInferencePipeline: - def __init__( self, model_card: str | None, @@ -256,7 +262,8 @@ def __init__( text_collate_opts = CollateOptionsOverride("text", pad_value=pad_idx) self.full_collater = Collater( - pad_value=0, overrides=[text_collate_opts] # Default pad value for audio + pad_value=0, + overrides=[text_collate_opts], # Default pad value for audio ) self.collater_audio = Collater(pad_value=0) self.collater_text = Collater(pad_value=pad_idx) @@ -412,8 +419,7 @@ def _apply_model(self, batch: Seq2SeqBatch) -> List[str]: return transcriptions def _build_audio_wavform_pipeline( - self, - inp_list: AudioInput, + self, inp_list: AudioInput, check_max_length: bool = True ) -> DataPipelineBuilder: """Process audio inputs using fairseq2.data pipeline similar to ASR task.""" # Build pipeline based on input type @@ -455,7 +461,7 @@ def _build_audio_wavform_pipeline( # Check max allowed length if non-streaming non_streaming = not self.streaming_config.is_streaming - if non_streaming: + if non_streaming and check_max_length: builder = builder.map(assert_max_length, selector="data") # Add waveform processing @@ -488,7 +494,9 @@ def _process_context_audio( return None # Use the same audio processing pipeline as main inference - builder = self._build_audio_wavform_pipeline([example.audio for example in context_examples]) # type: ignore[arg-type] + # Cast mixed list to AudioInput to satisfy type checker + raw_audio = cast(AudioInput, [example.audio for example in context_examples]) + builder = self._build_audio_wavform_pipeline(raw_audio) context_audio_tensors = list(builder.and_return()) collated_audio = self.collater_audio(context_audio_tensors) @@ -564,11 +572,14 @@ def transcribe( *, lang: List[str | None] | List[str] | List[None] | None = None, batch_size: int = 2, - ) -> List[str]: + chunk_len: float | None = None, + ) -> Tuple[List[str], List[List[Dict[str, Any]]]]: """ Transcribes `AudioInput` into text by preprocessing (decoding, resample to 16kHz, converting to mono, normalizing) each input sample and performing inference with `self.model`. + Returns text and extracted timestamps. + Works for both CTC and LLM model variants by optionally allowing a language conditioning token to help with LLM generation. It is ignored when performing inference with CTC. See `omnilingual_asr/models/wav2vec2_llama/lang_ids.py` for supported languages. @@ -580,13 +591,16 @@ def transcribe( - `List[ dict[str, Any] ]`: Pre-decoded audio with 'waveform' and 'sample_rate' keys `lang`: Language code for the input audios (e.g., 'eng_Latn', ...) (default: None) - List [ str | None ]`: Any combination of missing and available language ids. - `batch_size`: Number of audio samples to process in each batch. + `batch_size`: Number of audio samples to process in each batch (per chunk). + `chunk_len`: Maximum length in seconds for processing. Longer files will be split. Returns: - `List[str]`: Transcribed texts. + Tuple[List[str], List[List[Dict[str, Any]]]]: + - List of transcribed texts. + - List of List of timestamp dicts (e.g. [{'word': 'Hello', 'start': 0.0, 'end': 0.5}]) """ if len(inp) == 0: - return [] + return [], [] # fmt: off is_ctc_model = isinstance(self.model, Wav2Vec2AsrModel) @@ -608,24 +622,77 @@ def transcribe( assert len(lang) == len(inp), f"`lang` must be a list of the same length as `inp` ({len(inp)}), but is {len(lang)}." # fmt: on - # Process audio files using fairseq2.data pipeline - builder = DataPipeline.zip( - [ - self._build_audio_wavform_pipeline(inp).and_return(), - read_sequence(lang).and_return(), - ] - ) - - builder = builder.bucket(batch_size) - builder = builder.map(self._create_batch_simple) - builder = builder.prefetch(1) - builder = builder.map(self._apply_model) - builder = builder.yield_from( - lambda seq: read_sequence(seq).and_return() - ) # flatten the sequence of sequences - - transcriptions = list(builder.and_return()) - return transcriptions + final_transcripts = [] + final_timestamps = [] + + # Pre-load waveforms one by one to handle chunking without pipeline's strict length check + # This replaces the streaming data pipeline with an eager loading strategy to support chunking/alignment + for idx, (input_item, input_lang) in enumerate(zip(inp, lang)): + # Use pipeline builder to decode/resample/norm, but bypass length check + single_input: AudioInput = cast(AudioInput, [input_item]) + p = self._build_audio_wavform_pipeline( + single_input, check_max_length=False + ).and_return() + waveform = next(iter(p)) # Tensor[T] + + duration = waveform.shape[0] / 16000.0 + + if chunk_len is not None and duration > chunk_len: + chunks = chunk_waveform(waveform, 16000, chunk_len) + else: + if duration > MAX_ALLOWED_AUDIO_SEC and chunk_len is None: + raise ValueError( + f"Audio {idx} duration {duration:.2f}s > {MAX_ALLOWED_AUDIO_SEC}s. Provide chunk_len parameter." + ) + chunks = [(waveform, 0.0)] + + input_text_parts = [] + input_timestamps = [] + + chunk_waveforms = [c[0] for c in chunks] + offsets = [c[1] for c in chunks] + + # Process chunks in batches + for i in range(0, len(chunk_waveforms), batch_size): + batch_wavs = chunk_waveforms[i : i + batch_size] + batch_offsets = offsets[i : i + batch_size] + + batch_data = [(w, input_lang) for w in batch_wavs] + seq2seq_batch = self._create_batch_simple(batch_data) + texts = self._apply_model(seq2seq_batch) + + for j, text in enumerate(texts): + wav_segment = batch_wavs[j] + offset = batch_offsets[j] + + if not text.strip(): + input_text_parts.append("") + continue + + chunk_ts = [] + try: + if isinstance(self.model, Wav2Vec2AsrModel): + chunk_ts = align_ctc(self.model, wav_segment, 16000, text) + elif isinstance(self.model, Wav2Vec2LlamaModel): + chunk_ts = align_llm(self, wav_segment, text, input_lang) + except Exception as e: + log.warning( + f"Alignment failed for chunk {i + j} of input {idx}: {e}" + ) + chunk_ts = [] + + for w in chunk_ts: + w["start"] += offset + w["end"] += offset + input_timestamps.append(w) + + input_text_parts.append(text) + + full_transcript = " ".join(t for t in input_text_parts if t.strip()) + final_transcripts.append(full_transcript) + final_timestamps.append(input_timestamps) + + return final_transcripts, final_timestamps @torch.inference_mode() def transcribe_with_context( @@ -697,6 +764,4 @@ def transcribe_with_context( combined_builder = combined_builder.yield_from( lambda seq: read_sequence(seq).and_return() ) - - transcriptions = list(combined_builder.and_return()) - return transcriptions + return list(combined_builder.and_return()) diff --git a/src/omnilingual_asr/models/wav2vec2_llama/beamsearch.py b/src/omnilingual_asr/models/wav2vec2_llama/beamsearch.py index 04ca543..59b8f18 100644 --- a/src/omnilingual_asr/models/wav2vec2_llama/beamsearch.py +++ b/src/omnilingual_asr/models/wav2vec2_llama/beamsearch.py @@ -423,7 +423,8 @@ def generate_hypotheses_one_segment( # Choose nbest if self.config.length_norm: n_tokens = torch.logical_and( - out_tokens[:, :t] != self.pad_idx, out_tokens[:, :t] != self.eos_idx # type: ignore[arg-type] + out_tokens[:, :t] != self.pad_idx, + out_tokens[:, :t] != self.eos_idx, # type: ignore[arg-type] ).sum(dim=1, keepdim=True) if n_tokens[0, 0] > 0: candidate_scores = (scores.unsqueeze(1) * n_tokens + log_probs) / ( diff --git a/src/omnilingual_asr/models/wav2vec2_llama/model.py b/src/omnilingual_asr/models/wav2vec2_llama/model.py index 489b1a9..a9d908b 100644 --- a/src/omnilingual_asr/models/wav2vec2_llama/model.py +++ b/src/omnilingual_asr/models/wav2vec2_llama/model.py @@ -1017,7 +1017,9 @@ def embed_audio( seqs, seqs_layout ) enc_out, _ = self.encoder_frontend.process_features( - enc_out, enc_layout, self.masker if self.training else None # type: ignore + enc_out, + enc_layout, + self.masker if self.training else None, # type: ignore ) enc_out = self.encoder(enc_out, enc_layout) diff --git a/workflows/dataprep/audio_tools.py b/workflows/dataprep/audio_tools.py index cd6206a..c70c315 100644 --- a/workflows/dataprep/audio_tools.py +++ b/workflows/dataprep/audio_tools.py @@ -284,9 +284,10 @@ def binary_to_list_int8(binary_array: pa.Array | pa.ChunkedArray) -> pa.Array: data_np = np.frombuffer(data, dtype="int8")[offsets_np[0] :] # type: ignore offsets_np -= offsets_np[0] + offsets_array = pa.array(offsets_np, type=pa.int32()) values_array = pa.array(data_np, type=pa.int8()) list_array = pa.ListArray.from_arrays( - offsets_np, values_array, mask=binary_array.is_null() + offsets_array, values_array, mask=binary_array.is_null() ) return list_array diff --git a/workflows/dataprep/text_tools.py b/workflows/dataprep/text_tools.py index 074e9d7..120e158 100644 --- a/workflows/dataprep/text_tools.py +++ b/workflows/dataprep/text_tools.py @@ -80,7 +80,6 @@ def text_normalize( # The lookaround enables overlapping pattern matches to be replaced if remove_numbers: - digits_pattern = "[" + config["digit_set"] digits_pattern += "]+" diff --git a/workflows/recipes/wav2vec2/asr/criterion.py b/workflows/recipes/wav2vec2/asr/criterion.py index b1331b9..d4856b1 100644 --- a/workflows/recipes/wav2vec2/asr/criterion.py +++ b/workflows/recipes/wav2vec2/asr/criterion.py @@ -106,9 +106,7 @@ def _forward_with_logits(self, batch: Seq2SeqBatch) -> Tuple[ """ if isinstance(self._model.base_module, Wav2Vec2LlamaModel): # Llama model requires batch.extras for constructing the input batch - return self._model.module( - batch, return_logits=True - ) # type: ignore[call-overload] + return self._model.module(batch, return_logits=True) # type: ignore[call-overload] else: source_seqs, source_seqs_layout = batch.as_source_input() # Audio target_seqs, target_seqs_layout = batch.as_target_input() # Text tokens