Skip to content

Commit f4b2875

Browse files
committed
Add trtllm TriggerPhrase LP
Signed-off-by: aerdem4 <ahmeterd4@gmail.com>
1 parent cc08f36 commit f4b2875

File tree

6 files changed

+161
-40
lines changed

6 files changed

+161
-40
lines changed

example_notebooks/transformers/trigger_phrase_logits_processor.ipynb

Lines changed: 60 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,7 @@
2828
"name": "stderr",
2929
"output_type": "stream",
3030
"text": [
31-
"/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
32-
" warnings.warn(\n",
33-
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
34-
"/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
35-
" warnings.warn(\n"
31+
"Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.\n"
3632
]
3733
}
3834
],
@@ -70,14 +66,9 @@
7066
"name": "stderr",
7167
"output_type": "stream",
7268
"text": [
73-
"/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
74-
" warnings.warn(\n",
75-
"/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:397: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
76-
" warnings.warn(\n",
77-
"/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:407: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n",
78-
" warnings.warn(\n",
7969
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
80-
"Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.\n"
70+
"Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.\n",
71+
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n"
8172
]
8273
},
8374
{
@@ -113,9 +104,9 @@
113104
"\n",
114105
"Let me test this function with some examples. For n=0, it returns 0. For n=1, returns 1. For n=2, it's F(1)+F(0) = 1+0=1. For n=3, F(2)+F(1)=1+1=2. That looks correct.\n",
115106
"\n",
116-
"Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which matches the standard definition. So, the function should work regardless of the starting point as long as the base cases are correct.\n",
107+
"Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which is correct.\n",
117108
"\n",
118-
"Another thing to consider is the base cases. If the function is called with n=0, it returns 0, which is correct. For n=1, returns 1. For n=2, returns 1, which is correct. So, the function should handle all non-negative integers correctly.\n",
109+
"Another test case: n=5. Let's compute it step by step. F(0)=0, F(1)=1, F(2)=1, F(3)=2, F(4)=3, F(5)=5. So the function should return 5 for n=5.\n",
119110
"\n",
120111
"I think this should work. So, the function is straightforward. It's a simple recursive implementation, but it's not the most efficient for large n. However, for the purpose of this problem, it's acceptable.\n",
121112
"</think>\n",
@@ -215,39 +206,64 @@
215206
"\n",
216207
"Let me test this function with some examples. For n=0, it returns 0. For n=1, returns 1. For n=2, it's F(1)+F(0) = 1+0=1. For n=3, F(2)+F(1)=1+1=2. That looks correct.\n",
217208
"\n",
218-
"Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which matches the standard definition. So, the function should work regardless of the starting point as long as the base cases are correct.\n",
209+
"Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which is correct.\n",
219210
"\n",
220-
"Another thing to consider is the base cases. If the function is called with n=0, it returns 0, which is correct. For n=1, returns 1. For n=2, returns 1, which is correct. So, the function should handle all non-negative integers correctly.\n",
211+
"Another test case: n=5. Let's compute it step by step. F(0)=0, F(1)=1, F(2)=1, F(3)=2, F(4)=3, F(5)=5. So the function should return 5 for n=5.\n",
221212
"\n",
222213
"I think this should work. So, the function is straightforward. It's a simple recursive implementation, but it's not the most efficient for large n. However, for the purpose of this problem, it's acceptable.\n",
223-
",,,\n",
214+
"Wait, but in the problem statement, it says to make it recursive. So, the function as written is recursive, but it's not optimized. So, I think this should be the solution.\n",
215+
"Wait, but the problem says to make it recursive, so perhaps the function should handle larger n efficiently. But without memoization, it's not efficient. So, maybe the function should be written with memoization or an iterative approach.\n",
216+
"\n",
217+
"Alternatively, I can implement it iteratively, which is more efficient. Let's think about that.\n",
218+
"\n",
219+
"An iterative approach would start from the bottom and build up the Fibonacci numbers. For example:\n",
220+
"\n",
221+
"def fibonacci(n):\n",
222+
" if n == 0:\n",
223+
" return 0\n",
224+
" a, b = 0, 1\n",
225+
" for _ in range(2, n+1):\n",
226+
" a, b = b, a + b\n",
227+
" return b\n",
228+
"\n",
229+
"Wait, let's test this. For n=5, the loop runs from 2 to 5, which is 4 iterations. Let's see:\n",
230+
"\n",
231+
"Initial a=0, b=1.\n",
232+
"\n",
233+
"After first iteration (i=2): a becomes 1, b becomes 0+1=1.\n",
234+
"\n",
235+
"After second iteration (i=3): a=1, b=1+1=2.\n",
236+
"\n",
237+
"After third iteration (i=4): a=2, b=1+2=3.\n",
238+
"\n",
239+
"After fourth iteration (i=5): a=3, b=2+3=5.\n",
240+
"\n",
241+
"So, return b=5, which is correct.\n",
242+
"\n",
243+
"This approach is more efficient as it runs in O(n) time and uses O(1) space, which is better for large n.\n",
224244
"\n",
225-
"Wait, but the problem says to make it recursive. So, the function should call itself with smaller arguments. The approach I have is correct and recursive.\n",
245+
"But the problem says to make it recursive. So, perhaps the recursive approach is acceptable, but the iterative approach is better for larger n.\n",
226246
"\n",
227-
"So, the final function is as I wrote above.\n",
228-
",,,\n",
247+
"But since the problem doesn't specify, I'll proceed with the recursive approach as it's straightforward.\n",
229248
"\n",
230-
"Wait, but in the function, for n=2, it's F(1)+F(0) = 1+0=1, which is correct. For n=3, F(2)+F(1)=1+1=2. So, the function works as expected.\n",
249+
"Wait, but the problem says to make it recursive. So, perhaps the function should be written recursively, but without memoization. So, the function as I wrote earlier is correct.\n",
231250
"\n",
232-
"I think this should solve the problem.\n",
251+
"Alternatively, I can write it with memoization using a dictionary to store previously computed values. But that's more advanced and perhaps beyond the scope of this problem.\n",
252+
"\n",
253+
"So, the function I wrote earlier is correct and meets the requirements.\n",
233254
"</think>\n",
234255
"\n",
235256
"To solve this problem, we need to generate the nth Fibonacci number using a recursive approach. The Fibonacci sequence is a series of numbers where each number is the sum of the two preceding ones, starting from 0 and 1. \n",
236257
"\n",
237258
"### Approach\n",
238-
"The approach to solve this problem involves using recursion, which is a method where a function calls itself with a modified parameter to achieve the desired result. Here's a step-by-step breakdown of the approach:\n",
259+
"The Fibonacci sequence is defined as follows:\n",
260+
"- F(0) = 0\n",
261+
"- F(1) = 1\n",
262+
"- F(n) = F(n-1) + F(n-2) for n >= 2\n",
239263
"\n",
240-
"1. **Base Cases**: \n",
241-
" - If `n` is 0, return 0.\n",
242-
" - If `n` is 1, return 1.\n",
243-
" \n",
244-
"2. **Recursive Case**:\n",
245-
" - For any `n` greater than 1, the nth Fibonacci number is the sum of the (n-1)th and (n-2)th Fibonacci numbers. This is achieved by recursively calling the function with `n-1` and `n-2` and adding their results.\n",
246-
"\n",
247-
"This approach ensures that each Fibonacci number is computed by breaking down the problem into smaller subproblems, which are then solved recursively.\n",
264+
"Given the requirement to use a recursive approach, we can define a function that calls itself with smaller values of n until it reaches the base cases. The function will handle the base cases directly and use recursion for the general case.\n",
248265
"\n",
249266
"### Solution Code\n",
250-
"\n",
251267
"```python\n",
252268
"def fibonacci(n):\n",
253269
" if n == 0:\n",
@@ -259,10 +275,16 @@
259275
"```\n",
260276
"\n",
261277
"### Explanation\n",
262-
"- **Base Cases**: The function first checks if `n` is 0 or 1. If `n` is 0, it returns 0. If `n` is 1, it returns 1. These are the simplest cases of the Fibonacci sequence.\n",
263-
"- **Recursive Case**: For any `n` greater than 1, the function calls itself with `n-1` and `n-2`, and returns the sum of these two recursive calls. This builds up the solution by solving smaller subproblems and combining their results.\n",
278+
"The function `fibonacci` takes an integer `n` as input and returns the nth Fibonacci number. \n",
279+
"\n",
280+
"1. **Base Cases**:\n",
281+
" - If `n` is 0, the function returns 0.\n",
282+
" - If `n` is 1, the function returns 1.\n",
283+
"\n",
284+
"2. **Recursive Case**:\n",
285+
" - For `n >= 2`, the function calls itself with `n-1` and `n-2` and returns the sum of these two recursive calls. This builds up the Fibonacci sequence from the bottom up, ensuring that each value is computed only once.\n",
264286
"\n",
265-
"This approach is straightforward and leverages the divide-and-conquer strategy inherent in recursion, making it easy to understand and implement. However, it's important to note that this approach has a time complexity of O(2^n) due to the exponential number of function calls, which is not efficient for large values of `n`. For larger values, an iterative approach or memoization would be more efficient.\n",
287+
"This approach is straightforward and leverages the recursive nature of the Fibonacci sequence, making it easy to understand and implement. However, it's important to note that for very large values of `n`, this approach can be inefficient due to repeated calculations. For larger values, an iterative approach or memoization would be more efficient.\n",
266288
"-----END-----\n",
267289
"\n"
268290
]
@@ -332,9 +354,9 @@
332354
"\n",
333355
"Let me test this function with some examples. For n=0, it returns 0. For n=1, returns 1. For n=2, it's F(1)+F(0) = 1+0=1. For n=3, F(2)+F(1)=1+1=2. That looks correct.\n",
334356
"\n",
335-
"Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which matches the standard definition. So, the function should work regardless of the starting point as long as the base cases are correct.\n",
357+
"Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which is correct.\n",
336358
"\n",
337-
"Another thing to consider is the base cases. If the function is called with n=0, it returns 0, which is correct. For n=1, returns 1. For n=2, returns 1, which is correct. So, the function should handle all non-negative integers correctly.\n",
359+
"Another test case: n=5. Let's compute it step by step. F(0)=0, F(1)=1, F(2)=1, F(3)=2, F(4)=3, F(5)=5. So the function should return 5 for n=5.\n",
338360
"\n",
339361
"I think this should work. So, the function is straightforward. It's a simple recursive implementation, but it's not the most efficient for large n. However, for the purpose of this problem, it's acceptable.\n",
340362
"</think>\n",
@@ -348,7 +370,7 @@
348370
" return fibonacci(n-1) + fibonacci(n-2)\n",
349371
"```\n",
350372
"\n",
351-
"This function calculates the nth Fibonacci number using a recursive approach. It handles the base cases where n is 0 or 1 and recursively computes the value for larger n by summing the two preceding Fibonacci numbers.\n",
373+
"This function calculates the nth Fibonacci number using a recursive approach. It handles the base cases where n is 0 or 1 and for other values, it recursively calculates the sum of the two preceding Fibonacci numbers. While this implementation is straightforward, it's not the most efficient for large values of n due to repeated calculations.\n",
352374
"-----END-----\n",
353375
"\n"
354376
]

