Skip to content

Commit 4762332

Browse files
aciddelgadobaijumeswani
authored andcommitted
Fix Top K CPU (#1194)
Fixes this issue #1184 where top K didn't work for CPU.
1 parent f2f8b81 commit 4762332

File tree

5 files changed

+93
-65
lines changed

5 files changed

+93
-65
lines changed

examples/python/phi3-qa.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def main(args):
3030

3131
chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'
3232

33+
params = og.GeneratorParams(model)
34+
params.set_search_options(**search_options)
35+
generator = og.Generator(model, params)
36+
3337
# Keep asking for input prompts in a loop
3438
while True:
3539
text = input("Input: ")
@@ -44,9 +48,6 @@ def main(args):
4448

4549
input_tokens = tokenizer.encode(prompt)
4650

47-
params = og.GeneratorParams(model)
48-
params.set_search_options(**search_options)
49-
generator = og.Generator(model, params)
5051
generator.append_tokens(input_tokens)
5152
if args.verbose: print("Generator created")
5253

@@ -74,9 +75,6 @@ def main(args):
7475
print()
7576
print()
7677

77-
# Delete the generator to free the captured graph for the next generator, if graph capture is enabled
78-
del generator
79-
8078
if args.timings:
8179
prompt_time = first_token_timestamp - started_timestamp
8280
run_time = time.time() - first_token_timestamp

src/generators.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,5 @@ std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Config& config); /
161161
std::unique_ptr<Generator> CreateGenerator(const Model& model, const GeneratorParams& params);
162162

163163
float Float16ToFloat32(uint16_t v); // v is a IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction
164-
void top_k_indices(std::span<int32_t> top_k, std::span<const float> inputs);
165164

166165
} // namespace Generators

src/search.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,11 @@ void GreedySearch_Cpu::SampleTopK(int k, float temperature) {
161161
std::vector<int> indices(scores.size());
162162
std::iota(indices.begin(), indices.end(), 0);
163163
std::partial_sort(indices.begin(), indices.begin() + k, indices.end(), [scores = scores.data()](int i, int j) { return scores[i] > scores[j]; });
164+
std::vector<float> top_k_scores(k);
165+
for (int i = 0; i < k; i++)
166+
top_k_scores[i] = scores[indices[i]];
164167
// Sample a token from the top K
165-
std::discrete_distribution<> dis(scores.begin(), scores.begin() + k);
168+
std::discrete_distribution<> dis(top_k_scores.begin(), top_k_scores.end());
166169
SetNextToken(batch_id, indices[dis(gen_)]);
167170
}
168171
AppendNextTokensToSequences();

src/top_k_cpu.cpp

Lines changed: 0 additions & 32 deletions
This file was deleted.

test/sampling_tests.cpp

Lines changed: 85 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -176,31 +176,57 @@ TEST(SamplingTests, RandomizedSamplingTopPCpu) {
176176
}
177177
}
178178

179+
void SoftMax(std::span<float> scores, float temperature) {
180+
float const max_score = *std::max_element(scores.begin(), scores.end());
181+
182+
// Subtract max score and scale by temperature
183+
std::transform(scores.begin(), scores.end(), scores.begin(), [max_score, temperature](float score) { return std::exp((score - max_score) / temperature); });
184+
185+
// Compute sum of exponentials
186+
float const exp_sum = std::accumulate(scores.begin(), scores.end(), 0.0f);
187+
188+
// Divide each score by the sum of exponentials
189+
std::transform(scores.begin(), scores.end(), scores.begin(), [exp_sum](float score) { return score / exp_sum; });
190+
}
191+
179192
TEST(SamplingTests, RandomizedSamplingTopKCpu) {
180193
auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
181-
int batch_size = 5;
182-
int k = 5;
183-
std::vector<int32_t> input_ids{0, 1, 2, 3, 4};
194+
const int batch_size = 5;
195+
const int k = 5;
184196

185197
Generators::Config config;
186-
config.model.vocab_size = 32000; // vocab size of llama
198+
const int vocab_size = 13; // vocab size of llama
199+
config.model.vocab_size = vocab_size; // vocab size of llama
187200

201+
// Create a generator
188202
auto params = Generators::CreateGeneratorParams(config);
189203
params->search.max_length = 10;
190204
params->search.do_sample = true;
191205
params->search.top_k = k;
192206
params->search.batch_size = batch_size;
193207
params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CPU);
194208
params->device_type = Generators::DeviceType::CPU;
195-
std::vector<float> logits_cpu(config.model.vocab_size * batch_size);
209+
210+
// Create data structures for testing
196211
std::random_device rd;
197212
std::mt19937 engine(rd());
198-
std::uniform_int_distribution<> dist(5, 25);
199-
int num_iter = 100;
213+
std::vector<int> indices(vocab_size);
214+
std::vector<float> logits_cpu(vocab_size * batch_size);
215+
const int num_iter = 100;
216+
std::map<float, int> logit_to_count;
217+
218+
// Run test
200219
for (int i = 0; i < num_iter; i++) {
201-
int num_large = dist(engine);
202220
auto generator = Generators::CreateGenerator(*model, *params);
203-
CreateRandomLogits(logits_cpu.data(), num_large, config.model.vocab_size, batch_size, engine);
221+
logits_cpu = std::vector<float>(vocab_size * batch_size, 0.0f);
222+
// Shuffle integers 1 to k randomly into logits_cpu
223+
for (int b = 0; b < batch_size; b++) {
224+
std::iota(indices.begin(), indices.end(), 0);
225+
std::shuffle(indices.begin(), indices.end(), engine);
226+
for (int j = 0; j < k; j++)
227+
logits_cpu[indices[j] + vocab_size * b] = float(k - j);
228+
}
229+
// Set logits and get generated token
204230
auto logits_copy = logits_cpu;
205231
auto logits = params->p_device->WrapMemory<float>(logits_copy);
206232
generator->SetLogits(logits);
@@ -209,10 +235,22 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) {
209235
// Verify outputs match expected outputs
210236
for (int b = 0; b < batch_size; b++) {
211237
auto next_token = next_tokens[b];
212-
auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b];
213-
EXPECT_GT(next_token_score, 10.0f);
238+
auto next_token_score = logits_cpu[next_token + vocab_size * b];
239+
logit_to_count[next_token_score]++;
240+
EXPECT_GT(next_token_score, 0.0f);
214241
}
215242
}
243+
// Calculate expected distribution of tokens by softmaxing given logits (integers 1 through k)
244+
std::vector<float> expected_distributions(k);
245+
for (int i = 0; i < k; i++)
246+
expected_distributions[i] = float(i + 1);
247+
SoftMax(expected_distributions, 1.0f);
248+
// Check that the distribution of tokens generated by the model is close to the expected distribution
249+
const int total_count = batch_size * num_iter;
250+
for (auto& [logit, count] : logit_to_count) {
251+
const float expected_distribution = expected_distributions[int(logit) - 1];
252+
EXPECT_NEAR(count / float(total_count), expected_distribution, 0.1);
253+
}
216254
}
217255

