Skip to content

Commit 473e0ca

Browse files
author
Michael Gschwind
committed
support params file
1 parent b6fc99d commit 473e0ca

File tree

3 files changed

+96
-38
lines changed

3 files changed

+96
-38
lines changed

export.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,13 @@ def main(checkpoint_path, device, quantize = "{ }", args = None):
6767
print("Loading model ...")
6868
t0 = time.time()
6969
model = _load_model(
70-
checkpoint_path, args.checkpoint_dir, device=device, precision=precision, use_tp=False)
70+
checkpoint_path,
71+
args.checkpoint_dir,
72+
args.params_path,
73+
device=device,
74+
precision=precision,
75+
use_tp=False
76+
)
7177

7278
device_sync(device=device) # MKG
7379
print(f"Time to load model: {time.time() - t0:.02f} seconds")
@@ -155,6 +161,12 @@ def cli():
155161
default=None,
156162
help="Model checkpoint directory.",
157163
)
164+
parser.add_argument(
165+
"--params-path",
166+
type=Path,
167+
default=None,
168+
help="Parameter file path.",
169+
)
158170
parser.add_argument(
159171
"--output-pte-path",
160172
type=str,

generate.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,17 @@ def encode_tokens(tokenizer, string, bos=True, device="cuda"):
277277
def _load_model(
278278
checkpoint_path,
279279
checkpoint_dir,
280+
params_path,
280281
device,
281282
precision,
282283
use_tp=False
283284
):
284285
use_cuda = "cuda" in device
285286
with torch.device("meta"):
286-
model = Transformer.from_name(checkpoint_path.parent.name)
287+
if params_path:
288+
model = Transformer.from_params(params_path)
289+
else:
290+
model = Transformer.from_name(checkpoint_path.parent.name)
287291

288292
# checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
289293
cps = []
@@ -341,6 +345,7 @@ def main(
341345
temperature: float = 0.8,
342346
checkpoint_path: Optional[Path] = None,
343347
checkpoint_dir: Optional[Path] = None,
348+
params_path: Optional[Path] = None,
344349
tokenizer_path: Optional[Path] = None,
345350
compile: bool = True,
346351
compile_prefill: bool = False,
@@ -385,7 +390,14 @@ def main(
385390

386391
print("Loading model ...")
387392
t0 = time.time()
388-
model_ = _load_model(checkpoint_path, checkpoint_dir, device, precision, use_tp)
393+
model_ = _load_model(
394+
checkpoint_path,
395+
checkpoint_dir,
396+
params_path,
397+
device,
398+
precision,
399+
use_tp
400+
)
389401
if dso_path:
390402
assert not model_dtype, f"dtype setting not valid for a DSO model. Specify dtype during export."
391403
assert quantize is None or quantize == "{ }", f"quantize not valid for exported DSO model. Specify quantization during export."
@@ -424,7 +436,14 @@ def main(
424436
model.to(dtype=name_to_dtype(model_dtype))
425437

426438
if is_speculative:
427-
draft_model = _load_model(draft_checkpoint_path, None, device, precision, use_tp)
439+
draft_model = _load_model(
440+
draft_checkpoint_path,
441+
None,
442+
None,
443+
device,
444+
precision,
445+
use_tp
446+
)
428447
else:
429448
draft_model = None
430449

@@ -593,6 +612,12 @@ def cli():
593612
default=None,
594613
help="Model checkpoint directory.",
595614
)
615+
parser.add_argument(
616+
"--params-path",
617+
type=Path,
618+
default=None,
619+
help="Parameter file path.",
620+
)
596621
parser.add_argument(
597622
"--tokenizer-path",
598623
type=Path,
@@ -662,6 +687,7 @@ def cli():
662687
args.temperature,
663688
args.checkpoint_path,
664689
args.checkpoint_dir,
690+
args.params_path,
665691
args.tokenizer_path,
666692
args.compile,
667693
args.compile_prefill,

model.py

+54-34
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,39 @@ class ModelArgs:
2323
block_size: int = 2048
2424
vocab_size: int = 32000
2525
n_layer: int = 32
26-
n_head: int = 32
26+
# n_head in gpt-fast
27+
n_heads: int = 32
2728
dim: int = 4096
28-
intermediate_size: int = None
29+
# hidden dim is intermediate_size in gpt-fast
30+
hidden_dim: int = None
2931
n_local_heads: int = -1
3032
head_dim: int = 64
3133
rope_base: float = 10000
3234
norm_eps: float = 1e-5
33-
35+
multiple_of = 256
36+
ffn_dim_multiplier = None
37+
3438
def __post_init__(self):
3539
if self.n_local_heads == -1:
36-
self.n_local_heads = self.n_head
37-
if self.intermediate_size is None:
38-
hidden_dim = 4 * self.dim
39-
n_hidden = int(2 * hidden_dim / 3)
40-
self.intermediate_size = find_multiple(n_hidden, 256)
41-
self.head_dim = self.dim // self.n_head
40+
self.n_local_heads = self.n_heads
41+
if self.hidden_dim is None:
42+
# If hidden_dim is not explicitly set in the ModelArgs,
43+
# then calculate implicitly based on dim and
44+
# also multiple of `args.multiple_of`
45+
multiple_of = args.multiple_of
46+
hidden_dim = 4 * dim
47+
hidden_dim = int(2 * hidden_dim / 3)
48+
if args.ffn_dim_multiplier is not None:
49+
hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
50+
args.hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
51+
self.head_dim = self.dim // self.n_heads
4252

53+
@classmethod
54+
def from_params(cls, params_path: str):
55+
with open(params_path, "r") as f:
56+
params = json.loads(f.read())
57+
return cls(**params)
58+
4359
@classmethod
4460
def from_name(cls, name: str):
4561
print(f"name {name}")
@@ -69,31 +85,31 @@ def from_name(cls, name: str):
6985
"CodeLlama-7b-Python-hf": dict(
7086
block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000
7187
),
72-
"7B": dict(n_layer=32, n_head=32, dim=4096),
73-
"13B": dict(n_layer=40, n_head=40, dim=5120),
74-
"30B": dict(n_layer=60, n_head=52, dim=6656),
88+
"7B": dict(n_layer=32, n_heads=32, dim=4096),
89+
"13B": dict(n_layer=40, n_heads=40, dim=5120),
90+
"30B": dict(n_layer=60, n_heads=52, dim=6656),
7591
"34B": dict(
7692
n_layer=48,
77-
n_head=64,
93+
n_heads=64,
7894
dim=8192,
7995
vocab_size=32000,
8096
n_local_heads=8,
81-
intermediate_size=22016,
97+
hidden_dim=22016,
8298
rope_base=1000000,
8399
), # CodeLlama-34B-Python-hf
84100
"70B": dict(
85-
n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672
101+
n_layer=80, n_heads=64, dim=8192, n_local_heads=8, hidden_dim=28672
86102
),
87103
"Mistral-7B": dict(
88104
n_layer=32,
89-
n_head=32,
105+
n_heads=32,
90106
n_local_heads=8,
91107
dim=4096,
92-
intermediate_size=14336,
108+
hidden_dim=14336,
93109
vocab_size=32000,
94110
),
95-
"stories15M": dict(n_layer=6, n_head=6, dim=288),
96-
"stories110M": dict(n_layer=12, n_head=12, dim=768),
111+
"stories15M": dict(n_layer=6, n_heads=6, dim=288),
112+
"stories110M": dict(n_layer=12, n_heads=12, dim=768),
97113
}
98114

99115

@@ -140,7 +156,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
140156
and self.max_batch_size >= max_batch_size
141157
):
142158
return
143-
head_dim = self.config.dim // self.config.n_head
159+
head_dim = self.config.dim // self.config.n_heads
144160
max_seq_length = find_multiple(max_seq_length, 8)
145161
self.max_seq_length = max_seq_length
146162
self.max_batch_size = max_batch_size
@@ -150,8 +166,8 @@ def setup_caches(self, max_batch_size, max_seq_length):
150166
)
151167

152168
freqs_cis = precompute_freqs_cis(
153-
self.config.block_size,
154-
self.config.dim // self.config.n_head,
169+
self.config.dim // self.config.n_heads,
170+
self.config.block_size * 2,
155171
self.config.rope_base,
156172
)
157173
self.register_buffer("freqs_cis", freqs_cis, persistent=True)
@@ -182,6 +198,10 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
182198
def from_name(cls, name: str):
183199
return cls(ModelArgs.from_name(name))
184200

201+
@classmethod
202+
def from_params(cls, params_path: str):
203+
return cls(ModelArgs.from_params(params_path))
204+
185205

186206
class TransformerBlock(nn.Module):
187207
def __init__(self, config: ModelArgs) -> None:
@@ -202,19 +222,19 @@ def forward(
202222
class Attention(nn.Module):
203223
def __init__(self, config: ModelArgs):
204224
super().__init__()
205-
assert config.dim % config.n_head == 0
225+
assert config.dim % config.n_heads == 0
206226

207227
# key, query, value projections for all heads, but in a batch
208-
# total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
228+
# total_head_dim = (config.n_heads + 2 * config.n_local_heads) * config.head_dim
209229
# self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
210-
self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
230+
self.wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False)
211231
self.wk = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=False)
212232
self.wv = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=False)
213233

