Skip to content

Commit c2d9bcb

Browse files
Add transformer encoder
1 parent a9560de commit c2d9bcb

File tree

6 files changed

+479
-7
lines changed

6 files changed

+479
-7
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ repos:
1111
- id: check-docstring-first
1212
- id: check-yaml
1313
- id: debug-statements
14+
exclude: glm_experiments/models/components/transformer.py
1415
- id: detect-private-key
1516
- id: check-executables-have-shebangs
1617
- id: check-toml
@@ -23,6 +24,7 @@ repos:
2324
hooks:
2425
- id: black
2526
args: [--line-length, "99"]
27+
exclude: glm_experiments/models/components/transformer.py
2628

2729
# python import sorting
2830
- repo: https://github.com/PyCQA/isort
@@ -37,6 +39,7 @@ repos:
3739
hooks:
3840
- id: pyupgrade
3941
args: [--py313-plus]
42+
exclude: glm_experiments/models/components/transformer.py
4043

4144
# python check (PEP8), programming errors and code complexity
4245
- repo: https://github.com/PyCQA/flake8
@@ -48,7 +51,7 @@ repos:
4851
"--extend-ignore",
4952
"E203,E402,E501,F401,F841,RST2,RST301",
5053
"--exclude",
51-
"logs/*,data/*",
54+
"logs/*,data/*,glm_experiments/models/components/transformer.py",
5255
]
5356
additional_dependencies: [flake8-rst-docstrings==0.3.0]
5457

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
_target_: glm_experiments.models.bert_lit_module.BERTLitModule
2+
3+
net:
4+
_target_: glm_experiments.models.components.bert.BERT
5+
embedder:
6+
_target_: glm_experiments.models.components.transformer.Embedding
7+
vocab_size: 7
8+
d_model: 768 # Standard BERT-base size
9+
encoder:
10+
_target_: glm_experiments.models.components.transformer.Transformer
11+
hidden_size: ${..embedder.d_model} # 768
12+
n_layers: 12 # CS336 default
13+
num_heads: 12 # 12 heads → d_head = 64
14+
# d_ff: auto-computed as floor(768 * 8/3 / 64) * 64 = 2048
15+
rope_theta: 10000.0
16+
is_causal: false
17+
layer_norm:
18+
_target_: torch.nn.RMSNorm
19+
normalized_shape: ${..embedder.d_model}
20+
decoder:
21+
_target_: glm_experiments.models.components.transformer.Linear
22+
d_in: ${..embedder.d_model}
23+
d_out: ${..embedder.vocab_size}
24+
25+
optimizer:
26+
_target_: torch.optim.AdamW
27+
_partial_: true
28+
lr: 0.001 # CS336 default
29+
weight_decay: 0.1 # CS336 default
30+
betas: [0.9, 0.98] # CS336 default (beta1, beta2)
31+
eps: 1.0e-9 # CS336 default
32+
33+
scheduler:
34+
_target_: transformers.get_constant_schedule_with_warmup
35+
_partial_: true
36+
num_warmup_steps: 1000 # More warmup for larger model
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
_target_: glm_experiments.models.bert_lit_module.BERTLitModule
2+
3+
net:
4+
_target_: glm_experiments.models.components.bert.BERT
5+
embedder:
6+
_target_: glm_experiments.models.components.transformer.Embedding
7+
vocab_size: 7
8+
d_model: 128
9+
encoder:
10+
_target_: glm_experiments.models.components.transformer.Transformer
11+
hidden_size: ${..embedder.d_model} # 128
12+
n_layers: 6 # Fewer layers for fast iteration
13+
num_heads: 8 # 8 heads → d_head = 16
14+
# d_ff: auto-computed as floor(128 * 8/3 / 64) * 64 = 320
15+
rope_theta: 10000.0
16+
is_causal: false # Bidirectional for MLM
17+
layer_norm:
18+
_target_: torch.nn.RMSNorm # Use RMSNorm to match Transformer
19+
normalized_shape: ${..embedder.d_model}
20+
decoder:
21+
_target_: glm_experiments.models.components.transformer.Linear
22+
d_in: ${..embedder.d_model}
23+
d_out: ${..embedder.vocab_size}
24+
25+
optimizer:
26+
_target_: torch.optim.AdamW
27+
_partial_: true
28+
lr: 0.001 # CS336 default
29+
weight_decay: 0.1 # CS336 default
30+
betas: [0.9, 0.98] # CS336 default (beta1, beta2)
31+
eps: 1.0e-9 # CS336 default
32+
33+
scheduler:
34+
_target_: transformers.get_constant_schedule_with_warmup
35+
_partial_: true
36+
num_warmup_steps: 100

0 commit comments

Comments
 (0)