Skip to content

Commit 7d49962

Browse files
committed
Formatting
1 parent 271e7dd commit 7d49962

File tree

8 files changed

+14
-16
lines changed

8 files changed

+14
-16
lines changed

mergekit/common.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def set_config_value(config: PretrainedConfig, key: str, value: Any):
3939
for idx, part in enumerate(parts[:-1]):
4040
if not hasattr(obj, part):
4141
raise RuntimeError(
42-
f"Config {config} has no attribute {'.'.join(parts[:idx+1])}"
42+
f"Config {config} has no attribute {'.'.join(parts[: idx + 1])}"
4343
)
4444
obj = getattr(obj, part)
4545
setattr(obj, parts[-1], value)
@@ -52,7 +52,7 @@ def get_config_value(config: PretrainedConfig, key: str) -> Any:
5252
for idx, part in enumerate(parts):
5353
if not hasattr(obj, part):
5454
raise RuntimeError(
55-
f"Config {config} has no attribute {'.'.join(parts[:idx+1])}"
55+
f"Config {config} has no attribute {'.'.join(parts[: idx + 1])}"
5656
)
5757
obj = getattr(obj, part)
5858
return obj

mergekit/io/tasks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _normalized_shard_name(path: str) -> int:
6464
name = name.lower().replace("pytorch_model", "model")
6565
if m := shard_name_re.search(name):
6666
frac = int(m.group(1)) / int(m.group(2))
67-
name = f"model-{int(frac*100):03d}pct"
67+
name = f"model-{int(frac * 100):03d}pct"
6868
return name
6969

7070

mergekit/io/tensor_writer.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ def _flush_current_shard(self):
6565
if not self.current_shard:
6666
return
6767

68-
LOG.info(f"Writing shard #{self.shards_written+1} to disk")
68+
LOG.info(f"Writing shard #{self.shards_written + 1} to disk")
6969

7070
prefix, extension = self._get_name_components()
71-
shard_name = f"{prefix}-{self.shards_written+1}.{extension}"
71+
shard_name = f"{prefix}-{self.shards_written + 1}.{extension}"
7272

7373
for key in self.current_shard:
7474
self.weight_map[key] = shard_name
@@ -95,8 +95,8 @@ def finalize(self):
9595
total_shards = self.shards_written
9696
name_remap = {}
9797
for idx in range(total_shards):
98-
name_remap[f"{prefix}-{idx+1}.{extension}"] = (
99-
f"{prefix}-{idx+1:05d}-of-{total_shards:05d}.{extension}"
98+
name_remap[f"{prefix}-{idx + 1}.{extension}"] = (
99+
f"{prefix}-{idx + 1:05d}-of-{total_shards:05d}.{extension}"
100100
)
101101

102102
if total_shards < 2:

mergekit/merge_methods/easy_define.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def _execute(self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs):
167167

168168
tt_fields["execute"] = _execute
169169

170-
tt_name = f"{name.title().replace(' ','')}MergeTask"
170+
tt_name = f"{name.title().replace(' ', '')}MergeTask"
171171
tt_cls = pydantic.create_model(tt_name, __base__=Task[torch.Tensor], **tt_fields)
172172

173173
mm_fields = {}
@@ -220,7 +220,7 @@ def _parameters(self) -> List[ConfigParameterDef]:
220220

221221
mm_fields["parameters"] = _parameters
222222

223-
mm_name = f"{name.title().replace(' ','')}MergeMethod"
223+
mm_name = f"{name.title().replace(' ', '')}MergeMethod"
224224
mm_cls = type(mm_name, (MergeMethod,), mm_fields)
225225
REGISTERED_MERGE_METHODS[name] = mm_cls()
226226
return func

mergekit/multigpu_executor.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
import torch
2222
import tqdm
2323

24-
from mergekit.io.tasks import TensorWriterTask
25-
2624
from .graph import (
2725
Executor,
2826
Task,
@@ -101,7 +99,7 @@ def __init__(
10199
offending = [
102100
t.task() for t in parallel_handles if t.task().main_thread_only()
103101
]
104-
logging.error(f"Main-thread-only tasks in parallel section:")
102+
logging.error("Main-thread-only tasks in parallel section:")
105103
for task in offending:
106104
logging.error(f" {type(task).__name__}")
107105
raise RuntimeError(

mergekit/scripts/tokensurgeon.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def build_embedding_matrix(
532532
)
533533
shared_numeric_tokens = set(orig_vocab.keys()) & set(donor_numeric_tokens)
534534
LOG.debug(
535-
f"{len(shared_numeric_tokens)} shared numeric tokens ({100.0*len(shared_numeric_tokens)/len(donor_numeric_tokens):.2f}%)"
535+
f"{len(shared_numeric_tokens)} shared numeric tokens ({100.0 * len(shared_numeric_tokens) / len(donor_numeric_tokens):.2f}%)"
536536
)
537537
LOG.debug(
538538
[donor_tokenizer.decode([donor_vocab[tok]]) for tok in shared_numeric_tokens]

mergekit/tokenizer/build.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def build_tokenizer(
257257
orig_idx = model_vocab[tok]
258258
if orig_idx >= vocab_size:
259259
LOG.warning(
260-
f"{model} token {repr(tok)} has index {orig_idx}>{vocab_size-1} (padding?)"
260+
f"{model} token {repr(tok)} has index {orig_idx}>{vocab_size - 1} (padding?)"
261261
)
262262
continue
263263

mergekit/tokensurgeon/magikarp.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def well_trained_tokens(
7070
).float()
7171
threshold = torch.quantile(cos_sim, 1 - quantile, dim=0)
7272
LOG.debug(
73-
f"Unused token threshold in embed_tokens: {threshold.item():.4f} ({int((1-quantile) * 100)}th percentile)"
73+
f"Unused token threshold in embed_tokens: {threshold.item():.4f} ({int((1 - quantile) * 100)}th percentile)"
7474
)
7575
if threshold < 0.5:
7676
threshold = 0.5
@@ -89,7 +89,7 @@ def well_trained_tokens(
8989
).float()
9090
threshold = torch.quantile(cos_sim, 1 - quantile, dim=0)
9191
LOG.debug(
92-
f"Unused token threshold in lm_head: {threshold.item():.4f} ({int((1-quantile) * 100)}th percentile)"
92+
f"Unused token threshold in lm_head: {threshold.item():.4f} ({int((1 - quantile) * 100)}th percentile)"
9393
)
9494
if threshold < 0.5:
9595
threshold = 0.5

0 commit comments

Comments
 (0)