Skip to content

For Llama 3.1-8b-Instruct, logprobs differ significantly from the transformers' model #1921

@ugadiarov-la-phystech-edu

Description

I found that the logarithms of the probabilities of the generated sequences differ significantly for Llama 3.1-8b-Instruct. In this example, I'm using float32 precision, but the differences remain just as large with float16.

Environment:

  • Python 3.12.11
  • torch==2.9.0
  • ctranslate2==4.6.0
  • transformers==4.56.2
  • GPU: H100
  • CUDA Version: 12.6

Steps to reproduce:

  1. Convert Llama 3.1-8b-Instruct: ct2-transformers-converter --model /path/to/Llama3.1-8b-Instruct --quantization float32 --output_dir /path/to/ctranslate2/model/float32
  2. Run the script.

Expected result: logprobs of generated sequences by ctranslate2 are approximately equal to transformers' logprobs.
Actual result: logprobs diverge significantly.

import collections
import math
import pprint
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import ctranslate2
from tqdm import tqdm
import numpy as np


# generate one completion with ctranslate2
def generate_ct2(generator, tokenizer, prompt_tokens):
    results = generator.generate_batch(
            [prompt_tokens],
            sampling_temperature=1,
            sampling_topp=1,
            sampling_topk=100,
            max_length=512,
            return_scores=True,
            include_prompt_in_result=False,
            end_token=[11, tokenizer.eos_token_id],
            num_hypotheses=1
        )
    completion = tokenizer.decode(results[0].sequences_ids[0], skip_special_tokens=False)
    log_prob = results[0].scores[0]
    return completion, log_prob


# generate one completion with transformers
def generate_transformers(model, tokenizer, prompt_inputs, exclude_stop_token=True):
    outputs = model.generate(
        **prompt_inputs,
        max_new_tokens=512,
        do_sample=True,
        temperature=1,
        top_k=100,
        top_p=1,
        num_return_sequences=1,
        eos_token_id=[11, tokenizer.eos_token_id],
        return_dict_in_generate=True,
        output_scores=True,
        pad_token_id=tokenizer.eos_token_id,
    )
    transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
    input_length = prompt_inputs.input_ids.shape[1]    
    completion_ids = outputs.sequences[0][input_length:]
    if exclude_stop_token:
        completion_ids = completion_ids[:-1]
    
    completion = tokenizer.decode(completion_ids, skip_special_tokens=False)
    log_probs = transition_scores[0][:completion_ids.shape[0]]
    log_prob = log_probs.sum().item()
    return completion, log_prob


model_name = 'meta-llama/Llama-3.1-8B-Instruct'
llama31_tokenizer = AutoTokenizer.from_pretrained(model_name)
llama31 = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, device_map="auto")

ct2_path = '/path/to/ctranslate2/model/float32'
ct2_generator = ctranslate2.Generator(ct2_path, device="cuda", device_index=0, compute_type="float32")

prompt = "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"
message = {"role": "user", "content": prompt}
chat_template_inputs = llama31_tokenizer.apply_chat_template([message], add_generation_prompt=True, return_tensors="pt", return_dict=True).to(llama31.device)
chat_template_tokens = llama31_tokenizer.convert_ids_to_tokens(chat_template_inputs.input_ids[0])

test_completions = {"To find out how much Janet makes at the farmers' market", "To determine how much Janet makes at the farmers' market"}
ct2_completions = {}
transformers_completions_without_stop_token = {}
transformers_completions_with_stop_token = {}

while len(ct2_completions) < len(test_completions):
    completion, logprob = generate_ct2(ct2_generator, llama31_tokenizer, chat_template_tokens)
    if completion in test_completions and completion not in ct2_completions:
        ct2_completions[completion] = logprob

while len(transformers_completions_without_stop_token) < len(test_completions):
    completion, logprob = generate_transformers(llama31, llama31_tokenizer, chat_template_inputs, exclude_stop_token=True)
    if completion in test_completions and completion not in transformers_completions_without_stop_token:
        transformers_completions_without_stop_token[completion] = logprob

while len(transformers_completions_with_stop_token) < len(test_completions):
    completion, logprob = generate_transformers(llama31, llama31_tokenizer, chat_template_inputs, exclude_stop_token=False)
    if completion[:-1] in test_completions and completion not in transformers_completions_with_stop_token:
        transformers_completions_with_stop_token[completion] = logprob

print('stop tokens:', llama31_tokenizer.convert_ids_to_tokens([11, llama31_tokenizer.eos_token_id]))
print('ctranslate2:')
pprint.pp(ct2_completions)
print('transformers without stop token:')
pprint.pp(transformers_completions_without_stop_token)
print('transformers with stop token:')
pprint.pp(transformers_completions_with_stop_token)

Output:

stop tokens: [',', '<|eot_id|>']
ctranslate2:
{"To determine how much Janet makes at the farmers' market": -0.26174822449684143,
 "To find out how much Janet makes at the farmers' market": -0.11245159804821014}
transformers without stop token:
{"To find out how much Janet makes at the farmers' market": -1.1725850105285645,
 "To determine how much Janet makes at the farmers' market": -2.8389508724212646}
transformers with stop token:
{"To find out how much Janet makes at the farmers' market,": -1.4434683322906494,
 "To determine how much Janet makes at the farmers' market,": -3.1334826946258545}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions