|
| 1 | +--- |
| 2 | +title: now-playing-implementation |
| 3 | +publishedAt: '2024-05-26' |
| 4 | +summary: 'Streamlined implementation of the now-playing research paper' |
| 5 | +tags: |
| 6 | + - research_paper |
| 7 | + - code |
| 8 | +--- |
| 9 | +In this blog is a complement to the Now Playing blog, brief about Now Playing paper while in this one we talk solely on neural network fingerprinter implementation |
| 10 | + |
| 11 | +--- |
| 12 | + |
| 13 | +We will implement the streamlined version of Now Playing architecture. We will train the data on BurdCLEF-2024 dataset from kaggle. |
| 14 | + |
| 15 | +We have to start with creating a class to load the dataset from our file system into a format that can be used to train our model. We will be using PyTorch for implementation. |
| 16 | +# Audio Dataset Preparation |
| 17 | + |
| 18 | +Audio data often comes in raw formats like `.mp3`, `.wav` or `.ogg`. To work with PyTorch we have create custom Dataset class. |
| 19 | +The goal of this custom class is to arrange audio files with corresponding labels and apply preprocessing steps to transform raw data input into usable input. |
| 20 | + |
| 21 | +``` |
| 22 | +root_dir/ |
| 23 | + ├── class_1/ |
| 24 | + │ ├── audio1.wav |
| 25 | + │ ├── audio2.wav |
| 26 | + ├── class_2/ |
| 27 | + │ ├── audio3.wav |
| 28 | +``` |
| 29 | + |
| 30 | +Since we are working with audio we will be using mel spectrogram as an input. We have to transform raw data into mel spectrogram by preprocessing data while loading it. |
| 31 | + |
| 32 | +> Naturally question arise why we cannot use raw audio format such .mp3, `.wav` or `.ogg`?. |
| 33 | +
|
| 34 | +Because using raw audio have several imitations for the recognition tasks due to the following reasons: |
| 35 | + 1. **High Dimensionality**: Raw audio signals are sampled at high rates resulting in large amount of data which increases computational cost and memory usage. |
| 36 | + 2. **Lack of Frequency Domain Information**: Raw audio resides in the time domain, which does not explicity reveal the frequency content of the signal. Many recognition tasks required frequency information to capture pitch, timbre, and harmonic patterns. |
| 37 | + 3. **Noise Sensitivity**: Raw audio contains all the noise present in the signal, including environmental disturbances and recording artifacts. Noise in irrelevant frequency ranges can obscure meaningful patterns, making recognition less accurate. |
| 38 | + 4. **Inefficient Feature :** Raw audio lacks the structured representation of high-level features (e.g., formants in speech, harmonic in music). Extracting relevant features directly from raw audio requires more complex models, increasing the risk of overfitting. |
| 39 | + 5. **Poor Human Perception Alignment**: Human hearing is non-linear, with greater sensitivity to certain frequency ranges. Raw audio does not reflect this preceptual bias, leading to inefficient feature extraction for tasks involving human-centric audio processing. |
| 40 | + |
| 41 | +Mel Scale solves this limitations, since the Mel scale is a perceptual scale of pitches that map frequencies to a scale that aligns with human auditory perception. It works by compressing the frequency spectrum non-linearly, emphasizing frequencies that human hear more distinctly (low and mid frequencies) and de-emphasizing higher frequencies. |
| 42 | + |
| 43 | +## Key steps for Data Preparation |
| 44 | + |
| 45 | +1. **Preprocessing**: |
| 46 | + - The `MelSpectrogram` transform convert audio waveforms into Mel spectrogram. |
| 47 | + - To improve stability during trainign mel spectrograms are normalized to have zero mean and unit variance. |
| 48 | + |
| 49 | +2. **Triplet Sampling:** |
| 50 | + - We will be using triplet loss function for training or model. It involves selecting three samples: an anchor, a same class as the anchor (positive) and a different class (negative) |
| 51 | +## Basic Workflow of our AudioDataset Class |
| 52 | +The `AudioDataset` class will: |
| 53 | +- Load audio data from directories. |
| 54 | +- Preprocess the data by converting it into mel spectrograms from raw audio format |
| 55 | +- Generate triplets of anchor, positivem and negative. |
| 56 | + |
| 57 | +```python |
| 58 | +import os |
| 59 | +import random |
| 60 | +import torch |
| 61 | +import torch.nn as nn |
| 62 | +import torch.optim as optim |
| 63 | +from torch.utils.data import Dataset, DataLoader |
| 64 | +import torchaudio |
| 65 | +from torchaudio.transforms import MelSpectrogram |
| 66 | + |
| 67 | +# Dataset Class |
| 68 | +class AudioDataset(Dataset): |
| 69 | + def __init__(self, root_dir, sample_rate=16000, n_mels=128, max_time_steps=1000): |
| 70 | + self.root_dir = root_dir |
| 71 | + self.sample_rate = sample_rate |
| 72 | + self.n_mels = n_mels |
| 73 | + self.max_time_steps = max_time_steps |
| 74 | + self.data = [] |
| 75 | + self.label_to_files = {} |
| 76 | + self.mel_transform = MelSpectrogram( |
| 77 | + sample_rate=sample_rate, n_mels=n_mels, n_fft=1024, hop_length=512 |
| 78 | + ) |
| 79 | + self._load_data() |
| 80 | + # loading the data from the folders |
| 81 | + def _load_data(self): |
| 82 | + for label_idx, label in enumerate(sorted(os.listdir(self.root_dir))): |
| 83 | + label_path = os.path.join(self.root_dir, label) |
| 84 | + if os.path.isdir(label_path): |
| 85 | + for file in os.listdir(label_path): |
| 86 | + if file.endswith(".wav") or file.endswith(".ogg"): |
| 87 | + file_path = os.path.join(label_path, file) |
| 88 | + self.data.append((file_path, label_idx)) |
| 89 | + if label_idx not in self.label_to_files: |
| 90 | + self.label_to_files[label_idx] = [] |
| 91 | + self.label_to_files[label_idx].append(file_path) |
| 92 | + |
| 93 | + # Return the length of dataset |
| 94 | + def __len__(self): |
| 95 | + return len(self.data) |
| 96 | + |
| 97 | + # transform to mel spectrogram |
| 98 | + def _load_audio(self, audio_path): |
| 99 | + waveform, sample_rate = torchaudio.load(audio_path) |
| 100 | + |
| 101 | + # Resample if necessary |
| 102 | + if sample_rate != self.sample_rate: |
| 103 | + resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate) |
| 104 | + waveform = resampler(waveform) |
| 105 | + |
| 106 | + mel_spectrogram = self.mel_transform(waveform) |
| 107 | + mel_spectrogram = (mel_spectrogram - mel_spectrogram.mean()) / mel_spectrogram.std() |
| 108 | + |
| 109 | + # Pad or trim to fixed time steps |
| 110 | + if mel_spectrogram.shape[-1] > self.max_time_steps: |
| 111 | + mel_spectrogram = mel_spectrogram[:, :, :self.max_time_steps] |
| 112 | + else: |
| 113 | + mel_spectrogram = torch.nn.functional.pad(mel_spectrogram, (0, self.max_time_steps - mel_spectrogram.shape[-1])) |
| 114 | + |
| 115 | + return mel_spectrogram |
| 116 | + |
| 117 | + def get_triplet(self, idx): |
| 118 | + anchor_path, anchor_label = self.data[idx] |
| 119 | + positive_path = random.choice(self.label_to_files[anchor_label]) |
| 120 | + negative_label = random.choice([label for label in self.label_to_files if label != anchor_label]) |
| 121 | + negative_path = random.choice(self.label_to_files[negative_label]) |
| 122 | + |
| 123 | + anchor = self._load_audio(anchor_path) |
| 124 | + positive = self._load_audio(positive_path) |
| 125 | + negative = self._load_audio(negative_path) |
| 126 | + |
| 127 | + return anchor, positive, negative |
| 128 | + |
| 129 | + def __getitem__(self, idx): |
| 130 | + return self.get_triplet(idx) |
| 131 | + |
| 132 | +``` |
| 133 | + |
| 134 | +# Desigining the Neural Network Fingerprinter Model |
| 135 | + |
| 136 | +The NNFM model is at the heart of the pipeline. It learns to extract embedding from mel spectrograms. |
| 137 | + |
| 138 | +## Architecture |
| 139 | + |
| 140 | +Basically, are model is trying to understand the patterns with the help of loss function (triplet loss). After learning the patterns the model will able to create embedding. |
| 141 | + |
| 142 | +Following are the layers we used to understand the patterns in our input: |
| 143 | + |
| 144 | +### 1. Convolution Layers. |
| 145 | + |
| 146 | +#### Key Concept |
| 147 | +A convolution is a methematical operation that combines two functions (input and kernel) and produces a output function. |
| 148 | +The output function have a modify or characterized properties of the input function. |
| 149 | + |
| 150 | +In simple mathematical term, Convolution is a mathematical operation that expresses the amount of overlap of one function $g$ (kernel) as it is shifted over another function $f$ (input) over the $t \in [a, b]$ |
| 151 | + |
| 152 | +$$ |
| 153 | + (f * g)(t) = \int_{-\infty}^{\infty} f(\tau)g(t - \tau) \, d\tau |
| 154 | +$$ |
| 155 | +Where $[f *g](t)$ is a convolution of function $f(\tau)$ and $g(t - \tau)$ |
| 156 | + |
| 157 | +It is widely used in signal processing and convolutional neural networks (CNNs) to extract features like edges, textures, and patterns. |
| 158 | + |
| 159 | +### 2. Adaptive Pooling |
| 160 | +#### Key Concept |
| 161 | + |
| 162 | +We are using technique call adaptive pooling it allows us to fix the ouput size regardless the input size. It allows us |
| 163 | + |
| 164 | +### 3. Fully Connected Layer: |
| 165 | + |
| 166 | +They are fundamental part of neural networks known as feedforward networks, they are used mostly at the final stages to map the features to particualar classes. |
| 167 | + |
| 168 | +```python |
| 169 | +# Model Definition |
| 170 | +class AudioCNN(nn.Module): |
| 171 | + def __init__(self): |
| 172 | + super(AudioCNN, self).__init__() |
| 173 | + self.features = nn.Sequential( |
| 174 | + nn.Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), |
| 175 | + nn.ReLU(), |
| 176 | + nn.MaxPool2d(kernel_size=(2, 2)), |
| 177 | + nn.Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), |
| 178 | + nn.ReLU(), |
| 179 | + nn.MaxPool2d(kernel_size=(2, 2)) |
| 180 | + ) |
| 181 | + self.flatten = nn.Flatten() |
| 182 | + self.fc = nn.Sequential( |
| 183 | + nn.Linear(32 * (128 // 4) * (1000 // 4), 256), |
| 184 | + nn.ReLU(), |
| 185 | + nn.Linear(256, 128) |
| 186 | + ) |
| 187 | + |
| 188 | + def forward(self, x): |
| 189 | + x = self.features(x) |
| 190 | + x = self.flatten(x) |
| 191 | + x = self.fc(x) |
| 192 | + return x |
| 193 | +``` |
| 194 | + |
| 195 | +# Training with Tiplet Loss |
| 196 | + |
| 197 | +## Triplet Loss |
| 198 | + |
| 199 | +- A tirplet loss ensures that embeddings are learned such that similar sampels are closer in the embedding space and dissimilar samles are farther apart. |
| 200 | +- Formula: |
| 201 | + |
| 202 | +$$ |
| 203 | +||f(a) - f(p)||_2^2 + \text{margin} \leq ||f(a) - f(n)||_2^2 |
| 204 | +$$ |
| 205 | + |
| 206 | +Here, $(f(x))$ is the embedding function (our CNN in this case), and \(a\), \(p\), and \(n\) are the anchor, positive, and negative inputs, respectively. |
| 207 | + |
| 208 | +## Training the model |
| 209 | + |
| 210 | +## Collate Function |
| 211 | +To group anchor, positivre, and negative samples into separate tensors during batching. |
| 212 | + |
| 213 | +## Data Loader |
| 214 | +We are passing our data in batches of 4 triplets (anchor, positive, negative), while randomizing our whole data after each iteration (epoch). |
| 215 | + |
| 216 | +## Model and Loss |
| 217 | +AudioCNN: The CNN model defined earlier. |
| 218 | +TripletMarginLoss: |
| 219 | +margin=1.0: Ensures that negative embeddings are at least 1 unit farther from the anchor than positive embeddings. |
| 220 | +p=2: Uses L2 (Euclidean) distance. |
| 221 | +Adam: Optimizer for training with a learning rate of 0.001. |
| 222 | + |
| 223 | +```python |
| 224 | +# Training Script |
| 225 | +if __name__ == "__main__": |
| 226 | + root_dir = "/kaggle/input/birdclef-2024/train_audio" |
| 227 | + dataset = AudioDataset(root_dir=root_dir) |
| 228 | + |
| 229 | + def collate_fn(batch): |
| 230 | + anchors, positives, negatives = zip(*batch) |
| 231 | + anchors = torch.stack(anchors) |
| 232 | + positives = torch.stack(positives) |
| 233 | + negatives = torch.stack(negatives) |
| 234 | + return anchors, positives, negatives |
| 235 | + |
| 236 | + dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn, num_workers=4, pin_memory=True) |
| 237 | + |
| 238 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 239 | + model = AudioCNN().to(device) |
| 240 | + optimizer = optim.Adam(model.parameters(), lr=0.001) |
| 241 | + triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2).to(device) |
| 242 | + |
| 243 | + scaler = torch.cuda.amp.GradScaler() # Mixed precision scaler |
| 244 | + accumulation_steps = 2 # Gradient accumulation |
| 245 | + |
| 246 | + model.train() |
| 247 | + for epoch in range(10): # Number of epochs |
| 248 | + running_loss = 0.0 |
| 249 | + for batch_idx, (anchors, positives, negatives) in enumerate(dataloader): |
| 250 | + anchors = anchors.to(device, non_blocking=True) |
| 251 | + positives = positives.to(device, non_blocking=True) |
| 252 | + negatives = negatives.to(device, non_blocking=True) |
| 253 | + |
| 254 | + optimizer.zero_grad() |
| 255 | + |
| 256 | + with torch.cuda.amp.autocast(): # Mixed precision training |
| 257 | + anchor_embeds = model(anchors) |
| 258 | + positive_embeds = model(positives) |
| 259 | + negative_embeds = model(negatives) |
| 260 | + loss = triplet_loss(anchor_embeds, positive_embeds, negative_embeds) / accumulation_steps |
| 261 | + |
| 262 | + scaler.scale(loss).backward() |
| 263 | + |
| 264 | + if (batch_idx + 1) % accumulation_steps == 0: |
| 265 | + scaler.step(optimizer) |
| 266 | + scaler.update() |
| 267 | + |
| 268 | + running_loss += loss.item() |
| 269 | + |
| 270 | + print(f"Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}") |
| 271 | +``` |
0 commit comments