Skip to content

Commit cacdc42

Browse files
committed
[Google] [Gemini] 支持流式输出内容
1 parent 9bbb77e commit cacdc42

File tree

7 files changed

+120
-3
lines changed

7 files changed

+120
-3
lines changed

src/main/java/org/devlive/sdk/openai/listener/ConsoleEventSourceListener.java renamed to src/main/java/org/devlive/sdk/common/listener/ConsoleEventSourceListener.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package org.devlive.sdk.openai.listener;
1+
package org.devlive.sdk.common.listener;
22

33
import lombok.Builder;
44
import lombok.extern.slf4j.Slf4j;

src/main/java/org/devlive/sdk/openai/listener/HttpServletEventSourceListener.java renamed to src/main/java/org/devlive/sdk/common/listener/HttpServletEventSourceListener.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package org.devlive.sdk.openai.listener;
1+
package org.devlive.sdk.common.listener;
22

33
import jakarta.servlet.http.HttpServletRequest;
44
import jakarta.servlet.http.HttpServletResponse;

src/main/java/org/devlive/sdk/platform/google/GoogleClient.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,21 @@
55
import lombok.Builder;
66
import lombok.extern.slf4j.Slf4j;
77
import okhttp3.OkHttpClient;
8+
import okhttp3.Request;
9+
import okhttp3.RequestBody;
10+
import okhttp3.sse.EventSource;
11+
import okhttp3.sse.EventSourceListener;
12+
import okhttp3.sse.EventSources;
813
import org.apache.commons.lang3.ObjectUtils;
914
import org.apache.commons.lang3.StringUtils;
1015
import org.devlive.sdk.common.DefaultClient;
1116
import org.devlive.sdk.common.exception.ParamException;
17+
import org.devlive.sdk.common.exception.RequestException;
1218
import org.devlive.sdk.common.utils.ValidateUtils;
19+
import org.devlive.sdk.openai.mixin.IgnoreUnknownMixin;
1320
import org.devlive.sdk.openai.model.ProviderModel;
1421
import org.devlive.sdk.openai.model.UrlModel;
22+
import org.devlive.sdk.openai.utils.MultipartBodyUtils;
1523
import org.devlive.sdk.openai.utils.ProviderUtils;
1624
import org.devlive.sdk.platform.google.entity.ChatEntity;
1725
import org.devlive.sdk.platform.google.interceptor.GoogleInterceptor;
@@ -39,6 +47,7 @@ public class GoogleClient
3947
private String model;
4048
private VersionModel version;
4149
private GoogleApi api;
50+
private EventSourceListener listener;
4251

4352
private GoogleClient(GoogleClientBuilder builder)
4453
{
@@ -68,12 +77,20 @@ private GoogleClient(GoogleClientBuilder builder)
6877
}
6978
this.model = builder.model;
7079

80+
if (ObjectUtils.isEmpty(builder.listener)) {
81+
builder.listener(null);
82+
}
83+
super.listener = builder.listener;
84+
this.listener = builder.listener;
85+
7186
if (ObjectUtils.isEmpty(builder.client)) {
7287
builder.client(null);
7388
}
7489

7590
super.client = builder.client;
91+
this.client = builder.client;
7692
super.apiHost = builder.apiHost;
93+
this.apiHost = builder.apiHost;
7794
super.provider = ProviderModel.GOOGLE_GEMINI;
7895

7996
objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
@@ -89,10 +106,38 @@ private GoogleClient(GoogleClientBuilder builder)
89106
public ChatResponse createChatCompletions(ChatEntity configure)
90107
{
91108
String url = ProviderUtils.getUrl(provider, UrlModel.FETCH_CHAT_COMPLETIONS);
109+
if (ObjectUtils.isNotEmpty(this.listener)) {
110+
this.createEventSource(url, configure);
111+
return null;
112+
}
113+
92114
return this.api.fetchChatCompletions(url, configure)
93115
.blockingGet();
94116
}
95117

