11package io .scalecube .services .gateway .ws ;
22
3+ import io .netty .buffer .ByteBuf ;
34import io .scalecube .services .ServiceCall ;
45import io .scalecube .services .api .ServiceMessage ;
56import io .scalecube .services .exceptions .DefaultErrorMapper ;
1920import reactor .netty .http .server .HttpServerResponse ;
2021import reactor .netty .http .websocket .WebsocketInbound ;
2122import reactor .netty .http .websocket .WebsocketOutbound ;
22-
23+ import reactor . util . context . Context ;
2324
2425public class WebsocketGatewayAcceptor
2526 implements BiFunction <HttpServerRequest , HttpServerResponse , Publisher <Void >> {
@@ -48,41 +49,55 @@ public WebsocketGatewayAcceptor(
4849 public Publisher <Void > apply (HttpServerRequest httpRequest , HttpServerResponse httpResponse ) {
4950 return httpResponse .sendWebsocket (
5051 (WebsocketInbound inbound , WebsocketOutbound outbound ) ->
51- onConnect (new WebsocketGatewaySession (messageCodec , httpRequest , inbound , outbound )));
52+ onConnect (
53+ new WebsocketGatewaySession (
54+ messageCodec , httpRequest , inbound , outbound , gatewayHandler )));
5255 }
5356
5457 private Mono <Void > onConnect (WebsocketGatewaySession session ) {
5558 gatewayHandler .onSessionOpen (session );
5659
5760 session
5861 .receive ()
62+ .doOnError (th -> gatewayHandler .onSessionError (session , th ))
5963 .subscribe (
6064 byteBuf ->
61- Mono .fromCallable (() -> messageCodec .decode (byteBuf ))
62- .doOnNext (message -> metrics .markRequest ())
63- .map (this ::checkSid )
64- .flatMap (msg -> handleCancel (session , msg ))
65- .map (msg -> validateSid (session , (GatewayMessage ) msg ))
66- .map (this ::checkQualifier )
67- .map (msg -> gatewayHandler .mapMessage (session , msg ))
68- .subscribe (
69- request -> {
70- try {
71- handleMessage (session , request );
72- } catch (Exception ex ) {
73- gatewayHandler .onError (session , ex , request , null );
74- }
75- },
76- th -> handleError (session , th )),
77- th -> gatewayHandler .onError (session , th , null , null ));
65+ Mono .deferWithContext (context -> onRequest (session , byteBuf , context ))
66+ .subscriberContext (
67+ context -> gatewayHandler .onRequest (session , byteBuf , context ))
68+ .subscribe ());
7869
7970 return session .onClose (() -> gatewayHandler .onSessionClose (session ));
8071 }
8172
82- private void handleMessage (WebsocketGatewaySession session , GatewayMessage request ) {
83- Long sid = request .streamId ();
73+ private Mono <GatewayMessage > onRequest (
74+ WebsocketGatewaySession session , ByteBuf byteBuf , Context context ) {
75+ return Mono .fromCallable (() -> messageCodec .decode (byteBuf ))
76+ .doOnNext (message -> metrics .markRequest ())
77+ .map (this ::validateSid )
78+ .flatMap (msg -> onCancel (session , msg ))
79+ .map (msg -> validateSid (session , (GatewayMessage ) msg ))
80+ .map (this ::validateQualifier )
81+ .map (msg -> gatewayHandler .mapMessage (session , msg ))
82+ .doOnNext (request -> onMessage (session , request , context ))
83+ .doOnError (
84+ th -> {
85+ if (!(th instanceof WebsocketContextException )) {
86+ // decode failed at this point
87+ gatewayHandler .onError (session , th , context );
88+ return ;
89+ }
90+
91+ WebsocketContextException wex = (WebsocketContextException ) th ;
92+ wex .releaseRequest (); // release
93+
94+ onError (session , wex .request (), wex .getCause (), context );
95+ });
96+ }
8497
85- AtomicBoolean receivedError = new AtomicBoolean (false );
98+ private void onMessage (WebsocketGatewaySession session , GatewayMessage request , Context context ) {
99+ final Long sid = request .streamId ();
100+ final AtomicBoolean receivedError = new AtomicBoolean (false );
86101
87102 final Flux <ServiceMessage > serviceStream =
88103 serviceCall .requestMany (GatewayMessage .toServiceMessage (request ));
@@ -94,68 +109,41 @@ private void handleMessage(WebsocketGatewaySession session, GatewayMessage reque
94109 .map (response -> prepareResponse (sid , response , receivedError ))
95110 .doOnNext (response -> metrics .markServiceResponse ())
96111 .doFinally (signalType -> session .dispose (sid ))
97- .subscribe (
98- response ->
99- session
100- .send (response )
101- .subscribe (
102- avoid -> metrics .markResponse (),
103- th -> gatewayHandler .onError (session , th , request , response )),
104- th -> handleError (session , request , th ),
105- () -> handleCompletion (session , request , receivedError ));
112+ .flatMap (session ::send )
113+ .doOnError (th -> onError (session , request , th , context ))
114+ .doOnComplete (() -> onComplete (session , request , receivedError , context ))
115+ .subscriberContext (context )
116+ .subscribe ();
106117
107118 session .register (sid , disposable );
108119 }
109120
110- private void handleError (WebsocketGatewaySession session , Throwable throwable ) {
111- if (throwable instanceof WebsocketRequestException ) {
112- WebsocketRequestException ex = (WebsocketRequestException ) throwable ;
113- ex .releaseRequest (); // release
114- handleError (session , ex .request (), ex .getCause ());
115- } else {
116- gatewayHandler .onError (session , throwable , null , null );
117- }
118- }
119-
120- private void handleError (WebsocketGatewaySession session , GatewayMessage req , Throwable th ) {
121- gatewayHandler .onError (session , th , req , null );
121+ private void onError (
122+ WebsocketGatewaySession session , GatewayMessage req , Throwable th , Context context ) {
122123
123124 Builder builder = GatewayMessage .from (DefaultErrorMapper .INSTANCE .toMessage (th ));
124125 Optional .ofNullable (req .streamId ()).ifPresent (builder ::streamId );
125126 GatewayMessage response = builder .signal (Signal .ERROR ).build ();
126127
127- session
128- .send (response )
129- .subscribe (null , ex -> gatewayHandler .onError (session , ex , req , response ));
128+ session .send (response ).subscriberContext (context ).subscribe ();
130129 }
131130
132- private void handleCompletion (
133- WebsocketGatewaySession session , GatewayMessage req , AtomicBoolean receivedError ) {
131+ private void onComplete (
132+ WebsocketGatewaySession session ,
133+ GatewayMessage req ,
134+ AtomicBoolean receivedError ,
135+ Context context ) {
136+
134137 if (!receivedError .get ()) {
135138 Builder builder = GatewayMessage .builder ();
136139 Optional .ofNullable (req .streamId ()).ifPresent (builder ::streamId );
137140 GatewayMessage response = builder .signal (Signal .COMPLETE ).build ();
138- session .send (response ).subscribe (null , ex -> gatewayHandler .onError (session , ex , req , null ));
139- }
140- }
141141
142- private GatewayMessage checkQualifier (GatewayMessage msg ) {
143- if (msg .qualifier () == null ) {
144- throw WebsocketRequestException .newBadRequest ("qualifier is missing" , msg );
142+ session .send (response ).subscriberContext (context ).subscribe ();
145143 }
146- return msg ;
147144 }
148145
149- private GatewayMessage validateSid (WebsocketGatewaySession session , GatewayMessage msg ) {
150- if (session .containsSid (msg .streamId ())) {
151- throw WebsocketRequestException .newBadRequest (
152- "sid=" + msg .streamId () + " is already registered" , msg );
153- } else {
154- return msg ;
155- }
156- }
157-
158- private Mono <?> handleCancel (WebsocketGatewaySession session , GatewayMessage msg ) {
146+ private Mono <?> onCancel (WebsocketGatewaySession session , GatewayMessage msg ) {
159147 if (!msg .hasSignal (Signal .CANCEL )) {
160148 return Mono .just (msg );
161149 }
@@ -171,9 +159,25 @@ private Mono<?> handleCancel(WebsocketGatewaySession session, GatewayMessage msg
171159 return session .send (cancelAck ); // no need to subscribe here since flatMap will do
172160 }
173161
174- private GatewayMessage checkSid (GatewayMessage msg ) {
162+ private GatewayMessage validateQualifier (GatewayMessage msg ) {
163+ if (msg .qualifier () == null ) {
164+ throw WebsocketContextException .badRequest ("qualifier is missing" , msg );
165+ }
166+ return msg ;
167+ }
168+
169+ private GatewayMessage validateSid (WebsocketGatewaySession session , GatewayMessage msg ) {
170+ if (session .containsSid (msg .streamId ())) {
171+ throw WebsocketContextException .badRequest (
172+ "sid=" + msg .streamId () + " is already registered" , msg );
173+ } else {
174+ return msg ;
175+ }
176+ }
177+
178+ private GatewayMessage validateSid (GatewayMessage msg ) {
175179 if (msg .streamId () == null ) {
176- throw WebsocketRequestException . newBadRequest ("sid is missing" , msg );
180+ throw WebsocketContextException . badRequest ("sid is missing" , msg );
177181 } else {
178182 return msg ;
179183 }
0 commit comments