Skip to content

Commit efbc391

Browse files
DO NOT MERGE: Socket test better injection
1 parent a65d7fd commit efbc391

File tree

7 files changed

+231
-184
lines changed

7 files changed

+231
-184
lines changed

airbyte-cdk/bulk/core/base/src/main/kotlin/io/airbyte/cdk/output/OutputConsumer.kt

+4-6
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import java.util.function.Consumer
3333

3434
/** Emits the [AirbyteMessage] instances produced by the connector. */
3535
//@DefaultImplementation(StdoutOutputConsumer::class)
36-
@DefaultImplementation(UnixDomainSocketOutputConsumer::class)
36+
@DefaultImplementation(UnixDomainSocketOutputConsumerProvider::class)
3737
abstract class OutputConsumer(private val clock: Clock) : Consumer<AirbyteMessage>, AutoCloseable {
3838
/**
3939
* The constant emittedAt timestamp we use for record timestamps.
@@ -108,7 +108,9 @@ abstract class OutputConsumer(private val clock: Clock) : Consumer<AirbyteMessag
108108
)
109109
}
110110

111-
abstract fun getS(num: Int): List<OutputConsumer>?
111+
open fun getSocketConsumer(part: Int): UnixDomainSocketOutputConsumer {
112+
throw UnsupportedOperationException("Not implemented")
113+
}
112114
}
113115

114116
/** Configuration properties prefix for [StdoutOutputConsumer]. */
@@ -253,10 +255,6 @@ open class StdoutOutputConsumer(
253255
private val namespacedTemplates = ConcurrentHashMap<String, StreamToTemplateMap>()
254256
private val unNamespacedTemplates = StreamToTemplateMap()
255257

256-
override fun getS(num: Int): List<OutputConsumer>? {
257-
return null
258-
}
259-
260258
companion object {
261259
const val META_PREFIX = ""","meta":"""
262260
}

airbyte-cdk/bulk/core/base/src/main/kotlin/io/airbyte/cdk/output/UnixDomainSocketOutputConsumer.kt

-130
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
package io.airbyte.cdk.output
2+
3+
import com.fasterxml.jackson.annotation.JsonInclude
4+
import com.fasterxml.jackson.core.JsonGenerator
5+
import com.fasterxml.jackson.core.util.MinimalPrettyPrinter
6+
import com.fasterxml.jackson.databind.DeserializationFeature
7+
import com.fasterxml.jackson.databind.ObjectMapper
8+
import com.fasterxml.jackson.databind.SequenceWriter
9+
import com.fasterxml.jackson.databind.node.ObjectNode
10+
import com.fasterxml.jackson.dataformat.smile.databind.SmileMapper
11+
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule
12+
import com.fasterxml.jackson.module.afterburner.AfterburnerModule
13+
import com.fasterxml.jackson.module.kotlin.kotlinModule
14+
import io.airbyte.protocol.models.v0.AirbyteMessage
15+
import io.airbyte.protocol.models.v0.AirbyteRecordMessage
16+
import io.github.oshai.kotlinlogging.KotlinLogging
17+
import io.micronaut.context.annotation.Value
18+
import jakarta.inject.Singleton
19+
import java.io.BufferedOutputStream
20+
import java.io.File
21+
import java.io.OutputStream
22+
import java.io.PrintStream
23+
import java.net.StandardProtocolFamily
24+
import java.net.UnixDomainSocketAddress
25+
import java.nio.channels.Channels
26+
import java.nio.channels.ServerSocketChannel
27+
import java.nio.channels.SocketChannel
28+
import java.time.Clock
29+
30+
private const val SOCKET_NAME_TEMPLATE = "ab_socket_%d"
31+
private const val SOCKET_FULL_PATH = "/var/run/sockets/$SOCKET_NAME_TEMPLATE"
32+
//private const val SOCKET_FULL_PATH = "/tmp/$SOCKET_NAME_TEMPLATE"
33+
private val logger = KotlinLogging.logger {}
34+
35+
interface SocketConfig {
36+
val numSockets: Int get() = 1
37+
val bufferByteSize: Int get() = 8 * 1024
38+
val outputFormat: String get() = "jsonl"
39+
val devNullAfterSerialization: Boolean get() = false
40+
}
41+
42+
@Singleton
43+
class UnixDomainSocketOutputConsumerProvider(
44+
clock: Clock,
45+
stdout: PrintStream,
46+
@Value("\${$CONNECTOR_OUTPUT_PREFIX.buffer-byte-size-threshold-for-flush}")
47+
bufferByteSizeThresholdForFlush: Int,
48+
configuration: SocketConfig,
49+
) : StdoutOutputConsumer(stdout, clock, bufferByteSizeThresholdForFlush) {
50+
val numSockets = configuration.numSockets
51+
val bufferByteSize = configuration.bufferByteSize
52+
val outputFormat: String = configuration.outputFormat
53+
54+
private val socketConsumers = (0 until numSockets).map {
55+
UnixDomainSocketOutputConsumer(
56+
it,
57+
bufferByteSize,
58+
outputFormat,
59+
clock,
60+
stdout,
61+
bufferByteSizeThresholdForFlush,
62+
configuration.devNullAfterSerialization
63+
)
64+
}
65+
66+
override fun close() {
67+
super.close()
68+
socketConsumers.forEach { it.close() }
69+
}
70+
71+
override fun getSocketConsumer(part: Int): UnixDomainSocketOutputConsumer {
72+
return socketConsumers[part]
73+
}
74+
}
75+
76+
class DevNullOutputStream: OutputStream() {
77+
override fun write(b: Int) {
78+
// No-op
79+
}
80+
81+
override fun write(b: ByteArray, off: Int, len: Int) {
82+
// No-op
83+
}
84+
}
85+
86+
class UnixDomainSocketOutputConsumer(
87+
socketNum: Int,
88+
bufferSize: Int = DEFAULT_BUFFER_SIZE,
89+
val outputFormat: String,
90+
clock: Clock,
91+
stdout: PrintStream,
92+
bufferByteSizeThresholdForFlush: Int,
93+
val devNullAfterSerialization: Boolean = false,
94+
): StdoutOutputConsumer(stdout, clock, bufferByteSizeThresholdForFlush) {
95+
private val socketChannel: SocketChannel
96+
private val bufferedOutputStream: BufferedOutputStream
97+
private val writer: SequenceWriter
98+
private var numRecords: Int = 0
99+
100+
private fun configure(objectMapper: ObjectMapper): ObjectMapper {
101+
return objectMapper
102+
.registerModule(JavaTimeModule())
103+
.registerModule(AfterburnerModule())
104+
.registerModule(kotlinModule())
105+
.configure(JsonGenerator.Feature.WRITE_BIGDECIMAL_AS_PLAIN, true)
106+
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
107+
.setSerializationInclusion(JsonInclude.Include.NON_NULL)
108+
}
109+
110+
private fun initSmileMapper(): ObjectMapper {
111+
return configure(SmileMapper())
112+
}
113+
114+
private fun initJsonMapper(): ObjectMapper {
115+
return configure(ObjectMapper())
116+
}
117+
118+
private val SMILE_MAPPER: ObjectMapper = initSmileMapper()
119+
private val OBJECT_MAPPER: ObjectMapper = initJsonMapper()
120+
121+
init {
122+
val socketPath = String.format(SOCKET_FULL_PATH, socketNum)
123+
logger.info { "Using socket..." }
124+
val socketFile = File(socketPath)
125+
logger.info { "Socket File path $socketPath" }
126+
if (socketFile.exists()) {
127+
socketFile.delete()
128+
}
129+
val address = UnixDomainSocketAddress.of(socketFile.toPath())
130+
val serverSocketChannel: ServerSocketChannel =
131+
ServerSocketChannel.open(StandardProtocolFamily.UNIX)
132+
serverSocketChannel.bind(address)
133+
socketChannel = serverSocketChannel.accept()
134+
bufferedOutputStream = if (devNullAfterSerialization) {
135+
DevNullOutputStream().buffered(bufferSize)
136+
} else {
137+
Channels.newOutputStream(socketChannel).buffered(bufferSize)
138+
}
139+
140+
writer = if (outputFormat == "json") {
141+
OBJECT_MAPPER.writerFor(AirbyteMessage::class.java).with(
142+
MinimalPrettyPrinter(System.lineSeparator())
143+
)
144+
} else {
145+
SMILE_MAPPER.writerFor(AirbyteMessage::class.java)
146+
}.writeValues(bufferedOutputStream)
147+
}
148+
149+
override fun accept(airbyteMessage: AirbyteMessage) {
150+
writer.write(airbyteMessage)
151+
if (++ numRecords == 100_000) {
152+
bufferedOutputStream.flush()
153+
numRecords = 0
154+
}
155+
}
156+
157+
fun accept(recordData: ObjectNode, namespace: String, streamName: String) {
158+
if (outputFormat == "devnull") {
159+
return
160+
}
161+
val airbyteMessage = AirbyteMessage()
162+
.withType(AirbyteMessage.Type.RECORD)
163+
.withRecord(
164+
AirbyteRecordMessage()
165+
.withNamespace(namespace)
166+
.withStream(streamName)
167+
.withData(recordData)
168+
.withEmittedAt(clock.millis())
169+
)
170+
writer.write(airbyteMessage)
171+
if (++ numRecords == 100_000) {
172+
bufferedOutputStream.flush()
173+
numRecords = 0
174+
}
175+
}
176+
}

0 commit comments

Comments
 (0)