Skip to content

Commit b416ab9

Browse files
committed
Black formatting
1 parent 1e059c1 commit b416ab9

88 files changed

Lines changed: 4783 additions & 1896 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
python -m pip install .
3737
- name: Test with pytest
3838
run: |
39-
pytest -vvv --cov=./ --cov-report=xml
39+
pytest test/
4040
- name: Upload coverage reports to Codecov
4141
uses: codecov/codecov-action@eaaf4bedf32dbdc6b720b63067d99c4d77d6047d # v3.1.4
4242
with:

keys_values/__main__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def _check_commands():
5555

5656
class TeeOutput:
5757
"""Utility class to duplicate output to both file and stream (stdout/stderr)"""
58+
5859
def __init__(self, file_obj, stream):
5960
self.file = file_obj
6061
self.stream = stream
@@ -129,7 +130,10 @@ def main() -> None:
129130
warning_message = r"The epoch parameter in `scheduler.step\(\)` was not necessary and is being deprecated.*"
130131

131132
warnings.filterwarnings(
132-
action="ignore", message=warning_message, category=UserWarning, module=r".*torch\.optim\.lr_scheduler.*"
133+
action="ignore",
134+
message=warning_message,
135+
category=UserWarning,
136+
module=r".*torch\.optim\.lr_scheduler.*",
133137
)
134138

135139
torch.set_float32_matmul_precision("high")

