Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Top K CPU #1194

Merged
merged 16 commits into from
Jan 21, 2025
10 changes: 4 additions & 6 deletions examples/python/phi3-qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def main(args):

chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'

params = og.GeneratorParams(model)
params.set_search_options(**search_options)
generator = og.Generator(model, params)

# Keep asking for input prompts in a loop
while True:
text = input("Input: ")
Expand All @@ -44,9 +48,6 @@ def main(args):

input_tokens = tokenizer.encode(prompt)

params = og.GeneratorParams(model)
params.set_search_options(**search_options)
generator = og.Generator(model, params)
generator.append_tokens(input_tokens)
if args.verbose: print("Generator created")

Expand Down Expand Up @@ -74,9 +75,6 @@ def main(args):
print()
print()

# Delete the generator to free the captured graph for the next generator, if graph capture is enabled
del generator

if args.timings:
prompt_time = first_token_timestamp - started_timestamp
run_time = time.time() - first_token_timestamp
Expand Down
1 change: 0 additions & 1 deletion src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,5 @@ std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Config& config); /
std::unique_ptr<Generator> CreateGenerator(const Model& model, const GeneratorParams& params);

float Float16ToFloat32(uint16_t v); // v is a IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction
void top_k_indices(std::span<int32_t> top_k, std::span<const float> inputs);

} // namespace Generators
5 changes: 4 additions & 1 deletion src/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,11 @@ void GreedySearch_Cpu::SampleTopK(int k, float temperature) {
std::vector<int> indices(scores.size());
std::iota(indices.begin(), indices.end(), 0);
std::partial_sort(indices.begin(), indices.begin() + k, indices.end(), [scores = scores.data()](int i, int j) { return scores[i] > scores[j]; });
std::vector<float> top_k_scores(k);
for (int i = 0; i < k; i++)
top_k_scores[i] = scores[indices[i]];
// Sample a token from the top K
std::discrete_distribution<> dis(scores.begin(), scores.begin() + k);
std::discrete_distribution<> dis(top_k_scores.begin(), top_k_scores.end());
SetNextToken(batch_id, indices[dis(gen_)]);
}
AppendNextTokensToSequences();
Expand Down
32 changes: 0 additions & 32 deletions src/top_k_cpu.cpp

This file was deleted.

110 changes: 85 additions & 25 deletions test/sampling_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,31 +176,57 @@ TEST(SamplingTests, RandomizedSamplingTopPCpu) {
}
}

void SoftMax(std::span<float> scores, float temperature) {
float const max_score = *std::max_element(scores.begin(), scores.end());

// Subtract max score and scale by temperature
std::transform(scores.begin(), scores.end(), scores.begin(), [max_score, temperature](float score) { return std::exp((score - max_score) / temperature); });

// Compute sum of exponentials
float const exp_sum = std::accumulate(scores.begin(), scores.end(), 0.0f);

// Divide each score by the sum of exponentials
std::transform(scores.begin(), scores.end(), scores.begin(), [exp_sum](float score) { return score / exp_sum; });
}

TEST(SamplingTests, RandomizedSamplingTopKCpu) {
auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
int batch_size = 5;
int k = 5;
std::vector<int32_t> input_ids{0, 1, 2, 3, 4};
const int batch_size = 5;
const int k = 5;

Generators::Config config;
config.model.vocab_size = 32000; // vocab size of llama
const int vocab_size = 13; // vocab size of llama
config.model.vocab_size = vocab_size; // vocab size of llama

// Create a generator
auto params = Generators::CreateGeneratorParams(config);
params->search.max_length = 10;
params->search.do_sample = true;
params->search.top_k = k;
params->search.batch_size = batch_size;
params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CPU);
params->device_type = Generators::DeviceType::CPU;
std::vector<float> logits_cpu(config.model.vocab_size * batch_size);

// Create data structures for testing
std::random_device rd;
std::mt19937 engine(rd());
std::uniform_int_distribution<> dist(5, 25);
int num_iter = 100;
std::vector<int> indices(vocab_size);
std::vector<float> logits_cpu(vocab_size * batch_size);
const int num_iter = 100;
std::map<float, int> logit_to_count;

