Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/omnilingual_asr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def setup_fairseq2_extension(container: DependencyContainer) -> None:


def _register_models(container: DependencyContainer) -> None:

# Only adding custom wav2vec2 archs for wav2vec2_ssl model in fs2
register_omnilingual_asr_wav2vec2_ssl_configs(container)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def create_reader(
storage_config: ManifestStorageConfig,
task_config: AsrTaskConfig,
) -> DataReader[Seq2SeqBatch]:

storage = ManifestStorage(
splits=self.splits,
manifest_dir=self.manifest_dir,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ def is_train_streaming(split: str, sync_mode: SyncMode) -> bool:

@override
def create_raw_data_pipeline(self, split: str, gangs: Gangs) -> DataPipelineBuilder:

config = self.config
schema: LangASRSchema = LangASRSchema()

Expand Down
20 changes: 10 additions & 10 deletions src/omnilingual_asr/datasets/tasks/asr_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def apply_processing_pipeline( # type: ignore[override]
tokenizer: Tokenizer,
dtype: torch.dtype,
) -> DataPipelineBuilder:

config = self.config

# Filtering audio to optimize before batching
Expand Down Expand Up @@ -246,7 +245,6 @@ def add_tokenization_pipeline(
text_selector: str,
audio_length_selector: str,
) -> DataPipelineBuilder:

builder = filter_empty_text(builder, text_selector=text_selector)

builder = filter_fast_speech(
Expand All @@ -262,9 +260,13 @@ def add_tokenization_pipeline(
text_selector=text_selector,
)

builder = filter_unknown_sequences(
builder, unk_idx=tokenizer.vocab_info.unk_idx, text_selector=text_selector # type: ignore
)
unk_idx = tokenizer.vocab_info.unk_idx
if unk_idx is not None:
builder = filter_unknown_sequences(
builder,
unk_idx=unk_idx,
text_selector=text_selector, # type: ignore
)

if filter_long_text_threshold is not None:
builder = filter_long_text(
Expand All @@ -273,9 +275,10 @@ def add_tokenization_pipeline(
text_selector=text_selector,
)

if remove_unknown:
if remove_unknown and unk_idx is not None:
builder = filter_unknown_tokens(
builder, unk_idx=tokenizer.vocab_info.unk_idx # type: ignore
builder,
unk_idx=unk_idx, # type: ignore
)
return builder

Expand All @@ -293,7 +296,6 @@ def add_bucketing_pipeline(
batch_size: int,
no_padding: bool,
) -> DataPipelineBuilder:

if batching is BatchingStrategy.LENGTH:
builder = add_length_batching(
builder,
Expand Down Expand Up @@ -358,7 +360,6 @@ def add_audio_processing_pipeline(
npc: int,
unified_audio_feature_keys: bool,
) -> DataPipelineBuilder:

builder = add_audio_decoding(
builder,
dtype=dtype,
Expand Down Expand Up @@ -412,7 +413,6 @@ def add_postprocessing_pipeline(
num_prefetch: int,
no_padding: bool,
) -> DataPipelineBuilder:

if no_padding:
log.warning(
"Collating without padding is currently not supported, defaulting to padding."
Expand Down
4 changes: 0 additions & 4 deletions src/omnilingual_asr/datasets/tasks/ssl_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def apply_processing_pipeline( # type: ignore[override]
gangs: Gangs,
dtype: torch.dtype,
) -> DataPipelineBuilder:

config = self.config

# Shuffle individual samples
Expand Down Expand Up @@ -218,7 +217,6 @@ def add_bucketing_pipeline(
batch_size: int,
no_padding: bool,
) -> DataPipelineBuilder:

if batching is BatchingStrategy.LENGTH:
builder = add_length_batching(
builder,
Expand Down Expand Up @@ -260,7 +258,6 @@ def add_audio_processing_pipeline(
no_padding: bool,
seed: int,
) -> DataPipelineBuilder:

builder = add_audio_decoding(
builder,
dtype=dtype,
Expand Down Expand Up @@ -320,7 +317,6 @@ def add_postprocessing_pipeline(
num_prefetch: int,
no_padding: bool,
) -> DataPipelineBuilder:

collater = Collater(pad_value=None if no_padding else 0)

builder.map(collater, num_parallel_calls=npc)
Expand Down
Loading