Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Top K CPU #1194

Merged
merged 16 commits into from
Jan 21, 2025
96 changes: 77 additions & 19 deletions test/sampling_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,31 +176,54 @@
}
}

// TODO(aciddelgado): this is copy-pasted from softmax.h but I think that might be fine... not sure how we feel about that
void SoftMax(std::span<float> scores, float temperature) {
float const max_score = *std::max_element(scores.begin(), scores.end());

// Subtract max score and scale by temperature
std::transform(scores.begin(), scores.end(), scores.begin(), [max_score, temperature](float score) { return std::exp((score - max_score) / temperature); });

// Compute sum of exponentials
float const exp_sum = std::accumulate(scores.begin(), scores.end(), 0.0f);

// Divide each score by the sum of exponentials
std::transform(scores.begin(), scores.end(), scores.begin(), [exp_sum](float score) { return score / exp_sum; });
}

TEST(SamplingTests, RandomizedSamplingTopKCpu) {
auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
int batch_size = 5;
int k = 5;
std::vector<int32_t> input_ids{0, 1, 2, 3, 4};

Generators::Config config;
config.model.vocab_size = 32000; // vocab size of llama
int vocab_size = 32000; // vocab size of llama
config.model.vocab_size = vocab_size; // vocab size of llama

// Create a generator
auto params = Generators::CreateGeneratorParams(config);
params->search.max_length = 10;
params->search.do_sample = true;
params->search.top_k = k;
params->search.batch_size = batch_size;
params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CPU);
params->device_type = Generators::DeviceType::CPU;
std::vector<float> logits_cpu(config.model.vocab_size * batch_size);

// Create data structures for testing
std::random_device rd;
std::mt19937 engine(rd());
std::uniform_int_distribution<> dist(5, 25);
std::vector<int> indices(vocab_size);
std::vector<float> logits_cpu(vocab_size * batch_size);
int num_iter = 100;
std::map<float, int> logit_to_count;
for (int i = 0; i < num_iter; i++) {
int num_large = dist(engine);
auto generator = Generators::CreateGenerator(*model, *params);
CreateRandomLogits(logits_cpu.data(), num_large, config.model.vocab_size, batch_size, engine);
logits_cpu = std::vector<float>(vocab_size * batch_size, 0.0f);
for (int i = 0; i < batch_size; i++) {

Check failure on line 221 in test/sampling_tests.cpp

View workflow job for this annotation

GitHub Actions / windows-cuda-x64-build

the following warning is treated as an error

Check failure on line 221 in test/sampling_tests.cpp

View workflow job for this annotation

GitHub Actions / windows-cpu-x64-build

the following warning is treated as an error

Check warning on line 221 in test/sampling_tests.cpp

View workflow job for this annotation

GitHub Actions / windows-cpu-x64-build

declaration of 'i' hides previous local declaration
std::iota(indices.begin(), indices.end(), 0);
std::shuffle(indices.begin(), indices.end(), engine);
for (int j = 0; j < k; j++)
logits_cpu[indices[j] + vocab_size * i] = float(k - j);
}
auto logits_copy = logits_cpu;
auto logits = params->p_device->WrapMemory<float>(logits_copy);
generator->SetLogits(logits);
Expand All @@ -209,10 +232,23 @@
// Verify outputs match expected outputs
for (int b = 0; b < batch_size; b++) {
auto next_token = next_tokens[b];
auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b];
EXPECT_GT(next_token_score, 10.0f);
auto next_token_score = logits_cpu[next_token + vocab_size * b];
if (logit_to_count.find(next_token_score) == logit_to_count.end())
logit_to_count[next_token_score] = 1;
else
logit_to_count[next_token_score]++;
EXPECT_GT(next_token_score, 0.0f);
}
}
std::vector<float> expected_distributions(5);
for (int i = 0; i < k; i++)
expected_distributions[i] = float(i + 1);
SoftMax(expected_distributions, 1.0f);
int total_count = batch_size * num_iter;
for (auto& [logit, count] : logit_to_count) {
float expected_distribution = expected_distributions[int(logit) - 1];
EXPECT_NEAR(count / float(total_count), expected_distribution, 0.1);
}
}

TEST(SamplingTests, RandomizedSamplingTopPAndKCpu) {
Expand Down Expand Up @@ -398,41 +434,63 @@
auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
int batch_size = 5;
int k = 5;
std::vector<int32_t> input_ids{0, 1, 2, 3, 4};

Generators::Config config;
config.model.vocab_size = 32000; // vocab size of llama
int vocab_size = 32000; // vocab size of llama
config.model.vocab_size = vocab_size; // vocab size of llama

// Create a generator
auto params = Generators::CreateGeneratorParams(config);
params->search.max_length = 10;
params->search.do_sample = true;
params->search.top_k = k;
params->search.batch_size = batch_size;
params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA);
params->device_type = Generators::DeviceType::CUDA;
auto logits_gpu = params->p_device->Allocate<float>(config.model.vocab_size * batch_size);
auto indices_buffer = params->p_device->Allocate<int>(config.model.vocab_size * batch_size);

// Create data structures for testing
std::random_device rd;
std::mt19937 engine(rd());
std::uniform_int_distribution<> dist(1, 25);
std::vector<int> indices(vocab_size);
std::vector<float> logits_cpu(vocab_size * batch_size);
int num_iter = 100;
std::map<float, int> logit_to_count;
for (int i = 0; i < num_iter; i++) {
int num_large = dist(engine);
LaunchGeometricDecayKernel(logits_gpu.Span().data(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream);
LaunchFisherYatesKernel(logits_gpu.Span().data(), indices_buffer.Span().data(), config.model.vocab_size, batch_size, params->cuda_stream);
auto generator = Generators::CreateGenerator(*model, *params);
logits_cpu = std::vector<float>(vocab_size * batch_size, 0.0f);
for (int i = 0; i < batch_size; i++) {
std::iota(indices.begin(), indices.end(), 0);
std::shuffle(indices.begin(), indices.end(), engine);
for (int j = 0; j < k; j++)
logits_cpu[indices[j] + vocab_size * i] = float(k - j);
}
Generators::DeviceSpan<float> logits_gpu = params->p_device->Allocate<float>(config.model.vocab_size * batch_size);
auto cpu_span = logits_gpu.CpuSpan();
std::copy(logits_cpu.begin(), logits_cpu.end(), cpu_span.begin());
logits_gpu.CopyCpuToDevice();
generator->SetLogits(logits_gpu);
generator->GenerateNextToken();
auto next_tokens = generator->search_->GetNextTokens().CopyDeviceToCpu();
auto logits_cpu = logits_gpu.CopyDeviceToCpu();
// Verify outputs match expected outputs
for (int b = 0; b < batch_size; b++) {
auto next_token = next_tokens[b];
auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b];
EXPECT_GT(next_token_score, 10.0f);
auto next_token_score = logits_cpu[next_token + vocab_size * b];
if (logit_to_count.find(next_token_score) == logit_to_count.end())
logit_to_count[next_token_score] = 1;
else
logit_to_count[next_token_score]++;
EXPECT_GT(next_token_score, 0.0f);
}
}
std::vector<float> expected_distributions(5);
for (int i = 0; i < k; i++)
expected_distributions[i] = float(i + 1);
SoftMax(expected_distributions, 1.0f);
int total_count = batch_size * num_iter;
for (auto& [logit, count] : logit_to_count) {
float expected_distribution = expected_distributions[int(logit) - 1];
EXPECT_NEAR(count / float(total_count), expected_distribution, 0.1);
}
}

TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) {
Expand Down
Loading