Skip to content

Android Qwen thinking mode prompt support #10668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,10 @@ private String getConversationHistory() {
prevPromptID = currentPromptID;
}
if (conversation.getIsSent()) {
format = format.replace(PromptFormat.USER_PLACEHOLDER, conversation.getText());
format =
format
.replace(PromptFormat.USER_PLACEHOLDER, conversation.getText())
.replace(PromptFormat.THINKING_MODE_PLACEHOLDER, "");
} else {
format = format.replace(PromptFormat.ASSISTANT_PLACEHOLDER, conversation.getText());
}
Expand All @@ -704,12 +707,12 @@ private String getConversationHistory() {

private String getTotalFormattedPrompt(String conversationHistory, String rawPrompt) {
if (conversationHistory.isEmpty()) {
return mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt);
return mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt, mThinkMode);
}

return mCurrentSettingsFields.getFormattedSystemPrompt()
+ conversationHistory
+ mCurrentSettingsFields.getFormattedUserPrompt(rawPrompt);
+ mCurrentSettingsFields.getFormattedUserPrompt(rawPrompt, mThinkMode);
}

private void onModelRunStarted() {
Expand Down Expand Up @@ -738,7 +741,8 @@ private void onModelRunStopped() {
if (ModelUtils.getModelCategory(
mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType())
== ModelUtils.VISION_MODEL) {
finalPrompt = mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt);
finalPrompt =
mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt, mThinkMode);
} else {
finalPrompt = getTotalFormattedPrompt(getConversationHistory(), rawPrompt);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public class PromptFormat {
public static final String SYSTEM_PLACEHOLDER = "{{ system_prompt }}";
public static final String USER_PLACEHOLDER = "{{ user_prompt }}";
public static final String ASSISTANT_PLACEHOLDER = "{{ assistant_response }}";
public static final String THINKING_MODE_PLACEHOLDER = "{{ thinking_mode }}";
public static final String DEFAULT_SYSTEM_PROMPT = "Answer the questions in a few sentences";

public static String getSystemPromptTemplate(ModelType modelType) {
Expand All @@ -32,7 +33,7 @@ public static String getSystemPromptTemplate(ModelType modelType) {
}
}

public static String getUserPromptTemplate(ModelType modelType) {
public static String getUserPromptTemplate(ModelType modelType, boolean thinkingMode) {
switch (modelType) {
case LLAMA_3:
case LLAMA_3_1:
Expand All @@ -43,15 +44,13 @@ public static String getUserPromptTemplate(ModelType modelType) {
+ "<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>";

case LLAVA_1_5:
case QWEN_3:
return "<|im_start|>user\n"
+ USER_PLACEHOLDER
+ "<|im_end|>\n"
+ "\n<|im_end|>\n"
+ "<|im_start|>assistant\n"
+ "<think>\n"
+ "\n"
+ "</think>\n\n\n";
+ THINKING_MODE_PLACEHOLDER;
case LLAVA_1_5:
default:
return USER_PLACEHOLDER;
}
Expand All @@ -62,9 +61,14 @@ public static String getConversationFormat(ModelType modelType) {
case LLAMA_3:
case LLAMA_3_1:
case LLAMA_3_2:
return getUserPromptTemplate(modelType) + "\n" + ASSISTANT_PLACEHOLDER + "<|eot_id|>";
return getUserPromptTemplate(modelType, false)
+ "\n"
+ ASSISTANT_PLACEHOLDER
+ "<|eot_id|>";
case LLAVA_1_5:
return USER_PLACEHOLDER + " ASSISTANT:";
case QWEN_3:
return getUserPromptTemplate(modelType, false) + "<|im_end|>\n";
default:
return USER_PLACEHOLDER;
}
Expand All @@ -86,13 +90,22 @@ public static String getStopToken(ModelType modelType) {
}
}

public static String getThinkingModeToken(ModelType modelType, boolean thinkingMode) {
switch (modelType) {
case QWEN_3:
return thinkingMode ? "" : "<think>\n\n</think>\n\n\n";
default:
return "";
}
}

public static String getLlavaPresetPrompt() {
return "A chat between a curious human and an artificial intelligence assistant. The assistant"
+ " gives helpful, detailed, and polite answers to the human's questions. USER: ";
}

public static String getFormattedLlamaGuardPrompt(String userPrompt) {
return getUserPromptTemplate(ModelType.LLAMA_GUARD_3)
return getUserPromptTemplate(ModelType.LLAMA_GUARD_3, false)
.replace(
USER_PLACEHOLDER, getLlamaGuardPresetPrompt().replace(USER_PLACEHOLDER, userPrompt));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ public void afterTextChanged(Editable s) {
new DialogInterface.OnClickListener() {
public void onClick(DialogInterface dialog, int whichButton) {
// Clear the messageAdapter and sharedPreference
mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType));
mUserPromptEditText.setText(
PromptFormat.getUserPromptTemplate(mModelType, false));
}
})
.setNegativeButton(android.R.string.no, null)
Expand All @@ -295,7 +296,7 @@ private void showInvalidPromptDialog() {
.setPositiveButton(
android.R.string.yes,
(dialog, whichButton) -> {
mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType));
mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType, false));
})
.setNegativeButton(android.R.string.no, null)
.show();
Expand Down Expand Up @@ -377,7 +378,7 @@ private void setupModelTypeSelectorDialog() {
(dialog, item) -> {
mModelTypeTextView.setText(modelTypes[item]);
mModelType = ModelType.valueOf(modelTypes[item]);
mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType));
mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType, false));
dialog.dismiss();
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,21 @@ public String getUserPrompt() {
return userPrompt;
}

public String getFormattedSystemAndUserPrompt(String prompt) {
return getFormattedSystemPrompt() + getFormattedUserPrompt(prompt);
public String getFormattedSystemAndUserPrompt(String prompt, boolean thinkingMode) {
return getFormattedSystemPrompt() + getFormattedUserPrompt(prompt, thinkingMode);
}

public String getFormattedSystemPrompt() {
return PromptFormat.getSystemPromptTemplate(modelType)
.replace(PromptFormat.SYSTEM_PLACEHOLDER, systemPrompt);
}

public String getFormattedUserPrompt(String prompt) {
return userPrompt.replace(PromptFormat.USER_PLACEHOLDER, prompt);
public String getFormattedUserPrompt(String prompt, boolean thinkingMode) {
return userPrompt
.replace(PromptFormat.USER_PLACEHOLDER, prompt)
.replace(
PromptFormat.THINKING_MODE_PLACEHOLDER,
PromptFormat.getThinkingModeToken(modelType, thinkingMode));
}

public boolean getIsClearChatHistory() {
Expand Down Expand Up @@ -77,7 +81,7 @@ public SettingsFields() {
tokenizerFilePath = "";
temperature = SettingsActivity.TEMPERATURE_MIN_VALUE;
systemPrompt = "";
userPrompt = PromptFormat.getUserPromptTemplate(DEFAULT_MODEL);
userPrompt = PromptFormat.getUserPromptTemplate(DEFAULT_MODEL, false);
isClearChatHistory = false;
isLoadModel = false;
modelType = DEFAULT_MODEL;
Expand Down
Loading