-
Notifications
You must be signed in to change notification settings - Fork 424
Open
Description
I found that the logarithms of the probabilities of the generated sequences differ significantly for Llama 3.1-8b-Instruct. In this example, I'm using float32 precision, but the differences remain just as large with float16.
Environment:
- Python 3.12.11
- torch==2.9.0
- ctranslate2==4.6.0
- transformers==4.56.2
- GPU: H100
- CUDA Version: 12.6
Steps to reproduce:
- Convert Llama 3.1-8b-Instruct: ct2-transformers-converter --model /path/to/Llama3.1-8b-Instruct --quantization float32 --output_dir /path/to/ctranslate2/model/float32
- Run the script.
Expected result: logprobs of generated sequences by ctranslate2 are approximately equal to transformers' logprobs.
Actual result: logprobs diverge significantly.
import collections
import math
import pprint
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import ctranslate2
from tqdm import tqdm
import numpy as np
# generate one completion with ctranslate2
def generate_ct2(generator, tokenizer, prompt_tokens):
results = generator.generate_batch(
[prompt_tokens],
sampling_temperature=1,
sampling_topp=1,
sampling_topk=100,
max_length=512,
return_scores=True,
include_prompt_in_result=False,
end_token=[11, tokenizer.eos_token_id],
num_hypotheses=1
)
completion = tokenizer.decode(results[0].sequences_ids[0], skip_special_tokens=False)
log_prob = results[0].scores[0]
return completion, log_prob
# generate one completion with transformers
def generate_transformers(model, tokenizer, prompt_inputs, exclude_stop_token=True):
outputs = model.generate(
**prompt_inputs,
max_new_tokens=512,
do_sample=True,
temperature=1,
top_k=100,
top_p=1,
num_return_sequences=1,
eos_token_id=[11, tokenizer.eos_token_id],
return_dict_in_generate=True,
output_scores=True,
pad_token_id=tokenizer.eos_token_id,
)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
input_length = prompt_inputs.input_ids.shape[1]
completion_ids = outputs.sequences[0][input_length:]
if exclude_stop_token:
completion_ids = completion_ids[:-1]
completion = tokenizer.decode(completion_ids, skip_special_tokens=False)
log_probs = transition_scores[0][:completion_ids.shape[0]]
log_prob = log_probs.sum().item()
return completion, log_prob
model_name = 'meta-llama/Llama-3.1-8B-Instruct'
llama31_tokenizer = AutoTokenizer.from_pretrained(model_name)
llama31 = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, device_map="auto")
ct2_path = '/path/to/ctranslate2/model/float32'
ct2_generator = ctranslate2.Generator(ct2_path, device="cuda", device_index=0, compute_type="float32")
prompt = "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"
message = {"role": "user", "content": prompt}
chat_template_inputs = llama31_tokenizer.apply_chat_template([message], add_generation_prompt=True, return_tensors="pt", return_dict=True).to(llama31.device)
chat_template_tokens = llama31_tokenizer.convert_ids_to_tokens(chat_template_inputs.input_ids[0])
test_completions = {"To find out how much Janet makes at the farmers' market", "To determine how much Janet makes at the farmers' market"}
ct2_completions = {}
transformers_completions_without_stop_token = {}
transformers_completions_with_stop_token = {}
while len(ct2_completions) < len(test_completions):
completion, logprob = generate_ct2(ct2_generator, llama31_tokenizer, chat_template_tokens)
if completion in test_completions and completion not in ct2_completions:
ct2_completions[completion] = logprob
while len(transformers_completions_without_stop_token) < len(test_completions):
completion, logprob = generate_transformers(llama31, llama31_tokenizer, chat_template_inputs, exclude_stop_token=True)
if completion in test_completions and completion not in transformers_completions_without_stop_token:
transformers_completions_without_stop_token[completion] = logprob
while len(transformers_completions_with_stop_token) < len(test_completions):
completion, logprob = generate_transformers(llama31, llama31_tokenizer, chat_template_inputs, exclude_stop_token=False)
if completion[:-1] in test_completions and completion not in transformers_completions_with_stop_token:
transformers_completions_with_stop_token[completion] = logprob
print('stop tokens:', llama31_tokenizer.convert_ids_to_tokens([11, llama31_tokenizer.eos_token_id]))
print('ctranslate2:')
pprint.pp(ct2_completions)
print('transformers without stop token:')
pprint.pp(transformers_completions_without_stop_token)
print('transformers with stop token:')
pprint.pp(transformers_completions_with_stop_token)
Output:
stop tokens: [',', '<|eot_id|>']
ctranslate2:
{"To determine how much Janet makes at the farmers' market": -0.26174822449684143,
"To find out how much Janet makes at the farmers' market": -0.11245159804821014}
transformers without stop token:
{"To find out how much Janet makes at the farmers' market": -1.1725850105285645,
"To determine how much Janet makes at the farmers' market": -2.8389508724212646}
transformers with stop token:
{"To find out how much Janet makes at the farmers' market,": -1.4434683322906494,
"To determine how much Janet makes at the farmers' market,": -3.1334826946258545}
Metadata
Metadata
Assignees
Labels
No labels