Skip to content

Commit b0c0014

Browse files
committed
Implement CLAP style pretraining (WIP)
1 parent 6c78ba2 commit b0c0014

6 files changed

Lines changed: 560 additions & 19 deletions

File tree

cfg/config_text_audio_dev.gin

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# general training parameters
2+
train.wandb_params = {
3+
"project": "mtg-text_audio",
4+
"name": "dev.mtg_text_audio",
5+
"offline": False,
6+
"entity": "mtg-upf",
7+
}
8+
9+
new_freq = 24000
10+
11+
# modules to use
12+
__main__.build_module.module = @modules.clap.CLAP
13+
14+
# Choose the devalopment dataloader
15+
build_dev_datamodule.datamodule = @discotube_test_audio
16+
17+
18+
# Lighting trainer parameters
19+
train.params = {
20+
"accelerator": "gpu",
21+
"devices": 1,
22+
"max_steps": 400,
23+
"log_every_n_steps": 50,
24+
"precision": "bf16-mixed",
25+
"strategy": "ddp_find_unused_parameters_true"
26+
}
27+
28+
modules.clap.CLAP.audio_encoder_name = "/Users/palonso/data/text_audio/ssl-mtg-weights/bm23z5le/checkpoints/config_masking_conformer_multiview_enc_to_encmelcqt_small.gin"
29+
modules.clap.CLAP.text_encoder_name = "sentence-transformers/all-mpnet-base-v2"
30+
modules.clap.CLAP.proj_size = 512
31+
modules.clap.CLAP.temp = 0.1
32+
modules.clap.CLAP.lr = 1e-4
33+
modules.clap.CLAP.weight_decay = 1e-2
34+
modules.clap.CLAP.seed = 0
35+
36+
# CosineAnnealing scheduler
37+
CosineAnnealingCallback.warmup_steps = 30000
38+
CosineAnnealingCallback.eta_min = 1e-7
39+
40+
# MelSpectrogram parameters
41+
nets.melspectrogram.MelSpectrogram.sr = 16000
42+
nets.melspectrogram.MelSpectrogram.win_len = 512
43+
nets.melspectrogram.MelSpectrogram.hop_len = 256
44+
nets.melspectrogram.MelSpectrogram.power = 2
45+
nets.melspectrogram.MelSpectrogram.n_mel = 96
46+
nets.melspectrogram.MelSpectrogram.norm = "slaney"
47+
nets.melspectrogram.MelSpectrogram.mel_scale = "slaney"
48+
nets.melspectrogram.MelSpectrogram.norm_std = 1.268292820667291
49+
nets.melspectrogram.MelSpectrogram.norm_mean = 2.06755686098554
50+
51+
# data augmentation
52+
nets.melspectrogram.MelSpectrogram.stretch_factor = 1
53+
nets.melspectrogram.MelSpectrogram.freq_mask_param = 0
54+
nets.melspectrogram.MelSpectrogram.time_mask_param = 0
55+
nets.melspectrogram.MelSpectrogram.patch_size = (96, 4)
56+
57+
58+
# Transformer parameters
59+
nets.conformer.Conformer.patch_size = (96, 4)
60+
nets.conformer.Conformer.embed_dim = 512
61+
nets.conformer.Conformer.depth = 2
62+
nets.conformer.Conformer.conv_kernel_size = 5
63+
nets.conformer.Conformer.num_heads = 8
64+
nets.conformer.Conformer.mlp_ratio = 4.0
65+
nets.conformer.Conformer.mlp_residual_factor = 4.0
66+
nets.conformer.Conformer.dropout = 0.0
67+
nets.conformer.Conformer.input_dropout = 0.0
68+
nets.conformer.Conformer.use_deepnorm = True
69+
nets.conformer.Conformer.alpha_deepnorm = 2.21 # we can tune this number
70+
nets.conformer.Conformer.beta_deepnorm = 0.0026 # we can tune this number
71+
nets.conformer.Conformer.use_rope = True
72+
nets.conformer.Conformer.num_patches = None
73+
74+
# Dataloader
75+
AudioDataset.num_frames = 16000
76+
AudioDataset.orig_freq = 16000
77+
AudioDataset.new_freq = 16000
78+
AudioDataset.mono = True
79+
AudioDataset.half_precision = True
80+
AudioDataModule.num_workers = 0
81+
82+
# Discogs datamodule parameters
83+
DiscotubeTextAudioDataModule.batch_size = 1
84+
DiscotubeTextAudioDataModule.num_workers = 8
85+
DiscotubeTextAudioDataModule.data_dir = "/Users/palonso/data/text_audio/discotube_sample/audio/"
86+
DiscotubeTextAudioDataModule.filelist_train = "/Users/palonso/data/text_audio/discotube_sample/ids"
87+
DiscotubeTextAudioDataModule.filelist_val = "/Users/palonso/data/text_audio/discotube_sample/ids"
88+
DiscotubeTextAudioDataModule.metadata_youtube_file = "/Users/palonso/data/text_audio/discotube_sample/yotube_metadata.jsonl"
89+
DiscotubeTextAudioDataModule.metadata_discogs_file = "/Users/palonso/data/text_audio/discotube_sample/discogs_metadata.jsonl"
90+
DiscotubeTextAudioDataModule.metadata_id_map_file = "/Users/palonso/data/text_audio/discotube_sample/youtube_to_discgos_map.jsonl"
91+
92+
93+

