Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions ziti/src/integrationTest/kotlin/org/openziti/impl/LoadTests.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) 2018-2025 NetFoundry Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.openziti.impl

import kotlinx.serialization.json.Json
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.openziti.IdentityConfig
import org.openziti.Ziti
import org.openziti.identity.keystoreFromConfig
import org.openziti.integ.BaseTest
import java.io.ByteArrayOutputStream

class LoadTests: BaseTest() {
lateinit var cfg: IdentityConfig
@BeforeEach
fun before() {
cfg = createIdentity()
}

@AfterEach
fun after() {
Ziti.getContexts().forEach {
it.destroy()
}
}

@Test
fun testLoadConfigByteArray() {
val cfgBytes = Json.encodeToString(cfg).toByteArray(Charsets.UTF_8)

Ziti.init(cfgBytes, false)

val contexts = Ziti.getContexts()
assertEquals(1, contexts.size)
}

@Test
fun testLoadKeyStoreByteArray() {
val ks = keystoreFromConfig(cfg)
val output = ByteArrayOutputStream()
ks.store(output, charArrayOf())

val storeBtes = output.toByteArray()
Ziti.init(storeBtes, false)

val contexts = Ziti.getContexts()
assertEquals(1, contexts.size)
}
}
11 changes: 10 additions & 1 deletion ziti/src/main/kotlin/org/openziti/Ziti.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

package org.openziti

import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.openziti.api.Service
import org.openziti.identity.Enroller
import org.openziti.impl.ZitiImpl
Expand All @@ -29,6 +34,8 @@ import java.io.InputStream
import java.net.InetSocketAddress
import java.net.SocketAddress
import java.security.KeyStore
import java.util.function.Consumer
import java.util.stream.Stream
import javax.net.SocketFactory
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLSocketFactory
Expand Down Expand Up @@ -122,16 +129,18 @@ object Ziti {
fun connect(addr: SocketAddress): ZitiConnection = ZitiImpl.connect(addr)

@JvmStatic
fun getContexts(): Collection<ZitiContext> = ZitiImpl.contexts
fun getContexts(): Collection<ZitiContext> = ZitiImpl.contexts.value

@JvmStatic
fun setApplicationInfo(id: String, version: String) = ZitiImpl.setApplicationInfo(id, version)

@JvmStatic
fun getServiceFor(host: String, port: Int): Pair<ZitiContext, Service>? = ZitiImpl.getServiceFor(host, port)

@JvmStatic
fun getServiceFor(addr: InetSocketAddress): Pair<ZitiContext, Service>? = ZitiImpl.getServiceFor(addr)

@JvmStatic
fun findDialInfo(addr: InetSocketAddress): Pair<ZitiContext, SocketAddress>? = ZitiImpl.findDialInfo(addr)

fun serviceUpdates(): Flow<Pair<ZitiContext, ZitiContext.ServiceEvent>> = ZitiImpl.serviceUpdates()
Expand Down
9 changes: 3 additions & 6 deletions ziti/src/main/kotlin/org/openziti/impl/ZitiContextImpl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,6 @@ internal class ZitiContextImpl(internal val id: Identity, enabled: Boolean) : Zi
return conn
}

internal fun dialById(serviceId: String): ZitiConnection =
servicesById[serviceId]?.let {
dial(it.name)
} ?: throw ZitiException(ZitiException.Errors.ServiceNotAvailable)


internal fun dial(host: String, port: Int): ZitiConnection {
val service = getServiceForAddress(host, port) ?: throw ZitiException(Errors.ServiceNotAvailable)
return dial(service.name)
Expand Down Expand Up @@ -325,6 +319,9 @@ internal class ZitiContextImpl(internal val id: Identity, enabled: Boolean) : Zi
}

