From 0852578d71fd5147b5cdf33bfdccce9bf800a253 Mon Sep 17 00:00:00 2001 From: nlpfollower Date: Wed, 15 Jan 2025 19:09:53 -0800 Subject: [PATCH 1/2] Add encoded size to start_pos --- torchchat/generate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchchat/generate.py b/torchchat/generate.py index e271f5027..c37ac9d49 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1187,6 +1187,7 @@ def callback(x, *, done_generating=False): skip_cache_setup=not is_first_sample, max_seq_length=max_seq_length, ) + start_pos += encoded.size(0) for token_tensor, metrics in generator_func: if token_tensor is not None: start_pos += token_tensor.size(0) From e238a46766df37cb2cf7300e09859c618fc3002d Mon Sep 17 00:00:00 2001 From: nlpfollower Date: Wed, 15 Jan 2025 20:05:30 -0800 Subject: [PATCH 2/2] Only in chat mode --- torchchat/generate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index c37ac9d49..a14ece1ad 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1187,7 +1187,8 @@ def callback(x, *, done_generating=False): skip_cache_setup=not is_first_sample, max_seq_length=max_seq_length, ) - start_pos += encoded.size(0) + if generator_args.chat_mode: + start_pos += encoded.size(0) for token_tensor, metrics in generator_func: if token_tensor is not None: start_pos += token_tensor.size(0)