diff --git a/examples/python/phi3-qa.py b/examples/python/phi3-qa.py index 6d4abfd96..56cc8a82d 100644 --- a/examples/python/phi3-qa.py +++ b/examples/python/phi3-qa.py @@ -30,6 +30,10 @@ def main(args): chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>' + params = og.GeneratorParams(model) + params.set_search_options(**search_options) + generator = og.Generator(model, params) + # Keep asking for input prompts in a loop while True: text = input("Input: ") @@ -44,9 +48,6 @@ def main(args): input_tokens = tokenizer.encode(prompt) - params = og.GeneratorParams(model) - params.set_search_options(**search_options) - generator = og.Generator(model, params) generator.append_tokens(input_tokens) if args.verbose: print("Generator created") @@ -74,9 +75,6 @@ def main(args): print() print() - # Delete the generator to free the captured graph for the next generator, if graph capture is enabled - del generator - if args.timings: prompt_time = first_token_timestamp - started_timestamp run_time = time.time() - first_token_timestamp diff --git a/src/generators.h b/src/generators.h index 28a71e580..87468bb03 100644 --- a/src/generators.h +++ b/src/generators.h @@ -161,6 +161,5 @@ std::shared_ptr CreateGeneratorParams(const Config& config); / std::unique_ptr CreateGenerator(const Model& model, const GeneratorParams& params); float Float16ToFloat32(uint16_t v); // v is a IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction -void top_k_indices(std::span top_k, std::span inputs); } // namespace Generators diff --git a/src/search.cpp b/src/search.cpp index 274a5b2dd..a1a9f6890 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -161,8 +161,11 @@ void GreedySearch_Cpu::SampleTopK(int k, float temperature) { std::vector indices(scores.size()); std::iota(indices.begin(), indices.end(), 0); std::partial_sort(indices.begin(), indices.begin() + k, indices.end(), [scores = scores.data()](int i, int j) { return scores[i] > scores[j]; }); + std::vector top_k_scores(k); + for (int i = 0; i < k; i++) + top_k_scores[i] = scores[indices[i]]; // Sample a token from the top K - std::discrete_distribution<> dis(scores.begin(), scores.begin() + k); + std::discrete_distribution<> dis(top_k_scores.begin(), top_k_scores.end()); SetNextToken(batch_id, indices[dis(gen_)]); } AppendNextTokensToSequences(); diff --git a/src/top_k_cpu.cpp b/src/top_k_cpu.cpp deleted file mode 100644 index 745d02c01..000000000 --- a/src/top_k_cpu.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include "generators.h" -namespace Generators { - -void top_k_indices(std::span top_k, std::span inputs) { - int32_t k = static_cast(top_k.size()); - assert(k <= inputs.size()); // Use a smaller top_k span if k is larger than inputs - - // Min heap to store pairs of (element, index) - std::priority_queue, std::vector>, std::greater<>> pq; - - // Add first k elements into the heap - for (int32_t i = 0; i < k; i++) { - pq.push(std::make_pair(inputs[i], i)); - } - - // For the rest of the elements we already have k, so remove the smallest on each iteration - for (int32_t i = k; i < inputs.size(); i++) { - // Entry is smaller than the smallest, so don't bother - if (inputs[i] <= pq.top().first) - continue; - - pq.pop(); - pq.push(std::make_pair(inputs[i], i)); - } - - for (int i = 0; i < k; i++) { - top_k[k - i - 1] = pq.top().second; - pq.pop(); - } -} - -} // namespace Generators diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index a71910b15..6ca0080eb 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -176,15 +176,29 @@ TEST(SamplingTests, RandomizedSamplingTopPCpu) { } } +void SoftMax(std::span scores, float temperature) { + float const max_score = *std::max_element(scores.begin(), scores.end()); + + // Subtract max score and scale by temperature + std::transform(scores.begin(), scores.end(), scores.begin(), [max_score, temperature](float score) { return std::exp((score - max_score) / temperature); }); + + // Compute sum of exponentials + float const exp_sum = std::accumulate(scores.begin(), scores.end(), 0.0f); + + // Divide each score by the sum of exponentials + std::transform(scores.begin(), scores.end(), scores.begin(), [exp_sum](float score) { return score / exp_sum; }); +} + TEST(SamplingTests, RandomizedSamplingTopKCpu) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); - int batch_size = 5; - int k = 5; - std::vector input_ids{0, 1, 2, 3, 4}; + const int batch_size = 5; + const int k = 5; Generators::Config config; - config.model.vocab_size = 32000; // vocab size of llama + const int vocab_size = 13; // vocab size of llama + config.model.vocab_size = vocab_size; // vocab size of llama + // Create a generator auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->search.do_sample = true; @@ -192,15 +206,27 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { params->search.batch_size = batch_size; params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CPU); params->device_type = Generators::DeviceType::CPU; - std::vector logits_cpu(config.model.vocab_size * batch_size); + + // Create data structures for testing std::random_device rd; std::mt19937 engine(rd()); - std::uniform_int_distribution<> dist(5, 25); - int num_iter = 100; + std::vector indices(vocab_size); + std::vector logits_cpu(vocab_size * batch_size); + const int num_iter = 100; + std::map logit_to_count; + + // Run test for (int i = 0; i < num_iter; i++) { - int num_large = dist(engine); auto generator = Generators::CreateGenerator(*model, *params); - CreateRandomLogits(logits_cpu.data(), num_large, config.model.vocab_size, batch_size, engine); + logits_cpu = std::vector(vocab_size * batch_size, 0.0f); + // Shuffle integers 1 to k randomly into logits_cpu + for (int b = 0; b < batch_size; b++) { + std::iota(indices.begin(), indices.end(), 0); + std::shuffle(indices.begin(), indices.end(), engine); + for (int j = 0; j < k; j++) + logits_cpu[indices[j] + vocab_size * b] = float(k - j); + } + // Set logits and get generated token auto logits_copy = logits_cpu; auto logits = params->p_device->WrapMemory(logits_copy); generator->SetLogits(logits); @@ -209,10 +235,22 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { // Verify outputs match expected outputs for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; - auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b]; - EXPECT_GT(next_token_score, 10.0f); + auto next_token_score = logits_cpu[next_token + vocab_size * b]; + logit_to_count[next_token_score]++; + EXPECT_GT(next_token_score, 0.0f); } } + // Calculate expected distribution of tokens by softmaxing given logits (integers 1 through k) + std::vector expected_distributions(k); + for (int i = 0; i < k; i++) + expected_distributions[i] = float(i + 1); + SoftMax(expected_distributions, 1.0f); + // Check that the distribution of tokens generated by the model is close to the expected distribution + const int total_count = batch_size * num_iter; + for (auto& [logit, count] : logit_to_count) { + const float expected_distribution = expected_distributions[int(logit) - 1]; + EXPECT_NEAR(count / float(total_count), expected_distribution, 0.1); + } } TEST(SamplingTests, RandomizedSamplingTopPAndKCpu) { @@ -396,13 +434,14 @@ TEST(SamplingTests, RandomizedSamplingTopPCuda) { TEST(SamplingTests, RandomizedSamplingTopKCuda) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); - int batch_size = 5; - int k = 5; - std::vector input_ids{0, 1, 2, 3, 4}; + const int batch_size = 5; + const int k = 5; Generators::Config config; - config.model.vocab_size = 32000; // vocab size of llama + const int vocab_size = 17; // vocab size of llama + config.model.vocab_size = vocab_size; // vocab size of llama + // Create a generator auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->search.do_sample = true; @@ -410,29 +449,50 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { params->search.batch_size = batch_size; params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA); params->device_type = Generators::DeviceType::CUDA; - auto logits_gpu = params->p_device->Allocate(config.model.vocab_size * batch_size); - auto indices_buffer = params->p_device->Allocate(config.model.vocab_size * batch_size); + // Create data structures for testing std::random_device rd; std::mt19937 engine(rd()); - std::uniform_int_distribution<> dist(1, 25); - int num_iter = 100; + std::vector indices(vocab_size); + const int num_iter = 100; + std::map logit_to_count; + + // Run test for (int i = 0; i < num_iter; i++) { - int num_large = dist(engine); - LaunchGeometricDecayKernel(logits_gpu.Span().data(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); - LaunchFisherYatesKernel(logits_gpu.Span().data(), indices_buffer.Span().data(), config.model.vocab_size, batch_size, params->cuda_stream); auto generator = Generators::CreateGenerator(*model, *params); + Generators::DeviceSpan logits_gpu = params->p_device->Allocate(vocab_size * batch_size); + auto cpu_span = logits_gpu.CpuSpan(); + // Shuffle integers 1 to k randomly into cpu_span + for (int b = 0; b < batch_size; b++) { + std::iota(indices.begin(), indices.end(), 0); + std::shuffle(indices.begin(), indices.end(), engine); + for (int j = 0; j < k; j++) + cpu_span[indices[j] + vocab_size * b] = float(k - j); + } + // Copy logits onto device, set logits, and get generated token + logits_gpu.CopyCpuToDevice(); generator->SetLogits(logits_gpu); generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().CopyDeviceToCpu(); - auto logits_cpu = logits_gpu.CopyDeviceToCpu(); // Verify outputs match expected outputs for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; - auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b]; - EXPECT_GT(next_token_score, 10.0f); + auto next_token_score = cpu_span[next_token + vocab_size * b]; + logit_to_count[next_token_score]++; + EXPECT_GT(next_token_score, 0.0f); } } + // Calculate expected distribution of tokens by softmaxing given logits (integers 1 through k) + std::vector expected_distributions(k); + for (int i = 0; i < k; i++) + expected_distributions[i] = float(i + 1); + SoftMax(expected_distributions, 1.0f); + const int total_count = batch_size * num_iter; + // Check that the distribution of tokens generated by the model is close to the expected distribution + for (auto& [logit, count] : logit_to_count) { + const float expected_distribution = expected_distributions[int(logit) - 1]; + EXPECT_NEAR(count / float(total_count), expected_distribution, 0.1); + } } TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) {