diff --git a/client-server-client/src/main/scala/zio/raft/client/PendingQueries.scala b/client-server-client/src/main/scala/zio/raft/client/PendingQueries.scala index 6e8e904..9f8ceae 100644 --- a/client-server-client/src/main/scala/zio/raft/client/PendingQueries.scala +++ b/client-server-client/src/main/scala/zio/raft/client/PendingQueries.scala @@ -50,6 +50,10 @@ case class PendingQueries( ZIO.succeed(pending) } } + + /** Die all pending queries with the given error. */ + def dieAll(error: Throwable): UIO[Unit] = + ZIO.foreachDiscard(queries.values)(data => data.promise.die(error).ignore) } object PendingQueries { diff --git a/client-server-client/src/main/scala/zio/raft/client/PendingRequests.scala b/client-server-client/src/main/scala/zio/raft/client/PendingRequests.scala index a192104..a5346e8 100644 --- a/client-server-client/src/main/scala/zio/raft/client/PendingRequests.scala +++ b/client-server-client/src/main/scala/zio/raft/client/PendingRequests.scala @@ -71,6 +71,10 @@ case class PendingRequests( case Some(data) => data.promise.die(error).ignore.as(copy(requests = requests.removed(requestId))) case None => ZIO.succeed(this) } + + /** Die all pending requests with the given error. */ + def dieAll(error: Throwable): UIO[Unit] = + ZIO.foreachDiscard(requests.values)(data => data.promise.die(error).ignore) } object PendingRequests { diff --git a/client-server-client/src/main/scala/zio/raft/client/RaftClient.scala b/client-server-client/src/main/scala/zio/raft/client/RaftClient.scala index 44d75cf..92197fe 100644 --- a/client-server-client/src/main/scala/zio/raft/client/RaftClient.scala +++ b/client-server-client/src/main/scala/zio/raft/client/RaftClient.scala @@ -250,10 +250,14 @@ object RaftClient { connectToMember(leaderId, transport, s"Not leader at $currentMemberId") case RejectionReason.SessionExpired => - ZIO.dieMessage("Session not found - cannot continue") + pendingRequests.dieAll(new RuntimeException("Session expired")) *> + pendingQueries.dieAll(new RuntimeException("Session expired")) *> + ZIO.dieMessage("Session not found - cannot continue") case RejectionReason.InvalidCapabilities => - ZIO.dieMessage(s"Invalid capabilities - cannot connect: ${capabilities}") + pendingRequests.dieAll(new RuntimeException("Invalid capabilities")) *> + pendingQueries.dieAll(new RuntimeException("Invalid capabilities")) *> + ZIO.dieMessage(s"Invalid capabilities - cannot connect: ${capabilities}") } } else { ZIO.logWarning("Nonce mismatch, ignoring SessionRejected").as(this) @@ -399,12 +403,15 @@ object RaftClient { connectToMember(leaderId, transport, s"Not leader at $currentMemberId") case RejectionReason.SessionExpired => - ZIO.logWarning("Session not found - cannot continue") *> ZIO.dieMessage( - "Session not found - cannot continue" - ) + pendingRequests.dieAll(new RuntimeException("Session expired")) *> + pendingQueries.dieAll(new RuntimeException("Session expired")) *> + ZIO.logWarning("Session not found - cannot continue") *> + ZIO.dieMessage("Session not found - cannot continue") case RejectionReason.InvalidCapabilities => - ZIO.dieMessage(s"Invalid capabilities - cannot connect: ${capabilities}") + pendingRequests.dieAll(new RuntimeException("Invalid capabilities")) *> + pendingQueries.dieAll(new RuntimeException("Invalid capabilities")) *> + ZIO.dieMessage(s"Invalid capabilities - cannot connect: ${capabilities}") } } else { ZIO.logWarning("Nonce mismatch, ignoring SessionRejected").as(this) @@ -561,7 +568,8 @@ object RaftClient { case StreamEvent.ServerMsg(RequestError(requestId, RequestErrorReason.ResponseEvicted)) => if (pendingRequests.contains(requestId)) ZIO.logError(s"RequestError: ResponseEvicted for request $requestId, terminating client") *> - pendingRequests.die(requestId, new RuntimeException("ResponseEvicted")) *> + pendingRequests.dieAll(new RuntimeException("ResponseEvicted")) *> + pendingQueries.dieAll(new RuntimeException("ResponseEvicted")) *> ZIO.dieMessage("ResponseEvicted") else ZIO.logWarning(s"RequestError for non-pending request $requestId, ignoring").as(this) @@ -621,9 +629,10 @@ object RaftClient { ) case StreamEvent.ServerMsg(SessionClosed(SessionCloseReason.SessionExpired, _)) => - ZIO.logWarning("Session closed due to timeout, terminating client") *> ZIO.dieMessage( - "Session timed out: shutting down client." - ) + pendingRequests.dieAll(new RuntimeException("Session expired")) *> + pendingQueries.dieAll(new RuntimeException("Session expired")) *> + ZIO.logWarning("Session closed due to timeout, terminating client") *> + ZIO.dieMessage("Session timed out: shutting down client.") case StreamEvent.KeepAliveTick => for { diff --git a/client-server-client/src/test/scala/zio/raft/client/PendingQueriesSpec.scala b/client-server-client/src/test/scala/zio/raft/client/PendingQueriesSpec.scala index a7c5832..d96b902 100644 --- a/client-server-client/src/test/scala/zio/raft/client/PendingQueriesSpec.scala +++ b/client-server-client/src/test/scala/zio/raft/client/PendingQueriesSpec.scala @@ -72,6 +72,33 @@ object PendingQueriesSpec extends ZIOSpecDefault { assertTrue(d1.lastSentAt == currentTime) && assertTrue(d2.lastSentAt == currentTime.minusSeconds(5)) } + + test("dieAll completes all pending promises with death") { + for { + p1 <- Promise.make[Nothing, ByteVector] + p2 <- Promise.make[Nothing, ByteVector] + p3 <- Promise.make[Nothing, ByteVector] + now <- Clock.instant + pq = PendingQueries.empty + .add(CorrelationId.fromString("c1"), ByteVector(1), p1, now) + .add(CorrelationId.fromString("c2"), ByteVector(2), p2, now) + .add(CorrelationId.fromString("c3"), ByteVector(3), p3, now) + _ <- pq.dieAll(new RuntimeException("all dead")) + fiber1 <- p1.await.fork + fiber2 <- p2.await.fork + fiber3 <- p3.await.fork + exit1 <- fiber1.await + exit2 <- fiber2.await + exit3 <- fiber3.await + } yield assertTrue(exit1.isFailure) && assertTrue(exit2.isFailure) && assertTrue(exit3.isFailure) + } + + test("dieAll on empty pending queries succeeds") { + for { + pq <- ZIO.succeed(PendingQueries.empty) + _ <- pq.dieAll(new RuntimeException("boom")) + } yield assertTrue(true) + } } } diff --git a/client-server-client/src/test/scala/zio/raft/client/PendingRequestsSpec.scala b/client-server-client/src/test/scala/zio/raft/client/PendingRequestsSpec.scala index 3698ea6..5b82524 100644 --- a/client-server-client/src/test/scala/zio/raft/client/PendingRequestsSpec.scala +++ b/client-server-client/src/test/scala/zio/raft/client/PendingRequestsSpec.scala @@ -88,5 +88,35 @@ object PendingRequestsSpec extends ZIOSpecDefault { pending1 <- pending.die(rid, new RuntimeException("boom")) } yield assertTrue(pending1 == pending) } + + test("dieAll completes all pending promises with death") { + val rid1 = RequestId.fromLong(1L) + val rid2 = RequestId.fromLong(2L) + val rid3 = RequestId.fromLong(3L) + for { + p1 <- Promise.make[Nothing, ByteVector] + p2 <- Promise.make[Nothing, ByteVector] + p3 <- Promise.make[Nothing, ByteVector] + now <- Clock.instant + pending = PendingRequests.empty + .add(rid1, ByteVector.fromValidHex("aa"), p1, now) + .add(rid2, ByteVector.fromValidHex("bb"), p2, now) + .add(rid3, ByteVector.fromValidHex("cc"), p3, now) + _ <- pending.dieAll(new RuntimeException("all dead")) + fiber1 <- p1.await.fork + fiber2 <- p2.await.fork + fiber3 <- p3.await.fork + exit1 <- fiber1.await + exit2 <- fiber2.await + exit3 <- fiber3.await + } yield assertTrue(exit1.isFailure) && assertTrue(exit2.isFailure) && assertTrue(exit3.isFailure) + } + + test("dieAll on empty pending requests succeeds") { + for { + pending <- ZIO.succeed(PendingRequests.empty) + _ <- pending.dieAll(new RuntimeException("boom")) + } yield assertTrue(true) + } } }