Skip to content

Commit 716c542

Browse files
authored
If we receive an upgrade request after not upgrading, return redirect and close connection (#108)
* If we receive an websocket upgrade after not upgrading redirect and close connection * Include query in redirect * Fix for 5.10 * Add test for closing conenction on unexpected upgrade * Add detailed comment
1 parent c6f5ff0 commit 716c542

File tree

3 files changed

+69
-3
lines changed

3 files changed

+69
-3
lines changed

Sources/HummingbirdWebSocket/WebSocketChannel.swift

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ public struct HTTP1WebSocketUpgradeChannel: ServerChildChannel, HTTPChannelHandl
9898
}
9999
}
100100
}
101-
self.responder = responder
101+
self.responder = Self.getUpgradeResponder(responder)
102102
}
103103

104104
@available(*, deprecated, renamed: "init(responder:configuration:additionalChannelHandlers:shouldUpgrade:)")
@@ -167,7 +167,46 @@ public struct HTTP1WebSocketUpgradeChannel: ServerChildChannel, HTTPChannelHandl
167167
}
168168
return promise.futureResult
169169
}
170-
self.responder = responder
170+
self.responder = Self.getUpgradeResponder(responder)
171+
}
172+
173+
/// Return HTTP responder that responds with a redirect and connection closure on receiving an
174+
/// upgrade header set to websocket
175+
///
176+
/// The responder passed in as a parameter is called from the resultant responder if no upgrade
177+
/// header is found.
178+
///
179+
/// This is a temporary solution to the fact that the NIO upgrade code does not support parsing
180+
/// upgrade headers after having received a normal HTTP request. By returning a redirect to the
181+
/// same URI and closing the connection we are forcing the client to open a new connection
182+
/// where the upgrade code path will run.
183+
///
184+
/// - Parameter responder: HTTP responder to call
185+
/// - Returns: Result of HTTP responder or redirect
186+
static func getUpgradeResponder(_ responder: @escaping HTTPChannelHandler.Responder) -> HTTPChannelHandler.Responder {
187+
struct RedirectCloseError: Error {}
188+
return {
189+
(
190+
request: Request,
191+
responseWriter: consuming ResponseWriter,
192+
channel: Channel
193+
) in
194+
if request.headers[.upgrade] == "websocket" {
195+
var path = request.uri.path
196+
if let query = request.uri.query {
197+
path += "?\(query)"
198+
}
199+
let headers: HTTPFields = [
200+
.connection: "close",
201+
.location: path,
202+
]
203+
let response = HTTPResponse(status: .temporaryRedirect, headerFields: headers)
204+
try await responseWriter.writeResponse(response)
205+
throw RedirectCloseError()
206+
} else {
207+
try await responder(request, responseWriter, channel)
208+
}
209+
}
171210
}
172211

173212
/// Setup channel to accept HTTP1 with a WebSocket upgrade

Sources/HummingbirdWebSocket/WebSocketRouter.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ extension HTTP1WebSocketUpgradeChannel {
192192
}
193193
return promise.futureResult
194194
}
195-
self.responder = responder
195+
self.responder = Self.getUpgradeResponder(responder)
196196
}
197197
}
198198

Tests/HummingbirdWebSocketTests/WebSocketTests.swift

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,4 +866,31 @@ final class HummingbirdWebSocketTests: XCTestCase {
866866
try await httpClient.shutdown()
867867
}
868868

869+
func testUpgradeAfterNotUpgraded() async throws {
870+
let router = Router()
871+
router.get("/") { _, _ in
872+
"Helllo"
873+
}
874+
let app = Application(
875+
router: router,
876+
server: .http1WebSocketUpgrade { _, _, _ in
877+
.dontUpgrade
878+
}
879+
)
880+
try await app.test(.live) { client in
881+
try await client.execute(uri: "/", method: .get) { response in
882+
XCTAssertEqual(response.status, .ok)
883+
}
884+
// perform upgrade
885+
try await client.execute(uri: "/test?this=that", method: .get, headers: [.upgrade: "websocket"]) { response in
886+
XCTAssertEqual(response.status, .temporaryRedirect)
887+
XCTAssertEqual(response.headers[.location], "/test?this=that")
888+
}
889+
// check channel has been closed
890+
do {
891+
try await client.execute(uri: "/", method: .get)
892+
} catch let error as ChannelError where error == .ioOnClosedChannel {
893+
}
894+
}
895+
}
869896
}

0 commit comments

Comments
 (0)