src/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from .mnist import MNISTDataModule
22
from .discotube import DiscotubeAudioDataModule, DiscotubeMultiViewAudioDataModule
3+
from .discotube_text_audio import DiscotubeTextAudioDataModule
34

45
DATASETS = {
56
"mnist": MNISTDataModule,
67
"discotube": DiscotubeAudioDataModule,
78
"discotube_multiview": DiscotubeMultiViewAudioDataModule,
9+
"discotube_text_audio": DiscotubeTextAudioDataModule,
810
}

src/data/data_utils.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -45,24 +45,6 @@ def __getitem__(self, idx):
4545
# load audio
4646
audio = self.load_audio(file_path, frame_offset=self.frame_offset)
4747

48-
# downmix to mono if necessary
49-
if audio.shape[0] > 1 and self.mono:
50-
audio = torch.mean(audio, dim=0, keepdim=False)
51-
52-
# resample if necessary
53-
if self.orig_freq != self.new_freq:
54-
# only works with float tensors
55-
audio = audio.float()
56-
audio = self.resample(audio)
57-
58-
audio = audio.squeeze(0)
59-
60-
# work with 16-bit precission
61-
if self.half_precision:
62-
audio = audio.half()
63-
else:
64-
audio = audio.float()
65-
6648
return [audio]
6749

