Skip to content

Commit 0a23e73

Browse files
authored
Merge pull request #79 from tilesprivacy/harmony-support
feat: support for gpt-oss in interactive chat
2 parents 2c2b914 + 2b29bcd commit 0a23e73

File tree

2 files changed

+33
-14
lines changed

2 files changed

+33
-14
lines changed

server/backend/mlx_runner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,6 @@ def _extract_stop_tokens(self):
254254
if hasattr(self.tokenizer, "name_or_path"):
255255
name_or_path = str(getattr(self.tokenizer, "name_or_path", "")).lower()
256256
model_type = ReasoningExtractor.detect_model_type(name_or_path)
257-
258257
if model_type:
259258
# This is a reasoning model
260259
self._is_reasoning_model = True

tiles/src/runtime/mlx.rs

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -391,8 +391,15 @@ async fn start_repl(mlx_runtime: &MLXRuntime, modelname: &str, run_args: &RunArg
391391
loop {
392392
if remaining_count > 0 {
393393
let chat_start = remaining_count == run_args.relay_count;
394-
if let Ok(response) =
395-
chat(&input, modelname, chat_start, &python_code, &g_reply).await
394+
if let Ok(response) = chat(
395+
&input,
396+
modelname,
397+
chat_start,
398+
&python_code,
399+
&g_reply,
400+
run_args,
401+
)
402+
.await
396403
{
397404
if response.reply.is_empty() {
398405
if !response.code.is_empty() {
@@ -401,7 +408,11 @@ async fn start_repl(mlx_runtime: &MLXRuntime, modelname: &str, run_args: &RunArg
401408
remaining_count -= 1;
402409
} else {
403410
g_reply = response.reply.clone();
404-
println!("\n{}", response.reply.trim());
411+
if run_args.memory {
412+
println!("\n{}", response.reply.trim());
413+
} else {
414+
println!("\n");
415+
}
405416
break;
406417
}
407418
} else {
@@ -473,6 +484,7 @@ async fn chat(
473484
chat_start: bool,
474485
python_code: &str,
475486
g_reply: &str,
487+
run_args: &RunArgs,
476488
) -> Result<ChatResponse, String> {
477489
let client = Client::new();
478490

@@ -493,6 +505,7 @@ async fn chat(
493505
let mut stream = res.bytes_stream();
494506
let mut accumulated = String::new();
495507
println!();
508+
let mut is_answer_start = false;
496509
while let Some(chunk) = stream.next().await {
497510
let chunk = chunk.unwrap();
498511
let s = String::from_utf8_lossy(&chunk);
@@ -504,13 +517,20 @@ async fn chat(
504517
let data = line.trim_start_matches("data: ");
505518

506519
if data == "[DONE]" {
507-
return Ok(convert_to_chat_response(&accumulated));
520+
return Ok(convert_to_chat_response(&accumulated, run_args.memory));
508521
}
509522
// Parse JSON
510523
let v: Value = serde_json::from_str(data).unwrap();
511524
if let Some(delta) = v["choices"][0]["delta"]["content"].as_str() {
512525
accumulated.push_str(delta);
513-
print!("{}", delta.dimmed());
526+
if !run_args.memory && delta.contains("**[Answer]**") {
527+
is_answer_start = true;
528+
}
529+
if !is_answer_start {
530+
print!("{}", delta.dimmed());
531+
} else {
532+
print!("{}", delta);
533+
}
514534
use std::io::Write;
515535
std::io::stdout().flush().ok();
516536
}
@@ -519,15 +539,18 @@ async fn chat(
519539
Err(String::from("request failed"))
520540
}
521541

522-
fn convert_to_chat_response(content: &str) -> ChatResponse {
542+
fn convert_to_chat_response(content: &str, memory_mode: bool) -> ChatResponse {
523543
ChatResponse {
524-
reply: extract_reply(content),
544+
reply: extract_reply(content, memory_mode),
525545
code: extract_python(content),
526546
}
527547
}
528548

529-
fn extract_reply(content: &str) -> String {
530-
if content.contains("<reply>") && content.contains("</reply>") {
549+
fn extract_reply(content: &str, memory_mode: bool) -> String {
550+
if !memory_mode && content.contains("**[Answer]**") {
551+
let list_a = content.split("**[Answer]**").collect::<Vec<&str>>();
552+
list_a[1].to_owned()
553+
} else if content.contains("<reply>") && content.contains("</reply>") {
531554
let list_a = content.split("<reply>").collect::<Vec<&str>>();
532555
let list_b = list_a[1].split("</reply>").collect::<Vec<&str>>();
533556
list_b[0].to_owned()
@@ -561,14 +584,11 @@ async fn wait_until_server_is_up() {
561584
}
562585

563586
fn get_default_modelfile(memory_mode: bool) -> Result<PathBuf> {
564-
// get default by the args -m
565-
// let path =
566587
if memory_mode {
567588
let path = get_lib_dir()?.join("modelfiles/mem-agent");
568589
Ok(path)
569590
} else {
570-
// let path = get_lib_dir()?.join("modelfiles/gpt-oss");
571-
let path = get_lib_dir()?.join("modelfiles/mem-agent");
591+
let path = get_lib_dir()?.join("modelfiles/gpt-oss");
572592
Ok(path)
573593
}
574594
}

0 commit comments

Comments
 (0)