Skip to content

Commit 0ecbd6e

Browse files
committed
Add two variants of the SSL multi-view model
- with higher time resolution (25Hz) - using shufled input insteaad of random noise
1 parent d19b7f1 commit 0ecbd6e

5 files changed

Lines changed: 307 additions & 5 deletions
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# general training parameters
2+
train.wandb_params = {
3+
"project": "mtg-ssl",
4+
"name": "mask_conformer_large_mv_au_to_all_25hz",
5+
"offline": True,
6+
# NOTE: path to logs in the BSC cluster. Change it for local experiments
7+
"save_dir": "/gpfs/projects/upf97/logs/",
8+
"entity": "mtg-upf",
9+
"group": "masking_conformer",
10+
}
11+
12+
# modules to use
13+
build_module.representation = [@nets.cqt.CQT, @nets.encodec.EnCodec, @nets.melspectrogram.MelSpectrogram, @nets.waveform.Waveform]
14+
build_module.module = @modules.maskingmodel.MaskingModel
15+
build_module.net = @nets.conformer.Conformer
16+
17+
# Choose the devalopment dataloader
18+
build_dev_datamodule.datamodule = @discotube
19+
20+
# Lighting trainer parameters
21+
train.params = {
22+
"accelerator": "gpu",
23+
"devices": 4,
24+
"num_nodes": 2,
25+
"max_steps": 400000,
26+
"log_every_n_steps": 50,
27+
"precision": "bf16-mixed",
28+
"strategy": "ddp_find_unused_parameters_true",
29+
"num_sanity_val_steps": 0
30+
}
31+
32+
new_freq = 24000
33+
34+
# Dataloader
35+
AudioDataset.num_frames = 480000 # 30s
36+
AudioDataset.orig_freq = 16000
37+
AudioDataset.new_freq = %new_freq
38+
AudioDataset.mono = True
39+
AudioDataset.half_precision = True
40+
AudioDataModule.num_workers = 20
41+
42+
# Discogs datamodule parameters
43+
DiscotubeAudioDataModule.batch_size = 20
44+
DiscotubeAudioDataModule.data_dir = "/gpfs/scratch/upf97/mmap/"
45+
DiscotubeAudioDataModule.filelist_train = "/gpfs/projects/upf97/data/train_mmap.txt"
46+
DiscotubeAudioDataModule.filelist_val = "/gpfs/projects/upf97/data/test_mmap.txt"
47+
48+
# CosineAnnealing scheduler
49+
CosineAnnealingCallback.warmup_steps = 30000
50+
CosineAnnealingCallback.eta_min = 1e-7
51+
52+
# MelSpectrogram parameters
53+
nets.melspectrogram.MelSpectrogram.sr = %new_freq
54+
nets.melspectrogram.MelSpectrogram.win_len = 512
55+
nets.melspectrogram.MelSpectrogram.hop_len = 320
56+
nets.melspectrogram.MelSpectrogram.power = 2
57+
nets.melspectrogram.MelSpectrogram.n_mel = 96
58+
nets.melspectrogram.MelSpectrogram.norm = "slaney"
59+
nets.melspectrogram.MelSpectrogram.mel_scale = "slaney"
60+
nets.melspectrogram.MelSpectrogram.norm_std = 1.268292820667291
61+
nets.melspectrogram.MelSpectrogram.norm_mean = 2.06755686098554
62+
nets.melspectrogram.MelSpectrogram.patch_size = (96, 3)
63+
64+
# CQT parameters
65+
nets.cqt.CQT.sr = %new_freq
66+
nets.cqt.CQT.hop_len = 320
67+
nets.cqt.CQT.power = 2
68+
nets.cqt.CQT.bins_per_octave = 24
69+
nets.cqt.CQT.n_bins = 188 # 6 octaves * 24 bins
70+
nets.cqt.CQT.f_min = 32.703 # C0
71+
nets.cqt.CQT.magnitude = True
72+
nets.cqt.CQT.logC = True
73+
nets.cqt.CQT.norm_std = 1.9055732535255916
74+
nets.cqt.CQT.norm_mean = 4.754879065310596
75+
nets.cqt.CQT.patch_size = (188, 3)
76+
77+
# Waveform parameters
78+
nets.waveform.Waveform.sr = %new_freq
79+
nets.waveform.Waveform.norm_std = None
80+
nets.waveform.Waveform.norm_mean = None
81+
nets.waveform.Waveform.patch_size = (1, 960) # 16ms
82+
83+
# data augmentation
84+
nets.melspectrogram.MelSpectrogram.stretch_factor = 1
85+
nets.melspectrogram.MelSpectrogram.freq_mask_param = 0
86+
nets.melspectrogram.MelSpectrogram.time_mask_param = 0
87+
88+
# Encodec parameters
89+
nets.encodec.EnCodec.weights_path = "/gpfs/scratch/upf97/model_weights/encodec_24khz/"
90+
nets.encodec.EnCodec.norm_type = "global"
91+
nets.encodec.EnCodec.stats_path = "/gpfs/scratch/upf97/dataset_stats/discotube23/input_stats_1K_steps.json"
92+
nets.encodec.EnCodec.orig_sr = %new_freq
93+
nets.encodec.EnCodec.patch_size = (128, 3)
94+
95+
# MaskingModel parameters
96+
modules.maskingmodel.MaskingModel.num_codebooks = 1
97+
modules.maskingmodel.MaskingModel.lr = 1e-4
98+
modules.maskingmodel.MaskingModel.weight_decay = 1e-2
99+
modules.maskingmodel.MaskingModel.codebook_size = 8196
100+
modules.maskingmodel.MaskingModel.codebook_dim = 16
101+
modules.maskingmodel.MaskingModel.mask_seconds = 0.4
102+
modules.maskingmodel.MaskingModel.mask_prob = 0.6
103+
modules.maskingmodel.MaskingModel.seed = 0
104+
modules.maskingmodel.MaskingModel.plot_tokens = False
105+
modules.maskingmodel.MaskingModel.diff_input = False
106+
modules.maskingmodel.MaskingModel.input_representation = @nets.waveform.Waveform
107+
108+
# Transformer parameters
109+
nets.conformer.Conformer.embed_dim = 1024
110+
nets.conformer.Conformer.depth = 24
111+
nets.conformer.Conformer.conv_kernel_size = 5
112+
nets.conformer.Conformer.num_heads = 8
113+
nets.conformer.Conformer.mlp_ratio = 4.0
114+
nets.conformer.Conformer.mlp_residual_factor = 4.0
115+
nets.conformer.Conformer.dropout = 0.2
116+
nets.conformer.Conformer.input_dropout = 0.0
117+
nets.conformer.Conformer.use_deepnorm = True
118+
nets.conformer.Conformer.alpha_deepnorm = 2.6321480259049848 # we can tune this number
119+
nets.conformer.Conformer.beta_deepnorm = 0.022386873579657126 # we can tune this number
120+
nets.conformer.Conformer.use_rope = True
121+
nets.conformer.Conformer.num_patches = None
122+
123+
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# general training parameters
2+
train.wandb_params = {
3+
"project": "mtg-ssl",
4+
"name": "mask_conformer_large_mv_au_to_all_shuffle_mask",
5+
"offline": True,
6+
# NOTE: path to logs in the BSC cluster. Change it for local experiments
7+
"save_dir": "/gpfs/projects/upf97/logs/",
8+
"entity": "mtg-upf",
9+
"group": "masking_conformer",
10+
}
11+
12+
# modules to use
13+
build_module.representation = [@nets.cqt.CQT, @nets.encodec.EnCodec, @nets.melspectrogram.MelSpectrogram, @nets.waveform.Waveform]
14+
build_module.module = @modules.maskingmodel.MaskingModel
15+
build_module.net = @nets.conformer.Conformer
16+
17+
# Choose the devalopment dataloader
18+
build_dev_datamodule.datamodule = @discotube
19+
20+
# Lighting trainer parameters
21+
train.params = {
22+
"accelerator": "gpu",
23+
"devices": 4,
24+
"num_nodes": 2,
25+
"max_steps": 400000,
26+
"log_every_n_steps": 50,
27+
"precision": "bf16-mixed",
28+
"strategy": "ddp_find_unused_parameters_true",
29+
"num_sanity_val_steps": 0
30+
}
31+
32+
new_freq = 24000
33+
34+
# Dataloader
35+
AudioDataset.num_frames = 480000 # 30s
36+
AudioDataset.orig_freq = 16000
37+
AudioDataset.new_freq = %new_freq
38+
AudioDataset.mono = True
39+
AudioDataset.half_precision = True
40+
AudioDataModule.num_workers = 20
41+
42+
# Discogs datamodule parameters
43+
DiscotubeAudioDataModule.batch_size = 32
44+
DiscotubeAudioDataModule.data_dir = "/gpfs/scratch/upf97/mmap/"
45+
DiscotubeAudioDataModule.filelist_train = "/gpfs/projects/upf97/data/train_mmap.txt"
46+
DiscotubeAudioDataModule.filelist_val = "/gpfs/projects/upf97/data/test_mmap.txt"
47+
48+
# CosineAnnealing scheduler
49+
CosineAnnealingCallback.warmup_steps = 30000
50+
CosineAnnealingCallback.eta_min = 1e-7
51+
52+
# MelSpectrogram parameters
53+
nets.melspectrogram.MelSpectrogram.sr = %new_freq
54+
nets.melspectrogram.MelSpectrogram.win_len = 512
55+
nets.melspectrogram.MelSpectrogram.hop_len = 320
56+
nets.melspectrogram.MelSpectrogram.power = 2
57+
nets.melspectrogram.MelSpectrogram.n_mel = 96
58+
nets.melspectrogram.MelSpectrogram.norm = "slaney"
59+
nets.melspectrogram.MelSpectrogram.mel_scale = "slaney"
60+
nets.melspectrogram.MelSpectrogram.norm_std = 1.268292820667291
61+
nets.melspectrogram.MelSpectrogram.norm_mean = 2.06755686098554
62+
nets.melspectrogram.MelSpectrogram.patch_size = (96, 4)
63+
64+
# CQT parameters
65+
nets.cqt.CQT.sr = %new_freq
66+
nets.cqt.CQT.hop_len = 320
67+
nets.cqt.CQT.power = 2
68+
nets.cqt.CQT.bins_per_octave = 24
69+
nets.cqt.CQT.n_bins = 188 # 6 octaves * 24 bins
70+
nets.cqt.CQT.f_min = 32.703 # C0
71+
nets.cqt.CQT.magnitude = True
72+
nets.cqt.CQT.logC = True
73+
nets.cqt.CQT.norm_std = 1.9055732535255916
74+
nets.cqt.CQT.norm_mean = 4.754879065310596
75+
nets.cqt.CQT.patch_size = (188, 4)
76+
77+
# Waveform parameters
78+
nets.waveform.Waveform.sr = %new_freq
79+
nets.waveform.Waveform.norm_std = None
80+
nets.waveform.Waveform.norm_mean = None
81+
nets.waveform.Waveform.patch_size = (1, 1280) # 16ms
82+
83+
# data augmentation
84+
nets.melspectrogram.MelSpectrogram.stretch_factor = 1
85+
nets.melspectrogram.MelSpectrogram.freq_mask_param = 0
86+
nets.melspectrogram.MelSpectrogram.time_mask_param = 0
87+
88+
# Encodec parameters
89+
nets.encodec.EnCodec.weights_path = "/gpfs/scratch/upf97/model_weights/encodec_24khz/"
90+
nets.encodec.EnCodec.norm_type = "global"
91+
nets.encodec.EnCodec.stats_path = "/gpfs/scratch/upf97/dataset_stats/discotube23/input_stats_1K_steps.json"
92+
nets.encodec.EnCodec.orig_sr = %new_freq
93+
nets.encodec.EnCodec.patch_size = (128, 4)
94+
95+
# MaskingModel parameters
96+
modules.maskingmodel.MaskingModel.num_codebooks = 1
97+
modules.maskingmodel.MaskingModel.lr = 1e-4
98+
modules.maskingmodel.MaskingModel.weight_decay = 1e-2
99+
modules.maskingmodel.MaskingModel.codebook_size = 8196
100+
modules.maskingmodel.MaskingModel.codebook_dim = 16
101+
modules.maskingmodel.MaskingModel.mask_seconds = 0.4
102+
modules.maskingmodel.MaskingModel.mask_prob = 0.6
103+
modules.maskingmodel.MaskingModel.seed = 0
104+
modules.maskingmodel.MaskingModel.plot_tokens = False
105+
modules.maskingmodel.MaskingModel.diff_input = False
106+
modules.maskingmodel.MaskingModel.input_representation = @nets.waveform.Waveform
107+
modules.maskingmodel.MaskingModel.masking_noise_type = "shuffled_input"
108+
109+
# Transformer parameters
110+
nets.conformer.Conformer.embed_dim = 1024
111+
nets.conformer.Conformer.depth = 24
112+
nets.conformer.Conformer.conv_kernel_size = 5
113+
nets.conformer.Conformer.num_heads = 8
114+
nets.conformer.Conformer.mlp_ratio = 4.0
115+
nets.conformer.Conformer.mlp_residual_factor = 4.0
116+
nets.conformer.Conformer.dropout = 0.2
117+
nets.conformer.Conformer.input_dropout = 0.0
118+
nets.conformer.Conformer.use_deepnorm = True
119+
nets.conformer.Conformer.alpha_deepnorm = 2.6321480259049848 # we can tune this number
120+
nets.conformer.Conformer.beta_deepnorm = 0.022386873579657126 # we can tune this number
121+
nets.conformer.Conformer.use_rope = True
122+
nets.conformer.Conformer.num_patches = None
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#!/bin/bash
2+
3+
#SBATCH --job-name a2a_large_25hz
4+
#SBATCH --account=upf97
5+
#SBATCH --partition=acc
6+
#SBATCH --qos=acc_resa
7+
#SBATCH --nodes=2 # This needs to match Trainer(num_nodes=...)
8+
#SBATCH --cpus-per-task=20
9+
#SBATCH --gres=gpu:4
10+
#SBATCH --ntasks-per-node=4
11+
#SBATCH --time=72:00:00
12+
#SBATCH --output=debug_%j_output.txt
13+
#SBATCH --mail-type=all
14+
#SBATCH --mail-user=pablo.alonso@upf.edu
15+
# interrrupt and resubmit 90 seconds before training ends (experimental)
16+
# https://pytorch-lightning.readthedocs.io/en/1.2.10/clouds/slurm.html#wall-time-auto-resubmit
17+
# SBATCH --signal=SIGUSR1@90
18+
19+
export SRUN_CPUS_PER_TASK=$SLURM_CPUS_PER_TASK
20+
21+
source /gpfs/projects/upf97/envs/mtg-bsc/bin/activate
22+
23+
srun python3 src/train.py cfg/config_masking_conformer_multiview_au_to_all_large_25hz.gin
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#!/bin/bash
2+
3+
#SBATCH --job-name a2a
4+
#SBATCH --account=upf97
5+
#SBATCH --partition=acc
6+
#SBATCH --qos=acc_resa
7+
#SBATCH --nodes=2 # This needs to match Trainer(num_nodes=...)
8+
#SBATCH --cpus-per-task=20
9+
#SBATCH --gres=gpu:4
10+
#SBATCH --ntasks-per-node=4
11+
#SBATCH --time=72:00:00
12+
#SBATCH --output=debug_%j_output.txt
13+
#SBATCH --mail-type=all
14+
#SBATCH --mail-user=pablo.alonso@upf.edu
15+
# interrrupt and resubmit 90 seconds before training ends (experimental)
16+
# https://pytorch-lightning.readthedocs.io/en/1.2.10/clouds/slurm.html#wall-time-auto-resubmit
17+
# SBATCH --signal=SIGUSR1@90
18+
19+
export SRUN_CPUS_PER_TASK=$SLURM_CPUS_PER_TASK
20+
21+
source /gpfs/projects/upf97/envs/mtg-bsc/bin/activate
22+
23+
srun python3 src/train.py cfg/config_masking_conformer_multiview_au_to_all_large_shuffle_mask.gin

