From e11a307e0e064ba7253a27a9b218c7ecafb27564 Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Thu, 16 Jan 2025 11:26:58 -0800 Subject: [PATCH 01/15] fix top k --- examples/python/phi3-qa.py | 9 +++++---- src/generators.h | 1 - src/search.cpp | 13 ++++++++++--- src/top_k_cpu.cpp | 32 -------------------------------- 4 files changed, 15 insertions(+), 40 deletions(-) delete mode 100644 src/top_k_cpu.cpp diff --git a/examples/python/phi3-qa.py b/examples/python/phi3-qa.py index 6d4abfd96..323b4163e 100644 --- a/examples/python/phi3-qa.py +++ b/examples/python/phi3-qa.py @@ -30,6 +30,10 @@ def main(args): chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>' + params = og.GeneratorParams(model) + params.set_search_options(**search_options) + generator = og.Generator(model, params) + # Keep asking for input prompts in a loop while True: text = input("Input: ") @@ -44,9 +48,6 @@ def main(args): input_tokens = tokenizer.encode(prompt) - params = og.GeneratorParams(model) - params.set_search_options(**search_options) - generator = og.Generator(model, params) generator.append_tokens(input_tokens) if args.verbose: print("Generator created") @@ -75,7 +76,7 @@ def main(args): print() # Delete the generator to free the captured graph for the next generator, if graph capture is enabled - del generator + # del generator if args.timings: prompt_time = first_token_timestamp - started_timestamp diff --git a/src/generators.h b/src/generators.h index 28a71e580..87468bb03 100644 --- a/src/generators.h +++ b/src/generators.h @@ -161,6 +161,5 @@ std::shared_ptr CreateGeneratorParams(const Config& config); / std::unique_ptr CreateGenerator(const Model& model, const GeneratorParams& params); float Float16ToFloat32(uint16_t v); // v is a IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction -void top_k_indices(std::span top_k, std::span inputs); } // namespace Generators diff --git a/src/search.cpp b/src/search.cpp index 274a5b2dd..c04d2def5 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -161,9 +161,13 @@ void GreedySearch_Cpu::SampleTopK(int k, float temperature) { std::vector indices(scores.size()); std::iota(indices.begin(), indices.end(), 0); std::partial_sort(indices.begin(), indices.begin() + k, indices.end(), [scores = scores.data()](int i, int j) { return scores[i] > scores[j]; }); + std::vector top_k_scores(k); + for (int i = 0; i < k; i++) + top_k_scores.push_back(scores[indices[i]]); // Sample a token from the top K - std::discrete_distribution<> dis(scores.begin(), scores.begin() + k); - SetNextToken(batch_id, indices[dis(gen_)]); + std::discrete_distribution<> dis(top_k_scores.begin(), top_k_scores.end()); + int randi = dis(gen_); + SetNextToken(batch_id, indices[randi]); } AppendNextTokensToSequences(); } @@ -209,12 +213,15 @@ void GreedySearch_Cpu::SampleTopKTopP(int k, float p, float temperature) { std::vector indices(scores.size()); std::iota(indices.begin(), indices.end(), 0); std::partial_sort(indices.begin(), indices.begin() + k, indices.end(), [scores = scores.data()](int i, int j) { return scores[i] > scores[j]; }); + std::vector top_k_scores(k); + for (int i = 0; i < k; i++) + top_k_scores.push_back(scores[indices[i]]); // Sample a probability threshold float threshold = dis(gen_); int32_t token = indices[k - 1]; // Find the first token where the cumulative probability exceeds the threshold for (int i = 0; i < k; i++) { - threshold -= scores[indices[i]]; + threshold -= top_k_scores[indices[i]]; if (threshold > 0) { continue; } diff --git a/src/top_k_cpu.cpp b/src/top_k_cpu.cpp deleted file mode 100644 index 745d02c01..000000000 --- a/src/top_k_cpu.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include "generators.h" -namespace Generators { - -void top_k_indices(std::span top_k, std::span inputs) { - int32_t k = static_cast(top_k.size()); - assert(k <= inputs.size()); // Use a smaller top_k span if k is larger than inputs - - // Min heap to store pairs of (element, index) - std::priority_queue, std::vector>, std::greater<>> pq; - - // Add first k elements into the heap - for (int32_t i = 0; i < k; i++) { - pq.push(std::make_pair(inputs[i], i)); - } - - // For the rest of the elements we already have k, so remove the smallest on each iteration - for (int32_t i = k; i < inputs.size(); i++) { - // Entry is smaller than the smallest, so don't bother - if (inputs[i] <= pq.top().first) - continue; - - pq.pop(); - pq.push(std::make_pair(inputs[i], i)); - } - - for (int i = 0; i < k; i++) { - top_k[k - i - 1] = pq.top().second; - pq.pop(); - } -} - -} // namespace Generators From e6b14fa0a644c4ad65ccae758ef0ffc76fcad678 Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Thu, 16 Jan 2025 12:08:08 -0800 Subject: [PATCH 02/15] del generator no more --- examples/python/phi3-qa.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/python/phi3-qa.py b/examples/python/phi3-qa.py index 323b4163e..56cc8a82d 100644 --- a/examples/python/phi3-qa.py +++ b/examples/python/phi3-qa.py @@ -75,9 +75,6 @@ def main(args): print() print() - # Delete the generator to free the captured graph for the next generator, if graph capture is enabled - # del generator - if args.timings: prompt_time = first_token_timestamp - started_timestamp run_time = time.time() - first_token_timestamp From 101b3adb95cec80d9687e2d3050f1527fbe86679 Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Thu, 16 Jan 2025 12:10:04 -0800 Subject: [PATCH 03/15] line was weird --- src/search.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/search.cpp b/src/search.cpp index c04d2def5..4bc25d460 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -166,8 +166,7 @@ void GreedySearch_Cpu::SampleTopK(int k, float temperature) { top_k_scores.push_back(scores[indices[i]]); // Sample a token from the top K std::discrete_distribution<> dis(top_k_scores.begin(), top_k_scores.end()); - int randi = dis(gen_); - SetNextToken(batch_id, indices[randi]); + SetNextToken(batch_id, indices[dis(gen_)]); } AppendNextTokensToSequences(); } From 952754cbae4372ea36a1450f1b5230cba56d6bd7 Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Fri, 17 Jan 2025 08:36:55 -0800 Subject: [PATCH 04/15] fix dumb bmistakes --- src/search.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/search.cpp b/src/search.cpp index 4bc25d460..a1a9f6890 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -163,7 +163,7 @@ void GreedySearch_Cpu::SampleTopK(int k, float temperature) { std::partial_sort(indices.begin(), indices.begin() + k, indices.end(), [scores = scores.data()](int i, int j) { return scores[i] > scores[j]; }); std::vector top_k_scores(k); for (int i = 0; i < k; i++) - top_k_scores.push_back(scores[indices[i]]); + top_k_scores[i] = scores[indices[i]]; // Sample a token from the top K std::discrete_distribution<> dis(top_k_scores.begin(), top_k_scores.end()); SetNextToken(batch_id, indices[dis(gen_)]); @@ -212,15 +212,12 @@ void GreedySearch_Cpu::SampleTopKTopP(int k, float p, float temperature) { std::vector indices(scores.size()); std::iota(indices.begin(), indices.end(), 0); std::partial_sort(indices.begin(), indices.begin() + k, indices.end(), [scores = scores.data()](int i, int j) { return scores[i] > scores[j]; }); - std::vector top_k_scores(k); - for (int i = 0; i < k; i++) - top_k_scores.push_back(scores[indices[i]]); // Sample a probability threshold float threshold = dis(gen_); int32_t token = indices[k - 1]; // Find the first token where the cumulative probability exceeds the threshold for (int i = 0; i < k; i++) { - threshold -= top_k_scores[indices[i]]; + threshold -= scores[indices[i]]; if (threshold > 0) { continue; } From 6efed8e7001d705d84ccd410801ce10f3dd8ac20 Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Fri, 17 Jan 2025 11:46:34 -0800 Subject: [PATCH 05/15] new tests --- test/sampling_tests.cpp | 96 +++++++++++++++++++++++++++++++++-------- 1 file changed, 77 insertions(+), 19 deletions(-) diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index a71910b15..cf826d1de 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -176,15 +176,30 @@ TEST(SamplingTests, RandomizedSamplingTopPCpu) { } } +// TODO(aciddelgado): this is copy-pasted from softmax.h but I think that might be fine... not sure how we feel about that +void SoftMax(std::span scores, float temperature) { + float const max_score = *std::max_element(scores.begin(), scores.end()); + + // Subtract max score and scale by temperature + std::transform(scores.begin(), scores.end(), scores.begin(), [max_score, temperature](float score) { return std::exp((score - max_score) / temperature); }); + + // Compute sum of exponentials + float const exp_sum = std::accumulate(scores.begin(), scores.end(), 0.0f); + + // Divide each score by the sum of exponentials + std::transform(scores.begin(), scores.end(), scores.begin(), [exp_sum](float score) { return score / exp_sum; }); +} + TEST(SamplingTests, RandomizedSamplingTopKCpu) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); int batch_size = 5; int k = 5; - std::vector input_ids{0, 1, 2, 3, 4}; Generators::Config config; - config.model.vocab_size = 32000; // vocab size of llama + int vocab_size = 32000; // vocab size of llama + config.model.vocab_size = vocab_size; // vocab size of llama + // Create a generator auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->search.do_sample = true; @@ -192,15 +207,23 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { params->search.batch_size = batch_size; params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CPU); params->device_type = Generators::DeviceType::CPU; - std::vector logits_cpu(config.model.vocab_size * batch_size); + + // Create data structures for testing std::random_device rd; std::mt19937 engine(rd()); - std::uniform_int_distribution<> dist(5, 25); + std::vector indices(vocab_size); + std::vector logits_cpu(vocab_size * batch_size); int num_iter = 100; + std::map logit_to_count; for (int i = 0; i < num_iter; i++) { - int num_large = dist(engine); auto generator = Generators::CreateGenerator(*model, *params); - CreateRandomLogits(logits_cpu.data(), num_large, config.model.vocab_size, batch_size, engine); + logits_cpu = std::vector(vocab_size * batch_size, 0.0f); + for (int i = 0; i < batch_size; i++) { + std::iota(indices.begin(), indices.end(), 0); + std::shuffle(indices.begin(), indices.end(), engine); + for (int j = 0; j < k; j++) + logits_cpu[indices[j] + vocab_size * i] = float(k - j); + } auto logits_copy = logits_cpu; auto logits = params->p_device->WrapMemory(logits_copy); generator->SetLogits(logits); @@ -209,10 +232,23 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { // Verify outputs match expected outputs for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; - auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b]; - EXPECT_GT(next_token_score, 10.0f); + auto next_token_score = logits_cpu[next_token + vocab_size * b]; + if (logit_to_count.find(next_token_score) == logit_to_count.end()) + logit_to_count[next_token_score] = 1; + else + logit_to_count[next_token_score]++; + EXPECT_GT(next_token_score, 0.0f); } } + std::vector expected_distributions(5); + for (int i = 0; i < k; i++) + expected_distributions[i] = float(i + 1); + SoftMax(expected_distributions, 1.0f); + int total_count = batch_size * num_iter; + for (auto& [logit, count] : logit_to_count) { + float expected_distribution = expected_distributions[int(logit) - 1]; + EXPECT_NEAR(count / float(total_count), expected_distribution, 0.1); + } } TEST(SamplingTests, RandomizedSamplingTopPAndKCpu) { @@ -398,11 +434,12 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); int batch_size = 5; int k = 5; - std::vector input_ids{0, 1, 2, 3, 4}; Generators::Config config; - config.model.vocab_size = 32000; // vocab size of llama + int vocab_size = 32000; // vocab size of llama + config.model.vocab_size = vocab_size; // vocab size of llama + // Create a generator auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->search.do_sample = true; @@ -410,29 +447,50 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { params->search.batch_size = batch_size; params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA); params->device_type = Generators::DeviceType::CUDA; - auto logits_gpu = params->p_device->Allocate(config.model.vocab_size * batch_size); - auto indices_buffer = params->p_device->Allocate(config.model.vocab_size * batch_size); + // Create data structures for testing std::random_device rd; std::mt19937 engine(rd()); - std::uniform_int_distribution<> dist(1, 25); + std::vector indices(vocab_size); + std::vector logits_cpu(vocab_size * batch_size); int num_iter = 100; + std::map logit_to_count; for (int i = 0; i < num_iter; i++) { - int num_large = dist(engine); - LaunchGeometricDecayKernel(logits_gpu.Span().data(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); - LaunchFisherYatesKernel(logits_gpu.Span().data(), indices_buffer.Span().data(), config.model.vocab_size, batch_size, params->cuda_stream); auto generator = Generators::CreateGenerator(*model, *params); + logits_cpu = std::vector(vocab_size * batch_size, 0.0f); + for (int i = 0; i < batch_size; i++) { + std::iota(indices.begin(), indices.end(), 0); + std::shuffle(indices.begin(), indices.end(), engine); + for (int j = 0; j < k; j++) + logits_cpu[indices[j] + vocab_size * i] = float(k - j); + } + Generators::DeviceSpan logits_gpu = params->p_device->Allocate(config.model.vocab_size * batch_size); + auto cpu_span = logits_gpu.CpuSpan(); + std::copy(logits_cpu.begin(), logits_cpu.end(), cpu_span.begin()); + logits_gpu.CopyCpuToDevice(); generator->SetLogits(logits_gpu); generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().CopyDeviceToCpu(); - auto logits_cpu = logits_gpu.CopyDeviceToCpu(); // Verify outputs match expected outputs for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; - auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b]; - EXPECT_GT(next_token_score, 10.0f); + auto next_token_score = logits_cpu[next_token + vocab_size * b]; + if (logit_to_count.find(next_token_score) == logit_to_count.end()) + logit_to_count[next_token_score] = 1; + else + logit_to_count[next_token_score]++; + EXPECT_GT(next_token_score, 0.0f); } } + std::vector expected_distributions(5); + for (int i = 0; i < k; i++) + expected_distributions[i] = float(i + 1); + SoftMax(expected_distributions, 1.0f); + int total_count = batch_size * num_iter; + for (auto& [logit, count] : logit_to_count) { + float expected_distribution = expected_distributions[int(logit) - 1]; + EXPECT_NEAR(count / float(total_count), expected_distribution, 0.1); + } } TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) { From 30bdcda8a88b76ec40993f6a6eaceba479fe2130 Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Fri, 17 Jan 2025 15:00:01 -0800 Subject: [PATCH 06/15] nice comments --- test/sampling_tests.cpp | 61 +++++++++++++++++++++-------------------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index cf826d1de..8adb47ab0 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -176,7 +176,6 @@ TEST(SamplingTests, RandomizedSamplingTopPCpu) { } } -// TODO(aciddelgado): this is copy-pasted from softmax.h but I think that might be fine... not sure how we feel about that void SoftMax(std::span scores, float temperature) { float const max_score = *std::max_element(scores.begin(), scores.end()); @@ -192,11 +191,11 @@ void SoftMax(std::span scores, float temperature) { TEST(SamplingTests, RandomizedSamplingTopKCpu) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); - int batch_size = 5; - int k = 5; + const int batch_size = 5; + const int k = 5; Generators::Config config; - int vocab_size = 32000; // vocab size of llama + const int vocab_size = 13; // vocab size of llama config.model.vocab_size = vocab_size; // vocab size of llama // Create a generator @@ -213,17 +212,21 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { std::mt19937 engine(rd()); std::vector indices(vocab_size); std::vector logits_cpu(vocab_size * batch_size); - int num_iter = 100; + const int num_iter = 100; std::map logit_to_count; + + // Run test for (int i = 0; i < num_iter; i++) { auto generator = Generators::CreateGenerator(*model, *params); logits_cpu = std::vector(vocab_size * batch_size, 0.0f); + // Shuffle integers 1 to k randomly into logits_cpu for (int i = 0; i < batch_size; i++) { std::iota(indices.begin(), indices.end(), 0); std::shuffle(indices.begin(), indices.end(), engine); for (int j = 0; j < k; j++) logits_cpu[indices[j] + vocab_size * i] = float(k - j); } + // Set logits and get generated token auto logits_copy = logits_cpu; auto logits = params->p_device->WrapMemory(logits_copy); generator->SetLogits(logits); @@ -233,20 +236,19 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; auto next_token_score = logits_cpu[next_token + vocab_size * b]; - if (logit_to_count.find(next_token_score) == logit_to_count.end()) - logit_to_count[next_token_score] = 1; - else - logit_to_count[next_token_score]++; + logit_to_count[next_token_score]++; EXPECT_GT(next_token_score, 0.0f); } } - std::vector expected_distributions(5); + // Calculate expected distribution of tokens by softmaxing given logits (integers 1 through k) + std::vector expected_distributions(k); for (int i = 0; i < k; i++) expected_distributions[i] = float(i + 1); SoftMax(expected_distributions, 1.0f); - int total_count = batch_size * num_iter; + // Check that the distribution of tokens generated by the model is close to the expected distribution + const int total_count = batch_size * num_iter; for (auto& [logit, count] : logit_to_count) { - float expected_distribution = expected_distributions[int(logit) - 1]; + const float expected_distribution = expected_distributions[int(logit) - 1]; EXPECT_NEAR(count / float(total_count), expected_distribution, 0.1); } } @@ -432,11 +434,11 @@ TEST(SamplingTests, RandomizedSamplingTopPCuda) { TEST(SamplingTests, RandomizedSamplingTopKCuda) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); - int batch_size = 5; - int k = 5; + const int batch_size = 5; + const int k = 5; Generators::Config config; - int vocab_size = 32000; // vocab size of llama + const int vocab_size = 13; // vocab size of llama config.model.vocab_size = vocab_size; // vocab size of llama // Create a generator @@ -453,20 +455,22 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { std::mt19937 engine(rd()); std::vector indices(vocab_size); std::vector logits_cpu(vocab_size * batch_size); - int num_iter = 100; + const int num_iter = 100; std::map logit_to_count; + + // Run test for (int i = 0; i < num_iter; i++) { auto generator = Generators::CreateGenerator(*model, *params); - logits_cpu = std::vector(vocab_size * batch_size, 0.0f); + Generators::DeviceSpan logits_gpu = params->p_device->Allocate(config.model.vocab_size * batch_size); + auto cpu_span = logits_gpu.CpuSpan(); + // Shuffle integers 1 to k randomly into cpu_span for (int i = 0; i < batch_size; i++) { std::iota(indices.begin(), indices.end(), 0); std::shuffle(indices.begin(), indices.end(), engine); for (int j = 0; j < k; j++) - logits_cpu[indices[j] + vocab_size * i] = float(k - j); + cpu_span[indices[j] + vocab_size * i] = float(k - j); } - Generators::DeviceSpan logits_gpu = params->p_device->Allocate(config.model.vocab_size * batch_size); - auto cpu_span = logits_gpu.CpuSpan(); - std::copy(logits_cpu.begin(), logits_cpu.end(), cpu_span.begin()); + // Copy logits onto device, set logits, and get generated token logits_gpu.CopyCpuToDevice(); generator->SetLogits(logits_gpu); generator->GenerateNextToken(); @@ -474,21 +478,20 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { // Verify outputs match expected outputs for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; - auto next_token_score = logits_cpu[next_token + vocab_size * b]; - if (logit_to_count.find(next_token_score) == logit_to_count.end()) - logit_to_count[next_token_score] = 1; - else - logit_to_count[next_token_score]++; + auto next_token_score = cpu_span[next_token + vocab_size * b]; + logit_to_count[next_token_score]++; EXPECT_GT(next_token_score, 0.0f); } } - std::vector expected_distributions(5); + // Calculate expected distribution of tokens by softmaxing given logits (integers 1 through k) + std::vector expected_distributions(k); for (int i = 0; i < k; i++) expected_distributions[i] = float(i + 1); SoftMax(expected_distributions, 1.0f); - int total_count = batch_size * num_iter; + const int total_count = batch_size * num_iter; + // Check that the distribution of tokens generated by the model is close to the expected distribution for (auto& [logit, count] : logit_to_count) { - float expected_distribution = expected_distributions[int(logit) - 1]; + const float expected_distribution = expected_distributions[int(logit) - 1]; EXPECT_NEAR(count / float(total_count), expected_distribution, 0.1); } } From 5b86288db0623283cb310f4bc23534ed4affba93 Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Fri, 17 Jan 2025 15:02:42 -0800 Subject: [PATCH 07/15] one more --- test/sampling_tests.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index 8adb47ab0..4d85d558f 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -461,7 +461,7 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { // Run test for (int i = 0; i < num_iter; i++) { auto generator = Generators::CreateGenerator(*model, *params); - Generators::DeviceSpan logits_gpu = params->p_device->Allocate(config.model.vocab_size * batch_size); + Generators::DeviceSpan logits_gpu = params->p_device->Allocate(vocab_size * batch_size); auto cpu_span = logits_gpu.CpuSpan(); // Shuffle integers 1 to k randomly into cpu_span for (int i = 0; i < batch_size; i++) { From 1bb5a2224949051a7230657eb060d59724a739cd Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Fri, 17 Jan 2025 15:13:16 -0800 Subject: [PATCH 08/15] comment --- test/sampling_tests.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index 4d85d558f..629315d11 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -454,7 +454,6 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { std::random_device rd; std::mt19937 engine(rd()); std::vector indices(vocab_size); - std::vector logits_cpu(vocab_size * batch_size); const int num_iter = 100; std::map logit_to_count; From 8342260423e12849f7221ac5e18749530c9e2aac Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Fri, 17 Jan 2025 15:31:37 -0800 Subject: [PATCH 09/15] error cmake --- test/sampling_tests.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index 629315d11..91dcedc79 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -220,11 +220,11 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { auto generator = Generators::CreateGenerator(*model, *params); logits_cpu = std::vector(vocab_size * batch_size, 0.0f); // Shuffle integers 1 to k randomly into logits_cpu - for (int i = 0; i < batch_size; i++) { + for (int b = 0; b < batch_size; i++) { std::iota(indices.begin(), indices.end(), 0); std::shuffle(indices.begin(), indices.end(), engine); for (int j = 0; j < k; j++) - logits_cpu[indices[j] + vocab_size * i] = float(k - j); + logits_cpu[indices[j] + vocab_size * b] = float(k - j); } // Set logits and get generated token auto logits_copy = logits_cpu; @@ -463,11 +463,11 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { Generators::DeviceSpan logits_gpu = params->p_device->Allocate(vocab_size * batch_size); auto cpu_span = logits_gpu.CpuSpan(); // Shuffle integers 1 to k randomly into cpu_span - for (int i = 0; i < batch_size; i++) { + for (int b = 0; b < batch_size; b++) { std::iota(indices.begin(), indices.end(), 0); std::shuffle(indices.begin(), indices.end(), engine); for (int j = 0; j < k; j++) - cpu_span[indices[j] + vocab_size * i] = float(k - j); + cpu_span[indices[j] + vocab_size * b] = float(k - j); } // Copy logits onto device, set logits, and get generated token logits_gpu.CopyCpuToDevice(); From bcb1b56c69a5788f70c2044ffee864726e40c1f4 Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Fri, 17 Jan 2025 16:58:56 -0800 Subject: [PATCH 10/15] meaningless change --- test/sampling_tests.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index 91dcedc79..73758a354 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -195,7 +195,7 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { const int k = 5; Generators::Config config; - const int vocab_size = 13; // vocab size of llama + const int vocab_size = 17; // vocab size of llama config.model.vocab_size = vocab_size; // vocab size of llama // Create a generator @@ -438,7 +438,7 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { const int k = 5; Generators::Config config; - const int vocab_size = 13; // vocab size of llama + const int vocab_size = 17; // vocab size of llama config.model.vocab_size = vocab_size; // vocab size of llama // Create a generator From 0508930d7a8165dcfca2c4b3d5b30f8137df9368 Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Tue, 21 Jan 2025 08:21:31 -0800 Subject: [PATCH 11/15] try printing --- test/sampling_tests.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index 73758a354..a05fe995a 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -190,6 +190,7 @@ void SoftMax(std::span scores, float temperature) { } TEST(SamplingTests, RandomizedSamplingTopKCpu) { + std::cout << "Iteration " << i << std::endl; auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); const int batch_size = 5; const int k = 5; @@ -459,6 +460,7 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { // Run test for (int i = 0; i < num_iter; i++) { + std::cout << "Iteration " << i << std::endl; auto generator = Generators::CreateGenerator(*model, *params); Generators::DeviceSpan logits_gpu = params->p_device->Allocate(vocab_size * batch_size); auto cpu_span = logits_gpu.CpuSpan(); From 611639321e2f348324b62a36119e342fb52afd69 Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Tue, 21 Jan 2025 08:34:29 -0800 Subject: [PATCH 12/15] try printing --- test/sampling_tests.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index a05fe995a..7cfce8dc9 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -190,7 +190,6 @@ void SoftMax(std::span scores, float temperature) { } TEST(SamplingTests, RandomizedSamplingTopKCpu) { - std::cout << "Iteration " << i << std::endl; auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); const int batch_size = 5; const int k = 5; @@ -218,6 +217,7 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { // Run test for (int i = 0; i < num_iter; i++) { + std::cout << "Iteration " << i << std::endl; auto generator = Generators::CreateGenerator(*model, *params); logits_cpu = std::vector(vocab_size * batch_size, 0.0f); // Shuffle integers 1 to k randomly into logits_cpu From c790edc79035bbe7a38f388bdbc0a2834394a16d Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Tue, 21 Jan 2025 08:56:17 -0800 Subject: [PATCH 13/15] more prints --- test/sampling_tests.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index 7cfce8dc9..b6c34dbd1 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -221,6 +221,7 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { auto generator = Generators::CreateGenerator(*model, *params); logits_cpu = std::vector(vocab_size * batch_size, 0.0f); // Shuffle integers 1 to k randomly into logits_cpu + std::cout << "Shuffling logits" << std::endl; for (int b = 0; b < batch_size; i++) { std::iota(indices.begin(), indices.end(), 0); std::shuffle(indices.begin(), indices.end(), engine); @@ -228,12 +229,14 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { logits_cpu[indices[j] + vocab_size * b] = float(k - j); } // Set logits and get generated token + std::cout << "Generating next token" << std::endl; auto logits_copy = logits_cpu; auto logits = params->p_device->WrapMemory(logits_copy); generator->SetLogits(logits); generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().CopyDeviceToCpu(); // Verify outputs match expected outputs + std::cout << "Checking outputs" << std::endl; for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; auto next_token_score = logits_cpu[next_token + vocab_size * b]; @@ -242,11 +245,13 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { } } // Calculate expected distribution of tokens by softmaxing given logits (integers 1 through k) + std::cout << "Calculating expected distribution" << std::endl; std::vector expected_distributions(k); for (int i = 0; i < k; i++) expected_distributions[i] = float(i + 1); SoftMax(expected_distributions, 1.0f); // Check that the distribution of tokens generated by the model is close to the expected distribution + std::cout << "Checking distribution" << std::endl; const int total_count = batch_size * num_iter; for (auto& [logit, count] : logit_to_count) { const float expected_distribution = expected_distributions[int(logit) - 1]; From 90d596998f105f18e1d06130f638f4dd9bc6291b Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Tue, 21 Jan 2025 09:08:45 -0800 Subject: [PATCH 14/15] more prints --- test/sampling_tests.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index b6c34dbd1..e041a2fe6 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -195,7 +195,7 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { const int k = 5; Generators::Config config; - const int vocab_size = 17; // vocab size of llama + const int vocab_size = 13; // vocab size of llama config.model.vocab_size = vocab_size; // vocab size of llama // Create a generator @@ -224,9 +224,12 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { std::cout << "Shuffling logits" << std::endl; for (int b = 0; b < batch_size; i++) { std::iota(indices.begin(), indices.end(), 0); + std::cout << "iota done" << std::endl; std::shuffle(indices.begin(), indices.end(), engine); + std::cout << "shuffle done" << std::endl; for (int j = 0; j < k; j++) logits_cpu[indices[j] + vocab_size * b] = float(k - j); + std::cout << "batch " << b << " done" << std::endl; } // Set logits and get generated token std::cout << "Generating next token" << std::endl; From 601c2616106287e754465e6e20861936fefc2c3a Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Tue, 21 Jan 2025 09:34:47 -0800 Subject: [PATCH 15/15] fixed infinite loop dumb --- test/sampling_tests.cpp | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index e041a2fe6..6ca0080eb 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -217,29 +217,22 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { // Run test for (int i = 0; i < num_iter; i++) { - std::cout << "Iteration " << i << std::endl; auto generator = Generators::CreateGenerator(*model, *params); logits_cpu = std::vector(vocab_size * batch_size, 0.0f); // Shuffle integers 1 to k randomly into logits_cpu - std::cout << "Shuffling logits" << std::endl; - for (int b = 0; b < batch_size; i++) { + for (int b = 0; b < batch_size; b++) { std::iota(indices.begin(), indices.end(), 0); - std::cout << "iota done" << std::endl; std::shuffle(indices.begin(), indices.end(), engine); - std::cout << "shuffle done" << std::endl; for (int j = 0; j < k; j++) logits_cpu[indices[j] + vocab_size * b] = float(k - j); - std::cout << "batch " << b << " done" << std::endl; } // Set logits and get generated token - std::cout << "Generating next token" << std::endl; auto logits_copy = logits_cpu; auto logits = params->p_device->WrapMemory(logits_copy); generator->SetLogits(logits); generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().CopyDeviceToCpu(); // Verify outputs match expected outputs - std::cout << "Checking outputs" << std::endl; for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; auto next_token_score = logits_cpu[next_token + vocab_size * b]; @@ -248,13 +241,11 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { } } // Calculate expected distribution of tokens by softmaxing given logits (integers 1 through k) - std::cout << "Calculating expected distribution" << std::endl; std::vector expected_distributions(k); for (int i = 0; i < k; i++) expected_distributions[i] = float(i + 1); SoftMax(expected_distributions, 1.0f); // Check that the distribution of tokens generated by the model is close to the expected distribution - std::cout << "Checking distribution" << std::endl; const int total_count = batch_size * num_iter; for (auto& [logit, count] : logit_to_count) { const float expected_distribution = expected_distributions[int(logit) - 1]; @@ -468,7 +459,6 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { // Run test for (int i = 0; i < num_iter; i++) { - std::cout << "Iteration " << i << std::endl; auto generator = Generators::CreateGenerator(*model, *params); Generators::DeviceSpan logits_gpu = params->p_device->Allocate(vocab_size * batch_size); auto cpu_span = logits_gpu.CpuSpan();