Skip to content

Commit 819fb6a

Browse files
committed
address PR comments
1 parent a85be75 commit 819fb6a

File tree

14 files changed

+423
-446
lines changed

14 files changed

+423
-446
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package zio.raft.client
2+
3+
import zio.*
4+
import zio.stream.*
5+
import zio.raft.protocol.*
6+
import scodec.bits.BitVector
7+
import zio.zmq.{ZContext, ZSocket}
8+
import zio.raft.protocol.Codecs.{clientMessageCodec, serverMessageCodec}
9+
10+
/** Transport abstraction for client-server communication.
11+
*/
12+
trait ClientTransport {
13+
def connect(address: String): ZIO[Any, Throwable, Unit]
14+
def disconnect(): ZIO[Any, Throwable, Unit]
15+
def sendMessage(message: ClientMessage): ZIO[Any, Throwable, Unit]
16+
def incomingMessages: ZStream[Any, Throwable, ServerMessage]
17+
}
18+
19+
object ClientTransport {
20+
21+
/** Create a ZeroMQ-based client transport.
22+
*/
23+
def make(config: ClientConfig): ZIO[ZContext & Scope, Throwable, ClientTransport] =
24+
for {
25+
socket <- ZSocket.client
26+
_ <- socket.options.setLinger(0)
27+
_ <- socket.options.setHeartbeat(1.seconds, 10.second, 30.second)
28+
timeoutConnectionClosed = serverMessageCodec
29+
.encode(SessionClosed(SessionCloseReason.ConnectionClosed, None))
30+
.require
31+
.toByteArray
32+
_ <- socket.options.setHiccupMessage(timeoutConnectionClosed)
33+
_ <- socket.options.setHighWatermark(200000, 200000)
34+
lastAddressRef <- Ref.make(Option.empty[String])
35+
transport = new Zmq(socket, lastAddressRef)
36+
} yield transport
37+
38+
/** ZeroMQ-based implementation of ClientTransport.
39+
*/
40+
private[client] class Zmq(socket: ZSocket, lastAddressRef: Ref[Option[String]])
41+
extends ClientTransport {
42+
override def connect(address: String): ZIO[Any, Throwable, Unit] =
43+
for {
44+
_ <- lastAddressRef.set(Some(address))
45+
_ <- socket.connect(address)
46+
} yield ()
47+
48+
override def disconnect(): ZIO[Any, Throwable, Unit] =
49+
for {
50+
lastAddress <- lastAddressRef.get
51+
_ <- lastAddress match {
52+
case Some(address) => socket.disconnect(address)
53+
case None => ZIO.unit
54+
}
55+
_ <- lastAddressRef.set(None)
56+
} yield ()
57+
58+
override def sendMessage(message: ClientMessage): ZIO[Any, Throwable, Unit] =
59+
for {
60+
bytes <- ZIO.attempt(clientMessageCodec.encode(message).require.toByteArray)
61+
_ <- socket.sendImmediately(bytes)
62+
} yield ()
63+
64+
override def incomingMessages: ZStream[Any, Throwable, ServerMessage] =
65+
socket.stream.mapZIO { msg =>
66+
ZIO.attempt(serverMessageCodec.decode(BitVector(msg.data())).require.value)
67+
}
68+
}
69+
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package zio.raft.client
2+
3+
import zio.*
4+
import zio.raft.protocol.*
5+
import scodec.bits.ByteVector
6+
import java.time.Instant
7+
8+
/** Manages pending requests with lastSentAt timestamps for retry.
9+
*/
10+
case class PendingRequests(
11+
requests: Map[RequestId, PendingRequests.PendingRequestData]
12+
) {
13+
def add(
14+
requestId: RequestId,
15+
payload: ByteVector,
16+
promise: Promise[Throwable, ByteVector],
17+
sentAt: Instant
18+
): PendingRequests =
19+
copy(requests = requests.updated(requestId, PendingRequests.PendingRequestData(payload, promise, sentAt, sentAt)))
20+
21+
def complete(requestId: RequestId, result: ByteVector): ZIO[Any, Nothing, PendingRequests] =
22+
requests.get(requestId) match {
23+
case Some(data) =>
24+
data.promise.succeed(result).as(copy(requests = requests.removed(requestId)))
25+
case None =>
26+
ZIO.succeed(this)
27+
}
28+
29+
/** Resend all pending requests (used after successful connection). Returns updated PendingRequests with new
30+
* lastSentAt timestamps.
31+
*/
32+
def resendAll(transport: ClientTransport): UIO[PendingRequests] =
33+
ZIO.foldLeft(requests.toList)(this) { case (pending, (requestId, data)) =>
34+
for {
35+
now <- Clock.instant
36+
request = ClientRequest(requestId, data.payload, now)
37+
_ <- transport.sendMessage(request).orDie
38+
_ <- ZIO.logDebug(s"Resending pending request: $requestId")
39+
updatedData = data.copy(lastSentAt = now)
40+
} yield PendingRequests(pending.requests.updated(requestId, updatedData))
41+
}
42+
43+
/** Resend expired requests and update lastSentAt.
44+
*/
45+
def resendExpired(transport: ClientTransport, currentTime: Instant, timeout: Duration): UIO[PendingRequests] = {
46+
val timeoutSeconds = timeout.toSeconds
47+
ZIO.foldLeft(requests.toList)(this) { case (pending, (requestId, data)) =>
48+
val elapsed = Duration.fromInterval(data.lastSentAt, currentTime)
49+
if (elapsed > timeout) {
50+
val request = ClientRequest(requestId, data.payload, currentTime)
51+
for {
52+
_ <- transport.sendMessage(request).orDie
53+
_ <- ZIO.logDebug(s"Resending timed out request: $requestId")
54+
updatedData = data.copy(lastSentAt = currentTime)
55+
} yield PendingRequests(pending.requests.updated(requestId, updatedData))
56+
} else {
57+
ZIO.succeed(pending)
58+
}
59+
}
60+
}
61+
}
62+
63+
object PendingRequests {
64+
def empty: PendingRequests = PendingRequests(Map.empty)
65+
66+
case class PendingRequestData(
67+
payload: ByteVector,
68+
promise: Promise[Throwable, ByteVector],
69+
createdAt: Instant,
70+
lastSentAt: Instant
71+
)
72+
}

0 commit comments

Comments
 (0)