Skip to content

Commit 45f97ca

Browse files
committed
[bugfix] Eval- WebSocketEndpoint: suppress ClosedChannelException
clear sessions on shutdown replace text ping with WebSocket PING control frame serialize Long and Double as JSON numbers handle interruptions more gracefully
1 parent 37d5309 commit 45f97ca

4 files changed

Lines changed: 76 additions & 78 deletions

File tree

exist-core/src/main/java/org/exist/http/ws/EvalWebSocketEndpoint.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import jakarta.websocket.server.ServerEndpoint;
3737
import jakarta.websocket.server.ServerEndpointConfig;
3838
import java.io.IOException;
39+
import java.nio.channels.ClosedChannelException;
3940
import java.util.List;
4041
import java.util.Map;
4142
import java.util.concurrent.ConcurrentHashMap;
@@ -88,6 +89,7 @@ public static synchronized void shutdown() {
8889
queryExecutorService.shutdown();
8990
queryExecutorService = null;
9091
}
92+
sessions.clear();
9193
}
9294

9395
/**
@@ -216,15 +218,17 @@ public void onClose(final Session session, final CloseReason reason) {
216218

217219
@OnError
218220
public void onError(final Session session, final Throwable error) {
219-
if (error.getMessage() != null && error.getMessage().contains("Text message size")) {
220-
LOG.warn("WebSocket message exceeds {}MB buffer limit: {}", MAX_TEXT_MESSAGE_SIZE / (1024 * 1024), error.getMessage());
221-
} else {
222-
LOG.warn("WebSocket eval error: {}", error.getMessage(), error);
223-
}
224221
final EvalSession evalSession = sessions.remove(session);
225222
if (evalSession != null) {
226223
evalSession.cancelAll();
227224
}
225+
if (error instanceof ClosedChannelException) {
226+
LOG.debug("WebSocket eval client disconnected abruptly: session {}", session.getId());
227+
} else if (error.getMessage() != null && error.getMessage().contains("Text message size")) {
228+
LOG.warn("WebSocket message exceeds {}MB buffer limit: {}", MAX_TEXT_MESSAGE_SIZE / (1024 * 1024), error.getMessage());
229+
} else {
230+
LOG.warn("WebSocket eval error: {}", error.getMessage(), error);
231+
}
228232
}
229233

230234
private void handleEval(final Session session, final EvalSession evalSession,

exist-core/src/main/java/org/exist/xquery/functions/websocket/WebSocketEndpoint.java

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
import jakarta.websocket.server.ServerEndpoint;
3636
import java.io.IOException;
3737
import java.io.StringWriter;
38-
import java.util.Iterator;
38+
import java.nio.ByteBuffer;
39+
import java.nio.channels.ClosedChannelException;
3940
import java.util.Map;
4041
import java.util.concurrent.ConcurrentHashMap;
4142
import java.util.concurrent.Executors;
@@ -63,6 +64,7 @@ public class WebSocketEndpoint {
6364
private static final Logger LOG = LogManager.getLogger(WebSocketEndpoint.class);
6465
private static final JsonFactory JSON_FACTORY = new JsonFactory();
6566
private static final Map<Session, String> sessions = new ConcurrentHashMap<>();
67+
private static final ByteBuffer PING_PAYLOAD = ByteBuffer.allocate(0);
6668

6769
private static volatile boolean initialized = false;
6870
private static ScheduledExecutorService heartbeatService = null;
@@ -109,13 +111,13 @@ public static synchronized void shutdown() {
109111
monitorService.shutdown();
110112
monitorService = null;
111113
}
114+
sessions.clear();
112115
initialized = false;
113116
}
114117
}
115118

116119
@OnOpen
117120
public void openSession(final Session session) {
118-
session.setMaxIdleTimeout(10000);
119121
sessions.put(session, DEFAULT_CHANNEL);
120122
}
121123

@@ -124,6 +126,16 @@ public void closeSession(final Session session, final CloseReason closeReason) {
124126
sessions.remove(session);
125127
}
126128

129+
@OnError
130+
public void onError(final Session session, final Throwable throwable) {
131+
sessions.remove(session);
132+
if (throwable instanceof ClosedChannelException) {
133+
LOG.debug("WebSocket client disconnected abruptly: session {}", session.getId());
134+
} else {
135+
LOG.warn("WebSocket error on session {}: {}", session.getId(), throwable.getMessage(), throwable);
136+
}
137+
}
138+
127139
@OnMessage
128140
public void recv(final String message, final Session session) {
129141
try (final JsonParser parser = JSON_FACTORY.createParser(message)) {
@@ -142,7 +154,16 @@ public void recv(final String message, final Session session) {
142154
}
143155

144156
static void pingAll() {
145-
sendAll(null, "ping");
157+
for (final Session session : sessions.keySet()) {
158+
try {
159+
session.getBasicRemote().sendPing(PING_PAYLOAD.duplicate());
160+
} catch (final ClosedChannelException e) {
161+
sessions.remove(session);
162+
} catch (final IOException e) {
163+
LOG.debug("Ping failed, removing session {}: {}", session.getId(), e.getMessage());
164+
sessions.remove(session);
165+
}
166+
}
146167
}
147168

148169
/**
@@ -161,6 +182,8 @@ public static void sendAll(final String toChannel, final Map<String, Object> dat
161182
switch (entry.getValue()) {
162183
case String s -> gen.writeStringField(key, s);
163184
case Integer i -> gen.writeNumberField(key, i);
185+
case Long l -> gen.writeNumberField(key, l);
186+
case Double d -> gen.writeNumberField(key, d);
164187
case Boolean b -> gen.writeBooleanField(key, b);
165188
case null -> gen.writeNullField(key);
166189
default -> gen.writeStringField(key, entry.getValue().toString());
@@ -181,18 +204,16 @@ public static void sendAll(final String toChannel, final Map<String, Object> dat
181204
* @param message the message text
182205
*/
183206
public static void sendAll(final String toChannel, final String message) {
184-
final Iterator<Map.Entry<Session, String>> iterator = sessions.entrySet().iterator();
185-
while (iterator.hasNext()) {
186-
try {
187-
final Map.Entry<Session, String> entry = iterator.next();
188-
final Session session = entry.getKey();
189-
final String channel = entry.getValue();
190-
191-
if (toChannel == null || (!channel.equals(DEFAULT_CHANNEL) && toChannel.equals(channel))) {
207+
for (final Map.Entry<Session, String> entry : sessions.entrySet()) {
208+
final Session session = entry.getKey();
209+
final String channel = entry.getValue();
210+
if (toChannel == null || (!channel.equals(DEFAULT_CHANNEL) && toChannel.equals(channel))) {
211+
try {
192212
session.getBasicRemote().sendText(message);
213+
} catch (final IOException e) {
214+
LOG.debug("Removing disconnected WebSocket session: {}", e.getMessage());
215+
sessions.remove(session);
193216
}
194-
} catch (final IOException e) {
195-
LOG.error("Error sending message via websocket: {}", e.getMessage(), e);
196217
}
197218
}
198219
}

exist-core/src/test/java/org/exist/http/ws/EvalWebSocketEndpointTest.java

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,28 +1047,23 @@ public void monitorChannelReceivesQueryEvents() throws Exception {
10471047
final Session monitorSession = container.connectToServer(new Endpoint() {
10481048
@Override
10491049
public void onOpen(final Session session, final EndpointConfig config) {
1050+
try {
1051+
session.getBasicRemote().sendText("{\"channel\": \"_monitor\"}");
1052+
subscribedLatch.countDown();
1053+
} catch (final IOException e) {
1054+
throw new UncheckedIOException(e);
1055+
}
10501056
session.addMessageHandler(new MessageHandler.Whole<String>() {
1051-
private boolean subscribed = false;
10521057
@Override
10531058
public void onMessage(final String message) {
1054-
if (!subscribed && "ping".equals(message)) {
1055-
try {
1056-
session.getBasicRemote().sendText("{\"channel\": \"_monitor\"}");
1057-
subscribed = true;
1058-
subscribedLatch.countDown();
1059-
} catch (IOException e) {
1060-
throw new UncheckedIOException(e);
1061-
}
1062-
} else if (subscribed && !"ping".equals(message)) {
1063-
try {
1064-
final Map<String, Object> parsed = parseJson(message);
1065-
if ("monitor".equals(parsed.get("type"))) {
1066-
monitorMessages.add(parsed);
1067-
monitorEventLatch.countDown();
1068-
}
1069-
} catch (final IOException e) {
1070-
// ignore parse errors for pings etc.
1059+
try {
1060+
final Map<String, Object> parsed = parseJson(message);
1061+
if ("monitor".equals(parsed.get("type"))) {
1062+
monitorMessages.add(parsed);
1063+
monitorEventLatch.countDown();
10711064
}
1065+
} catch (final IOException e) {
1066+
// ignore non-JSON frames
10721067
}
10731068
}
10741069
});

exist-core/src/test/java/org/exist/xquery/functions/websocket/WebSocketEndpointTest.java

Lines changed: 20 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -45,31 +45,23 @@ public class WebSocketEndpointTest {
4545
public static final ExistWebServer existWebServer = new ExistWebServer(true, false, true, true);
4646

4747
@Test
48-
public void connectAndReceiveHeartbeat() throws Exception {
48+
public void connectAndHeartbeatKeepsSessionOpen() throws Exception {
4949
final int port = existWebServer.getPort();
5050
final URI wsUri = new URI("ws://localhost:" + port + "/ws");
5151

52-
final CountDownLatch messageLatch = new CountDownLatch(1);
53-
final AtomicReference<String> receivedMessage = new AtomicReference<>();
54-
5552
final WebSocketContainer container = ContainerProvider.getWebSocketContainer();
5653
final Session session = container.connectToServer(new Endpoint() {
5754
@Override
5855
public void onOpen(final Session session, final EndpointConfig config) {
59-
session.addMessageHandler(new MessageHandler.Whole<String>() {
60-
@Override
61-
public void onMessage(final String message) {
62-
receivedMessage.set(message);
63-
messageLatch.countDown();
64-
}
65-
});
6656
}
6757
}, ClientEndpointConfig.Builder.create().build(), wsUri);
6858

6959
try {
70-
// should receive a heartbeat ping within 1 second
71-
assertTrue("Should receive a message within 2s", messageLatch.await(2, TimeUnit.SECONDS));
72-
assertEquals("ping", receivedMessage.get());
60+
// The heartbeat sends WebSocket PING control frames every 500ms.
61+
// They are handled transparently by the WS layer and do not fire onMessage.
62+
// The observable effect is that the session stays open.
63+
Thread.sleep(1500);
64+
assertTrue("Session should remain open after heartbeat interval", session.isOpen());
7365
} finally {
7466
session.close();
7567
}
@@ -88,24 +80,17 @@ public void subscribeToChannelAndReceiveMessage() throws Exception {
8880
final Session session = container.connectToServer(new Endpoint() {
8981
@Override
9082
public void onOpen(final Session session, final EndpointConfig config) {
83+
try {
84+
session.getBasicRemote().sendText("{\"channel\": \"test-channel\"}");
85+
subscribedLatch.countDown();
86+
} catch (final IOException e) {
87+
throw new UncheckedIOException(e);
88+
}
9189
session.addMessageHandler(new MessageHandler.Whole<String>() {
92-
private boolean subscribed = false;
93-
9490
@Override
9591
public void onMessage(final String message) {
96-
if (!subscribed && "ping".equals(message)) {
97-
// after first ping, subscribe to a channel
98-
try {
99-
session.getBasicRemote().sendText("{\"channel\": \"test-channel\"}");
100-
subscribed = true;
101-
subscribedLatch.countDown();
102-
} catch (IOException e) {
103-
throw new UncheckedIOException(e);
104-
}
105-
} else if (subscribed && !"ping".equals(message)) {
106-
receivedMessage.set(message);
107-
messageLatch.countDown();
108-
}
92+
receivedMessage.set(message);
93+
messageLatch.countDown();
10994
}
11095
});
11196
}
@@ -141,19 +126,12 @@ public void channelCountReflectsSubscribers() throws Exception {
141126
final Session session = container.connectToServer(new Endpoint() {
142127
@Override
143128
public void onOpen(final Session session, final EndpointConfig config) {
144-
session.addMessageHandler(new MessageHandler.Whole<String>() {
145-
@Override
146-
public void onMessage(final String message) {
147-
if ("ping".equals(message)) {
148-
try {
149-
session.getBasicRemote().sendText("{\"channel\": \"count-test\"}");
150-
subscribedLatch.countDown();
151-
} catch (IOException e) {
152-
throw new UncheckedIOException(e);
153-
}
154-
}
155-
}
156-
});
129+
try {
130+
session.getBasicRemote().sendText("{\"channel\": \"count-test\"}");
131+
subscribedLatch.countDown();
132+
} catch (final IOException e) {
133+
throw new UncheckedIOException(e);
134+
}
157135
}
158136
}, ClientEndpointConfig.Builder.create().build(), wsUri);
159137

0 commit comments

Comments
 (0)