118+
private ObjectMapper createObjectMapper()
119+
{
120+
ObjectMapper objectMapper = new ObjectMapper();
121+
objectMapper.addMixIn(Object.class, IgnoreUnknownMixin.class);
122+
return objectMapper;
123+
}
124+
125+
private void createEventSource(String url, Object configure)
126+
{
127+
try {
128+
EventSource.Factory factory = EventSources.createFactory(this.client);
129+
ObjectMapper mapper = this.createObjectMapper();
130+
Request request = new Request.Builder()
131+
.url(String.join("/", this.apiHost, url))
132+
.post(RequestBody.create(MultipartBodyUtils.JSON, mapper.writeValueAsString(configure)))
133+
.build();
134+
factory.newEventSource(request, this.listener);
135+
}
136+
catch (Exception e) {
137+
throw new RequestException(String.format("Failed to create event source: %s", e.getMessage()));
138+
}
139+
}
140+
96141
public static class GoogleClientBuilder
97142
{
98143
public GoogleClientBuilder apiKey(String apiKey)
@@ -140,6 +185,9 @@ public GoogleClientBuilder client(OkHttpClient client)
140185
interceptor.setApiKey(apiKey);
141186
interceptor.setVersion(version);
142187
interceptor.setModel(model);
188+
if (listener != null) {
189+
interceptor.setStream(true);
190+
}
143191
client = client.newBuilder()
144192
.addInterceptor(interceptor)
145193
.build();

src/main/java/org/devlive/sdk/platform/google/interceptor/GoogleInterceptor.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ public class GoogleInterceptor
1919
{
2020
@Setter
2121
private VersionModel version;
22+
@Setter
23+
private Boolean stream = false;
2224

2325
public GoogleInterceptor()
2426
{
@@ -46,6 +48,13 @@ protected Request prepared(Request original)
4648
.addPathSegments(String.join("/", pathSegments))
4749
.addQueryParameter("key", this.getApiKey())
4850
.build();
51+
52+
if (stream) {
53+
httpUrl = httpUrl.newBuilder()
54+
.addQueryParameter("alt", "sse")
55+
.build();
56+
}
57+
4958
log.info("Google interceptor request url {}", httpUrl);
5059

5160
return original.newBuilder()

src/test/java/org/devlive/sdk/openai/StreamClientTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
import com.google.common.collect.Lists;
44
import lombok.extern.slf4j.Slf4j;
5+
import org.devlive.sdk.common.listener.ConsoleEventSourceListener;
56
import org.devlive.sdk.openai.entity.ChatEntity;
67
import org.devlive.sdk.openai.entity.CompletionEntity;
78
import org.devlive.sdk.openai.entity.MessageEntity;
8-
import org.devlive.sdk.openai.listener.ConsoleEventSourceListener;
99
import org.junit.Before;
1010
import org.junit.Test;
1111

src/test/java/org/devlive/sdk/openai/listener/HttpServletEventSourceListenerTest.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import jakarta.servlet.http.HttpServletRequest;
44
import jakarta.servlet.http.HttpServletResponse;
55
import jakarta.servlet.http.HttpSession;
6+
import org.devlive.sdk.common.listener.HttpServletEventSourceListener;
67
import org.devlive.sdk.openai.OpenAiClient;
78
import org.devlive.sdk.openai.entity.CompletionEntity;
89
import org.junit.Test;
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package org.devlive.sdk.platform.google;
2+
3+
import com.google.common.collect.Lists;
4+
import lombok.extern.slf4j.Slf4j;
5+
import org.devlive.sdk.ResourceUtils;
6+
import org.devlive.sdk.common.listener.ConsoleEventSourceListener;
7+
import org.devlive.sdk.platform.google.entity.ChatEntity;
8+
import org.devlive.sdk.platform.google.entity.ObjectEntity;
9+
import org.devlive.sdk.platform.google.entity.PartEntity;
10+
import org.junit.Before;
11+
import org.junit.Test;
12+
13+
import java.util.List;
14+
import java.util.concurrent.CountDownLatch;
15+
16+
@Slf4j
17+
public class GoogleStreamClientTest
18+
{
19+
private GoogleClient client;
20+
private CountDownLatch countDownLatch;
21+
private String token;
22+
23+
@Before
24+
public void before()
25+
{
26+
countDownLatch = new CountDownLatch(1);
27+
ConsoleEventSourceListener listener = ConsoleEventSourceListener.builder()
28+
.countDownLatch(countDownLatch)
29+
.build();
30+
token = ResourceUtils.getValue("google.token");
31+
client = GoogleClient.builder()
32+
.apiKey(token)
33+
.listener(listener)
34+
.build();
35+
}
36+
37+
@Test
38+
public void testCreateChat()
39+
{
40+
List<ObjectEntity> contents = Lists.newArrayList();
41+
PartEntity part = PartEntity.builder()
42+
.text("帮我写一万字的作文")
43+
.build();
44+
ObjectEntity object = ObjectEntity.builder()
45+
.parts(Lists.newArrayList(part))
46+
.build();
47+
contents.add(object);
48+
ChatEntity chat = ChatEntity.builder()
49+
.contents(contents)
50+
.build();
51+
client.createChatCompletions(chat);
52+
try {
53+
countDownLatch.await();
54+
}
55+
catch (InterruptedException e) {
56+
log.error("Interrupted while waiting", e);
57+
}
58+
}
59+
}

0 commit comments

Comments
 (0)