Skip to content

Commit e722955

Browse files
committed
refactor: convert pretrained
1 parent 33eeb5a commit e722955

31 files changed

+1039
-1479
lines changed

.vscode/extensions.json

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"recommendations": [
3+
"detachhead.basedpyright",
4+
"charliermarsh.ruff",
5+
"ms-python.python"
6+
]
7+
}

pyproject.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ build-backend = "hatchling.build"
2121

2222
[dependency-groups]
2323
dev = [
24-
"mypy>=1.13.0",
24+
"basedpyright>=1.20.0",
2525
"pre-commit>=4.0.1",
2626
"ruff>=0.7.1",
2727
]
@@ -131,3 +131,6 @@ docstring-code-format = false
131131
docstring-code-line-length = "dynamic"
132132

133133

134+
[tool.pyright]
135+
typeCheckingMode = "standard"
136+
reportUnknownMemberType = false

src/xlens/components/transformer_block.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,7 @@ def __init__(self, cfg: HookedTransformerConfig, block_index: int):
4141
self.layer_id = block_index
4242

4343
if cfg.normalization_type == "LN":
44-
normalization_layer: Callable[
45-
[HookedTransformerConfig],
46-
Callable[[Float[jax.Array, "batch pos d_model"]], Float[jax.Array, "batch pos d_model"]],
47-
] = LayerNorm
44+
normalization_layer = LayerNorm
4845
elif cfg.normalization_type == "LNPre":
4946
# We've folded in LayerNorm weights, so just need the center + scale parts
5047
normalization_layer = LayerNormPre

src/xlens/components/unembed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class Unembed(eqx.Module):
1919
def __init__(self, cfg: HookedTransformerConfig):
2020
self.cfg = cfg
2121
# Note that there's a separate variable for d_vocab_out and d_vocab (the input vocab size). For language tasks these are always the same, but for algorithmic tasks we may want them to be different.
22-
self.W_U: Float[jax.Array, "d_model d_vocab_out"] = jnp.zeros((self.cfg.d_model, self.cfg.d_vocab_out))
22+
self.W_U = jnp.zeros((self.cfg.d_model, self.cfg.d_vocab_out))
2323

2424
def __call__(self, residual: Float[jax.Array, "batch pos d_model"]) -> Float[jax.Array, "batch pos d_vocab_out"]:
2525
return residual @ self.W_U

src/xlens/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def __post_init__(self):
168168
if not self.attn_only:
169169
if self.d_mlp is None:
170170
# For some reason everyone hard codes in this hyper-parameter!
171-
self.d_mlp: int = self.d_model * 4
171+
self.d_mlp = self.d_model * 4
172172
assert self.act_fn is not None, "act_fn must be specified for non-attn-only models"
173173
assert self.act_fn in SUPPORTED_ACTIVATIONS, f"act_fn={self.act_fn} must be one of {SUPPORTED_ACTIVATIONS}"
174174

src/xlens/hooked_transformer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from xlens.components import Embed, LayerNorm, LayerNormPre, PosEmbed, RMSNorm, RMSNormPre, TransformerBlock, Unembed
99
from xlens.hooks import with_cache, with_hooks
10-
from xlens.pretrained.loading_from_pretrained import get_pretrained_model_config, get_pretrained_state_dict
10+
from xlens.pretrained.convert import get_pretrained_model_config, get_pretrained_weights
1111
from xlens.utils import load_pretrained_weights
1212

1313
from .config import HookedTransformerConfig
@@ -167,7 +167,7 @@ def from_pretrained(cls, model_name: str, hf_model=None) -> "HookedTransformer":
167167
"""
168168

169169
cfg = get_pretrained_model_config(model_name)
170-
state_dict = get_pretrained_state_dict(model_name, cfg, hf_model=hf_model)
170+
weights = get_pretrained_weights(cfg, model_name, hf_model=hf_model)
171171
model = HookedTransformer(cfg)
172-
model = load_pretrained_weights(model, state_dict)
172+
model = load_pretrained_weights(model, weights)
173173
return model

src/xlens/pretrained/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .loading_from_pretrained import get_pretrained_model_config, get_pretrained_state_dict
1+
from .convert import get_pretrained_model_config, get_pretrained_weights
22

3-
__all__ = ["get_pretrained_state_dict", "get_pretrained_model_config"]
3+
__all__ = ["get_pretrained_weights", "get_pretrained_model_config"]

src/xlens/pretrained/convert.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Loading Pretrained Models Utilities.
2+
3+
This module contains functions for loading pretrained models from the Hugging Face Hub.
4+
"""
5+
6+
import jax
7+
8+
from xlens.config import HookedTransformerConfig
9+
from xlens.pretrained.converters import (
10+
GPT2Converter,
11+
GPTNeoXConverter,
12+
LlamaConverter,
13+
MistralConverter,
14+
Qwen2Converter,
15+
)
16+
from xlens.pretrained.model_converter import HuggingFaceModelConverter
17+
18+
converter = HuggingFaceModelConverter(
19+
converters=[
20+
GPT2Converter(),
21+
Qwen2Converter(),
22+
LlamaConverter(),
23+
MistralConverter(),
24+
GPTNeoXConverter(),
25+
]
26+
)
27+
28+
29+
def get_pretrained_model_config(model_name: str) -> HookedTransformerConfig:
30+
return converter.get_pretrained_model_config(model_name)
31+
32+
33+
def get_pretrained_weights(cfg: HookedTransformerConfig, model_name: str, hf_model=None) -> dict[str, jax.Array]:
34+
return converter.get_pretrained_weights(cfg, model_name, hf_model=hf_model)

src/xlens/pretrained/convert_weight/__init__.py

-13
This file was deleted.

src/xlens/pretrained/convert_weight/gpt2.py

-80
This file was deleted.

src/xlens/pretrained/convert_weight/llama.py

-64
This file was deleted.

src/xlens/pretrained/convert_weight/mistral.py

-57
This file was deleted.

src/xlens/pretrained/convert_weight/neox.py

-65
This file was deleted.

0 commit comments

Comments
 (0)