Skip to content

Commit 72b638f

Browse files
committed
Update to ruff format II
1 parent 606a605 commit 72b638f

3 files changed

Lines changed: 10 additions & 10 deletions

File tree

src/nets/rope.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
5353
rot_dim = freqs.shape[-1]
5454
end_index = start_index + rot_dim
5555

56-
assert (
57-
rot_dim <= t.shape[-1]
58-
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
56+
assert rot_dim <= t.shape[-1], (
57+
f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
58+
)
5959

6060
t_left, t, t_right = (
6161
t[..., :start_index],
@@ -173,9 +173,9 @@ def get_seq_pos(self, seq_len, device, dtype, offset=0):
173173
def rotate_queries_or_keys(self, t, seq_dim=None, offset=0, scale=None):
174174
seq_dim = default(seq_dim, self.default_seq_dim)
175175

176-
assert (
177-
not self.use_xpos or exists(scale)
178-
), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
176+
assert not self.use_xpos or exists(scale), (
177+
"you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
178+
)
179179

180180
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
181181

src/prediction/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ def __init__(
4141
self.n_frames = num_frames
4242
self.n_seconds = 30
4343

44-
assert (
45-
self.overlap_ratio >= 0 and self.overlap_ratio < 1
46-
), "Overlap ratio must be between 0 and 1."
44+
assert self.overlap_ratio >= 0 and self.overlap_ratio < 1, (
45+
"Overlap ratio must be between 0 and 1."
46+
)
4747

4848
self.compute_segments_per_file()
4949

src/probe/modules/sequence_classifiers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def plot_confusion_matrix(self, conf_matrix, multiclass=False):
195195
)
196196
axes = axes.flatten()
197197
labels = (
198-
[f"{i+1}" for i in range(50)] if self.labels is None else self.labels
198+
[f"{i + 1}" for i in range(50)] if self.labels is None else self.labels
199199
)
200200
for ax, cm, label in zip(axes, conf_matrix, labels):
201201
# Plot the confusion matrix in each subplot

0 commit comments

Comments
 (0)