Skip to content

Commit 63ca37e

Browse files
committed
fix assistant prefill logic (+1 squashed commits)
Squashed commits: [f4963ba] fix prefills
1 parent 53b3bf4 commit 63ca37e

1 file changed

Lines changed: 12 additions & 15 deletions

File tree

koboldcpp.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3104,29 +3104,26 @@ def raise_exception(msg):
31043104
func["arguments"] = json.loads(args)
31053105
except Exception:
31063106
pass
3107-
# Fix tool content for some templates
3108-
# if m.get("role") == "tool" and isinstance(m.get("content"), str):
3109-
# try:
3110-
# m["content"] = json.loads(m["content"])
3111-
# except Exception:
3112-
# pass
31133107
jinja_env.globals['strftime_now'] = strftime_now
31143108
jinja_env.globals['raise_exception'] = raise_exception
31153109
jinja_env.filters["tojson"] = tojson
31163110
jinja_compiled_template = jinja_env.from_string(cached_chat_template)
31173111
text = None
3118-
last_assist_msg = messages[-1]["content"]
3112+
messages_for_render = []
3113+
assist_should_prefill = False
31193114
chat_template_kwargs = chat_template_kwargs or {}
3120-
assist_should_prefill = (messages and messages[-1]["role"] == "assistant" and last_assist_msg and isinstance(last_assist_msg, str) and len(last_assist_msg.strip())>0) #avoid single character newline or space content
3115+
last_assist_msg = ""
3116+
if messages:
3117+
last_assist_msg = messages[-1]["content"]
3118+
assist_should_prefill = (messages and messages[-1]["role"] == "assistant" and last_assist_msg and isinstance(last_assist_msg, str) and len(last_assist_msg.strip())>0) #avoid single character newline or space content
3119+
last_assist_msg = "" if not assist_should_prefill else last_assist_msg
3120+
messages_for_render = messages[:-1] if assist_should_prefill else messages
31213121
if tools and len(tools)>0:
3122-
text = jinja_compiled_template.render(messages=messages, tools=tools, add_generation_prompt=True, bos_token="", eos_token="", **chat_template_kwargs)
3122+
text = jinja_compiled_template.render(messages=messages_for_render, tools=tools, add_generation_prompt=True, bos_token="", eos_token="", **chat_template_kwargs)
31233123
else:
3124-
text = jinja_compiled_template.render(messages=messages, add_generation_prompt=True, bos_token="", eos_token="", **chat_template_kwargs)
3125-
3126-
if assist_should_prefill and text: # handle prefill continuations
3127-
lastindex = text.rfind(last_assist_msg)
3128-
if lastindex != -1:
3129-
text = text[:lastindex + len(last_assist_msg)]
3124+
text = jinja_compiled_template.render(messages=messages_for_render, add_generation_prompt=True, bos_token="", eos_token="", **chat_template_kwargs)
3125+
if assist_should_prefill and text and last_assist_msg: # handle prefill continuations
3126+
text = text + last_assist_msg
31303127
return text if text else None
31313128
except Exception as e:
31323129
print(f"Jinja formatting failed: {e}")

0 commit comments

Comments
 (0)