-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
45 lines (32 loc) · 1.13 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""
This script demonstrates how to train a model using the stable-SSL library.
"""
import sys
import os
# Add the parent folder to sys.path
parent_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.dirname(parent_dir))
import hydra
from omegaconf import DictConfig
from models.ssl_models.stable_ssl_patches import patch_stable_ssl
from models.ssl_models.custom_config import get_args
from models.ssl_models.custom_supervised import Supervised
from models.ssl_models.custom_barlow_twins import BarlowTwins
from models.ssl_models.factored_models import CovarianceFactorization, MaskingFactorization
model_dict = {
"Supervised": Supervised,
"BarlowTwins": BarlowTwins,
"CovarianceFactorization": CovarianceFactorization,
"MaskingFactorization": MaskingFactorization,
}
@hydra.main(config_path="configs/ssl_configs/")
def main(cfg: DictConfig):
changed = patch_stable_ssl()
print(f"Applied {len(changed)} patches to stable-ssl!")
args = get_args(cfg)
print("--- Arguments ---")
print(args)
trainer = model_dict[args.model.name](args)
trainer()
if __name__ == "__main__":
main()