Skip to content

Commit

Permalink
Fix Top K CPU (#1194)
Browse files Browse the repository at this point in the history
Fixes this issue #1184 where top K didn't work for CPU.
  • Loading branch information
aciddelgado authored and baijumeswani committed Jan 30, 2025
1 parent f2f8b81 commit 4762332
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 65 deletions.
10 changes: 4 additions & 6 deletions examples/python/phi3-qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def main(args):

chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'

params = og.GeneratorParams(model)
params.set_search_options(**search_options)
generator = og.Generator(model, params)

# Keep asking for input prompts in a loop
while True:
text = input("Input: ")
Expand All @@ -44,9 +48,6 @@ def main(args):

input_tokens = tokenizer.encode(prompt)

params = og.GeneratorParams(model)
params.set_search_options(**search_options)
generator = og.Generator(model, params)
generator.append_tokens(input_tokens)
if args.verbose: print("Generator created")

Expand Down Expand Up @@ -74,9 +75,6 @@ def main(args):
print()
print()

# Delete the generator to free the captured graph for the next generator, if graph capture is enabled
del generator

if args.timings:
prompt_time = first_token_timestamp - started_timestamp
run_time = time.time() - first_token_timestamp
Expand Down
1 change: 0 additions & 1 deletion src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,5 @@ std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Config& config); /
std::unique_ptr<Generator> CreateGenerator(const Model& model, const GeneratorParams& params);

float Float16ToFloat32(uint16_t v); // v is a IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction
void top_k_indices(std::span<int32_t> top_k, std::span<const float> inputs);

} // namespace Generators
5 changes: 4 additions & 1 deletion src/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,11 @@ void GreedySearch_Cpu::SampleTopK(int k, float temperature) {
std::vector<int> indices(scores.size());
std::iota(indices.begin(), indices.end(), 0);
std::partial_sort(indices.begin(), indices.begin() + k, indices.end(), [scores = scores.data()](int i, int j) { return scores[i] > scores[j]; });
std::vector<float> top_k_scores(k);
for (int i = 0; i < k; i++)
top_k_scores[i] = scores[indices[i]];
// Sample a token from the top K
std::discrete_distribution<> dis(scores.begin(), scores.begin() + k);
std::discrete_distribution<> dis(top_k_scores.begin(), top_k_scores.end());
SetNextToken(batch_id, indices[dis(gen_)]);
}
AppendNextTokensToSequences();
Expand Down
32 changes: 0 additions & 32 deletions src/top_k_cpu.cpp

This file was deleted.

110 changes: 85 additions & 25 deletions test/sampling_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,31 +176,57 @@ TEST(SamplingTests, RandomizedSamplingTopPCpu) {
}
}

void SoftMax(std::span<float> scores, float temperature) {
float const max_score = *std::max_element(scores.begin(), scores.end());

// Subtract max score and scale by temperature
std::transform(scores.begin(), scores.end(), scores.begin(), [max_score, temperature](float score) { return std::exp((score - max_score) / temperature); });

// Compute sum of exponentials
float const exp_sum = std::accumulate(scores.begin(), scores.end(), 0.0f);

// Divide each score by the sum of exponentials
std::transform(scores.begin(), scores.end(), scores.begin(), [exp_sum](float score) { return score / exp_sum; });
}

TEST(SamplingTests, RandomizedSamplingTopKCpu) {
auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
int batch_size = 5;
int k = 5;
std::vector<int32_t> input_ids{0, 1, 2, 3, 4};
const int batch_size = 5;
const int k = 5;

Generators::Config config;
config.model.vocab_size = 32000; // vocab size of llama
const int vocab_size = 13; // vocab size of llama
config.model.vocab_size = vocab_size; // vocab size of llama

// Create a generator
auto params = Generators::CreateGeneratorParams(config);
params->search.max_length = 10;
params->search.do_sample = true;
params->search.top_k = k;
params->search.batch_size = batch_size;
params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CPU);
params->device_type = Generators::DeviceType::CPU;
std::vector<float> logits_cpu(config.model.vocab_size * batch_size);

// Create data structures for testing
std::random_device rd;
std::mt19937 engine(rd());
std::uniform_int_distribution<> dist(5, 25);
int num_iter = 100;
std::vector<int> indices(vocab_size);
std::vector<float> logits_cpu(vocab_size * batch_size);
const int num_iter = 100;
std::map<float, int> logit_to_count;