example_notebooks/trtllm/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,6 @@ python example_notebooks/trtllm/multiple_choice_logits_processor.py -p "I am get
2626
3. Battery"
2727
2828
python example_notebooks/trtllm/prevent_hallucination_logits_processor.py -p "Tell me the Nobel Prizes in 1977"
29+
30+
python example_notebooks/trtllm/trigger_phrase_logits_processor.py -p "Generate a python function to calculate nth fibonacci number. Make it recursive. Keep thinking short."
2931
```
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from transformers import AutoTokenizer
2+
from logits_processor_zoo.trtllm import TriggerPhraseLogitsProcessor
3+
from utils import TRTLLMTester, get_parser
4+
5+
6+
if __name__ == "__main__":
7+
args = get_parser()
8+
9+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
10+
llm_tester = TRTLLMTester(args.model_name, args.backend)
11+
12+
lp = TriggerPhraseLogitsProcessor("...Wait, let me think more.", " function", tokenizer,
13+
trigger_count=2, trigger_after=False)
14+
llm_tester.run([args.prompt], logits_processor=lp)
15+
16+
lp = TriggerPhraseLogitsProcessor("\n```python", " function", tokenizer, trigger_count=1, trigger_after=True)
17+
llm_tester.run([args.prompt], logits_processor=lp)