src/modules/maskingmodel.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
diff_input: bool,
4545
plot_tokens: bool = False,
4646
input_representation: nn.Module | None = None,
47+
masking_noise_type: str = "random_normal",
4748
):
4849
super(MaskingModel, self).__init__()
4950

@@ -62,6 +63,7 @@ def __init__(
6263
self.first_coverage = True
6364
self.diff_input = diff_input
6465
self.input_representation = input_representation
66+
self.masking_noise_type = masking_noise_type
6567

6668
# downstream evaluation params
6769
self.downstream_embedding_layer = set([-1])
@@ -263,7 +265,7 @@ def random_masking_simple(self, patches):
263265
return masked_spec, mask.to(patches.device)
264266

265267
def random_masking(self, patches):
266-
B, num_patches, patch_size = patches.shape
268+
B, num_patches, _ = patches.shape
267269
mx = patches.clone()
268270

269271
len_masking_spec_frames = math.ceil(
@@ -285,10 +287,19 @@ def random_masking(self, patches):
285287
if mask.size(1) > num_patches:
286288
mask = mask[:, :num_patches]
287289

288-
# Mask with random values
289-
masking_noise = (torch.randn(mx.shape, dtype=patches.dtype) * 0.1).to(
290-
patches.device
291-
) # 0 mean 0.1 std
290+
if self.masking_noise_type == "random_normal":
291+
# Mask with random values
292+
masking_noise = (torch.randn(mx.shape, dtype=patches.dtype) * 0.1).to(
293+
patches.device
294+
) # 0 mean 0.1 std
295+
elif self.masking_noise_type == "shuffled_input":
296+
# make a copy of patches shuffled on the time axis
297+
masking_noise = patches[:, torch.randperm(num_patches), :]
298+
else:
299+
raise NotImplementedError(
300+
f"Masking noise type {self.masking_noise_type} not implemented."
301+
)
302+
292303
# Apply masking in parallel
293304
mx[mask] = masking_noise[mask]
294305
# tensor 1 x N repeat to 16 x N

0 commit comments

Comments
 (0)