|
12 | 12 | import io.undertow.server.HttpHandler; |
13 | 13 | import io.undertow.server.HttpServerExchange; |
14 | 14 | import io.undertow.util.Headers; |
| 15 | +import io.undertow.util.HttpString; |
15 | 16 | import io.undertow.util.PathMatcher; |
16 | 17 | import io.undertow.websockets.WebSocketConnectionCallback; |
17 | 18 | import io.undertow.websockets.WebSocketProtocolHandshakeHandler; |
18 | | -import io.undertow.websockets.client.WebSocketClient; |
| 19 | +import io.undertow.websockets.client.WebSocketClientNegotiation; |
19 | 20 | import io.undertow.websockets.core.WebSocketChannel; |
| 21 | +import io.undertow.websockets.core.protocol.Handshake; |
| 22 | +import io.undertow.websockets.core.protocol.version07.Hybi07Handshake; |
| 23 | +import io.undertow.websockets.core.protocol.version08.Hybi08Handshake; |
| 24 | +import io.undertow.websockets.core.protocol.version13.Hybi13Handshake; |
20 | 25 | import io.undertow.websockets.spi.WebSocketHttpExchange; |
21 | 26 | import org.slf4j.Logger; |
22 | 27 | import org.slf4j.LoggerFactory; |
23 | 28 |
|
24 | 29 | import java.io.IOException; |
25 | 30 | import java.net.URI; |
26 | 31 | import java.net.URISyntaxException; |
27 | | -import java.util.Map; |
| 32 | +import java.util.*; |
28 | 33 | import java.util.concurrent.ConcurrentHashMap; |
29 | 34 |
|
30 | 35 | public class WebSocketRouterHandler implements MiddlewareHandler, WebSocketConnectionCallback { |
@@ -69,7 +74,17 @@ public void handleRequest(HttpServerExchange exchange) throws Exception { |
69 | 74 | if (isWsRequest) { |
70 | 75 | // Delegate to Undertow's WebSocketProtocolHandshakeHandler |
71 | 76 | // which handles the upgrade and calls our onConnect callback |
72 | | - new WebSocketProtocolHandshakeHandler(this).handleRequest(exchange); |
| 77 | + String protocolHeader = exchange.getRequestHeaders().getFirst("Sec-WebSocket-Protocol"); |
| 78 | + if (protocolHeader != null) { |
| 79 | + Set<String> subprotocols = new LinkedHashSet<>(Arrays.asList(protocolHeader.split(","))); |
| 80 | + Collection<Handshake> handshakes = new ArrayList<>(); |
| 81 | + handshakes.add(new Hybi13Handshake(subprotocols, false)); |
| 82 | + handshakes.add(new Hybi08Handshake(subprotocols, false)); |
| 83 | + handshakes.add(new Hybi07Handshake(subprotocols, false)); |
| 84 | + new WebSocketProtocolHandshakeHandler(handshakes, this).handleRequest(exchange); |
| 85 | + } else { |
| 86 | + new WebSocketProtocolHandshakeHandler(this).handleRequest(exchange); |
| 87 | + } |
73 | 88 | } else { |
74 | 89 | // Not a websocket request for us, pass to next handler |
75 | 90 | Handler.next(exchange, next); |
@@ -191,13 +206,23 @@ public void onConnect(WebSocketHttpExchange exchange, WebSocketChannel channel) |
191 | 206 | } |
192 | 207 |
|
193 | 208 | /* create new connection to downstream */ |
194 | | - final var webSocketConnection = new WebSocketClient.ConnectionBuilder( |
| 209 | + String subprotocol = exchange.getRequestHeader("Sec-WebSocket-Protocol"); |
| 210 | + List<String> subprotocols = subprotocol != null ? Collections.singletonList(subprotocol) : Collections.emptyList(); |
| 211 | + WebSocketClientNegotiation negotiation = new WebSocketClientNegotiation(subprotocols, Collections.emptyList()) { |
| 212 | + @Override |
| 213 | + public void beforeRequest(Map<String, List<String>> headers) { |
| 214 | + if (subprotocol != null) { |
| 215 | + headers.put("Sec-WebSocket-Protocol", Collections.singletonList(subprotocol)); |
| 216 | + } |
| 217 | + } |
| 218 | + }; |
| 219 | + |
| 220 | + final var webSocketConnection = new io.undertow.websockets.client.WebSocketClient.ConnectionBuilder( |
195 | 221 | channel.getWorker(), |
196 | 222 | channel.getBufferPool(), |
197 | 223 | new URI(targetUri) |
198 | | - ); |
| 224 | + ).setClientNegotiation(negotiation); |
199 | 225 | final var outChannel = webSocketConnection.connect().get(); |
200 | | - |
201 | 226 | outChannel.setAttribute(WsAttributes.CHANNEL_GROUP_ID, channelId); |
202 | 227 | outChannel.setAttribute(WsAttributes.CHANNEL_DIRECTION, WsProxyClientPair.SocketFlow.PROXY_TO_DOWNSTREAM); |
203 | 228 | outChannel.getReceiveSetter().set(new WebSocketSessionProxyReceiveListener(CHANNELS)); |
|
0 commit comments