Skip to content

Commit 452276a

Browse files
authored
Refactor: Extract shared constants, populate __init__.py, fix Python version constraint (jhauret#63)
2 parents e479891 + 70538e1 commit 452276a

14 files changed

Lines changed: 134 additions & 22 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ version = "0.1.1"
44
description = "Speech to Phoneme, Bandwidth Extension and Speaker Verification using the Vibravox dataset."
55
authors = [{ name = "Julien Hauret", email = "j.hauret.33@gmail.com" }]
66
readme = "README.md"
7-
requires-python = "==3.12.0"
7+
requires-python = ">=3.12.0"
88
dependencies = [
99
"moshi==0.2.4",
1010
"isort",

vibravox/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""Vibravox: Speech to Phoneme, Bandwidth Extension and Speaker Verification using the Vibravox dataset."""
2+
3+
from vibravox.constants import LIST_OF_VIBRAVOX, AVAILABLE_SENSORS
4+
5+
__all__ = [
6+
"LIST_OF_VIBRAVOX",
7+
"AVAILABLE_SENSORS",
8+
]

vibravox/constants.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Constants used throughout the vibravox package."""
2+
3+
# List of supported Vibravox datasets
4+
LIST_OF_VIBRAVOX = [
5+
"Cnam-LMSSC/vibravox",
6+
"Cnam-LMSSC/vibravox2",
7+
"Cnam-LMSSC/vibravox-test",
8+
"Cnam-LMSSC/non_curated_vibravox",
9+
"Cnam-LMSSC/vibravox_enhanced_by_EBEN",
10+
]
11+
12+
# List of available sensors
13+
AVAILABLE_SENSORS = [
14+
"headset_microphone",
15+
"throat_microphone",
16+
"forehead_accelerometer",
17+
"rigid_in_ear_microphone",
18+
"soft_in_ear_microphone",
19+
"temple_vibration_pickup",
20+
]

vibravox/datasets/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Dataset utilities for the Vibravox package."""
2+
3+
from vibravox.datasets.speech_noise import SpeechNoiseDataset
4+
5+
__all__ = [
6+
"SpeechNoiseDataset",
7+
]
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Lightning DataModules for the Vibravox package."""
2+
3+
from vibravox.lightning_datamodules.bwe import BWELightningDataModule
4+
from vibravox.lightning_datamodules.noisybwe import NoisyBWELightningDataModule
5+
from vibravox.lightning_datamodules.spkv import SPKVLightningDataModule
6+
from vibravox.lightning_datamodules.stp import STPLightningDataModule
7+
8+
__all__ = [
9+
"BWELightningDataModule",
10+
"NoisyBWELightningDataModule",
11+
"SPKVLightningDataModule",
12+
"STPLightningDataModule",
13+
]

vibravox/lightning_datamodules/bwe.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,13 @@
77
from torch.nn.utils.rnn import pad_sequence
88
from torch.utils.data import DataLoader
99

10+
from vibravox.constants import LIST_OF_VIBRAVOX
1011
from vibravox.torch_modules.dsp.data_augmentation import WaveformDataAugmentation
1112
from vibravox.utils import set_audio_duration
1213

1314

1415
class BWELightningDataModule(LightningDataModule):
1516

16-
LIST_OF_VIBRAVOX = [
17-
"Cnam-LMSSC/vibravox",
18-
"Cnam-LMSSC/vibravox2",
19-
"Cnam-LMSSC/vibravox-test",
20-
"Cnam-LMSSC/non_curated_vibravox",
21-
"Cnam-LMSSC/vibravox_enhanced_by_EBEN",
22-
]
23-
2417
def __init__(
2518
self,
2619
sample_rate: int = 16000,
@@ -63,12 +56,12 @@ def __init__(
6356

6457
self.dataset_name_principal = dataset_name_principal
6558
assert (
66-
dataset_name_principal in self.LIST_OF_VIBRAVOX
59+
dataset_name_principal in LIST_OF_VIBRAVOX
6760
), f"dataset_name_principal {dataset_name_principal} not supported."
6861

6962
self.dataset_name_secondary = dataset_name_secondary
7063
assert (
71-
dataset_name_secondary is None or dataset_name_secondary in self.LIST_OF_VIBRAVOX
64+
dataset_name_secondary is None or dataset_name_secondary in LIST_OF_VIBRAVOX
7265
), f"dataset_name_secondary {dataset_name_secondary} not supported."
7366

7467
self.subset = subset

vibravox/lightning_datamodules/noisybwe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from datasets import Audio, load_dataset
66
from torch.nn.utils.rnn import pad_sequence
77
from torch.utils.data import DataLoader
8-
from lightning import LightningDataModule
98
from vibravox.utils import mix_speech_and_noise_without_rescaling
109
from vibravox.utils import set_audio_duration
1110
from vibravox.torch_modules.dsp.data_augmentation import WaveformDataAugmentation

vibravox/lightning_datamodules/stp.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,13 @@
55
from lightning import LightningDataModule
66
from torch.utils.data import DataLoader
77
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer
8+
9+
from vibravox.constants import LIST_OF_VIBRAVOX
810
from vibravox.torch_modules.dsp.data_augmentation import WaveformDataAugmentation
911

1012

1113
class STPLightningDataModule(LightningDataModule):
1214

13-
LIST_OF_VIBRAVOX = [
14-
"Cnam-LMSSC/vibravox",
15-
"Cnam-LMSSC/vibravox2",
16-
"Cnam-LMSSC/vibravox-test",
17-
"Cnam-LMSSC/non_curated_vibravox",
18-
"Cnam-LMSSC/vibravox_enhanced_by_EBEN",
19-
]
20-
2115
def __init__(
2216
self,
2317
sample_rate: int = 16000,
@@ -57,12 +51,12 @@ def __init__(
5751
self.sample_rate = sample_rate
5852
self.dataset_name_principal = dataset_name_principal
5953
assert (
60-
dataset_name_principal in self.LIST_OF_VIBRAVOX
54+
dataset_name_principal in LIST_OF_VIBRAVOX
6155
), f"dataset_name_principal {dataset_name_principal} not supported."
6256

6357
self.dataset_name_secondary = dataset_name_secondary
6458
assert (
65-
dataset_name_secondary is None or dataset_name_secondary in self.LIST_OF_VIBRAVOX
59+
dataset_name_secondary is None or dataset_name_secondary in LIST_OF_VIBRAVOX
6660
), f"dataset_name_secondary {dataset_name_secondary} not supported."
6761
self.subset = subset
6862
self.sensor = sensor
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Lightning Modules for the Vibravox package."""
2+
3+
from vibravox.lightning_modules.base_se import BaseSELightningModule
4+
from vibravox.lightning_modules.eben import EBENLightningModule
5+
from vibravox.lightning_modules.ecapa2 import ECAPA2LightningModule
6+
from vibravox.lightning_modules.regressive_mimi import RegressiveMimiLightningModule
7+
from vibravox.lightning_modules.wav2vec2_for_stp import Wav2Vec2ForSTPLightningModule
8+
9+
__all__ = [
10+
"BaseSELightningModule",
11+
"EBENLightningModule",
12+
"ECAPA2LightningModule",
13+
"RegressiveMimiLightningModule",
14+
"Wav2Vec2ForSTPLightningModule",
15+
]

vibravox/metrics/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Metrics for the Vibravox package."""
2+
3+
from vibravox.metrics.embedding_distance import BinaryEmbeddingDistance
4+
from vibravox.metrics.equal_error_rate import EqualErrorRate
5+
from vibravox.metrics.minimum_dcf import MinimumDetectionCostFunction
6+
from vibravox.metrics.noresqa_mos import NoresqaMOS
7+
from vibravox.metrics.torchsquim_stoi import TorchsquimSTOI
8+
9+
__all__ = [
10+
"BinaryEmbeddingDistance",
11+
"EqualErrorRate",
12+
"MinimumDetectionCostFunction",
13+
"NoresqaMOS",
14+
"TorchsquimSTOI",
15+
]

0 commit comments

Comments
 (0)