Skip to content

Commit a7a9c1c

Browse files
committed
feat: make custom server works with stream
1 parent 6b6e91d commit a7a9c1c

File tree

1 file changed

+32
-18
lines changed

1 file changed

+32
-18
lines changed

src/main/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProvider.kt

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
11
package cc.unitmesh.devti.llms.custom
22

33
import cc.unitmesh.devti.llms.CodeCopilotProvider
4+
import cc.unitmesh.devti.llms.azure.SimpleOpenAIBody
45
import cc.unitmesh.devti.prompting.model.CustomPromptConfig
56
import cc.unitmesh.devti.settings.AutoDevSettingsState
7+
import com.fasterxml.jackson.databind.ObjectMapper
68
import com.intellij.openapi.components.Service
79
import com.intellij.openapi.diagnostic.logger
810
import com.intellij.openapi.project.Project
11+
import com.theokanning.openai.completion.chat.ChatCompletionResult
12+
import com.theokanning.openai.service.SSE
13+
import io.reactivex.BackpressureStrategy
14+
import io.reactivex.Flowable
15+
import io.reactivex.FlowableEmitter
16+
import kotlinx.coroutines.Dispatchers
917
import kotlinx.coroutines.ExperimentalCoroutinesApi
1018
import kotlinx.coroutines.channels.awaitClose
1119
import kotlinx.coroutines.flow.Flow
1220
import kotlinx.coroutines.flow.callbackFlow
21+
import kotlinx.coroutines.withContext
1322
import kotlinx.serialization.Serializable
1423
import kotlinx.serialization.encodeToString
1524
import kotlinx.serialization.json.Json
@@ -63,30 +72,35 @@ class CustomLLMProvider(val project: Project) : CodeCopilotProvider {
6372
.post(body)
6473
.build()
6574

66-
return callbackFlow {
67-
val listener = object : EventSourceListener() {
68-
override fun onOpen(eventSource: EventSource, response: Response) {
69-
println("onOpen")
70-
}
71-
72-
override fun onEvent(eventSource: EventSource, id: String?, type: String?, data: String) {
73-
println(data)
74-
trySend(data)
75-
}
75+
val call = client.newCall(request)
76+
val emitDone = false
7677

77-
override fun onClosed(eventSource: EventSource) {
78+
val sseFlowable = Flowable
79+
.create({ emitter: FlowableEmitter<SSE> ->
80+
call.enqueue(cc.unitmesh.devti.llms.azure.ResponseBodyCallback(emitter, emitDone))
81+
}, BackpressureStrategy.BUFFER)
7882

79-
}
83+
try {
84+
return callbackFlow {
85+
withContext(Dispatchers.IO) {
86+
sseFlowable
87+
.doOnError(Throwable::printStackTrace)
88+
.blockingForEach { sse ->
89+
val result: ChatCompletionResult =
90+
ObjectMapper().readValue(sse!!.data, ChatCompletionResult::class.java)
91+
val completion = result.choices[0].message
92+
if (completion != null && completion.content != null) {
93+
trySend(completion.content)
94+
}
95+
}
8096

81-
override fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) {
8297
close()
8398
}
8499
}
85-
86-
val eventSource = EventSources.createFactory(client).newEventSource(request, listener)
87-
88-
awaitClose {
89-
eventSource.cancel()
100+
} catch (e: Exception) {
101+
logger.error("Failed to stream", e)
102+
return callbackFlow {
103+
close()
90104
}
91105
}
92106
}

0 commit comments

Comments
 (0)