// Run test
for (int i = 0; i < num_iter; i++) {
int num_large = dist(engine);
auto generator = Generators::CreateGenerator(*model, *params);
CreateRandomLogits(logits_cpu.data(), num_large, config.model.vocab_size, batch_size, engine);
logits_cpu = std::vector<float>(vocab_size * batch_size, 0.0f);
// Shuffle integers 1 to k randomly into logits_cpu
for (int b = 0; b < batch_size; b++) {
std::iota(indices.begin(), indices.end(), 0);
std::shuffle(indices.begin(), indices.end(), engine);
for (int j = 0; j < k; j++)
logits_cpu[indices[j] + vocab_size * b] = float(k - j);
}
// Set logits and get generated token
auto logits_copy = logits_cpu;
auto logits = params->p_device->WrapMemory<float>(logits_copy);
generator->SetLogits(logits);
Expand All @@ -209,10 +235,22 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) {
// Verify outputs match expected outputs
for (int b = 0; b < batch_size; b++) {
auto next_token = next_tokens[b];
auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b];
EXPECT_GT(next_token_score, 10.0f);
auto next_token_score = logits_cpu[next_token + vocab_size * b];
logit_to_count[next_token_score]++;
EXPECT_GT(next_token_score, 0.0f);
}
}
// Calculate expected distribution of tokens by softmaxing given logits (integers 1 through k)
std::vector<float> expected_distributions(k);
for (int i = 0; i < k; i++)
expected_distributions[i] = float(i + 1);
SoftMax(expected_distributions, 1.0f);
// Check that the distribution of tokens generated by the model is close to the expected distribution
const int total_count = batch_size * num_iter;
for (auto& [logit, count] : logit_to_count) {
const float expected_distribution = expected_distributions[int(logit) - 1];
EXPECT_NEAR(count / float(total_count), expected_distribution, 0.1);
}
}

TEST(SamplingTests, RandomizedSamplingTopPAndKCpu) {
Expand Down Expand Up @@ -396,43 +434,65 @@ TEST(SamplingTests, RandomizedSamplingTopPCuda) {

TEST(SamplingTests, RandomizedSamplingTopKCuda) {
auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
int batch_size = 5;
int k = 5;
std::vector<int32_t> input_ids{0, 1, 2, 3, 4};
const int batch_size = 5;
const int k = 5;

Generators::Config config;
config.model.vocab_size = 32000; // vocab size of llama
const int vocab_size = 17; // vocab size of llama
config.model.vocab_size = vocab_size; // vocab size of llama

// Create a generator
auto params = Generators::CreateGeneratorParams(config);
params->search.max_length = 10;
params->search.do_sample = true;
params->search.top_k = k;
params->search.batch_size = batch_size;
params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA);
params->device_type = Generators::DeviceType::CUDA;
auto logits_gpu = params->p_device->Allocate<float>(config.model.vocab_size * batch_size);
auto indices_buffer = params->p_device->Allocate<int>(config.model.vocab_size * batch_size);

// Create data structures for testing
std::random_device rd;
std::mt19937 engine(rd());
std::uniform_int_distribution<> dist(1, 25);
int num_iter = 100;
std::vector<int> indices(vocab_size);
const int num_iter = 100;
std::map<float, int> logit_to_count;

// Run test
for (int i = 0; i < num_iter; i++) {
int num_large = dist(engine);
LaunchGeometricDecayKernel(logits_gpu.Span().data(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream);
LaunchFisherYatesKernel(logits_gpu.Span().data(), indices_buffer.Span().data(), config.model.vocab_size, batch_size, params->cuda_stream);
auto generator = Generators::CreateGenerator(*model, *params);
Generators::DeviceSpan<float> logits_gpu = params->p_device->Allocate<float>(vocab_size * batch_size);
auto cpu_span = logits_gpu.CpuSpan();
// Shuffle integers 1 to k randomly into cpu_span
for (int b = 0; b < batch_size; b++) {
std::iota(indices.begin(), indices.end(), 0);
std::shuffle(indices.begin(), indices.end(), engine);
for (int j = 0; j < k; j++)
cpu_span[indices[j] + vocab_size * b] = float(k - j);
}
// Copy logits onto device, set logits, and get generated token
logits_gpu.CopyCpuToDevice();
generator->SetLogits(logits_gpu);
generator->GenerateNextToken();
auto next_tokens = generator->search_->GetNextTokens().CopyDeviceToCpu();
auto logits_cpu = logits_gpu.CopyDeviceToCpu();
// Verify outputs match expected outputs
for (int b = 0; b < batch_size; b++) {
auto next_token = next_tokens[b];
auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b];
EXPECT_GT(next_token_score, 10.0f);
auto next_token_score = cpu_span[next_token + vocab_size * b];
logit_to_count[next_token_score]++;
EXPECT_GT(next_token_score, 0.0f);
}
}
// Calculate expected distribution of tokens by softmaxing given logits (integers 1 through k)
std::vector<float> expected_distributions(k);
for (int i = 0; i < k; i++)
expected_distributions[i] = float(i + 1);
SoftMax(expected_distributions, 1.0f);
const int total_count = batch_size * num_iter;
// Check that the distribution of tokens generated by the model is close to the expected distribution
for (auto& [logit, count] : logit_to_count) {
const float expected_distribution = expected_distributions[int(logit) - 1];
EXPECT_NEAR(count / float(total_count), expected_distribution, 0.1);
}
}

TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) {
Expand Down
Loading