// Run test
for (int i = 0; i < num_iter; i++) {
int num_large = dist(engine);
auto generator = Generators::CreateGenerator(*model, *params);
CreateRandomLogits(logits_cpu.data(), num_large, config.model.vocab_size, batch_size, engine);
logits_cpu = std::vector<float>(vocab_size * batch_size, 0.0f);
// Shuffle integers 1 to k randomly into logits_cpu
for (int b = 0; b < batch_size; b++) {
std::iota(indices.begin(), indices.end(), 0);
std::shuffle(indices.begin(), indices.end(), engine);
for (int j = 0; j < k; j++)
logits_cpu[indices[j] + vocab_size * b] = float(k - j);
}
// Set logits and get generated token
auto logits_copy = logits_cpu;
auto logits = params->p_device->WrapMemory<float>(logits_copy);
generator->SetLogits(logits);
Expand All @@ -209,10 +235,22 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) {
// Verify outputs match expected outputs
for (int b = 0; b < batch_size; b++) {
auto next_token = next_tokens[b];
auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b];
EXPECT_GT(next_token_score, 10.0f);
auto next_token_score = logits_cpu[next_token + vocab_size * b];
logit_to_count[next_token_score]++;
EXPECT_GT(next_token_score, 0.0f);
}
}
// Calculate expected distribution of tokens by softmaxing given logits (integers 1 through k)
std::vector<float> expected_distributions(k);
for (int i = 0; i < k; i++)
expected_distributions[i] = float(i + 1);
SoftMax(expected_distributions, 1.0f);
// Check that the distribution of tokens generated by the model is close to the expected distribution
const int total_count = batch_size * num_iter;
for (auto& [logit, count] : logit_to_count) {
const float expected_distribution = expected_distributions[int(logit) - 1];
EXPECT_NEAR(count / float(total_count), expected_distribution, 0.1);
}
}

TEST(SamplingTests, RandomizedSamplingTopPAndKCpu) {
Expand Down Expand Up @@ -396,43 +434,65 @@ TEST(SamplingTests, RandomizedSamplingTopPCuda) {

TEST(SamplingTests, RandomizedSamplingTopKCuda) {
auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
int batch_size = 5;
int k = 5;
std::vector<int32_t> input_ids{0, 1, 2, 3, 4};
const int batch_size = 5;
const int k = 5;

Generators::Config config;
config.model.vocab_size = 32000; // vocab size of llama
const int vocab_size = 17; // vocab size of llama
config.model.vocab_size = vocab_size; // vocab size of llama

// Create a generator
auto params = Generators::CreateGeneratorParams(config);
params->search.max_length = 10;
params->search.do_sample = true;
params->search.top_k = k;
params->search.batch_size = batch_size;
params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA);
params->device_type = Generators::DeviceType::CUDA;
auto logits_gpu = params->p_device->Allocate<float>(config.model.vocab_size * batch_size);
auto indices_buffer = params->p_device->Allocate<int>(config.model.vocab_size * batch_size);

// Create data structures for testing
std::random_device rd;
std::mt19937 engine(rd());
std::uniform_int_distribution<> dist(1, 25);
int num_iter = 100;
std::vector<int> indices(vocab_size);
const int num_iter = 100;
std::map<float, int> logit_to_count;

// Run test
for (int i = 0; i < num_iter; i++) {
int num_large = dist(engine);
LaunchGeometricDecayKernel(logits_gpu.Span().data(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream);
LaunchFisherYatesKernel(logits_gpu.Span().data(), indices_buffer.Span().data(), config.model.vocab_size, batch_size, params->cuda_stream);
auto generator = Generators::CreateGenerator(*model, *params);
Generators::DeviceSpan<float> logits_gpu = params->p_device->Allocate<float>(vocab_size * batch_size);
auto cpu_span = logits_gpu.CpuSpan();
// Shuffle integers 1 to k randomly into cpu_span
for (int b = 0; b < batch_size; b++) {
std::iota(indices.begin(), indices.end(), 0);
std::shuffle(indices.begin(), indices.end(), engine);
for (int j = 0; j < k; j++)
cpu_span[indices[j] + vocab_size * b] = float(k - j);
}
// Copy logits onto device, set logits, and get generated token
logits_gpu.CopyCpuToDevice();
generator->SetLogits(logits_gpu);
generator->GenerateNextToken();
auto next_tokens = generator->search_->GetNextTokens().CopyDeviceToCpu();
auto logits_cpu = logits_gpu.CopyDeviceToCpu();
// Verify outputs match expected outputs
for (int b = 0; b < batch_size; b++) {
auto next_token = next_tokens[b];
auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b];
EXPECT_GT(next_token_score, 10.0f);
auto next_token_score = cpu_span[next_token + vocab_size * b];
logit_to_count[next_token_score]++;
EXPECT_GT(next_token_score, 0.0f);
}
}
// Calculate expected distribution of tokens by softmaxing given logits (integers 1 through k)
std::vector<float> expected_distributions(k);
for (int i = 0; i < k; i++)
expected_distributions[i] = float(i + 1);
SoftMax(expected_distributions, 1.0f);
const int total_count = batch_size * num_iter;
// Check that the distribution of tokens generated by the model is close to the expected distribution
for (auto& [logit, count] : logit_to_count) {
const float expected_distribution = expected_distributions[int(logit) - 1];
EXPECT_NEAR(count / float(total_count), expected_distribution, 0.1);
}
}

TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) {
Expand Down

0 comments on commit 4762332

Please sign in to comment.