214234
self.wo = nn.Linear(config.dim, config.dim, bias=False)
215235
self.kv_cache = None
216236

217-
self.n_head = config.n_head
237+
self.n_heads = config.n_heads
218238
self.head_dim = config.head_dim
219239
self.n_local_heads = config.n_local_heads
220240
self.dim = config.dim
@@ -243,7 +263,7 @@ def forward(
243263
# kv_size = self.n_local_heads * self.head_dim
244264
# q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
245265

246-
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
266+
q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
247267
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
248268
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
249269

@@ -255,8 +275,8 @@ def forward(
255275
if self.kv_cache is not None:
256276
k, v = self.kv_cache.update(input_pos, k, v)
257277

258-
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
259-
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
278+
k = k.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
279+
v = v.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
260280
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
261281

262282
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
@@ -268,9 +288,9 @@ def forward(
268288
class FeedForward(nn.Module):
269289
def __init__(self, config: ModelArgs) -> None:
270290
super().__init__()
271-
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
272-
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
273-
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
291+
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
292+
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
293+
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
274294

275295
def forward(self, x: Tensor) -> Tensor:
276296
return self.w2(F.silu(self.w1(x)) * self.w3(x))
@@ -289,8 +309,8 @@ def forward(self, x: Tensor) -> Tensor:
289309
output = self._norm(x.float()).type_as(x)
290310
return output * self.weight
291311

292-
293-
def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
312+
# transpsoed first two arguments to align with model in ET
313+
def precompute_freqs_cis(n_elem: int, seq_len: int, base: int = 10000) -> Tensor:
294314
freqs = 1.0 / (
295315
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
296316
)

0 commit comments

Comments
 (0)