Skip to content

Commit 606a605

Browse files
committed
Update to ruff format
1 parent e9d7fc5 commit 606a605

5 files changed

Lines changed: 30 additions & 26 deletions

File tree

src/cosineannealingscheduler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ def __init__(self, optimizer, total_steps, warmup_steps, eta_min, last_epoch=-1)
1111
self.eta_min = eta_min
1212
super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch)
1313

14-
assert (
15-
self.total_steps > self.warmup_steps
16-
), f"total_steps: {self.total_steps} must be greater than warmup_steps: {self.warmup_steps}"
14+
assert self.total_steps > self.warmup_steps, (
15+
f"total_steps: {self.total_steps} must be greater than warmup_steps: {self.warmup_steps}"
16+
)
1717

1818
def get_lr(self):
1919
if self.last_epoch < self.warmup_steps:

src/modules/maskingmodel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,9 +433,9 @@ def extract_embeddings(
433433

434434
# Compute the representation
435435
if isinstance(self.representation, nn.ModuleList):
436-
assert (
437-
self.input_representation is not None
438-
), "`input_representation` must be provided."
436+
assert self.input_representation is not None, (
437+
"`input_representation` must be provided."
438+
)
439439
for rep in self.representation:
440440
if isinstance(rep, self.input_representation):
441441
input_rep = rep

src/nets/mlp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ def __init__(
2626
def forward(self, x):
2727
x = x.reshape(x.size(0), -1)
2828

29-
assert (
30-
x.shape[1] == self.input_shape[0]
31-
), f"Expected shape {self.input_shape}, got {x.shape}"
29+
assert x.shape[1] == self.input_shape[0], (
30+
f"Expected shape {self.input_shape}, got {x.shape}"
31+
)
3232

3333
x = self.l1(x)
3434
if self.hidden_shape:

src/nets/rope.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
5353
rot_dim = freqs.shape[-1]
5454
end_index = start_index + rot_dim
5555

56-
assert rot_dim <= t.shape[-1], (
57-
f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
58-
)
56+
assert (
57+
rot_dim <= t.shape[-1]
58+
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
5959

6060
t_left, t, t_right = (
6161
t[..., :start_index],
@@ -173,9 +173,9 @@ def get_seq_pos(self, seq_len, device, dtype, offset=0):
173173
def rotate_queries_or_keys(self, t, seq_dim=None, offset=0, scale=None):
174174
seq_dim = default(seq_dim, self.default_seq_dim)
175175

176-
assert not self.use_xpos or exists(scale), (
177-
"you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
178-
)
176+
assert (
177+
not self.use_xpos or exists(scale)
178+
), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
179179

180180
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
181181

src/prediction/dataset.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(
3939
self.overlap_ratio = overlap_ratio
4040

4141
self.n_frames = num_frames
42-
self.n_seconds = self.n_frames / self.orig_freq
42+
self.n_seconds = 30
4343

4444
assert (
4545
self.overlap_ratio >= 0 and self.overlap_ratio < 1
@@ -57,18 +57,19 @@ def compute_segments_per_file(self):
5757

5858
i = 0
5959
for filepath in tqdm(self.filelist):
60-
try:
61-
hop_size = self.n_seconds * self.overlap_ratio
60+
# try:
61+
hop_size = self.n_seconds * (1 - self.overlap_ratio)
6262

63-
metadata = torchaudio.info(self.data_dir / filepath)
64-
seconds = metadata.num_frames / metadata.sample_rate
65-
n_segments = int(seconds / hop_size)
63+
metadata = torchaudio.info(self.data_dir / filepath)
64+
seconds = metadata.num_frames / metadata.sample_rate
6665

67-
for j in range(n_segments):
68-
self.index[i] = (filepath, j)
69-
i += 1
70-
except Exception as e:
71-
print(f"Error processing file {filepath}")
66+
n_segments = int(seconds / hop_size)
67+
68+
for j in range(n_segments):
69+
self.index[i] = (filepath, j)
70+
i += 1
71+
# except Exception as e:
72+
# print(f"Error processing file {filepath}")
7273

7374
def __len__(self):
7475
return len(self.index)
@@ -111,6 +112,9 @@ def __getitem__(self, idx):
111112
audio = audio.float()
112113

113114
# TODO zero pad
115+
tgt_len = 720000
116+
if audio.size(0) < tgt_len:
117+
audio = torch.nn.functional.pad(audio, (0, tgt_len - audio.size(0)))
114118

115119
return audio, str(file_path)
116120

0 commit comments

Comments
 (0)