218256
TEST(SamplingTests, RandomizedSamplingTopPAndKCpu) {
@@ -396,43 +434,65 @@ TEST(SamplingTests, RandomizedSamplingTopPCuda) {
396434

397435
TEST(SamplingTests, RandomizedSamplingTopKCuda) {
398436
auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
399-
int batch_size = 5;
400-
int k = 5;
401-
std::vector<int32_t> input_ids{0, 1, 2, 3, 4};
437+
const int batch_size = 5;
438+
const int k = 5;
402439

403440
Generators::Config config;
404-
config.model.vocab_size = 32000; // vocab size of llama
441+
const int vocab_size = 17; // vocab size of llama
442+
config.model.vocab_size = vocab_size; // vocab size of llama
405443

444+
// Create a generator
406445
auto params = Generators::CreateGeneratorParams(config);
407446
params->search.max_length = 10;
408447
params->search.do_sample = true;
409448
params->search.top_k = k;
410449
params->search.batch_size = batch_size;
411450
params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA);
412451
params->device_type = Generators::DeviceType::CUDA;
413-
auto logits_gpu = params->p_device->Allocate<float>(config.model.vocab_size * batch_size);
414-
auto indices_buffer = params->p_device->Allocate<int>(config.model.vocab_size * batch_size);
415452

453+
// Create data structures for testing
416454
std::random_device rd;
417455
std::mt19937 engine(rd());
418-
std::uniform_int_distribution<> dist(1, 25);
419-
int num_iter = 100;
456+
std::vector<int> indices(vocab_size);
457+
const int num_iter = 100;
458+
std::map<float, int> logit_to_count;
459+
460+
// Run test
420461
for (int i = 0; i < num_iter; i++) {
421-
int num_large = dist(engine);
422-
LaunchGeometricDecayKernel(logits_gpu.Span().data(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream);
423-
LaunchFisherYatesKernel(logits_gpu.Span().data(), indices_buffer.Span().data(), config.model.vocab_size, batch_size, params->cuda_stream);
424462
auto generator = Generators::CreateGenerator(*model, *params);
463+
Generators::DeviceSpan<float> logits_gpu = params->p_device->Allocate<float>(vocab_size * batch_size);
464+
auto cpu_span = logits_gpu.CpuSpan();
465+
// Shuffle integers 1 to k randomly into cpu_span
466+
for (int b = 0; b < batch_size; b++) {
467+
std::iota(indices.begin(), indices.end(), 0);
468+
std::shuffle(indices.begin(), indices.end(), engine);
469+
for (int j = 0; j < k; j++)
470+
cpu_span[indices[j] + vocab_size * b] = float(k - j);
471+
}
472+
// Copy logits onto device, set logits, and get generated token
473+
logits_gpu.CopyCpuToDevice();
425474
generator->SetLogits(logits_gpu);
426475
generator->GenerateNextToken();
427476
auto next_tokens = generator->search_->GetNextTokens().CopyDeviceToCpu();
428-
auto logits_cpu = logits_gpu.CopyDeviceToCpu();
429477
// Verify outputs match expected outputs
430478
for (int b = 0; b < batch_size; b++) {
431479
auto next_token = next_tokens[b];
432-
auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b];
433-
EXPECT_GT(next_token_score, 10.0f);
480+
auto next_token_score = cpu_span[next_token + vocab_size * b];
481+
logit_to_count[next_token_score]++;
482+
EXPECT_GT(next_token_score, 0.0f);
434483
}
435484
}
485+
// Calculate expected distribution of tokens by softmaxing given logits (integers 1 through k)
486+
std::vector<float> expected_distributions(k);
487+
for (int i = 0; i < k; i++)
488+
expected_distributions[i] = float(i + 1);
489+
SoftMax(expected_distributions, 1.0f);
490+
const int total_count = batch_size * num_iter;
491+
// Check that the distribution of tokens generated by the model is close to the expected distribution
492+
for (auto& [logit, count] : logit_to_count) {
493+
const float expected_distribution = expected_distributions[int(logit) - 1];
494+
EXPECT_NEAR(count / float(total_count), expected_distribution, 0.1);
495+
}
436496
}
437497

438498
TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) {

0 commit comments

Comments
 (0)