keys_values/adapter.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,15 @@ def __init__(self, config: Config, **mha_kwargs) -> None:
5353
self.transformer = nn.ModuleDict(
5454
dict(
5555
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
56-
h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)),
56+
h=nn.ModuleList(
57+
Block(config, block_idx) for block_idx in range(config.n_layer)
58+
),
5759
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
5860
)
5961
)
6062
self.mha = MultiHeadSelfAttention(
61-
config, **transform_mha_kwargs(mha_kwargs, config),
63+
config,
64+
**transform_mha_kwargs(mha_kwargs, config),
6265
)
6366
self.max_seq_length = self.config.block_size
6467
self._start_of_layer_hook = None
@@ -102,6 +105,7 @@ class CausalSelfAttention(BaseCausalSelfAttention):
102105
attention over the adaption prompt.
103106
104107
"""
108+
105109
def __init__(
106110
self,
107111
config: Config,
@@ -140,7 +144,13 @@ def _transform_output(
140144
prefix = self.adapter_wte.weight.reshape(1, a_num, self.config.n_embd)
141145
aqkv = self.qkv(prefix)
142146
q_per_kv = self.config.n_head // self.config.n_query_groups
143-
aqkv = aqkv.view(1, a_num, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
147+
aqkv = aqkv.view(
148+
1,
149+
a_num,
150+
self.config.n_query_groups,
151+
q_per_kv + 2,
152+
self.config.head_size,
153+
)
144154
aqkv = aqkv.permute(0, 2, 3, 1, 4)
145155
_, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)
146156
if self.config.n_query_groups != 1:
@@ -171,8 +181,12 @@ def reset_parameters(self) -> None:
171181
if hasattr(self, "gating_factor"):
172182
torch.nn.init.zeros_(self.gating_factor)
173183

174-
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
184+
def _load_from_state_dict(
185+
self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any
186+
) -> None:
175187
"""For compatibility with older checkpoints."""
176-
if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
188+
if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(
189+
1
190+
) == self.config.n_head:
177191
state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
178192
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

keys_values/adapter_v2.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,15 @@ def __init__(self, config: Config, **mha_kwargs) -> None:
5555
self.transformer = nn.ModuleDict(
5656
dict(
5757
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
58-
h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)),
58+
h=nn.ModuleList(
59+
Block(config, block_idx) for block_idx in range(config.n_layer)
60+
),
5961
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
6062
)
6163
)
6264
self.mha = MultiHeadSelfAttention(
63-
config, **transform_mha_kwargs(mha_kwargs, config),
65+
config,
66+
**transform_mha_kwargs(mha_kwargs, config),
6467
)
6568
self.max_seq_length = self.config.block_size
6669
self._start_of_layer_hook = None
@@ -77,9 +80,14 @@ def _init_weights(self, module: nn.Module) -> None:
7780
if isinstance(module, AdapterV2Linear):
7881
module.reset_parameters()
7982

80-
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
83+
def _load_from_state_dict(
84+
self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any
85+
) -> None:
8186
"""For compatibility with base checkpoints."""
82-
mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"}
87+
mapping = {
88+
"lm_head.weight": "lm_head.linear.weight",
89+
"lm_head.bias": "lm_head.linear.bias",
90+
}
8391
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
8492
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
8593

@@ -107,6 +115,7 @@ def __init__(
107115

108116
class CausalSelfAttention(BaseCausalSelfAttention):
109117
"""A modification of `keys_values.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class"""
118+
110119
def __init__(
111120
self,
112121
config: Config,
@@ -129,7 +138,11 @@ def __init__(
129138
)
130139

131140
def _load_from_state_dict(
132-
self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any,
141+
self,
142+
state_dict: Dict,
143+
prefix: str,
144+
*args: Any,
145+
**kwargs: Any,
133146
) -> None:
134147
"""For compatibility with base and/or legacy checkpoints."""
135148
mapping = {
@@ -140,13 +153,17 @@ def _load_from_state_dict(
140153
}
141154
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
142155
# For compatibility with older checkpoints
143-
if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
156+
if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(
157+
1
158+
) == self.config.n_head:
144159
state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
145160

146161
for attr in ("weight", "bias"):
147162
legacy_key = f"{prefix}attn.linear.{attr}"
148163
current_key = f"{prefix}qkv.linear.{attr}"
149164
if legacy_key in state_dict:
150-
state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config)
165+
state_dict[current_key] = qkv_reassemble(
166+
state_dict.pop(legacy_key), self.config
167+
)
151168

152169
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

keys_values/array_limit.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,17 @@
1414
from typing import Optional
1515

1616

17-
REDUCTION_FACTORS = [3/4, 2/4, 1/4, 3/16, 2/16, 1/16, 3/64, 2/64, 1/64]
17+
REDUCTION_FACTORS = [
18+
3 / 4,
19+
2 / 4,
20+
1 / 4,
21+
3 / 16,
22+
2 / 16,
23+
1 / 16,
24+
3 / 64,
25+
2 / 64,
26+
1 / 64,
27+
]
1828

1929

2030
class TemporaryArrayLimit:
@@ -28,6 +38,7 @@ class TemporaryArrayLimit:
2838
to this object and read the limit from here.
2939
3040
"""
41+
3142
def __init__(self, init_val: float, name: str):
3243
if init_val <= 0:
3344
raise ValueError("Initial value must be positive (unit is GB)")

keys_values/attention.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
build_mask_slice,
2828
create_temp_array,
2929
sdpa_attention_weights,
30-
slice_as_flat, pytorch_scaled_dot_product_attention,
30+
slice_as_flat,
31+
pytorch_scaled_dot_product_attention,
3132
)
3233
from keys_values.pos_encoding import position_encoding_factory, PositionEncoding
3334
from keys_values.sdpa_wrapper import scaled_dot_product_attention as qpadded_sdpa
@@ -62,7 +63,10 @@ def values(self) -> torch.Tensor:
6263
class DefaultKeysAndValues(KeysAndValues):
6364
def __init__(self, keys: torch.Tensor, values: torch.Tensor):
6465
# The final dimension of K and V can be different (in general)
65-
assert keys.shape[:-1] == values.shape[:-1] and keys.ndim == 4, (keys.shape, values.shape)
66+
assert keys.shape[:-1] == values.shape[:-1] and keys.ndim == 4, (
67+
keys.shape,
68+
values.shape,
69+
)
6670
self._keys = keys
6771
self._values = values
6872

@@ -156,6 +160,7 @@ class MultiHeadSelfAttention:
156160
Look at :class:`DefaultUseEagerKernel` for choosing `use_eager_kernel`.
157161
158162
"""
163+
159164
def __init__(
160165
self,
161166
config: Config,
@@ -296,9 +301,12 @@ def __call__(
296301

297302
def _get_sliding_window_size(self, block_idx: int) -> Optional[int]:
298303
apply_sliding_window_attention = (
299-
self.config.sliding_window_size is not None and self.config.sliding_window_indices[block_idx] == 1
304+
self.config.sliding_window_size is not None
305+
and self.config.sliding_window_indices[block_idx] == 1
306+
)
307+
return (
308+
self.config.sliding_window_size if apply_sliding_window_attention else None
300309
)
301-
return self.config.sliding_window_size if apply_sliding_window_attention else None
302310

303311
def _sdpa_mode(
304312
self,
@@ -326,7 +334,11 @@ def _sdpa_mode(
326334
return SDPA_IMPL_EAGER_NO_BLOCKS
327335
must_eager = return_attn_weights or self.use_eager_sdpa_always
328336
if must_eager or not is_causal:
329-
if must_eager or sliding_window_size is not None or self._use_eager_kernel(kv_len, q_len):
337+
if (
338+
must_eager
339+
or sliding_window_size is not None
340+
or self._use_eager_kernel(kv_len, q_len)
341+
):
330342
return SDPA_IMPL_EAGER_BLOCKS
331343
else:
332344
return SDPA_IMPL_QPADDED_PYTORCH
@@ -455,7 +467,10 @@ def eager_scaled_dot_product_attention(
455467
attn_weights = attn_weights.sum(dim=2)
456468
if n_head != n_query_groups:
457469
attn_weights = attn_weights.view(
458-
batch_size, n_query_groups, -1, kv_len,
470+
batch_size,
471+
n_query_groups,
472+
-1,
473+
kv_len,
459474
).mean(dim=2)
460475
else:
461476
attn_weights = None
@@ -530,7 +545,11 @@ def scaled_dot_product_attention_in_blocks(
530545
source = _tmp_array[:, :n_query_groups, :, :]
531546
torch.mean(
532547
attn_weights_part.view(
533-
batch_size, n_query_groups, -1, sz, kv_len,
548+
batch_size,
549+
n_query_groups,
550+
-1,
551+
sz,
552+
kv_len,
534553
),
535554
dim=2,
536555
out=source,
@@ -542,7 +561,8 @@ def scaled_dot_product_attention_in_blocks(
542561
# - output_part (bs, nh_q, sz, hs)
543562
output_parts.append(
544563
attention_compute_weighted_values(
545-
scores=attn_weights_part, value=value32,
564+
scores=attn_weights_part,
565+
value=value32,
546566
).to(dtype)
547567
)
548568
start = end

keys_values/attention_utils.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,14 @@ def filter_sdpa_kernels(
4242
for kernel in sdpa_kernels:
4343
if kernel == SDPBackend.FLASH_ATTENTION and not can_use_flash_attention(params):
4444
continue
45-
elif kernel == SDPBackend.EFFICIENT_ATTENTION and not can_use_efficient_attention(params):
45+
elif (
46+
kernel == SDPBackend.EFFICIENT_ATTENTION
47+
and not can_use_efficient_attention(params)
48+
):
4649
continue
47-
elif kernel == SDPBackend.CUDNN_ATTENTION and not can_use_cudnn_attention(params):
50+
elif kernel == SDPBackend.CUDNN_ATTENTION and not can_use_cudnn_attention(
51+
params
52+
):
4853
continue
4954
new_kernels.append(kernel)
5055
return new_kernels
@@ -202,11 +207,21 @@ def mask_slice_bool(
202207
q_per_kv = n_head // n_query_groups
203208
assert n_head == n_query_groups * q_per_kv and q_per_kv >= 1
204209
if q_per_kv > 1:
205-
token_positions = token_positions.unsqueeze(2).expand(
206-
-1, -1, q_per_kv, -1,
207-
).reshape(batch_size, n_head, -1)
210+
token_positions = (
211+
token_positions.unsqueeze(2)
212+
.expand(
213+
-1,
214+
-1,
215+
q_per_kv,
216+
-1,
217+
)
218+
.reshape(batch_size, n_head, -1)
219+
)
208220
token_positions = token_positions.unsqueeze(2).expand(
209-
-1, -1, num, -1,
221+
-1,
222+
-1,
223+
num,
224+
-1,
210225
)
211226
kwargs = dict(device=token_positions.device, dtype=token_positions.dtype)
212227
bool_mask = (
@@ -276,7 +291,7 @@ def build_mask_slice(
276291

277292

278293
# Maximum number of `float32` entries for `tmp_array` for GB
279-
ENTRIES_PER_GB = 2 ** 28
294+
ENTRIES_PER_GB = 2**28
280295

281296
# Maximum size of `tmp_array` in GB
282297
DEFAULT_TMP_ARRAY_LIMIT_GB = 3
@@ -324,7 +339,9 @@ def create_temp_array(
324339
else:
325340
tmp_len = tmp_array_max_num_entries // factor
326341
if tmp_len < 1:
327-
raise ValueError(f"batch_size={batch_size}, n_head={n_head}, kv_len={kv_len} too large. Their product must be <= {tmp_array_max_num_entries}")
342+
raise ValueError(
343+
f"batch_size={batch_size}, n_head={n_head}, kv_len={kv_len} too large. Their product must be <= {tmp_array_max_num_entries}"
344+
)
328345
num_splits = int(math.ceil(q_len / tmp_len))
329346
shape = (batch_size, n_head, tmp_len, kv_len)
330347
kwargs = dict(device=device, dtype=torch.float32)
@@ -388,7 +405,10 @@ def sdpa_attention_weights(
388405
_, n_query_groups, kv_len, _ = key.shape
389406
# Compute attention weights f(S)
390407
attention_compute_scores(
391-
query=query, key=key, out=tmp_array, scale_factor=scale_factor,
408+
query=query,
409+
key=key,
410+
out=tmp_array,
411+
scale_factor=scale_factor,
392412
)
393413
# Attention masking
394414
if token_positions is None:
@@ -422,17 +442,21 @@ def sample_token_positions(
422442
) -> torch.Tensor:
423443
index_kwargs = dict(dtype=torch.int64, device=device)
424444
token_positions = torch.zeros(
425-
(batch_size, n_query_groups, kv_len), **index_kwargs,
445+
(batch_size, n_query_groups, kv_len),
446+
**index_kwargs,
426447
)
427448
for bs in range(batch_size):
428449
for nq in range(n_query_groups):
429450
token_positions[bs, nq, :] = torch.randperm(
430-
input_pos, **index_kwargs,
451+
input_pos,
452+
**index_kwargs,
431453
)[:kv_len]
432454
# Ensure that `input_pos:(input_pos + q_len)` is present
433455
index = torch.randperm(kv_len, **index_kwargs)[:q_len]
434456
token_positions[bs, nq, index] = torch.arange(
435-
input_pos, input_pos + q_len, **index_kwargs,
457+
input_pos,
458+
input_pos + q_len,
459+
**index_kwargs,
436460
)
437461
return token_positions
438462

0 commit comments

Comments
 (0)