77import com .dtflys .forest .annotation .SSEMessage ;
88import com .dtflys .forest .annotation .SSERetryMessage ;
99import com .dtflys .forest .exceptions .ForestRuntimeException ;
10+ import com .dtflys .forest .interceptor .Interceptor ;
11+ import com .dtflys .forest .interceptor .SSEInterceptor ;
1012import com .dtflys .forest .reflection .MethodLifeCycleHandler ;
1113import com .dtflys .forest .sse .EventSource ;
1214import com .dtflys .forest .sse .ForestSSEListener ;
@@ -81,22 +83,38 @@ void init(final ForestRequest request) {
8183 this .request .setLifeCycleHandler (new MethodLifeCycleHandler <InputStream >(InputStream .class , InputStream .class ) {});
8284 final Class <?> clazz = this .getClass ();
8385 final Method [] methods = ReflectUtils .getMethods (clazz );
84-
85- // 批量注册 SSE 控制器类中的消息处理方法
86- for (final Method method : methods ) {
87- final Annotation [] annArray = method .getAnnotations ();
88- for (final Annotation ann : annArray ) {
89- if (ann instanceof SSEMessage ) {
90- registerMessageMethod (method , ann , null );
91- } else if (ann instanceof SSEDataMessage ) {
92- registerMessageMethod (method , ann , "data" );
93- } else if (ann instanceof SSEEventMessage ) {
94- registerMessageMethod (method , ann , "event" );
95- } else if (ann instanceof SSEIdMessage ) {
96- registerMessageMethod (method , ann , "id" );
97- } else if (ann instanceof SSERetryMessage ) {
98- registerMessageMethod (method , ann , "retry" );
99- }
86+ final List <Interceptor > interceptors = request .getInterceptorChain ().getInterceptors ();
87+ for (final Interceptor interceptor : interceptors ) {
88+ if (interceptor instanceof SSEInterceptor ) {
89+ final Class <?> interceptorClass = interceptor .getClass ();
90+ final Method [] interceptorMethods = ReflectUtils .getMethods (interceptorClass );
91+ registerMethodArray (interceptor , interceptorMethods );
92+ }
93+ }
94+ registerMethodArray (this , methods );
95+ }
96+ }
97+
98+ /**
99+ * 批量注册 SSE 控制器类中的消息处理方法
100+ *
101+ * @param instance 方法所属实例
102+ * @param methods Java 方法数组
103+ */
104+ private void registerMethodArray (Object instance , final Method [] methods ) {
105+ for (final Method method : methods ) {
106+ final Annotation [] annArray = method .getAnnotations ();
107+ for (final Annotation ann : annArray ) {
108+ if (ann instanceof SSEMessage ) {
109+ registerMessageMethod (instance , method , ann , null );
110+ } else if (ann instanceof SSEDataMessage ) {
111+ registerMessageMethod (instance , method , ann , "data" );
112+ } else if (ann instanceof SSEEventMessage ) {
113+ registerMessageMethod (instance , method , ann , "event" );
114+ } else if (ann instanceof SSEIdMessage ) {
115+ registerMessageMethod (instance , method , ann , "id" );
116+ } else if (ann instanceof SSERetryMessage ) {
117+ registerMessageMethod (instance , method , ann , "retry" );
100118 }
101119 }
102120 }
@@ -105,18 +123,19 @@ void init(final ForestRequest request) {
105123 /**
106124 * 注册 SSE 消息处理方法
107125 *
126+ * @param instance 方法所属实例
108127 * @param method Java 方法对象
109128 * @param ann 注解对象
110129 * @param defaultName SSE 消息默认名称
111130 * @since 1.6.0
112131 */
113- private void registerMessageMethod (Method method , Annotation ann , String defaultName ) {
132+ private void registerMessageMethod (Object instance , Method method , Annotation ann , String defaultName ) {
114133 final Map <String , Object > attrs = ReflectUtils .getAttributesFromAnnotation (ann );
115134 final String valueRegex = String .valueOf (attrs .getOrDefault ("valueRegex" , "" ));
116135 final String valuePrefix = String .valueOf (attrs .getOrDefault ("valuePrefix" , "" ));
117136 final String valuePostfix = String .valueOf (attrs .getOrDefault ("valuePostfix" , "" ));
118137 final String annName = defaultName != null ? defaultName : String .valueOf (attrs .getOrDefault ("name" , "" ));
119- final SSEMessageMethod sseMessageMethod = new SSEMessageMethod (this , method );
138+ final SSEMessageMethod sseMessageMethod = new SSEMessageMethod (instance , method );
120139 if (StringUtils .isEmpty (valueRegex ) && StringUtils .isEmpty (valuePrefix ) && StringUtils .isEmpty (valuePostfix )) {
121140 addConsumer (annName , (eventSource , name , value ) -> sseMessageMethod .invoke (eventSource ));
122141 } else {
@@ -501,6 +520,19 @@ public ForestSSE addOnRetryMatchesPrefix(String valuePrefix, SSEStringMessageCon
501520 public ForestSSE addOnRetryMatchesPostfix (String valuePostfix , SSEStringMessageConsumer consumer ) {
502521 return addConsumerMatchesPostfix ("retry" , valuePostfix , consumer );
503522 }
523+
524+ private void doOnOpen (final EventSource eventSource ) {
525+ final List <Interceptor > interceptors = eventSource .getRequest ().getInterceptorChain ().getInterceptors ();
526+ for (Interceptor interceptor : interceptors ) {
527+ if (interceptor instanceof SSEInterceptor ) {
528+ ((SSEInterceptor ) interceptor ).onSSEOpen (eventSource );
529+ }
530+ }
531+ onOpen (eventSource );
532+ if (onOpenConsumer != null ) {
533+ onOpenConsumer .accept (eventSource );
534+ }
535+ }
504536
505537 /**
506538 * 监听打开回调函数:在开始 SSE 数据流监听的时候调用
@@ -509,9 +541,19 @@ public ForestSSE addOnRetryMatchesPostfix(String valuePostfix, SSEStringMessageC
509541 * @since 1.6.0
510542 */
511543 protected void onOpen (EventSource eventSource ) {
512- if (onOpenConsumer != null ) {
513- onOpenConsumer .accept (eventSource );
544+ }
545+
546+ private void doOnClose (final ForestRequest request , final ForestResponse response ) {
547+ final List <Interceptor > interceptors = request .getInterceptorChain ().getInterceptors ();
548+ for (Interceptor interceptor : interceptors ) {
549+ if (interceptor instanceof SSEInterceptor ) {
550+ ((SSEInterceptor ) interceptor ).onSSEClose (request , response );
551+ }
514552 }
553+ if (onCloseConsumer != null ) {
554+ onCloseConsumer .accept (request , response );
555+ }
556+ onClose (request , response );
515557 }
516558
517559 /**
@@ -522,9 +564,6 @@ protected void onOpen(EventSource eventSource) {
522564 * @since 1.6.0
523565 */
524566 protected void onClose (ForestRequest request , ForestResponse response ) {
525- if (onCloseConsumer != null ) {
526- onCloseConsumer .accept (request , response );
527- }
528567 }
529568
530569 /**
@@ -600,7 +639,7 @@ public <R extends ForestSSE> R listen() {
600639 throw new ForestRuntimeException (e );
601640 }
602641 } else {
603- response = this .request .execute (new TypeReference <ForestResponse <InputStream >>() {});
642+ response = this .request .execute (new TypeReference <ForestResponse <InputStream >>() {});
604643 }
605644 if (response == null ) {
606645 return (R ) this ;
@@ -609,7 +648,7 @@ public <R extends ForestSSE> R listen() {
609648 return (R ) this ;
610649 }
611650 final EventSource openEventSource = new EventSource ("open" , request , response );
612- this .onOpen (openEventSource );
651+ this .doOnOpen (openEventSource );
613652 if (SSEMessageResult .CLOSE .equals (openEventSource .getMessageResult ())) {
614653 onClose (request , response );
615654 return (R ) this ;
@@ -634,7 +673,7 @@ public <R extends ForestSSE> R listen() {
634673 } catch (IOException e ) {
635674 throw new ForestRuntimeException (e );
636675 } finally {
637- onClose (request , response );
676+ doOnClose (request , response );
638677 }
639678 });
640679 } catch (Exception e ) {
0 commit comments