Skip to content

Critical Improvements for Model Correctness, Efficiency, and Robustness #717

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 69 additions & 123 deletions inference/generate.py
Original file line number Diff line number Diff line change
@@ -1,185 +1,131 @@
import os
import json
from argparse import ArgumentParser
from typing import List

from typing import List, Optional
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_model

from model import Transformer, ModelArgs


def sample(logits, temperature: float = 1.0):
"""
Samples a token from the logits using temperature scaling.

Args:
logits (torch.Tensor): The logits tensor for token predictions.
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.

Returns:
torch.Tensor: The sampled token.
"""
logits = logits / max(temperature, 1e-5)
probs = torch.softmax(logits, dim=-1)
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)

def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None) -> torch.Tensor:
if temperature <= 1e-5:
return logits.argmax(dim=-1)
logits = logits / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
remove_mask = cum_probs > top_p
remove_mask[..., 1:] = remove_mask[..., :-1].clone()
remove_mask[..., 0] = False
remove_indices = remove_mask.scatter(-1, sorted_indices, remove_mask)
logits[remove_indices] = -float('Inf')
gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-10))
return (logits + gumbel_noise).argmax(dim=-1)

@torch.inference_mode()
def generate(
model: Transformer,
prompt_tokens: List[List[int]],
max_new_tokens: int,
eos_id: int,
temperature: float = 1.0
) -> List[List[int]]:
"""
Generates new tokens based on the given prompt tokens using the specified model.

Args:
model (Transformer): The transformer model used for token generation.
prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
max_new_tokens (int): The maximum number of new tokens to generate.
eos_id (int): The end-of-sequence token ID.
temperature (float, optional): The temperature value for sampling. Defaults to 1.0.

Returns:
List[List[int]]: A list of lists containing the generated tokens for each sequence.
"""
def generate(model: Transformer, prompt_tokens: List[List[int]], max_new_tokens: int, eos_id: int, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None) -> List[List[int]]:
model.reset_cache()
prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
tokens[i, :len(t)] = torch.tensor(t, device="cuda")
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
finished = torch.zeros(len(prompt_tokens), dtype=torch.bool, device="cuda")
prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
next_token = sample(logits, temperature, top_k, top_p)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
finished |= (~prompt_mask[:, cur_pos] & (next_token == eos_id))
prev_pos = cur_pos
if finished.all():
break
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
completion_tokens.append(toks)
return completion_tokens


def main(
ckpt_path: str,
config: str,
input_file: str = "",
interactive: bool = True,
max_new_tokens: int = 100,
temperature: float = 1.0,
) -> None:
"""
Main function to load the model and perform interactive or batch text generation.

Args:
ckpt_path (str): Path to the model checkpoint directory.
config (str): Path to the model configuration file.
input_file (str, optional): Path to a file containing input prompts. Defaults to "".
interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
"""
completions = []
for i, seq in enumerate(tokens.tolist()):
seq = seq[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
completions.append(seq[:seq.index(eos_id)] if eos_id in seq else seq)
return completions

def main(ckpt_path: str, config: str, input_file: str = "", interactive: bool = True, max_new_tokens: int = 100, temperature: float = 0.2, top_k: Optional[int] = None, top_p: Optional[float] = None) -> None:
if not os.path.isdir(ckpt_path):
raise FileNotFoundError(f"Checkpoint directory missing: {ckpt_path}")
if not os.path.isfile(config):
raise FileNotFoundError(f"Config file missing: {config}")
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1:
dist.init_process_group("nccl")
global print
dist.init_process_group("nccl", init_method="env://")
if rank != 0:
print = lambda *_, **__: None
torch.cuda.set_device(local_rank)
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(965)
with open(config) as f:
args = ModelArgs(**json.load(f))
print(args)
with torch.device("cuda"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
model_args = ModelArgs(**json.load(f))
model = Transformer(model_args).to(torch.bfloat16).cuda()
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))

tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
if interactive:
messages = []
while True:
if world_size == 1:
prompt = input(">>> ")
elif rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
else:
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
prompt = get_input(rank, world_size)
if prompt == "/exit":
break
elif prompt == "/clear":
if prompt == "/clear":
messages.clear()
continue
messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
try:
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
except Exception as e:
print(f"Tokenization error: {e}")
continue
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature, top_k, top_p)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
print(completion)
messages.append({"role": "assistant", "content": completion})
else:
with open(input_file) as f:
prompts = [line.strip() for line in f.readlines()]
assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
prompts = [line.strip() for line in f if line.strip()]
batch_size = model_args.max_batch_size
completions = []
for i in range(0, len(prompts), batch_size):
batch_prompts = prompts[i:i+batch_size]
batch_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": p}], add_generation_prompt=True) for p in batch_prompts]
completion_tokens = generate(model, batch_tokens, max_new_tokens, tokenizer.eos_token_id, temperature, top_k, top_p)
completions.extend(tokenizer.batch_decode(completion_tokens, skip_special_tokens=True))
for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt)
print("Completion:", completion)
print()

print(f"Prompt: {prompt}\nCompletion: {completion}\n{'-'*50}")
if world_size > 1:
dist.destroy_process_group()

def get_input(rank: int, world_size: int) -> str:
if world_size == 1 or rank == 0:
prompt = input(">>> ")
if world_size > 1:
dist.broadcast_object_list([prompt], src=0)
return prompt
else:
res = [None]
dist.broadcast_object_list(res, src=0)
return res[0]

if __name__ == "__main__":
"""
Command-line interface for distributed text generation.

Arguments:
--ckpt-path (str): Path to the model checkpoint directory.
--config (str): Path to the model configuration file.
--input-file (str, optional): File containing prompts for batch processing.
--interactive (bool, optional): Enable interactive mode for generating text.
--max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
--temperature (float, optional): Temperature for sampling. Defaults to 0.2.

Raises:
AssertionError: If neither input-file nor interactive mode is specified.
"""
parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--input-file", type=str, default="")
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top-k", type=int, default=None)
parser.add_argument("--top-p", type=float, default=None)
args = parser.parse_args()
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature, args.top_k, args.top_p)