Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ download = "5.6.0"

# OpenZiti Edge API
ziti-api = "0.26.42"
ziti-cli = "1.3.3"
ziti-cli = "1.5.4"

# third party
lazysodium-java = "5.1.4"
Expand Down
9 changes: 5 additions & 4 deletions ziti/src/main/kotlin/org/openziti/impl/ChannelImpl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ internal class ChannelImpl(val addr: String, val sslContext: SSLContext, val api
private val synchers = ConcurrentHashMap<Int, CompletableDeferred<Unit>>()

private val recMutex = Mutex()
private val receivers = mutableMapOf<Int, Channel.MessageReceiver>()
private val receivers = mutableMapOf<UInt, Channel.MessageReceiver>()
private val chState = MutableStateFlow<Channel.State>(Channel.State.Initial)
private val reconnectSignal = kotlinx.coroutines.channels.Channel<Unit>()

Expand All @@ -73,13 +73,14 @@ internal class ChannelImpl(val addr: String, val sslContext: SSLContext, val api
get() = chState.value


override fun registerReceiver(id: Int, rec: Channel.MessageReceiver) = runBlocking{
override fun registerReceiver(id: UInt, rec: Channel.MessageReceiver) = runBlocking{
recMutex.withLock {
receivers[id] = rec
}
}

override fun deregisterReceiver(id: Int): Unit = runBlocking {

override fun deregisterReceiver(id: UInt): Unit = runBlocking {
recMutex.withLock { receivers.remove(id) }
}

Expand Down Expand Up @@ -243,7 +244,7 @@ internal class ChannelImpl(val addr: String, val sslContext: SSLContext, val api
if (waiter != null) {
waiter.complete(m)
} else {
val recId = m.getIntHeader(ZitiProtocol.Header.ConnId)
val recId = m.getIntHeader(ZitiProtocol.Header.ConnId)?.toUInt()
recId?.let {
val receiver = recMutex.withLock { receivers[it] }

Expand Down
4 changes: 2 additions & 2 deletions ziti/src/main/kotlin/org/openziti/impl/ZitiContextImpl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ internal class ZitiContextImpl(internal val id: Identity, enabled: Boolean) : Zi

private val connCounter = AtomicInteger(0)

private val connections = sortedMapOf<Int, ZitiConnection>()
private val connections = sortedMapOf<UInt, ZitiConnection>()

init {
this._enabled = enabled
Expand Down Expand Up @@ -527,7 +527,7 @@ internal class ZitiContextImpl(internal val id: Identity, enabled: Boolean) : Zi
}.getOrElse { throw TimeoutException("failed to get service[$host:$port] in ${timeout}ms") }
}

internal fun nextConnId() = connCounter.incrementAndGet()
internal fun nextConnId() = connCounter.incrementAndGet().toUInt()

internal val channels = ConcurrentHashMap<String, Channel>()

Expand Down
4 changes: 2 additions & 2 deletions ziti/src/main/kotlin/org/openziti/net/Channel.kt
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ internal interface Channel: Closeable {
val name: String
val state: State

fun deregisterReceiver(id: Int)
fun registerReceiver(id: Int, rec: MessageReceiver)
fun deregisterReceiver(id: UInt)
fun registerReceiver(id: UInt, rec: MessageReceiver)

suspend fun Send(msg: Message)
suspend fun SendSynch(msg: Message)
Expand Down
7 changes: 7 additions & 0 deletions ziti/src/main/kotlin/org/openziti/net/Message.kt
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ class Message(
headers.put(header.id, v.toByteArray())
}

fun setHeader(header: ZitiProtocol.Header, v: UInt): Message = this.setHeader(header, v.toInt())

fun setHeader(header: ZitiProtocol.Header, v: Int): Message = this.apply {
val b = ByteArray(4)
ByteBuffer.wrap(b).order(ByteOrder.LITTLE_ENDIAN).putInt(v)
Expand All @@ -128,6 +130,11 @@ class Message(
headers.put(header.id, b)
}

fun setHeader(header: ZitiProtocol.Header, b: Boolean) = this.apply {
val v: Byte = if (b) 1 else 0
headers.put(header.id, byteArrayOf(v))
}

fun setHeader(headerId: Int, v: Boolean) = this.apply {
val b: Byte = if (v) 1 else 0
headers.put(headerId, byteArrayOf(b))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ internal class ZitiServerSocketChannel(val ctx: ZitiContextImpl):
private val listenerId = ByteArray(32).apply {
SecureRandom().nextBytes(this)
}
val connId: Int = ctx.nextConnId()
val connId: UInt = ctx.nextConnId()
var state: State = State.initial
lateinit var incoming: Chan<Message>
lateinit var token: String
Expand Down Expand Up @@ -135,6 +135,7 @@ internal class ZitiServerSocketChannel(val ctx: ZitiContextImpl):
setHeader(Header.ConnId, connId)
setHeader(Header.SeqHeader, 0)
setHeader(Header.ListenerId, listenerId)
setHeader(Header.RouterProvidedConnId, true)

val bindId = localAddr.identity ?: (if (localAddr.useEdgeId) ctx.getId()?.name else null)

Expand Down Expand Up @@ -192,8 +193,13 @@ internal class ZitiServerSocketChannel(val ctx: ZitiContextImpl):

val child = ZitiSocketChannel(ctx)
d{"accepting child conn[${child.connId}] on parent[$connId]"}
req.getIntHeader(Header.RouterProvidedConnId)?.toUInt()?.let {
d{"setting child[${child.connId}].rtConnId = $it (router provided)"}
child.rtConnId = it
}

val connIdBuf = ByteArray(4)
ByteBuffer.wrap(connIdBuf).order(ByteOrder.LITTLE_ENDIAN).putInt(child.connId)
ByteBuffer.wrap(connIdBuf).order(ByteOrder.LITTLE_ENDIAN).putInt(child.connId.toInt())
val dialSuccess = Message(ZitiProtocol.ContentType.DialSuccess, connIdBuf)
dialSuccess.setHeader(Header.SeqHeader, 0)
dialSuccess.setHeader(Header.ConnId, connId)
Expand All @@ -206,21 +212,21 @@ internal class ZitiServerSocketChannel(val ctx: ZitiContextImpl):
child.setupCrypto(sessKeys)
} ?: child.setupCrypto(null)

val startMsg = ch.SendAndWait(dialSuccess)

if (startMsg.content == ZitiProtocol.ContentType.StateConnected) {
runCatching {
ch.SendSynch(dialSuccess)
}.onSuccess {
child.state.set(ZitiSocketChannel.State.connected)
ch.registerReceiver(child.connId, child)
ch.registerReceiver(child.rtConnId, child)
child.channel.complete(ch)
child.startCrypto(ch)
child.local = localAddr
child.remote = ZitiAddress.Session("$connId", localAddr.service,
req.getStringHeader(Header.CallerIdHeader), req.getHeader(Header.AppDataHeader))

handler.completed(child, att)
} else {
val err = Charsets.UTF_8.decode(ByteBuffer.wrap(startMsg.body)).toString()
handler.failed(IOException(err), att)
}.onFailure { t ->
val err = t.message
handler.failed(IOException(err, t), att)
}
} catch (ex: Throwable) {
when (ex) {
Expand Down
17 changes: 10 additions & 7 deletions ziti/src/main/kotlin/org/openziti/net/ZitiSocketChannel.kt
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ import java.util.concurrent.atomic.AtomicReference
import kotlin.text.Charsets.UTF_8
import kotlinx.coroutines.channels.Channel as Chan

internal class ZitiSocketChannel private constructor(internal val ctx: ZitiContextImpl, val connId: Int):
internal class ZitiSocketChannel private constructor(internal val ctx: ZitiContextImpl, val connId: UInt):
AsynchronousSocketChannel(Provider),
Channel.MessageReceiver,
ZitiConnection,
Expand Down Expand Up @@ -104,6 +104,9 @@ internal class ZitiSocketChannel private constructor(internal val ctx: ZitiConte
override val inputSupport = InputChannel.InputSupport(receiveQueue)
val crypto = CompletableDeferred<Crypto.SecretStream?>()

/** router provided connection id */
internal var rtConnId: UInt = connId

override fun getLocalAddress(): SocketAddress? = local

override fun getRemoteAddress(): SocketAddress? = remote
Expand Down Expand Up @@ -220,7 +223,7 @@ internal class ZitiSocketChannel private constructor(internal val ctx: ZitiConte
private fun deregister() {
ctx.launch {
channel.runCatching {
await().deregisterReceiver(connId)
await().deregisterReceiver(rtConnId)
}
}
}
Expand All @@ -239,7 +242,7 @@ internal class ZitiSocketChannel private constructor(internal val ctx: ZitiConte
override fun shutdownOutput(): AsynchronousSocketChannel {
if (state.get() == State.connected && sentFin.compareAndSet(false, true)) {
val finMsg = Message(ZitiProtocol.ContentType.Data).apply {
setHeader(Header.ConnId, connId)
setHeader(Header.ConnId, rtConnId)
setHeader(Header.FlagsHeader, ZitiProtocol.EdgeFlags.FIN)
setHeader(Header.SeqHeader, seq.getAndIncrement())
}
Expand All @@ -264,7 +267,7 @@ internal class ZitiSocketChannel private constructor(internal val ctx: ZitiConte
state.set(State.closed)
State.connecting, State.connected -> {
val closeMsg = Message(ZitiProtocol.ContentType.StateClosed).apply {
setHeader(Header.ConnId, connId)
setHeader(Header.ConnId, rtConnId)
}
d("closing conn = ${this.connId}")
ctx.async {
Expand Down Expand Up @@ -343,7 +346,7 @@ internal class ZitiSocketChannel private constructor(internal val ctx: ZitiConte
}

val dataMessage = Message(ZitiProtocol.ContentType.Data, data)
dataMessage.setHeader(Header.ConnId, connId)
dataMessage.setHeader(Header.ConnId, rtConnId)
dataMessage.setHeader(Header.SeqHeader, seq.getAndIncrement())
v("sending $dataMessage")
channel.await().Send(dataMessage)
Expand Down Expand Up @@ -434,7 +437,7 @@ internal class ZitiSocketChannel private constructor(internal val ctx: ZitiConte

internal suspend fun doZitiHandshake(ch: Channel, remote: ZitiAddress.Dial, ns: Session, kp: KeyPair?) {
val connectMsg = Message(ZitiProtocol.ContentType.Connect, ns.token.toByteArray(UTF_8)).apply {
setHeader(Header.ConnId, connId)
setHeader(Header.ConnId, rtConnId)
setHeader(Header.SeqHeader, 0)
kp?.let {
setHeader(Header.PublicKeyHeader, it.publicKey.asBytes)
Expand Down Expand Up @@ -502,7 +505,7 @@ internal class ZitiSocketChannel private constructor(internal val ctx: ZitiConte
crypto.await()?.let {
val header = it.header()
val headerMessage = Message(ZitiProtocol.ContentType.Data, header)
.setHeader(Header.ConnId, connId)
.setHeader(Header.ConnId, rtConnId)
.setHeader(Header.SeqHeader, seq.getAndIncrement())
ch.Send(headerMessage)
}
Expand Down
Loading