Skip to content

KTOR-8208 Java: Make webSocket field nullable #4688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ internal class JavaHttpWebSocket(
private val requestTime: GMTDate = GMTDate()
) : WebSocket.Listener, WebSocketSession {

private lateinit var webSocket: WebSocket
private var _webSocket: WebSocket? = null
private val webSocket: WebSocket
get() = checkNotNull(_webSocket) { "Web socket is not connected yet." }

private val socketJob = Job(callContext[Job])
private val _incoming = Channel<Frame>(Channel.UNLIMITED)
private val _outgoing = Channel<Frame>(Channel.UNLIMITED)
Expand All @@ -90,50 +93,45 @@ internal class JavaHttpWebSocket(
get() = emptyList()

init {
launch {
launch(CoroutineName("java-ws-outgoing")) {
_outgoing.consumeEach { frame ->
when (frame.frameType) {
FrameType.TEXT -> {
webSocket.sendText(String(frame.data), frame.fin).await()
}

FrameType.BINARY -> {
webSocket.sendBinary(frame.buffer, frame.fin).await()
}

FrameType.CLOSE -> {
val data = buildPacket { writeFully(frame.data) }
val code = data.readShort().toInt()
val reason = data.readText()
webSocket.sendClose(code, reason).await()
socketJob.complete()
return@launch
}

FrameType.PING -> {
webSocket.sendPing(frame.buffer).await()
}

FrameType.PONG -> {
webSocket.sendPong(frame.buffer).await()
}
webSocket.sendFrame(frame)
if (frame.frameType == FrameType.CLOSE) {
socketJob.complete()
return@launch
}
}
}

GlobalScope.launch(callContext, start = CoroutineStart.ATOMIC) {
GlobalScope.launch(callContext + CoroutineName("java-ws-closer"), start = CoroutineStart.ATOMIC) {
try {
socketJob[Job]!!.join()
} catch (cause: Throwable) {
val code = CloseReason.Codes.INTERNAL_ERROR.code.toInt()
webSocket.sendClose(code, "Client failed")
_webSocket?.sendClose(code, "Client failed")
} finally {
_incoming.close()
_outgoing.cancel()
}
}
}

private suspend fun WebSocket.sendFrame(frame: Frame) {
when (frame.frameType) {
FrameType.TEXT -> sendText(String(frame.data), frame.fin).await()
FrameType.BINARY -> sendBinary(frame.buffer, frame.fin).await()
FrameType.PING -> sendPing(frame.buffer).await()
FrameType.PONG -> sendPong(frame.buffer).await()

FrameType.CLOSE -> {
val data = buildPacket { writeFully(frame.data) }
val code = data.readShort().toInt()
val reason = data.readText()
sendClose(code, reason).await()
}
}
}

@OptIn(InternalAPI::class)
suspend fun getResponse(): HttpResponseData {
val builder = httpClient.newWebSocketBuilder()
Expand Down Expand Up @@ -163,7 +161,7 @@ internal class JavaHttpWebSocket(
var status = HttpStatusCode.SwitchingProtocols
var headers: Headers
try {
webSocket = builder.buildAsync(requestData.url.toURI(), this).await()
_webSocket = builder.buildAsync(requestData.url.toURI(), this).await()
val protocol = webSocket.subprotocol?.takeIf { it.isNotEmpty() }
headers = if (protocol != null) headersOf(HttpHeaders.SecWebSocketProtocol, protocol) else Headers.Empty
} catch (cause: WebSocketHandshakeException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import io.ktor.utils.io.charsets.*
import io.ktor.websocket.*
import kotlinx.coroutines.*
import kotlinx.coroutines.test.runTest
import kotlinx.io.IOException
import kotlin.test.*
import kotlin.time.Duration.Companion.seconds

Expand Down Expand Up @@ -261,6 +262,24 @@ class WebSocketTest : ClientLoader() {
}
}

@Test
fun testConnectionTimeoutExceeded() = clientTests(except(ENGINES_WITHOUT_WS)) {
val nonExistingHost = "ws://192.0.2.0" // RFC 5737: TEST-NET-1

config {
install(WebSockets)
install(HttpTimeout) { connectTimeoutMillis = 1 }
}

test { client ->
val exception = assertFailsWith<IOException> {
client.webSocket(nonExistingHost) { fail("Shouldn't be reached") }
}

assertTrue(exception.suppressedExceptions.isEmpty(), "No exception should be suppressed")
}
}

@Test
fun testCountPong() = clientTests(except(ENGINES_WITHOUT_WS + "Js")) {
config {
Expand Down