-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Expand file tree
/
Copy pathllmRequestTest.cpp
More file actions
806 lines (714 loc) · 32.2 KB
/
llmRequestTest.cpp
File metadata and controls
806 lines (714 loc) · 32.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/executor/types.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <numeric>
#include <string>
#include <vector>
namespace tr = tensorrt_llm::runtime;
namespace tc = tensorrt_llm::common;
namespace texec = tensorrt_llm::executor;
namespace tb = tensorrt_llm::batch_manager;
using VecTokens = tb::LlmRequest::VecTokens;
using SizeType32 = tb::LlmRequest::SizeType32;
using VecTokenExtraIds = tb::LlmRequest::VecTokenExtraIds;
using VecUniqueTokens = tb::LlmRequest::VecUniqueTokens;
class LlmRequestTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-type-member-init)
{
protected:
void SetUp() override {}
void TearDown() override {}
};
TEST_F(LlmRequestTest, fromExecutorRequest)
{
VecTokens inputTokens{1, 2, 3, 4, 5};
SizeType32 maxNewTokens(66);
texec::IdType requestId{77};
{
texec::Request execReq(inputTokens, maxNewTokens);
tb::LlmRequest llmReq(requestId, execReq);
EXPECT_EQ(llmReq.getTokens().size(), 1);
EXPECT_EQ(llmReq.getTokens().at(0), inputTokens);
EXPECT_EQ(llmReq.mMaxNewTokens, maxNewTokens);
EXPECT_EQ(llmReq.mSamplingConfig.numReturnSequences, execReq.getSamplingConfig().getNumReturnSequences());
EXPECT_EQ(llmReq.getOrigPromptLen(), inputTokens.size());
EXPECT_EQ(llmReq.getMaxSentTokenLen(), inputTokens.size());
EXPECT_EQ(llmReq.getState(), tb::LlmRequestState::kCONTEXT_INIT);
EXPECT_FALSE(llmReq.mSeqSlot);
// No speculative decoding config, draft tokens should be empty
EXPECT_EQ(llmReq.getNumDraftTokens(), 0);
EXPECT_FALSE(llmReq.getEmbeddingBias().has_value());
EXPECT_FALSE(llmReq.getBadWordsList().has_value());
EXPECT_FALSE(llmReq.getStopWordsList().has_value());
EXPECT_FALSE(llmReq.getPromptEmbeddingTable().has_value());
EXPECT_FALSE(llmReq.getPromptVocabSize().has_value());
}
// Embedding bias
{
texec::Request execReq(inputTokens, maxNewTokens);
SizeType32 vocabSize = 100;
// Try adding embedding bias
auto embeddingBias = texec::Tensor::cpu(texec::DataType::kFP32, {vocabSize});
execReq.setEmbeddingBias(embeddingBias);
tb::LlmRequest llmReq(requestId, execReq);
EXPECT_TRUE(llmReq.getEmbeddingBias().has_value());
EXPECT_EQ(llmReq.getEmbeddingBias().value()->getShape().nbDims, 2);
EXPECT_EQ(llmReq.getEmbeddingBias().value()->getShape().d[0], 1);
EXPECT_EQ(llmReq.getEmbeddingBias().value()->getShape().d[1], vocabSize);
}
// bad/stop words
{
texec::Request execReq(inputTokens, maxNewTokens);
SizeType32 vocabSize = 100;
// Try adding embedding bias
std::list<VecTokens> badWords{{1, 2, 3}, {4, 5}, {9}};
std::list<VecTokens> stopWords{{1, 3}, {4}};
execReq.setBadWords(badWords);
execReq.setStopWords(stopWords);
tb::LlmRequest llmReq(requestId, execReq);
EXPECT_TRUE(llmReq.getBadWordsList().has_value());
EXPECT_TRUE(llmReq.getStopWordsList().has_value());
{
auto badWordsTensor = llmReq.getBadWordsList().value();
EXPECT_EQ(badWordsTensor->getDataType(), nvinfer1::DataType::kINT32);
EXPECT_EQ(badWordsTensor->getShape().nbDims, 3);
EXPECT_EQ(badWordsTensor->getShape().d[0], 1);
EXPECT_EQ(badWordsTensor->getShape().d[1], 2);
EXPECT_EQ(badWordsTensor->getShape().d[2], 6);
auto data = tr::bufferCast<int32_t>(*badWordsTensor);
EXPECT_EQ(data[0], 1);
EXPECT_EQ(data[1], 2);
EXPECT_EQ(data[2], 3);
EXPECT_EQ(data[3], 4);
EXPECT_EQ(data[4], 5);
EXPECT_EQ(data[5], 9);
EXPECT_EQ(data[6 + 0], 3);
EXPECT_EQ(data[6 + 1], 5);
EXPECT_EQ(data[6 + 2], 6);
EXPECT_EQ(data[6 + 3], -1);
EXPECT_EQ(data[6 + 4], -1);
EXPECT_EQ(data[6 + 5], -1);
}
{
auto stopWordsTensor = llmReq.getStopWordsList().value();
EXPECT_EQ(stopWordsTensor->getDataType(), nvinfer1::DataType::kINT32);
EXPECT_EQ(stopWordsTensor->getShape().nbDims, 3);
EXPECT_EQ(stopWordsTensor->getShape().d[0], 1);
EXPECT_EQ(stopWordsTensor->getShape().d[1], 2);
EXPECT_EQ(stopWordsTensor->getShape().d[2], 3);
auto data = tr::bufferCast<int32_t>(*stopWordsTensor);
EXPECT_EQ(data[0], 1);
EXPECT_EQ(data[1], 3);
EXPECT_EQ(data[2], 4);
EXPECT_EQ(data[3 + 0], 2);
EXPECT_EQ(data[3 + 1], 3);
EXPECT_EQ(data[3 + 2], -1);
}
}
// Prompt tuning
{
texec::Request execReq(inputTokens, maxNewTokens);
SizeType32 vocabSize = 100;
SizeType32 hiddenSize = 64;
auto embeddingTable = texec::Tensor::cpu(texec::DataType::kFP32, {vocabSize, hiddenSize});
VecTokenExtraIds extraIds{1, 1, 1, 0, 0};
texec::PromptTuningConfig config(embeddingTable, extraIds);
execReq.setPromptTuningConfig(config);
tb::LlmRequest llmReq(requestId, execReq);
EXPECT_TRUE(llmReq.getPromptEmbeddingTable().has_value());
EXPECT_TRUE(llmReq.getPromptVocabSize().has_value());
EXPECT_EQ(llmReq.getPromptEmbeddingTable().value()->getShape().nbDims, 3);
EXPECT_EQ(llmReq.getPromptEmbeddingTable().value()->getShape().d[0], 1);
EXPECT_EQ(llmReq.getPromptEmbeddingTable().value()->getShape().d[1], vocabSize);
EXPECT_EQ(llmReq.getPromptEmbeddingTable().value()->getShape().d[2], hiddenSize);
EXPECT_EQ(llmReq.getPromptEmbeddingTable().value()->getDataType(), nvinfer1::DataType::kFLOAT);
EXPECT_EQ(llmReq.getPromptVocabSize().value(), vocabSize);
VecUniqueTokens uniqueTokens;
for (size_t i = 0; i < inputTokens.size(); ++i)
{
uniqueTokens.push_back({inputTokens[i], extraIds[i]});
}
EXPECT_EQ(llmReq.getUniqueTokens(0), uniqueTokens);
}
}
TEST_F(LlmRequestTest, invalidExecRequest)
{
VecTokens inputTokens{1, 2, 3, 4, 5};
SizeType32 maxNewTokens(66);
texec::IdType requestId{77};
// Input is too long
std::list<std::pair<std::function<void()>, std::string>> lambdaErrMsgs;
{
auto lambda = [&inputTokens, maxNewTokens, requestId]()
{
texec::Request execReq(inputTokens, maxNewTokens);
tb::LlmRequest llmReq(requestId, execReq);
llmReq.validate(2, 1000, 0, 32000);
};
lambdaErrMsgs.emplace_back(lambda, "exceeds maximum input");
}
// Invalid beam width
{
auto lambda = [&inputTokens, maxNewTokens, requestId]()
{
texec::Request execReq(inputTokens, maxNewTokens);
execReq.setSamplingConfig(texec::SamplingConfig(-1));
tb::LlmRequest llmReq(requestId, execReq);
llmReq.validate(500, 1000, 0, 32000);
};
lambdaErrMsgs.emplace_back(lambda, "beamWidth > 0");
}
// Invalid input draft len
{
auto lambda = [&inputTokens, maxNewTokens, requestId]()
{
texec::Request execReq(inputTokens, maxNewTokens);
execReq.setExternalDraftTokensConfig(texec::ExternalDraftTokensConfig({1, 2}));
tb::LlmRequest llmReq(requestId, execReq);
llmReq.validate(500, 1000, 1, 32000);
};
lambdaErrMsgs.emplace_back(lambda, "exceeds maximum draft");
}
// Invalid ptable shape
{
auto lambda = [&inputTokens, maxNewTokens, requestId]()
{
texec::Request execReq(inputTokens, maxNewTokens);
auto embeddingTable = texec::Tensor::cpu(texec::DataType::kFP32, {17, 32, 69});
texec::PromptTuningConfig config(embeddingTable);
execReq.setPromptTuningConfig(config);
tb::LlmRequest llmReq(requestId, execReq);
};
lambdaErrMsgs.emplace_back(lambda, "Expected prompt embedding table to have shape");
}
// Invalid extra id vector's size
{
auto lambda = [&inputTokens, maxNewTokens, requestId]()
{
texec::Request execReq(inputTokens, maxNewTokens);
auto embeddingTable = texec::Tensor::cpu(texec::DataType::kFP32, {4, 8});
VecTokenExtraIds extraIds(inputTokens.size() - 1, 0);
texec::PromptTuningConfig config(embeddingTable, extraIds);
execReq.setPromptTuningConfig(config);
tb::LlmRequest llmReq(requestId, execReq);
};
lambdaErrMsgs.emplace_back(lambda, "must be the same as input token vector size");
}
// Extra ids not provided when enabling kv cache reuse with prompt table
{
auto lambda = [&inputTokens, maxNewTokens, requestId]()
{
texec::Request execReq(inputTokens, maxNewTokens);
auto embeddingTable = texec::Tensor::cpu(texec::DataType::kFP32, {4, 8});
texec::PromptTuningConfig config(embeddingTable);
execReq.setPromptTuningConfig(config);
tb::LlmRequest llmReq(requestId, execReq);
llmReq.validate(500, 1000, 1, 32000, std::nullopt, true);
};
lambdaErrMsgs.emplace_back(lambda, "Input token extra ids must be provided");
}
// Invalid endId
{
auto lambda = [&inputTokens, maxNewTokens, requestId]()
{
texec::Request execReq(
inputTokens, maxNewTokens, false, texec::SamplingConfig(), texec::OutputConfig(), -2);
tb::LlmRequest llmReq(requestId, execReq);
llmReq.validate(500, 1000, 1, 32000);
};
lambdaErrMsgs.emplace_back(lambda, "EndId (-2) is not within acceptable range [-1, 32000)");
}
for (auto& lambdaErrMsg : lambdaErrMsgs)
{
auto& lambda = lambdaErrMsg.first;
auto& errMsg = lambdaErrMsg.second;
try
{
lambda();
FAIL() << "Expected failure with " << errMsg;
}
catch (tc::TllmException const& e)
{
EXPECT_THAT(e.what(), testing::HasSubstr(errMsg));
}
catch (std::exception const& e)
{
FAIL() << "Expected TllmException with " << errMsg << " got " << e.what();
}
}
{
// Validate output len truncation w/o draft tokens
texec::Request execReq(inputTokens, maxNewTokens);
tb::LlmRequest llmReq(requestId, execReq);
llmReq.validate(10, 60, 0, 32000);
EXPECT_EQ(llmReq.mMaxNewTokens, 60 - inputTokens.size());
}
{
// Validate output len truncation w draft tokens
texec::Request execReq(inputTokens, maxNewTokens);
tb::LlmRequest llmReq(requestId, execReq);
llmReq.validate(10, 60, 2, 32000);
EXPECT_EQ(llmReq.mMaxNewTokens, 60 - inputTokens.size() - 2);
}
{
// Validate extra ids when enabling kv cache reuse with prompt table
texec::Request execReq(inputTokens, maxNewTokens);
auto embeddingTable = texec::Tensor::cpu(texec::DataType::kFP32, {6, 42});
VecTokenExtraIds extraIds(inputTokens.size(), 1);
texec::PromptTuningConfig config(embeddingTable, extraIds);
execReq.setPromptTuningConfig(config);
tb::LlmRequest llmReq(requestId, execReq);
EXPECT_EQ(static_cast<size_t>(llmReq.getOrigPromptLen()), inputTokens.size());
llmReq.validate(500, 1000, 1, 32000, std::nullopt, true);
}
{
using AdditionalModelOutput = texec::AdditionalModelOutput;
// Validate additional context and gen outputs
texec::Request execReq(inputTokens, maxNewTokens);
std::vector<AdditionalModelOutput> additionalModelOutputs{
AdditionalModelOutput{"context_gen_output", true}, AdditionalModelOutput{"gen_output", false}};
texec::OutputConfig outputConfig;
outputConfig.additionalModelOutputs = additionalModelOutputs;
execReq.setOutputConfig(outputConfig);
tb::LlmRequest llmReq(requestId, execReq);
llmReq.validate(10, 60, 2, 32000, std::nullopt, false);
auto const& additionalContextOutputs = llmReq.getAdditionalContextOutputs();
EXPECT_EQ(additionalContextOutputs.count("context_gen_output"), 1);
EXPECT_EQ(additionalContextOutputs.count("gen_output"), 0);
auto const& additionalGenerationOutputs = llmReq.getAdditionalGenerationOutputs();
EXPECT_EQ(additionalGenerationOutputs.count("context_gen_output"), 1);
EXPECT_EQ(additionalGenerationOutputs.count("gen_output"), 1);
}
}
TEST_F(LlmRequestTest, pause)
{
auto inputTokens = std::make_shared<VecTokens>(VecTokens{1, 2, 3, 4, 5});
SizeType32 maxNewTokens(66);
tb::LlmRequest::RequestIdType requestId{77};
tb::LlmRequest llmReq(requestId, maxNewTokens, inputTokens, tr::SamplingConfig(1), false);
llmReq.addNewToken(1, 0);
llmReq.addNewToken(1, 0);
llmReq.addNewToken(1, 0);
llmReq.addNewToken(1, 0);
llmReq.addNewToken(1, 0);
EXPECT_EQ(llmReq.getMaxNumGeneratedTokens(), 5);
// maxInput is larger then num tokens
llmReq.pause(12);
EXPECT_EQ(llmReq.mPromptLen, 10);
EXPECT_EQ(llmReq.mMaxNewTokens, 61);
EXPECT_EQ(llmReq.getState(), tb::LlmRequestState::kCONTEXT_INIT);
EXPECT_EQ(llmReq.getMaxNumGeneratedTokens(), 0);
llmReq.addNewToken(1, 0);
llmReq.addNewToken(1, 0);
llmReq.addNewToken(1, 0);
llmReq.addNewToken(1, 0);
EXPECT_EQ(llmReq.getMaxNumGeneratedTokens(), 4);
llmReq.pause(12);
// max Input is now smaller than num tokens
EXPECT_EQ(llmReq.mPromptLen, 12);
EXPECT_EQ(llmReq.mMaxNewTokens, 59);
EXPECT_EQ(llmReq.getState(), tb::LlmRequestState::kCONTEXT_INIT);
EXPECT_EQ(llmReq.getMaxNumGeneratedTokens(), 0);
}
TEST_F(LlmRequestTest, testAllocateLogitsBuffer)
{
auto inputTokens = std::make_shared<VecTokens>(VecTokens{1, 2, 3, 4, 5});
SizeType32 maxNewTokens(60);
tb::LlmRequest::RequestIdType requestId{77};
tb::LlmRequest llmReq(requestId, maxNewTokens, inputTokens, tr::SamplingConfig(1), false);
EXPECT_EQ(llmReq.mPromptLen, 5);
SizeType32 vocabSizePadded = 32000;
nvinfer1::DataType logitsDataType = nvinfer1::DataType::kFLOAT;
// Test the allocation of context logits
EXPECT_EQ(llmReq.getContextLogitsHost(), nullptr);
llmReq.allocContextLogitsHost(vocabSizePadded, logitsDataType);
auto contextLogitsHostShape = llmReq.getContextLogitsHost()->getShape();
EXPECT_EQ(contextLogitsHostShape.nbDims, 2);
EXPECT_EQ(contextLogitsHostShape.d[0], 5);
EXPECT_EQ(contextLogitsHostShape.d[1], vocabSizePadded);
// Test the allocation of generation logits
EXPECT_EQ(llmReq.getGenerationLogitsHost(), nullptr);
llmReq.allocGenerationLogitsHost(vocabSizePadded, logitsDataType);
auto generationLogitsHostShape = llmReq.getGenerationLogitsHost()->getShape();
EXPECT_EQ(generationLogitsHostShape.nbDims, 3);
EXPECT_EQ(generationLogitsHostShape.d[0], 1);
EXPECT_EQ(generationLogitsHostShape.d[1], maxNewTokens);
EXPECT_EQ(generationLogitsHostShape.d[2], vocabSizePadded);
// Test the allocation of target model's accepted token logits
// Set draft token
EXPECT_EQ(llmReq.getNumDraftTokens(), 0);
auto draftTokens = std::make_shared<VecTokens>(VecTokens{7, 8, 9});
llmReq.setDraftTokens(draftTokens);
EXPECT_EQ(llmReq.getNumDraftTokens(), 3);
// Clean the generation logits
llmReq.setGenerationLogitsHost(nullptr);
EXPECT_EQ(llmReq.getGenerationLogitsHost(), nullptr);
llmReq.allocTargetModelAcceptedTokenLogitsHost(vocabSizePadded, logitsDataType);
auto targetModelAcceptedTokenLogitShape = llmReq.getGenerationLogitsHost()->getShape();
EXPECT_EQ(targetModelAcceptedTokenLogitShape.nbDims, 3);
EXPECT_EQ(targetModelAcceptedTokenLogitShape.d[0], 1);
EXPECT_EQ(targetModelAcceptedTokenLogitShape.d[1], 4);
EXPECT_EQ(targetModelAcceptedTokenLogitShape.d[2], vocabSizePadded);
}
TEST_F(LlmRequestTest, testLastTokensSetIndependence)
{
tb::LlmRequest::RequestIdType requestId{77};
SizeType32 maxNewTokens(66);
auto inputTokens = std::make_shared<VecTokens>(VecTokens{1, 2, 3, 4, 5});
SizeType32 beamWidth = 3;
bool streaming = false;
tb::LlmRequest::BeamTokens expectedInitialOutput
= {{1, 2, 3, 4, 5, 10, 20}, {1, 2, 3, 4, 5, 11, 21}, {1, 2, 3, 4, 5, 12, 22}};
tb::LlmRequest::BeamTokens expectedOverwrittenOutput
= {{1, 2, 3, 4, 5, 100, 200}, {1, 2, 3, 4, 5, 101, 201}, {1, 2, 3, 4, 5, 102, 202}};
tb::LlmRequest llmReq(requestId, maxNewTokens, inputTokens, tr::SamplingConfig(beamWidth), streaming);
// check individually set tokens
llmReq.addNewToken(10, 0);
llmReq.addNewToken(11, 1);
llmReq.addNewToken(12, 2);
auto lastTokens = llmReq.getLastTokens();
EXPECT_EQ(lastTokens.size(), beamWidth);
EXPECT_THAT(lastTokens, testing::ElementsAreArray({10, 11, 12}));
// check tokens set all-at-once
VecTokens expectedLastTokens = VecTokens({20, 21, 22});
llmReq.addNewTokens(expectedLastTokens);
for (SizeType32 beam = 0; beam < beamWidth; beam++)
{
EXPECT_EQ(llmReq.getLastTokens(beam), expectedLastTokens[beam]);
}
// check mTokens when written by addNewToken
for (SizeType32 beam = 0; beam < beamWidth; beam++)
{
EXPECT_THAT(llmReq.getTokens(beam), testing::ElementsAreArray(expectedInitialOutput[beam]));
}
// check that setGeneratedTokens sets mTokens, but doesn't change lastTokens
tb::LlmRequest::BeamTokens overwriteTokens = {{100, 200}, {101, 201}, {102, 202}};
llmReq.setGeneratedTokens(overwriteTokens);
for (SizeType32 beam = 0; beam < beamWidth; beam++)
{
EXPECT_THAT(llmReq.getTokens(beam), testing::ElementsAreArray(expectedOverwrittenOutput[beam]));
}
EXPECT_THAT(llmReq.getLastTokens(), testing::ElementsAreArray({20, 21, 22}));
}
TEST_F(LlmRequestTest, testCreateRequests)
{
auto inputTokens = std::make_shared<VecTokens>(VecTokens{1, 2, 3, 4, 5});
SizeType32 maxNewTokens{60};
tb::LlmRequest::RequestIdType requestId{77};
SizeType32 vocabSize{32};
nvinfer1::DataType dtype{nvinfer1::DataType::kHALF};
tr::SamplingConfig samplingConfig(1);
samplingConfig.randomSeed = std::vector<texec::RandomSeedType>{7};
tb::LlmRequest llmReq(requestId, maxNewTokens, inputTokens, samplingConfig, false);
try
{
auto childReq = llmReq.createChildRequest(1837);
FAIL() << "Expected an exception.";
}
catch (tc::TllmException const& e)
{
EXPECT_THAT(e.what(), testing::HasSubstr("Cannot create child requests more than"));
}
samplingConfig.numReturnSequences = 3;
tb::LlmRequest llmReq2(requestId, maxNewTokens, inputTokens, samplingConfig, false);
auto childReq1 = llmReq2.createChildRequest(78);
{
EXPECT_EQ(llmReq2.getChildRequests().size(), 1);
EXPECT_EQ(childReq1->mRequestId, 78);
EXPECT_EQ(childReq1->getTokens().at(0), *inputTokens);
EXPECT_EQ(childReq1->getNumTokens(0), llmReq.getNumTokens(0));
EXPECT_EQ(childReq1->getOrigPromptLen(), llmReq.getOrigPromptLen());
EXPECT_EQ(childReq1->mMaxNewTokens, llmReq.mMaxNewTokens);
EXPECT_EQ(childReq1->getState(), llmReq.getState());
EXPECT_EQ(childReq1->mSamplingConfig.randomSeed.value(), std::vector<texec::RandomSeedType>{8});
EXPECT_EQ(llmReq2.mSamplingConfig.randomSeed.value(), std::vector<texec::RandomSeedType>{7});
EXPECT_FALSE(childReq1->mSeqSlot);
}
{
auto childReq2 = llmReq2.createChildRequest(79);
auto childRequests = llmReq2.getChildRequests();
EXPECT_EQ(childRequests.size(), 2);
EXPECT_EQ(childRequests.at(0), childReq1);
EXPECT_EQ(childRequests.at(1), childReq2);
EXPECT_EQ(childReq2->mSamplingConfig.randomSeed.value(), std::vector<texec::RandomSeedType>{9});
EXPECT_EQ(childReq1->mSamplingConfig.randomSeed.value(), std::vector<texec::RandomSeedType>{8});
EXPECT_EQ(llmReq2.mSamplingConfig.randomSeed.value(), std::vector<texec::RandomSeedType>{7});
}
}
using ParamType = std::tuple<bool, bool, bool, SizeType32, SizeType32, SizeType32>;
std::string generateTestName(testing::TestParamInfo<ParamType> const& info)
{
auto const streaming = std::get<0>(info.param);
auto const excludeInputFromOutput = std::get<1>(info.param);
auto const returnAllGeneratedTokens = std::get<2>(info.param);
auto const beamWdith = std::get<3>(info.param);
auto const tokensPerIteration = std::get<4>(info.param);
auto const numReturnSequences = std::get<5>(info.param);
std::string name = "llmRequestTest";
if (streaming)
{
name += "Streaming";
}
if (excludeInputFromOutput)
{
name += "ExclInput";
}
if (returnAllGeneratedTokens)
{
name += "RetAllTokens";
}
name += "Bw" + std::to_string(beamWdith);
name += "TokensPerIt" + std::to_string(tokensPerIteration);
name += "N" + std::to_string(numReturnSequences);
return name;
}
class ParamTest : public LlmRequestTest, public ::testing::WithParamInterface<ParamType>
{
};
TEST_P(ParamTest, createResponse)
{
bool const streaming{std::get<0>(GetParam())};
bool const excludeInputFromOutput{std::get<1>(GetParam())};
bool const returnAllGeneratedTokens{std::get<2>(GetParam())};
SizeType32 const beamWidth{std::get<3>(GetParam())};
SizeType32 const tokensPerIteration{std::get<4>(GetParam())};
SizeType32 const numReturnSequences{std::get<5>(GetParam())};
auto inputTokens = std::make_shared<VecTokens>(VecTokens{1, 2, 3, 4, 5});
SizeType32 maxNewTokens(66);
tb::LlmRequest::RequestIdType requestId{77};
tr::SamplingConfig samplingConfig(beamWidth);
// numReturnSequences = nullopt, otherwise.
if (beamWidth == 1 || numReturnSequences < beamWidth)
{
samplingConfig.numReturnSequences = numReturnSequences;
}
auto numReturnBeams = samplingConfig.getNumReturnBeams();
// Expect one sequence per request in beam search.
auto numSequences = beamWidth > 1 ? 1 : numReturnSequences;
std::vector<std::shared_ptr<tb::LlmRequest>> llmRequests;
llmRequests.emplace_back(
std::make_shared<tb::LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, streaming));
{
auto llmReq = llmRequests.at(0);
llmReq->setExcludeInputFromOutput(excludeInputFromOutput);
if (streaming && beamWidth > 1 && !returnAllGeneratedTokens)
{
EXPECT_THROW(
llmReq->setReturnAllGeneratedTokens(returnAllGeneratedTokens), tensorrt_llm::common::TllmException);
return;
}
llmReq->setReturnAllGeneratedTokens(returnAllGeneratedTokens);
}
if (beamWidth == 1)
{
auto llmReq = llmRequests.at(0);
for (auto seqIdx = 1; seqIdx < numReturnSequences; seqIdx++)
{
tb::LlmRequest::RequestIdType childReqId{77 + static_cast<tb::LlmRequest::RequestIdType>(seqIdx)};
auto childReq = llmReq->createChildRequest(childReqId);
EXPECT_EQ(childReq->getReturnAllGeneratedTokens(), llmReq->getReturnAllGeneratedTokens());
EXPECT_TRUE(childReq->isChild());
llmRequests.emplace_back(std::move(childReq));
}
}
for (auto& llmReq : llmRequests)
{
auto response = llmReq->createResponse();
EXPECT_FALSE(response);
}
SizeType32 constexpr numIterations{5};
std::vector<texec::TokenIdType> newTokens(numSequences);
std::iota(newTokens.begin(), newTokens.end(), 1);
for (auto seqIdx = 0; seqIdx < numSequences; seqIdx++)
{
auto llmReq = llmRequests.at(seqIdx);
for (int i = 0; i < numIterations - 1; ++i)
{
for (int j = 0; j < tokensPerIteration; ++j)
{
llmReq->addNewTokens(VecTokens(numReturnBeams, newTokens.at(seqIdx)));
}
llmReq->setState(tb::LlmRequestState::kGENERATION_IN_PROGRESS);
auto response = llmReq->createResponse();
EXPECT_TRUE(streaming == response.has_value());
for (int beamIdx = 0; beamIdx < numReturnBeams; ++beamIdx)
{
if (streaming)
{
EXPECT_EQ(response.value().getRequestId(), requestId);
auto result = response.value().getResult();
EXPECT_EQ(result.outputTokenIds.size(), numReturnBeams);
auto const& beamTokens = result.outputTokenIds.at(beamIdx);
if (returnAllGeneratedTokens)
{
auto const expectedSize = (i + 1) * tokensPerIteration;
EXPECT_EQ(beamTokens.size(), expectedSize);
VecTokens expectedTokens(expectedSize, newTokens.at(seqIdx));
EXPECT_THAT(beamTokens, testing::ElementsAreArray(expectedTokens));
}
else
{
auto const expectedSize = tokensPerIteration;
EXPECT_EQ(beamTokens.size(), expectedSize);
VecTokens expectedTokens(expectedSize, newTokens.at(seqIdx));
EXPECT_THAT(beamTokens, testing::ElementsAreArray(expectedTokens));
}
}
}
response = llmReq->createResponse();
EXPECT_FALSE(response);
}
}
for (auto seqIdx = 0; seqIdx < numSequences; seqIdx++)
{
for (int j = 0; j < tokensPerIteration; ++j)
{
llmRequests.at(seqIdx)->addNewTokens(VecTokens(numReturnBeams, newTokens.at(seqIdx)));
}
}
llmRequests.at(0)->setState(tb::LlmRequestState::kGENERATION_COMPLETE);
auto const numNewTokens = numIterations * tokensPerIteration;
for (auto seqIdx = 0; seqIdx < numSequences; seqIdx++)
{
auto llmReq = llmRequests.at(seqIdx);
auto response = llmReq->createResponse();
if (!streaming && llmRequests.at(seqIdx)->getState() != tb::LlmRequestState::kGENERATION_COMPLETE)
{
EXPECT_FALSE(response);
continue;
}
EXPECT_TRUE(response) << "seqIdx " << seqIdx;
EXPECT_FALSE(response.value().hasError()) << "seqIdx " << seqIdx;
// All response should have the same request id of the original request.
EXPECT_EQ(response.value().getRequestId(), requestId);
auto result = response.value().getResult();
EXPECT_EQ(result.outputTokenIds.size(), numReturnBeams);
// Only the first sequence has finished.
EXPECT_EQ(result.isSequenceFinal, seqIdx == 0) << "seqIdx " << seqIdx;
EXPECT_EQ(result.isFinal, numSequences == 1) << "seqIdx " << seqIdx;
auto newToken = newTokens.at(seqIdx);
for (int beamIdx = 0; beamIdx < numReturnBeams; ++beamIdx)
{
auto const& beamTokens = result.outputTokenIds.at(beamIdx);
if (!streaming)
{
if (excludeInputFromOutput)
{
EXPECT_EQ(beamTokens.size(), numNewTokens);
VecTokens expectedTokens(numNewTokens, newToken);
EXPECT_THAT(beamTokens, testing::ElementsAreArray(expectedTokens));
}
else
{
auto const expectedSize = inputTokens->size() + numNewTokens;
EXPECT_EQ(beamTokens.size(), expectedSize);
VecTokens expectedTokens(*inputTokens);
expectedTokens.resize(expectedSize, newToken);
EXPECT_THAT(beamTokens, testing::ElementsAreArray(expectedTokens));
}
}
else
{
if (returnAllGeneratedTokens)
{
EXPECT_EQ(beamTokens.size(), numNewTokens);
VecTokens expectedTokens(numNewTokens, newToken);
EXPECT_THAT(beamTokens, testing::ElementsAreArray(expectedTokens));
}
else
{
auto const expectedSize = tokensPerIteration;
EXPECT_EQ(beamTokens.size(), expectedSize);
VecTokens expectedTokens(expectedSize, newToken);
EXPECT_THAT(beamTokens, testing::ElementsAreArray(expectedTokens));
}
}
}
}
if (numSequences > 1)
{
for (auto seqIdx = 1; seqIdx < numSequences; seqIdx++)
{
auto llmReq = llmRequests.at(seqIdx);
for (int j = 0; j < tokensPerIteration; ++j)
{
llmReq->addNewTokens(VecTokens(beamWidth, newTokens.at(seqIdx)));
}
llmReq->setState(tb::LlmRequestState::kGENERATION_COMPLETE);
}
for (auto seqIdx = 1; seqIdx < numSequences; seqIdx++)
{
auto response = llmRequests.at(seqIdx)->createResponse();
EXPECT_TRUE(response) << "seqIdx " << seqIdx;
EXPECT_FALSE(response.value().hasError()) << "seqIdx " << seqIdx;
auto result = response.value().getResult();
// All sequences have finished.
EXPECT_TRUE(result.isSequenceFinal) << "seqIdx " << seqIdx;
EXPECT_TRUE(result.isFinal) << "seqIdx " << seqIdx;
}
}
}
// Regression test for nvbug/5961736: createResult() must produce a valid
// response with contextPhaseParams when the request is in
// kDISAGG_CONTEXT_COMPLETE, not just kDISAGG_CONTEXT_TRANS_IN_PROGRESS.
// Without the fix, createResult() returns nullopt for CONTEXT_COMPLETE,
// causing ctx_request_id=None in the disaggregated serving response.
TEST_F(LlmRequestTest, createResultDisaggContextComplete)
{
VecTokens inputTokens{1, 2, 3, 4, 5};
SizeType32 maxNewTokens{10};
texec::IdType requestId{42};
// Build an executor::Request and configure it as context-only with ContextPhaseParams.
texec::Request execReq(inputTokens, maxNewTokens);
execReq.setRequestType(texec::RequestType::REQUEST_TYPE_CONTEXT_ONLY);
texec::ContextPhaseParams ctxParams({100}, requestId, static_cast<void*>(nullptr), std::nullopt);
execReq.setContextPhaseParams(std::move(ctxParams));
tb::LlmRequest llmReq(requestId, execReq);
EXPECT_TRUE(llmReq.isContextOnlyRequest());
// Add a generated token (required by createResult's firstGenTokens extraction).
llmReq.addNewTokens(VecTokens{42});
// Verify isFinished() covers DISAGG_CONTEXT_COMPLETE.
llmReq.setState(tb::LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
EXPECT_TRUE(llmReq.isFinished());
// This is the regression case — without the fix, createResult() returns nullopt
// because DISAGG_CONTEXT_COMPLETE was not handled by createResult's early guard
// or its context-phase branch.
auto response = llmReq.createResult(/*useFastLogits=*/false, /*mpiWorldRank=*/0);
ASSERT_TRUE(response.has_value()) << "createResult() must not return nullopt for DISAGG_CONTEXT_COMPLETE";
EXPECT_TRUE(response->contextPhaseParams.has_value())
<< "contextPhaseParams must be populated for context-only DISAGG_CONTEXT_COMPLETE requests";
EXPECT_EQ(response->contextPhaseParams->getReqId(), requestId);
EXPECT_TRUE(response->isSequenceFinal);
}
INSTANTIATE_TEST_SUITE_P(LlmRequestTest, ParamTest,
testing::Combine(
// TODO: Support and add coverage for streamLLM
testing::Values(false),
// excludeInputFromOutput
testing::Values(false, true),
// returnAllGeneratedTokens
testing::Values(false, true),
// beamWidth
testing::Values(1, 2),
// tokensPerIteration
testing::Values(1, 3),
// numReturnSequences
testing::Values(1, 2)),
generateTestName);