logits_processor_zoo/transformers/trigger_phrase.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
5555
if scores[i, :].argmax() == self.trigger_token and it == -1:
5656
self.iterators[i] = 0
5757
if not self.trigger_after:
58-
scores[i] = enforce_tokens(scores[i], [self.phrase_tokens[it]])
58+
scores[i] = enforce_tokens(scores[i], [self.phrase_tokens[0]])
5959
self.iterators[i] += 1
6060
elif len(self.phrase_tokens) > it >= 0:
6161
scores[i] = enforce_tokens(scores[i], [self.phrase_tokens[it]])

logits_processor_zoo/trtllm/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .cite_prompt import CiteFromPromptLogitsProcessor
2121
from .multiple_choice import MultipleChoiceLogitsProcessor
2222
from .prevent_hallucination import PreventHallucinationLogitsProcessor
23+
from .trigger_phrase import TriggerPhraseLogitsProcessor
2324

2425
__all__ = ['GenLengthLogitsProcessor', 'ForceLastPhraseLogitsProcessor', 'CiteFromPromptLogitsProcessor',
25-
'MultipleChoiceLogitsProcessor', 'PreventHallucinationLogitsProcessor']
26+
'MultipleChoiceLogitsProcessor', 'PreventHallucinationLogitsProcessor', 'TriggerPhraseLogitsProcessor']
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import List, Optional
19+
from transformers import PreTrainedTokenizer
20+
import torch
21+
from logits_processor_zoo.utils import enforce_tokens, text_to_token
22+
from tensorrt_llm.sampling_params import LogitsProcessor
23+
24+
25+
class TriggerPhraseLogitsProcessor(LogitsProcessor):
26+
"""
27+
A logits processor which triggers phrases when it encounters a given token.
28+
29+
Parameters
30+
----------
31+
phrase (str): The phrase to be generated by LLM when it encounters the trigger token.
32+
trigger_token_phrase (str): One token phrase in string to trigger phrases.
33+
tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM.
34+
trigger_count (int): How many times the phrase will be triggered.
35+
trigger_after (bool): Whether the phrase is written after the trigger token or instead of the trigger token.
36+
"""
37+
def __init__(self, phrase: str, trigger_token_phrase: str, tokenizer: PreTrainedTokenizer,
38+
trigger_count: int = 1, trigger_after: bool = False):
39+
self.tokenizer = tokenizer
40+
self.trigger_token = text_to_token(self.tokenizer, trigger_token_phrase, last=False)
41+
self.phrase_tokens = self.tokenizer.encode(phrase, add_special_tokens=False)
42+
self.initial_trigger_count = trigger_count
43+
self.trigger_after = trigger_after
44+
self.iterators = None
45+
self.trigger_counts = None
46+
47+
def _init_before_gen(self, beam_width):
48+
self.iterators = -torch.ones(beam_width, dtype=torch.int32)
49+
self.trigger_counts = self.initial_trigger_count*torch.ones(beam_width, dtype=torch.int32)
50+
51+
def __call__(self, req_id: int, logits: torch.Tensor,
52+
token_ids: List[List[int]], stream_ptr: Optional[int],
53+
client_id: Optional[int]) -> None:
54+
beam_width = len(token_ids)
55+
if self.iterators is None:
56+
self._init_before_gen(beam_width)
57+
58+
stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr)
59+
60+
with torch.cuda.stream(stream):
61+
for i in range(beam_width): # iterate over beams
62+
if self.trigger_counts[i] <= 0:
63+
continue
64+
65+
current_index = self.iterators[i].item()
66+
67+
if logits[0, i].argmax() == self.trigger_token and current_index == -1:
68+
self.iterators[i] = 0
69+
print("triggering...")
70+
if not self.trigger_after:
71+
enforce_tokens(logits[0, i], [self.phrase_tokens[0]])
72+
self.iterators[i] += 1
73+
elif len(self.phrase_tokens) > current_index >= 0:
74+
enforce_tokens(logits[0, i], [self.phrase_tokens[current_index]])
75+
self.iterators[i] += 1
76+
77+
if len(self.phrase_tokens) == self.iterators[i].item(): # phrase completed, reset for next trigger
78+
self.iterators[i] = -1
79+
self.trigger_counts[i] -= 1

0 commit comments

Comments
 (0)