Skip to content

Commit e923e16

Browse files
committed
feat: SSE 拦截器
1 parent 487b8f8 commit e923e16

File tree

10 files changed

+236
-44
lines changed

10 files changed

+236
-44
lines changed

forest-core/src/main/java/com/dtflys/forest/http/ForestRequest.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@
118118
import java.util.function.BiConsumer;
119119
import java.util.function.Consumer;
120120
import java.util.function.Function;
121+
import java.util.function.Supplier;
121122
import java.util.stream.Collectors;
122123

123124
import static com.dtflys.forest.mapping.MappingParameter.TARGET_BODY;
@@ -4678,6 +4679,16 @@ public <R> R getAttachment(String name, Class<R> clazz) {
46784679
}
46794680
return clazz.cast(result);
46804681
}
4682+
4683+
4684+
public <R> R getOrAddAttachment(String name, Supplier<R> supplier) {
4685+
Object obj = getAttachment(name);
4686+
if (obj == null) {
4687+
obj = supplier.get();
4688+
addAttachment(name, obj);
4689+
}
4690+
return (R) obj;
4691+
}
46814692

46824693
/**
46834694
* 获取序列化器

forest-core/src/main/java/com/dtflys/forest/http/ForestSSE.java

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import com.dtflys.forest.annotation.SSEMessage;
88
import com.dtflys.forest.annotation.SSERetryMessage;
99
import com.dtflys.forest.exceptions.ForestRuntimeException;
10+
import com.dtflys.forest.interceptor.Interceptor;
11+
import com.dtflys.forest.interceptor.SSEInterceptor;
1012
import com.dtflys.forest.reflection.MethodLifeCycleHandler;
1113
import com.dtflys.forest.sse.EventSource;
1214
import com.dtflys.forest.sse.ForestSSEListener;
@@ -81,22 +83,38 @@ void init(final ForestRequest request) {
8183
this.request.setLifeCycleHandler(new MethodLifeCycleHandler<InputStream>(InputStream.class, InputStream.class) {});
8284
final Class<?> clazz = this.getClass();
8385
final Method[] methods = ReflectUtils.getMethods(clazz);
84-
85-
// 批量注册 SSE 控制器类中的消息处理方法
86-
for (final Method method : methods) {
87-
final Annotation[] annArray = method.getAnnotations();
88-
for (final Annotation ann : annArray) {
89-
if (ann instanceof SSEMessage) {
90-
registerMessageMethod(method, ann, null);
91-
} else if (ann instanceof SSEDataMessage) {
92-
registerMessageMethod(method, ann, "data");
93-
} else if (ann instanceof SSEEventMessage) {
94-
registerMessageMethod(method, ann, "event");
95-
} else if (ann instanceof SSEIdMessage) {
96-
registerMessageMethod(method, ann, "id");
97-
} else if (ann instanceof SSERetryMessage) {
98-
registerMessageMethod(method, ann, "retry");
99-
}
86+
final List<Interceptor> interceptors = request.getInterceptorChain().getInterceptors();
87+
for (final Interceptor interceptor : interceptors) {
88+
if (interceptor instanceof SSEInterceptor) {
89+
final Class<?> interceptorClass = interceptor.getClass();
90+
final Method[] interceptorMethods = ReflectUtils.getMethods(interceptorClass);
91+
registerMethodArray(interceptor, interceptorMethods);
92+
}
93+
}
94+
registerMethodArray(this, methods);
95+
}
96+
}
97+
98+
/**
99+
* 批量注册 SSE 控制器类中的消息处理方法
100+
*
101+
* @param instance 方法所属实例
102+
* @param methods Java 方法数组
103+
*/
104+
private void registerMethodArray(Object instance, final Method[] methods) {
105+
for (final Method method : methods) {
106+
final Annotation[] annArray = method.getAnnotations();
107+
for (final Annotation ann : annArray) {
108+
if (ann instanceof SSEMessage) {
109+
registerMessageMethod(instance, method, ann, null);
110+
} else if (ann instanceof SSEDataMessage) {
111+
registerMessageMethod(instance, method, ann, "data");
112+
} else if (ann instanceof SSEEventMessage) {
113+
registerMessageMethod(instance, method, ann, "event");
114+
} else if (ann instanceof SSEIdMessage) {
115+
registerMessageMethod(instance, method, ann, "id");
116+
} else if (ann instanceof SSERetryMessage) {
117+
registerMessageMethod(instance, method, ann, "retry");
100118
}
101119
}
102120
}
@@ -105,18 +123,19 @@ void init(final ForestRequest request) {
105123
/**
106124
* 注册 SSE 消息处理方法
107125
*
126+
* @param instance 方法所属实例
108127
* @param method Java 方法对象
109128
* @param ann 注解对象
110129
* @param defaultName SSE 消息默认名称
111130
* @since 1.6.0
112131
*/
113-
private void registerMessageMethod(Method method, Annotation ann, String defaultName) {
132+
private void registerMessageMethod(Object instance, Method method, Annotation ann, String defaultName) {
114133
final Map<String, Object> attrs = ReflectUtils.getAttributesFromAnnotation(ann);
115134
final String valueRegex = String.valueOf(attrs.getOrDefault("valueRegex", ""));
116135
final String valuePrefix = String.valueOf(attrs.getOrDefault("valuePrefix", ""));
117136
final String valuePostfix = String.valueOf(attrs.getOrDefault("valuePostfix", ""));
118137
final String annName = defaultName != null ? defaultName : String.valueOf(attrs.getOrDefault("name", ""));
119-
final SSEMessageMethod sseMessageMethod = new SSEMessageMethod(this, method);
138+
final SSEMessageMethod sseMessageMethod = new SSEMessageMethod(instance, method);
120139
if (StringUtils.isEmpty(valueRegex) && StringUtils.isEmpty(valuePrefix) && StringUtils.isEmpty(valuePostfix)) {
121140
addConsumer(annName, (eventSource, name, value) -> sseMessageMethod.invoke(eventSource));
122141
} else {
@@ -501,6 +520,19 @@ public ForestSSE addOnRetryMatchesPrefix(String valuePrefix, SSEStringMessageCon
501520
public ForestSSE addOnRetryMatchesPostfix(String valuePostfix, SSEStringMessageConsumer consumer) {
502521
return addConsumerMatchesPostfix("retry", valuePostfix, consumer);
503522
}
523+
524+
private void doOnOpen(final EventSource eventSource) {
525+
final List<Interceptor> interceptors = eventSource.getRequest().getInterceptorChain().getInterceptors();
526+
for (Interceptor interceptor : interceptors) {
527+
if (interceptor instanceof SSEInterceptor) {
528+
((SSEInterceptor) interceptor).onSSEOpen(eventSource);
529+
}
530+
}
531+
onOpen(eventSource);
532+
if (onOpenConsumer != null) {
533+
onOpenConsumer.accept(eventSource);
534+
}
535+
}
504536

505537
/**
506538
* 监听打开回调函数:在开始 SSE 数据流监听的时候调用
@@ -509,9 +541,19 @@ public ForestSSE addOnRetryMatchesPostfix(String valuePostfix, SSEStringMessageC
509541
* @since 1.6.0
510542
*/
511543
protected void onOpen(EventSource eventSource) {
512-
if (onOpenConsumer != null) {
513-
onOpenConsumer.accept(eventSource);
544+
}
545+
546+
private void doOnClose(final ForestRequest request, final ForestResponse response) {
547+
final List<Interceptor> interceptors = request.getInterceptorChain().getInterceptors();
548+
for (Interceptor interceptor : interceptors) {
549+
if (interceptor instanceof SSEInterceptor) {
550+
((SSEInterceptor) interceptor).onSSEClose(request, response);
551+
}
514552
}
553+
if (onCloseConsumer != null) {
554+
onCloseConsumer.accept(request, response);
555+
}
556+
onClose(request, response);
515557
}
516558

517559
/**
@@ -522,9 +564,6 @@ protected void onOpen(EventSource eventSource) {
522564
* @since 1.6.0
523565
*/
524566
protected void onClose(ForestRequest request, ForestResponse response) {
525-
if (onCloseConsumer != null) {
526-
onCloseConsumer.accept(request, response);
527-
}
528567
}
529568

530569
/**
@@ -600,7 +639,7 @@ public <R extends ForestSSE> R listen() {
600639
throw new ForestRuntimeException(e);
601640
}
602641
} else {
603-
response = this.request.execute(new TypeReference<ForestResponse<InputStream>>() {});
642+
response = this.request.execute(new TypeReference<ForestResponse<InputStream>>() {});
604643
}
605644
if (response == null) {
606645
return (R) this;
@@ -609,7 +648,7 @@ public <R extends ForestSSE> R listen() {
609648
return (R) this;
610649
}
611650
final EventSource openEventSource = new EventSource("open", request, response);
612-
this.onOpen(openEventSource);
651+
this.doOnOpen(openEventSource);
613652
if (SSEMessageResult.CLOSE.equals(openEventSource.getMessageResult())) {
614653
onClose(request, response);
615654
return (R) this;
@@ -634,7 +673,7 @@ public <R extends ForestSSE> R listen() {
634673
} catch (IOException e) {
635674
throw new ForestRuntimeException(e);
636675
} finally {
637-
onClose(request, response);
676+
doOnClose(request, response);
638677
}
639678
});
640679
} catch (Exception e) {

forest-core/src/main/java/com/dtflys/forest/interceptor/Interceptor.java

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import com.dtflys.forest.reflection.ForestMethod;
1818
import com.dtflys.forest.utils.ForestProgress;
1919

20+
import java.util.function.Supplier;
21+
2022
/**
2123
* Forest拦截器接口
2224
* <p>拦截器在请求的初始化、发送请求前、发送成功、发送失败等生命周期中都会被调用
@@ -238,15 +240,15 @@ default Object getAttribute(ForestRequest request, String name) {
238240
* @param request Forest请求对象
239241
* @param name 属性名称
240242
* @param clazz 属性值的类型对象
241-
* @param <T> 属性值类型的泛型
243+
* @param <R> 属性值类型的泛型
242244
* @return Attribute 属性值
243245
*/
244-
default <T> T getAttribute(ForestRequest request, String name, Class<T> clazz) {
246+
default <R> R getAttribute(ForestRequest request, String name, Class<R> clazz) {
245247
Object obj = request.getInterceptorAttribute(this.getClass(), name);
246248
if (obj == null) {
247249
return null;
248250
}
249-
return (T) obj;
251+
return clazz.cast(obj);
250252
}
251253

252254
/**
@@ -309,4 +311,24 @@ default Double getAttributeAsDouble(ForestRequest request, String name) {
309311
return (Double) attr;
310312
}
311313

314+
/**
315+
* 获取或添加请求在本拦截器中的 Attribute 属性
316+
* <p>当 Attribute 属性中不存在属性名称所对应的值,则添加属性值</p>
317+
*
318+
* @param request Forest请求对象
319+
* @param name 属性名称
320+
* @param supplier 属性值回调函数
321+
* @return 属性值
322+
* @param <R> 属性值类型
323+
* @since 1.6.1
324+
*/
325+
default <R> R getOrAddAttribute(ForestRequest request, String name, Supplier<R> supplier) {
326+
Object obj = getAttribute(request, name);
327+
if (obj == null && supplier != null) {
328+
obj = supplier.get();
329+
addAttribute(request, name, obj);
330+
}
331+
return (R) obj;
332+
}
333+
312334
}

forest-core/src/main/java/com/dtflys/forest/interceptor/InterceptorChain.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,8 @@ public void afterExecute(ForestRequest request, ForestResponse response) {
138138
item.afterExecute(request, response);
139139
}
140140
}
141+
142+
public LinkedList<Interceptor> getInterceptors() {
143+
return interceptors;
144+
}
141145
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package com.dtflys.forest.interceptor;
2+
3+
import com.dtflys.forest.http.ForestRequest;
4+
import com.dtflys.forest.http.ForestResponse;
5+
import com.dtflys.forest.sse.EventSource;
6+
7+
/**
8+
* Forest SSE 拦截器
9+
*
10+
* @since 1.6.1
11+
*/
12+
public interface SSEInterceptor extends Interceptor {
13+
14+
/**
15+
* 监听打开回调函数:在开始 SSE 数据流监听的时候调用
16+
*
17+
* @param eventSource SSE 事件来源
18+
* @since 1.6.1
19+
*/
20+
default void onSSEOpen(EventSource eventSource) {
21+
}
22+
23+
/**
24+
* 监听关闭回调函数:在结束 SSE 数据流监听的时候调用
25+
*
26+
* @param request Forest 请求对象
27+
* @param response Forest 响应对象
28+
* @since 1.6.1
29+
*/
30+
default void onSSEClose(ForestRequest request, ForestResponse response) {
31+
}
32+
33+
}

forest-core/src/main/java/com/dtflys/forest/sse/EventSource.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import com.dtflys.forest.http.ForestResponse;
55

66
/**
7-
* Forest SSE Event Source
7+
* Forest SSE 事件来源
88
*
99
* @since 1.6.0
1010
*/

forest-core/src/main/java/com/dtflys/forest/sse/SSEMessageMethod.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,32 @@
1515
import java.lang.reflect.Type;
1616
import java.util.function.Function;
1717

18+
/**
19+
* SSE 消息方法
20+
* <p>用于包装注册好的 SSE 消息处理方法</p>
21+
*
22+
* @since 1.6.1
23+
*/
1824
public class SSEMessageMethod {
1925

20-
private final ForestSSE sse;
26+
/**
27+
* 方法所属实例
28+
*/
29+
private final Object instance;
2130

31+
/**
32+
* Java 方法
33+
*/
2234
private final Method method;
23-
35+
36+
/**
37+
* 方法参数值获取函数表
38+
*/
2439
private Function<EventSource, ?>[] argumentFunctions;
2540

2641

27-
public SSEMessageMethod(ForestSSE sse, Method method) {
28-
this.sse = sse;
42+
public SSEMessageMethod(Object instance, Method method) {
43+
this.instance = instance;
2944
this.method = method;
3045
init();
3146
}
@@ -77,7 +92,7 @@ public void invoke(final EventSource eventSource) {
7792
final boolean accessible = method.isAccessible();
7893
method.setAccessible(true);
7994
try {
80-
method.invoke(sse, args);
95+
method.invoke(instance, args);
8196
} catch (InvocationTargetException | IllegalAccessException e) {
8297
throw new ForestRuntimeException(e);
8398
} finally {

forest-core/src/test/java/com/dtflys/forest/test/http/sse/SSEClient.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import com.dtflys.forest.test.sse.MySSEInterceptor;
99

1010
@Address(host = "localhost", port = "{port}")
11-
@BaseRequest(interceptor = MySSEInterceptor.class)
1211
public interface SSEClient {
1312

1413
@Get("/sse")
@@ -17,4 +16,7 @@ public interface SSEClient {
1716
@Get("/sse")
1817
MySSEHandler testSSE_withCustomClass();
1918

19+
@Get(url = "/sse", interceptor = MySSEInterceptor.class)
20+
ForestSSE testSSE_withInterceptor();
21+
2022
}

forest-core/src/test/java/com/dtflys/forest/test/http/sse/TestSSEClient.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,30 @@ public void testSSE_withCustomClass() {
9292
);
9393
}
9494

95+
@Test
96+
public void testSSE_withInterceptor() {
97+
server.enqueue(new MockResponse().setResponseCode(200).setBody(
98+
"data:start\n" +
99+
"data:{\"event\": \"message\", \"conversation_id\": \"aee49897-5214308b6b2d\", \"message_id\": \"9e292a7d\", \"created_at\": 1734689225 \"answer\": \"I\", \"from_variable_selector\": null}\n" +
100+
"event:{\"name\":\"Peter\",\"age\": \"18\",\"phone\":\"12345678\"}\n" +
101+
"event:close\n" +
102+
"data:dont show"
103+
));
104+
105+
ForestSSE sse = sseClient.testSSE_withInterceptor().listen();
106+
107+
System.out.println(sse.getRequest().getAttachment("text"));
108+
assertThat(sse.getRequest().getAttachment("text").toString()).isEqualTo(
109+
"MySSEInterceptor onSuccess\n" +
110+
"MySSEInterceptor afterExecute\n" +
111+
"MySSEInterceptor onSSEOpen\n" +
112+
"Receive data: start\n" +
113+
"Receive data: {\"event\": \"message\", \"conversation_id\": \"aee49897-5214308b6b2d\", \"message_id\": \"9e292a7d\", \"created_at\": 1734689225 \"answer\": \"I\", \"from_variable_selector\": null}\n" +
114+
"name: Peter; age: 18; phone: 12345678\n" +
115+
"receive close --- close\n" +
116+
"MySSEInterceptor onSSEClose"
117+
);
118+
}
119+
120+
95121
}

0 commit comments

Comments
 (0)