Skip to content

Commit 850f762

Browse files
committed
Fix interference between clap model and main interface
Do not load CLAP trough the main module factory to prevent circular import issue.
1 parent 336d603 commit 850f762

7 files changed

Lines changed: 47 additions & 47 deletions

cfg/text_audio/config_clap_mpnet_base_v2_ssl_a2a_large.gin

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,21 @@ train.params = {
1818
"strategy": "ddp_find_unused_parameters_true"
1919
}
2020

21-
modules.clap.CLAP.audio_encoder_name = "/gpfs/projects/upf97/logs/mtg-ssl/pq1scuuq/checkpoints/config_masking_conformer_multiview_au_to_all_large.gin"
22-
modules.clap.CLAP.text_encoder_name = "/gpfs/scratch/upf97/model_weights/all-mpnet-base-v2"
23-
modules.clap.CLAP.audio_encoder_params = {
21+
CLAP.audio_encoder_name = "/gpfs/projects/upf97/logs/mtg-ssl/pq1scuuq/checkpoints/config_masking_conformer_multiview_au_to_all_large.gin"
22+
CLAP.text_encoder_name = "/gpfs/scratch/upf97/model_weights/all-mpnet-base-v2"
23+
CLAP.audio_encoder_params = {
2424
"encodec_weights_path": "/gpfs/scratch/upf97/model_weights/encodec_24khz/"
2525
}
26-
modules.clap.CLAP.proj_size = 512
27-
modules.clap.CLAP.temp = 0.1
28-
modules.clap.CLAP.lr = 1e-5
29-
modules.clap.CLAP.weight_decay = 1e-2
30-
modules.clap.CLAP.seed = 0
31-
modules.clap.CLAP.train_audio_encoder = True
32-
modules.clap.CLAP.train_text_encoder = False
33-
modules.clap.CLAP.tokenizers_parallelism = False
34-
modules.clap.CLAP.aggregation_type = "attention_pooler"
35-
modules.clap.CLAP.n_pool_att_heads = 8
26+
CLAP.proj_size = 512
27+
CLAP.temp = 0.1
28+
CLAP.lr = 1e-5
29+
CLAP.weight_decay = 1e-2
30+
CLAP.seed = 0
31+
CLAP.train_audio_encoder = True
32+
CLAP.train_text_encoder = False
33+
CLAP.tokenizers_parallelism = False
34+
CLAP.aggregation_type = "attention_pooler"
35+
CLAP.n_pool_att_heads = 8
3636

3737
# CosineAnnealing scheduler
3838
CosineAnnealingCallback.warmup_steps = 20000

cfg/text_audio/config_clap_mpnet_base_v2_ssl_a2a_small.gin

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,28 @@ train.wandb_params = {
1010
# Lighting trainer parameters
1111
train.params = {
1212
"accelerator": "gpu",
13-
"devices": 4,
13+
"devices": 1,
1414
"max_steps": 400000,
1515
"log_every_n_steps": 50,
1616
"precision": "bf16-mixed",
1717
"strategy": "ddp_find_unused_parameters_true"
1818
}
1919

20-
modules.clap.CLAP.audio_encoder_name = "/gpfs/projects/upf97/logs/mtg-ssl/jpwu50v3/checkpoints/config_masking_conformer_multiview_au_to_all_small.gin"
21-
modules.clap.CLAP.text_encoder_name = "/gpfs/scratch/upf97/model_weights/all-mpnet-base-v2"
22-
modules.clap.CLAP.audio_encoder_params = {
20+
CLAP.audio_encoder_name = "/gpfs/projects/upf97/logs/mtg-ssl/jpwu50v3/checkpoints/config_masking_conformer_multiview_au_to_all_small.gin"
21+
CLAP.text_encoder_name = "/gpfs/scratch/upf97/model_weights/all-mpnet-base-v2"
22+
CLAP.audio_encoder_params = {
2323
"encodec_weights_path": "/gpfs/scratch/upf97/model_weights/encodec_24khz/"
2424
}
25-
modules.clap.CLAP.proj_size = 512
26-
modules.clap.CLAP.temp = 0.1
27-
modules.clap.CLAP.lr = 1e-4
28-
modules.clap.CLAP.weight_decay = 1e-2
29-
modules.clap.CLAP.seed = 0
30-
modules.clap.CLAP.train_audio_encoder = True
31-
modules.clap.CLAP.train_text_encoder = False
32-
modules.clap.CLAP.tokenizers_parallelism = False
33-
modules.clap.CLAP.aggregation_type = "attention_pooler"
34-
modules.clap.CLAP.n_pool_att_heads = 8
25+
CLAP.proj_size = 512
26+
CLAP.temp = 0.1
27+
CLAP.lr = 1e-4
28+
CLAP.weight_decay = 1e-2
29+
CLAP.seed = 0
30+
CLAP.train_audio_encoder = True
31+
CLAP.train_text_encoder = False
32+
CLAP.tokenizers_parallelism = False
33+
CLAP.aggregation_type = "attention_pooler"
34+
CLAP.n_pool_att_heads = 8
3535

3636
# CosineAnnealing scheduler
3737
CosineAnnealingCallback.warmup_steps = 20000

cfg/text_audio/config_clap_mpnet_base_v2_ssl_a2a_small_debug.gin

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,19 @@ train.params = {
1717
"strategy": "ddp_find_unused_parameters_true"
1818
}
1919

20-
modules.clap.CLAP.audio_encoder_name = "/gpfs/projects/upf97/logs/mtg-ssl/jpwu50v3/checkpoints/config_masking_conformer_multiview_au_to_all_small.gin"
21-
modules.clap.CLAP.text_encoder_name = "/gpfs/scratch/upf97/model_weights/all-mpnet-base-v2"
22-
modules.clap.CLAP.audio_encoder_params = {
20+
CLAP.audio_encoder_name = "/gpfs/projects/upf97/logs/mtg-ssl/jpwu50v3/checkpoints/config_masking_conformer_multiview_au_to_all_small.gin"
21+
CLAP.text_encoder_name = "/gpfs/scratch/upf97/model_weights/all-mpnet-base-v2"
22+
CLAP.audio_encoder_params = {
2323
"encodec_weights_path": "/gpfs/scratch/upf97/model_weights/encodec_24khz/"
2424
}
25-
modules.clap.CLAP.proj_size = 512
26-
modules.clap.CLAP.temp = 0.1
27-
modules.clap.CLAP.lr = 1e-4
28-
modules.clap.CLAP.weight_decay = 1e-2
29-
modules.clap.CLAP.seed = 0
30-
modules.clap.CLAP.train_audio_encoder = True
31-
modules.clap.CLAP.train_text_encoder = False
32-
modules.clap.CLAP.tokenizers_parallelism = False
25+
CLAP.proj_size = 512
26+
CLAP.temp = 0.1
27+
CLAP.lr = 1e-4
28+
CLAP.weight_decay = 1e-2
29+
CLAP.seed = 0
30+
CLAP.train_audio_encoder = True
31+
CLAP.train_text_encoder = False
32+
CLAP.tokenizers_parallelism = False
3333

3434
# CosineAnnealing scheduler
3535
CosineAnnealingCallback.warmup_steps = 1000

src/modules/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
from .classifier import Classifier
2-
from .clap import CLAP
32
from .maskingmodel import MaskingModel
43
from .simclr import SimCLR
54

65
MODULES = {
76
"classifier": Classifier,
87
"simclr": SimCLR,
98
"maskingmodel": MaskingModel,
10-
"clap": CLAP,
119
}
1210

1311

src/ssl_mtg.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,13 @@ def get_model(
107107
bindings.append("nets.encodec.EnCodec.stats_path = None")
108108

109109
# Parse the gin config
110-
gin.parse_config_files_and_bindings(
111-
[str(config_file)],
112-
bindings,
113-
skip_unknown=True,
114-
finalize_config=finalize_config,
115-
)
110+
with gin.unlock_config():
111+
gin.parse_config_files_and_bindings(
112+
[str(config_file)],
113+
bindings,
114+
skip_unknown=True,
115+
finalize_config=finalize_config,
116+
)
116117

117118
gin_config = gin.get_bindings(build_module)
118119

src/train_clap.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from cosineannealingscheduler import CosineAnnealingCallback
1212
from data import DATASETS
1313
from modules import MODULES
14+
from modules_clap.clap import CLAP
1415

1516
from callbacks import GinConfigSaverCallback
1617
from utils import gin_config_to_readable_dictionary, build_dev_datamodule
@@ -64,7 +65,7 @@ def train(
6465
try:
6566
gin.parse_config_file(args.train_config, skip_unknown=True)
6667

67-
module = MODULES["clap"]()
68+
module = CLAP()
6869
ckpt_path = None
6970

7071
datamodule = build_dev_datamodule()

0 commit comments

Comments
 (0)