Skip to content

Commit c789a40

Browse files
committed
Update ssl_mtg
Support skipping adding gin config when an empty string is passed.
1 parent 72b638f commit c789a40

1 file changed

Lines changed: 44 additions & 36 deletions

File tree

src/ssl_mtg.py

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -82,32 +82,37 @@ def get_model(
8282
# Init representation related variables
8383
sr, hop_len, patch_size = None, None, None
8484

85-
config_file = Path(config_file)
86-
87-
# Read previous config bindings
88-
bindings = []
89-
cfg_str = gin.config_str()
90-
91-
# If these are not empty, this model is part of a larger setup
92-
# Do not finish the configuration now
9385
finalize_config = False
94-
if cfg_str == "":
95-
finalize_config = True
96-
97-
lines = cfg_str.split("\n")
98-
bindings.extend(lines)
9986

100-
if encodec_weights_path is not None:
101-
bindings.append(f"nets.encodec.EnCodec.weights_path = '{encodec_weights_path}'")
102-
bindings.append("nets.encodec.EnCodec.stats_path = None")
103-
104-
# Parse the gin config
105-
gin.parse_config_files_and_bindings(
106-
[config_file],
107-
bindings,
108-
skip_unknown=True,
109-
finalize_config=finalize_config,
110-
)
87+
# When no config file is provided, it is assumed that an external
88+
# gin-config file with all the required fileds has already been parsed.
89+
# Don't try to moddify the gin configuration nor load a checkpoint.
90+
if config_file != "":
91+
# Read previous config bindings
92+
bindings = []
93+
cfg_str = gin.config_str()
94+
95+
# If these are not empty, this model is part of a larger setup
96+
# Do not finish the configuration now
97+
if cfg_str == "":
98+
finalize_config = True
99+
100+
lines = cfg_str.split("\n")
101+
bindings.extend(lines)
102+
103+
if encodec_weights_path is not None:
104+
bindings.append(
105+
f"nets.encodec.EnCodec.weights_path = '{encodec_weights_path}'"
106+
)
107+
bindings.append("nets.encodec.EnCodec.stats_path = None")
108+
109+
# Parse the gin config
110+
gin.parse_config_files_and_bindings(
111+
[str(config_file)],
112+
bindings,
113+
skip_unknown=True,
114+
finalize_config=finalize_config,
115+
)
111116

112117
gin_config = gin.get_bindings(build_module)
113118

@@ -116,11 +121,6 @@ def get_model(
116121
representation = gin_config["representation"]
117122
module = gin_config["module"]
118123

119-
# Make the checkpoint path relative to the config file location
120-
# insted of taking the absolute path
121-
ckpt_path = Path(gin_config["ckpt_path"])
122-
ckpt_path = config_file.parent / ckpt_path.name
123-
124124
# Instantiate the classes
125125
net = net()
126126

@@ -135,13 +135,21 @@ def get_model(
135135
sr = representation.sr
136136
hop_len = representation.hop_len
137137

138-
module = module.load_from_checkpoint(
139-
ckpt_path,
140-
net=net,
141-
representation=representation,
142-
strict=False,
143-
map_location=device,
144-
)
138+
if config_file != "":
139+
# Make the checkpoint path relative to the config file location
140+
# insted of taking the absolute path
141+
ckpt_path = Path(gin_config["ckpt_path"])
142+
ckpt_path = Path(config_file).parent / ckpt_path.name
143+
144+
module = module.load_from_checkpoint(
145+
ckpt_path,
146+
net=net,
147+
representation=representation,
148+
strict=False,
149+
map_location=device,
150+
)
151+
else:
152+
module = module(net=net, representation=representation)
145153

146154
# Set the model to eval mode
147155
module.eval()

0 commit comments

Comments
 (0)