forked from MTG/compIAM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__init__.py
More file actions
193 lines (157 loc) · 7.18 KB
/
Copy path__init__.py
File metadata and controls
193 lines (157 loc) · 7.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
import sys
import numpy as np
from typing import Dict
from tqdm import tqdm
from compiam.exceptions import ModelNotTrainedError
from compiam.utils.download import download_remote_model
from compiam.utils import get_logger, WORKDIR
from compiam.io import write_csv
logger = get_logger(__name__)
class TCNTracker(object):
"""TCN beat tracker tuned to Carnatic Music."""
def __init__(self, post_processor="joint", model_version=42, model_path=None, download_link=None, download_checksum=None, gpu=-1):
"""TCN beat tracker init method.
:param post_processor: Post-processing method to use. Choose from 'joint', or 'sequential'.
:param model_version: Version of the pre-trained model to use. Choose from 42, 52, or 62.
:param model_path: path to file to the model weights.
:param download_link: link to the remote pre-trained model.
:param download_checksum: checksum of the model file.
"""
### IMPORTING OPTIONAL DEPENDENCIES
try:
global torch
import torch
except ImportError:
raise ImportError(
"Torch is required to use TCNTracker. "
"Install compIAM with torch support: pip install 'compiam[torch]'"
)
try:
global madmom
import madmom
except ImportError:
raise ImportError(
"Madmom is required to use TCNTracker. "
"Install compIAM with madmom support: pip install 'compiam[madmom]'"
)
###
global MultiTracker, PreProcessor, joint_tracker, sequential_tracker
from compiam.rhythm.meter.tcn_carnatic.model import MultiTracker
from compiam.rhythm.meter.tcn_carnatic.pre import PreProcessor
from compiam.rhythm.meter.tcn_carnatic.post import joint_tracker, sequential_tracker
if post_processor not in ["beat", "joint", "sequential"]:
raise ValueError(f"Invalid post_processor: {post_processor}. Choose from 'joint', or 'sequential'.")
if model_version not in [42, 52, 62]:
raise ValueError(f"Invalid model_version: {model_version}. Choose from 42, 52, or 62.")
self.gpu = gpu
self.device = None
self.select_gpu(gpu)
self.model_path = model_path
self.model_version = f'multitracker_{model_version}.pth'
self.download_link = download_link
self.download_checksum = download_checksum
self.trained = False
self.model = self._build_model()
if self.model_path is not None:
self.load_model(self.model_path)
self.pre_processor = PreProcessor(fps=100)
self.pad_frames = 2
self.post_processor = joint_tracker if post_processor == "joint" else \
sequential_tracker
def _build_model(self):
"""Build the TCN model."""
model = MultiTracker().to(self.device)
model.eval()
return model
def load_model(self, model_path):
"""Load pre-trained model weights."""
if not os.path.exists(os.path.join(model_path, self.model_version)):
self.download_model(model_path) # Downloading model weights
self.model.load_weights(os.path.join(model_path, self.model_version), self.device)
self.model_path = model_path
self.trained = True
def download_model(self, model_path=None, force_overwrite=True):
"""Download pre-trained model."""
download_path = (
#os.sep + os.path.join(*model_path.split(os.sep)[:-2])
model_path
if model_path is not None
else os.path.join(WORKDIR, "models", "rhythm", "tcn-carnatic")
)
# Creating model folder to store the weights
if not os.path.exists(download_path):
os.makedirs(download_path)
download_remote_model(
self.download_link,
self.download_checksum,
download_path,
force_overwrite=force_overwrite,
)
def predict(self, input_data: str) -> Dict:
"""Run inference on input audio file.
:param input_data: path to audio file or numpy array like audio signal.
:returns: a 2-D list with beats and beat positions.
"""
if self.trained is False:
raise ModelNotTrainedError(
"""Model is not trained. Please load model before running inference!
You can load the pre-trained instance with the load_model wrapper."""
)
features = self.preprocess_audio(input_data)
x = torch.from_numpy(features).to(self.device)
output = self.model(x)
beats_act = output["beats"].squeeze().detach().cpu().numpy()
downbeats_act = output["downbeats"].squeeze().detach().cpu().numpy()
pred = self.post_processor(beats_act, downbeats_act)
return pred
def preprocess_audio(self, input_data: str, input_sr: int = 44100) -> np.ndarray:
"""Preprocess input audio file to extract features for inference.
:param audio_path: Path to the input audio file.
:param input_sr: Sampling rate of the input audio file.
:returns: Preprocessed features as a numpy array.
"""
if isinstance(input_data, str):
if not os.path.exists(input_data):
raise FileNotFoundError("Target audio not found.")
audio, sr = madmom.io.audio.load_audio_file(input_data)
if audio.shape[0] == 2:
audio = audio.mean(axis=0)
s = madmom.audio.Signal(audio, sr, num_channels=1)
elif isinstance(input_data, np.ndarray):
audio = input_data
if audio.shape[0] == 2:
audio = audio.mean(axis=0)
s = madmom.audio.Signal(audio, input_sr, num_channels=1)
else:
raise ValueError("Input must be path to audio signal or an audio array")
x = self.pre_processor(s)
pad_start = np.repeat(x[:1], self.pad_frames, axis=0)
pad_stop = np.repeat(x[-1:], self.pad_frames, axis=0)
x_padded = np.concatenate((pad_start, x, pad_stop))
x_final = np.expand_dims(np.expand_dims(x_padded, axis=0), axis=0)
return x_final
@staticmethod
def save_pitch(data, output_path):
"""Calling the write_csv function in compiam.io to write the output beat track in a file
:param data: the data to write
:param output_path: the path where the data is going to be stored
:returns: None
"""
return write_csv(data, output_path)
def select_gpu(self, gpu="-1"):
"""Select the GPU to use for inference.
:param gpu: Id of the available GPU to use (-1 by default, to run on CPU), use string: '0', '1', etc.
:returns: None
"""
if int(gpu) == -1:
self.device = torch.device("cpu")
else:
if torch.cuda.is_available():
self.device = torch.device("cuda:" + str(gpu))
elif torch.backends.mps.is_available():
self.device = torch.device("mps:" + str(gpu))
else:
self.device = torch.device("cpu")
logger.warning("No GPU available. Running on CPU.")
self.gpu = gpu