6850
def load_audio(
@@ -109,7 +91,27 @@ def load_audio(
10991
else:
11092
raise ValueError(f"Invalid frame_offset: {frame_offset}")
11193

112-
return torch.from_numpy(audio)
94+
# downmix to mono if necessary
95+
if audio.shape[0] > 1 and self.mono:
96+
audio = torch.mean(audio, dim=0, keepdim=False)
97+
98+
audio = torch.from_numpy(audio)
99+
100+
# resample if necessary
101+
if self.orig_freq != self.new_freq:
102+
# only works with float tensors
103+
audio = audio.float()
104+
audio = self.resample(audio)
105+
106+
audio = audio.squeeze(0)
107+
108+
# work with 16-bit precission
109+
if self.half_precision:
110+
audio = audio.half()
111+
else:
112+
audio = audio.float()
113+
114+
return audio
113115

114116
@staticmethod
115117
def get_audio_duration(filepath: Path):

src/data/discotube_text_audio.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import gin
2+
import json
3+
import random
4+
import traceback
5+
from pathlib import Path
6+
from typing import Union
7+
8+
import torch
9+
import pytorch_lightning as L
10+
import yaml
11+
from tqdm import tqdm
12+
from torch.utils.data import DataLoader
13+
14+
15+
from .data_utils import AudioDataset
16+
17+
18+
@gin.configurable
19+
class DiscotubeTextAudioDataset(AudioDataset):
20+
"""Generic audio dataset."""
21+
22+
def __init__(
23+
self,
24+
data_dir: Path,
25+
filelist: Path,
26+
metadata_youtube: dict,
27+
metadata_discogs: dict,
28+
metadata_id_map: dict,
29+
frame_offset: Union[int, str] = "random",
30+
):
31+
super().__init__(
32+
data_dir=data_dir,
33+
filelist=filelist,
34+
frame_offset=frame_offset,
35+
)
36+
37+
self.metadata_youtube = metadata_youtube
38+
self.metadata_discogs = metadata_discogs
39+
self.metadata_id_map = metadata_id_map
40+
41+
def __len__(self):
42+
return len(self.filelist)
43+
44+
@staticmethod
45+
def get_audio_path(youtube_id: str) -> Path:
46+
return Path(youtube_id[:2], youtube_id).with_suffix(".mmap")
47+
48+
def __getitem__(self, idx):
49+
try:
50+
id_yt = self.filelist[idx]
51+
52+
file_path = self.data_dir / self.get_audio_path(id_yt)
53+
54+
# load audio
55+
# audio = self.load_audio(file_path, frame_offset=self.frame_offset)
56+
audio = torch.rand(1, 16000 * 30)
57+
58+
# load YouTube metadata
59+
meta_youtube = self.metadata_youtube[id_yt]
60+
61+
# load discogs metadata
62+
ids_discogs = self.metadata_id_map[id_yt]
63+
64+
# sample randonly among available releases
65+
id_discogs = random.choice(ids_discogs)
66+
meta_discogs = self.metadata_discogs[id_discogs]
67+
68+
# process metadata
69+
text = self.preprocess_text(
70+
{"youtube_metadata": meta_youtube, "discogs_metadata": meta_discogs}
71+
)
72+
except Exception:
73+
print(f"Error loading {self.filelist[idx]}")
74+
print(traceback.format_exc())
75+
return [None, None]
76+
77+
return [audio, text]
78+
79+
def preprocess_text(self, metadata: dict) -> str:
80+
"""Text preprocessing"""
81+
82+
# Process YouTube metadata
83+
fields_to_keep = ["description", "categories", "tags", "view_count"]
84+
youtube_metadata = metadata["youtube_metadata"]
85+
new_youtube_metadata = {
86+
field: youtube_metadata[field]
87+
for field in fields_to_keep
88+
if field in youtube_metadata
89+
}
90+
91+
# Process Discogs metadata
92+
93+
fields_to_keep = ["labels", "genres", "styles", "country", "released"]
94+
dicogs_metadata = metadata["discogs_metadata"]
95+
new_discogs_metadata = {
96+
field: dicogs_metadata[field]
97+
for field in fields_to_keep
98+
if field in dicogs_metadata
99+
}
100+
101+
# Fetch artist description
102+
# TODO: Get this too
103+
104+
metadata = {
105+
"youtube_metadata": new_youtube_metadata,
106+
"discogs_metadata": new_discogs_metadata,
107+
}
108+
109+
# format as YAML
110+
yaml_text = yaml.dump(metadata, sort_keys=False)
111+
return yaml_text
112+
113+
114+
@gin.configurable
115+
class DiscotubeTextAudioDataModule(L.LightningDataModule):
116+
"""AudioDataModule for the Discogs dataset."""
117+
118+
def __init__(
119+
self,
120+
batch_size: int,
121+
data_dir: Path,
122+
filelist_train: Path,
123+
filelist_val: Path,
124+
metadata_youtube_file: Path,
125+
metadata_discogs_file: Path,
126+
metadata_id_map_file: Path,
127+
num_workers: int,
128+
):
129+
super().__init__()
130+
131+
self.batch_size = batch_size
132+
133+
self.data_dir = Path(data_dir)
134+
self.filelist_train = Path(filelist_train)
135+
self.filelist_val = Path(filelist_val)
136+
137+
self.num_workers = num_workers
138+
139+
self.metadata_youtube_file = metadata_youtube_file
140+
self.metadata_discogs_file = metadata_discogs_file
141+
self.metadata_id_map_file = metadata_id_map_file
142+
143+
def setup(self, stage: str):
144+
# load YouTube metadata from jsonl (one json object per line)
145+
self.metadata_youtube = dict()
146+
with open(self.metadata_youtube_file, "r") as f:
147+
for line in tqdm(f.readlines(), desc="Loading YouTube metadata"):
148+
line = json.loads(line)
149+
self.metadata_youtube[line["id"]] = line
150+
151+
# load Discogs metadata from jsonl (one json object per line)
152+
self.metadata_discogs = dict()
153+
with open(self.metadata_discogs_file, "r") as f:
154+
for line in tqdm(f.readlines(), desc="Loading Discogs metadata"):
155+
line = json.loads(line)
156+
self.metadata_discogs[line["id"]] = line
157+
158+
# load the id map from jsonl (one json object per line)
159+
self.metadata_id_map = dict()
160+
with open(self.metadata_id_map_file, "r") as f:
161+
for line in tqdm(f.readlines(), desc="Loading ID map"):
162+
line = json.loads(line)
163+
for k, v in line.items():
164+
self.metadata_id_map[k] = v
165+
166+
self.dataset_train = DiscotubeTextAudioDataset(
167+
self.data_dir,
168+
filelist=self.filelist_train,
169+
metadata_youtube=self.metadata_youtube,
170+
metadata_discogs=self.metadata_discogs,
171+
metadata_id_map=self.metadata_id_map,
172+
)
173+
self.dataset_val = DiscotubeTextAudioDataset(
174+
self.data_dir,
175+
filelist=self.filelist_val,
176+
metadata_youtube=self.metadata_youtube,
177+
metadata_discogs=self.metadata_discogs,
178+
metadata_id_map=self.metadata_id_map,
179+
)
180+
181+
def train_dataloader(self):
182+
return DataLoader(
183+
self.dataset_train,
184+
batch_size=self.batch_size,
185+
num_workers=self.num_workers,
186+
pin_memory=True,
187+
)
188+
189+
def val_dataloader(self):
190+
return DataLoader(
191+
self.dataset_val,
192+
batch_size=self.batch_size,
193+
num_workers=self.num_workers,
194+
pin_memory=True,
195+
)

0 commit comments

Comments
 (0)