override fun destroy() {
Ziti.removeContext(this)
if (supervisor.isCompleted) return

d{"stopping networking"}
stop()
d{"stopping controller"}
Expand Down
40 changes: 21 additions & 19 deletions ziti/src/main/kotlin/org/openziti/impl/ZitiImpl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.openziti.*
import org.openziti.api.Service
import org.openziti.identity.Enroller
import org.openziti.identity.KeyStoreIdentity
import org.openziti.identity.findIdentityAlias
import org.openziti.identity.loadKeystore
Expand All @@ -38,7 +37,7 @@ import java.net.URI
import java.security.KeyStore

internal object ZitiImpl : Logged by ZitiLog() {
internal val contexts = mutableListOf<ZitiContextImpl>()
internal val contexts = MutableStateFlow<Collection<ZitiContextImpl>>(emptyList())
internal var appId = ""
internal var appVersion = ""

Expand All @@ -48,7 +47,7 @@ internal object ZitiImpl : Logged by ZitiLog() {
try {
Class.forName("android.util.Log")
true
} catch (cnf: ClassNotFoundException) {
} catch (_: ClassNotFoundException) {
false
}
}
Expand All @@ -59,7 +58,7 @@ internal object ZitiImpl : Logged by ZitiLog() {

fun loadContext(cfg: IdentityConfig, enabled: Boolean) =
ZitiContextImpl(cfg, enabled).also { ztx ->
contexts.add(ztx)
contexts.value += ztx
ztx.launch {
ztxEvents.emit(Ziti.IdentityEvent(Ziti.IdentityEventType.Loaded, ztx))
ztx.serviceUpdates().collect {
Expand All @@ -73,7 +72,7 @@ internal object ZitiImpl : Logged by ZitiLog() {
val idName = alias ?: findIdentityAlias(ks)
val id = KeyStoreIdentity(ks, idName)
return ZitiContextImpl(id, true).also { ctx ->
contexts.add(ctx)
contexts.value += ctx
ctx.launch {
ztxEvents.emit(Ziti.IdentityEvent(Ziti.IdentityEventType.Loaded, ctx))
ctx.serviceUpdates().collect {
Expand All @@ -95,10 +94,11 @@ internal object ZitiImpl : Logged by ZitiLog() {
return loadContext(ks, alias)
}

internal fun loadContext(id: ByteArray): ZitiContextImpl {
val ks = loadKeystore(id)
return loadContext(ks, null)
}
internal fun loadContext(id: ByteArray): ZitiContextImpl =
id.inputStream().use {
val ks = loadKeystore(it, charArrayOf())
loadContext(ks, null)
}

fun init(c: ByteArray, seamless: Boolean) {
initInternalNetworking(seamless)
Expand All @@ -115,14 +115,16 @@ internal object ZitiImpl : Logged by ZitiLog() {
}

fun removeContext(ctx: ZitiContext) {
contexts.remove(ctx)
if(ctx is ZitiContextImpl) {

if(ctx is ZitiContextImpl && contexts.value.contains(ctx)) {
val ctxs = contexts.value - ctx
contexts.value = ctxs
runBlocking { ztxEvents.emit(Ziti.IdentityEvent(Ziti.IdentityEventType.Removed, ctx)) }
ctx.destroy()
}
}

fun init(ks: KeyStore, seamless: Boolean): List<ZitiContext> {
fun init(ks: KeyStore, seamless: Boolean): Collection<ZitiContext> {
initInternalNetworking(seamless)

for (a in ks.aliases()) {
Expand All @@ -131,7 +133,7 @@ internal object ZitiImpl : Logged by ZitiLog() {
}
}

return contexts
return contexts.value
}

private fun isZitiIdentity(ks: KeyStore, alias: String): Boolean =
Expand All @@ -155,9 +157,9 @@ internal object ZitiImpl : Logged by ZitiLog() {
return ZitiDNSManager.lookup(addr.address)?.let { getServiceFor(it, addr.port) }
}

fun getServiceFor(host: String, port: Int): Pair<ZitiContext, Service>? = contexts.map { c ->
c.getServiceForAddress(host, port)?.let { Pair(c, it) }
}.filterNotNull().firstOrNull()
fun getServiceFor(host: String, port: Int): Pair<ZitiContext, Service>? = contexts.value.firstNotNullOfOrNull { c ->
c.getServiceForAddress(host, port)?.let { Pair(c, it) }
}

fun connect(addr: SocketAddress): ZitiConnection {
when (addr) {
Expand All @@ -166,7 +168,7 @@ internal object ZitiImpl : Logged by ZitiLog() {
return ztx.dial(svc.name)
}
is ZitiAddress.Dial -> {
for (c in contexts) {
for (c in contexts.value) {
c.getService(addr.service)?.let {
return c.dial(addr)
}
Expand All @@ -189,14 +191,14 @@ internal object ZitiImpl : Logged by ZitiLog() {

private val ztxEvents = MutableSharedFlow<Ziti.IdentityEvent>()
internal fun getEvents(): Flow<Ziti.IdentityEvent> = flow {
contexts.forEach {
contexts.value.forEach {
emit(Ziti.IdentityEvent(Ziti.IdentityEventType.Loaded, it))
}
emitAll(ztxEvents)
}

fun findDialInfo(addr: InetSocketAddress): Pair<ZitiContext, SocketAddress>? {
for (c in contexts) {
for (c in contexts.value) {
val dial = c.getDialAddress(addr)
if (dial != null) {
return c to dial
Expand Down
Loading