|
1 | 1 | package cc.unitmesh.devti.llms.custom |
2 | 2 |
|
3 | 3 | import cc.unitmesh.devti.llms.CodeCopilotProvider |
| 4 | +import cc.unitmesh.devti.llms.azure.SimpleOpenAIBody |
4 | 5 | import cc.unitmesh.devti.prompting.model.CustomPromptConfig |
5 | 6 | import cc.unitmesh.devti.settings.AutoDevSettingsState |
| 7 | +import com.fasterxml.jackson.databind.ObjectMapper |
6 | 8 | import com.intellij.openapi.components.Service |
7 | 9 | import com.intellij.openapi.diagnostic.logger |
8 | 10 | import com.intellij.openapi.project.Project |
| 11 | +import com.theokanning.openai.completion.chat.ChatCompletionResult |
| 12 | +import com.theokanning.openai.service.SSE |
| 13 | +import io.reactivex.BackpressureStrategy |
| 14 | +import io.reactivex.Flowable |
| 15 | +import io.reactivex.FlowableEmitter |
| 16 | +import kotlinx.coroutines.Dispatchers |
9 | 17 | import kotlinx.coroutines.ExperimentalCoroutinesApi |
10 | 18 | import kotlinx.coroutines.channels.awaitClose |
11 | 19 | import kotlinx.coroutines.flow.Flow |
12 | 20 | import kotlinx.coroutines.flow.callbackFlow |
| 21 | +import kotlinx.coroutines.withContext |
13 | 22 | import kotlinx.serialization.Serializable |
14 | 23 | import kotlinx.serialization.encodeToString |
15 | 24 | import kotlinx.serialization.json.Json |
@@ -63,30 +72,35 @@ class CustomLLMProvider(val project: Project) : CodeCopilotProvider { |
63 | 72 | .post(body) |
64 | 73 | .build() |
65 | 74 |
|
66 | | - return callbackFlow { |
67 | | - val listener = object : EventSourceListener() { |
68 | | - override fun onOpen(eventSource: EventSource, response: Response) { |
69 | | - println("onOpen") |
70 | | - } |
71 | | - |
72 | | - override fun onEvent(eventSource: EventSource, id: String?, type: String?, data: String) { |
73 | | - println(data) |
74 | | - trySend(data) |
75 | | - } |
| 75 | + val call = client.newCall(request) |
| 76 | + val emitDone = false |
76 | 77 |
|
77 | | - override fun onClosed(eventSource: EventSource) { |
| 78 | + val sseFlowable = Flowable |
| 79 | + .create({ emitter: FlowableEmitter<SSE> -> |
| 80 | + call.enqueue(cc.unitmesh.devti.llms.azure.ResponseBodyCallback(emitter, emitDone)) |
| 81 | + }, BackpressureStrategy.BUFFER) |
78 | 82 |
|
79 | | - } |
| 83 | + try { |
| 84 | + return callbackFlow { |
| 85 | + withContext(Dispatchers.IO) { |
| 86 | + sseFlowable |
| 87 | + .doOnError(Throwable::printStackTrace) |
| 88 | + .blockingForEach { sse -> |
| 89 | + val result: ChatCompletionResult = |
| 90 | + ObjectMapper().readValue(sse!!.data, ChatCompletionResult::class.java) |
| 91 | + val completion = result.choices[0].message |
| 92 | + if (completion != null && completion.content != null) { |
| 93 | + trySend(completion.content) |
| 94 | + } |
| 95 | + } |
80 | 96 |
|
81 | | - override fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) { |
82 | 97 | close() |
83 | 98 | } |
84 | 99 | } |
85 | | - |
86 | | - val eventSource = EventSources.createFactory(client).newEventSource(request, listener) |
87 | | - |
88 | | - awaitClose { |
89 | | - eventSource.cancel() |
| 100 | + } catch (e: Exception) { |
| 101 | + logger.error("Failed to stream", e) |
| 102 | + return callbackFlow { |
| 103 | + close() |
90 | 104 | } |
91 | 105 | } |
92 | 106 | } |
|
0 commit comments