Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/forty-results-speak.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@hono/node-ws': minor
---

Fix WebSocket connections failing when the endpoint is registered under `app.route()`.
57 changes: 57 additions & 0 deletions packages/node-ws/src/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -312,4 +312,61 @@ describe('WebSocket helper', () => {
expect(clientWs).toBeTruthy()
expect(wss.clients.size).toBe(1)
})

it('Should work with app.route()', async () => {
const subApp = new Hono()
subApp.get(
'/ws',
upgradeWebSocket(() => ({
onOpen(_, ws) {
ws.send('Hello from sub app')
},
}))
)

app.route('/sub', subApp)

const ws = new WebSocket('ws://localhost:3030/sub/ws')
const mainPromise = new Promise<string>((resolve, reject) => {
ws.onmessage = (event) => {
resolve(event.data as string)
}
ws.onerror = () => {
reject(new Error('WebSocket error'))
}
})

expect(await mainPromise).toBe('Hello from sub app')
ws.close()
})

it('Should work with nested app.route()', async () => {
const subSubApp = new Hono()
subSubApp.get(
'/ws',
upgradeWebSocket(() => ({
onOpen(_, ws) {
ws.send('Hello from nested')
},
}))
)

const subApp = new Hono()
subApp.route('/nested', subSubApp)

app.route('/sub', subApp)

const ws = new WebSocket('ws://localhost:3030/sub/nested/ws')
const mainPromise = new Promise<string>((resolve, reject) => {
ws.onmessage = (event) => {
resolve(event.data as string)
}
ws.onerror = () => {
reject(new Error('WebSocket error'))
}
})

expect(await mainPromise).toBe('Hello from nested')
ws.close()
})
})
31 changes: 16 additions & 15 deletions packages/node-ws/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,20 @@ const CONNECTION_SYMBOL_KEY: unique symbol = Symbol('CONNECTION_SYMBOL_KEY')
*/
export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => {
const wss = new WebSocketServer({ noServer: true })
const waiterMap = new Map<
IncomingMessage,
{ resolve: (ws: WebSocket) => void; connectionSymbol: symbol }
>()
const upgradeAllowed = new WeakSet<IncomingMessage>()
const waiterMap = new Map<IncomingMessage, (ws: WebSocket) => void>()

wss.on('connection', (ws, request) => {
const waiter = waiterMap.get(request)
if (waiter) {
waiter.resolve(ws)
const resolve = waiterMap.get(request)
if (resolve) {
resolve(ws)
waiterMap.delete(request)
}
})

const nodeUpgradeWebSocket = (request: IncomingMessage, connectionSymbol: symbol) => {
const nodeUpgradeWebSocket = (request: IncomingMessage) => {
return new Promise<WebSocket>((resolve) => {
waiterMap.set(request, { resolve, connectionSymbol })
waiterMap.set(request, resolve)
})
}

Expand All @@ -72,15 +70,13 @@ export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => {
const env: {
incoming: IncomingMessage
outgoing: undefined
[CONNECTION_SYMBOL_KEY]?: symbol
} = {
incoming: request,
outgoing: undefined,
}
await init.app.request(url, { headers: headers }, env)
const waiter = waiterMap.get(request)

if (!waiter || waiter.connectionSymbol !== env[CONNECTION_SYMBOL_KEY]) {
if (!upgradeAllowed.has(request)) {
socket.end(
'HTTP/1.1 400 Bad Request\r\n' +
'Connection: close\r\n' +
Expand All @@ -91,6 +87,9 @@ export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => {
return
}

// Remove the mark after checking to prevent memory leak
upgradeAllowed.delete(request)

wss.handleUpgrade(request, socket, head, (ws) => {
wss.emit('connection', ws, request)
})
Expand All @@ -102,10 +101,12 @@ export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => {
return
}

const connectionSymbol = generateConnectionSymbol()
c.env[CONNECTION_SYMBOL_KEY] = connectionSymbol
const request = c.env.incoming as IncomingMessage

// Instead of writing to c.env, use a WeakSet to track the request object directly
upgradeAllowed.add(request)
;(async () => {
const ws = await nodeUpgradeWebSocket(c.env.incoming, connectionSymbol)
const ws = await nodeUpgradeWebSocket(request)

// buffer messages to handle messages received before the events are set up
const messagesReceivedInStarting: [data: WebSocket.RawData, isBinary: boolean][] = []
Expand Down
Loading