diff --git a/raft/src/test/scala/zio/raft/RaftIntegrationSpec.scala b/raft/src/test/scala/zio/raft/RaftIntegrationSpec.scala index 057cfbf6..a60a838c 100644 --- a/raft/src/test/scala/zio/raft/RaftIntegrationSpec.scala +++ b/raft/src/test/scala/zio/raft/RaftIntegrationSpec.scala @@ -5,9 +5,27 @@ import zio.test.TestAspect.withLiveClock import zio.{ZIO, durationInt} import zio.raft.LogEntry.NoopLogEntry import zio.raft.LogEntry.CommandLogEntry +import zio.LogLevel +import zio.ZLogger +import zio.Cause +import zio.FiberId +import zio.FiberRefs +import zio.LogSpan +import zio.Trace +import java.util.concurrent.ConcurrentLinkedQueue +import scala.jdk.CollectionConverters._ object RaftIntegrationSpec extends ZIOSpecDefault: + // We use TestLogger instead of ZTestLogger because ZTestLogger can cause duplicated log lines which causes flakiness in our tests. + class TestLogger extends ZLogger[String, Unit] { + val messages: ConcurrentLinkedQueue[String] = new ConcurrentLinkedQueue() + override def apply(trace: Trace, fiberId: FiberId, logLevel: LogLevel, message: () => String, cause: Cause[Any], context: FiberRefs, spans: List[LogSpan], annotations: Map[String, String]): Unit = + messages.add(message()) + + def getMessages: List[String] = messages.asScala.toList + } + private def findTheNewLeader( currentLeader: Raft[Int, TestCommands], raft1: Raft[Int, TestCommands], @@ -182,6 +200,7 @@ object RaftIntegrationSpec extends ZIOSpecDefault: }, test("read returns the correct state with multiple writes") { for + testLogger <- ZIO.succeed(new TestLogger()) ( r1, killSwitch1, @@ -189,23 +208,23 @@ object RaftIntegrationSpec extends ZIOSpecDefault: killSwitch2, r3, killSwitch3 - ) <- makeRaft().provideSomeLayer(zio.Runtime.removeDefaultLoggers >>> zio.test.ZTestLogger.default) + ) <- makeRaft().provideSomeLayer(zio.Runtime.removeDefaultLoggers >>> zio.Runtime.addLogger(testLogger)) // Making sure we call readState while there are queued write commands is difficult, // we use this approach to make sure there are some unhandled commands before we call readState, hopefully it won't be too flaky _ <- r1.sendCommand(Increase).fork.repeatN(99) - + readResult1 <- r1.readState - output <- ZTestLogger.logOutput - _ = output.foreach(s => println(s.message())) - pendingHeartbeatLogCount = output.count(_.message().contains("memberId=MemberId(peer1) read pending heartbeat")) - pendingCommandLogCount = output.count(_.message().contains("memberId=MemberId(peer1) read pending command")) + messages = testLogger.getMessages + pendingHeartbeatLogCount = messages.count(_.contains("memberId=MemberId(peer1) read pending heartbeat")) + pendingCommandLogCount = messages.count(_.contains("memberId=MemberId(peer1) read pending command")) yield assertTrue(readResult1 > 0) && assertTrue(pendingHeartbeatLogCount == 0) && assertTrue(pendingCommandLogCount == 1) - }, + } @@ TestAspect.flaky, // TODO (eran): because of the way this test is structured it is currently flaky, we'll need to find another way to send commands so the readState will have pending commands test("read returns the correct state when there are no pending writes.") { for + testLogger <- ZIO.succeed(new TestLogger()) ( r1, killSwitch1, @@ -213,7 +232,7 @@ object RaftIntegrationSpec extends ZIOSpecDefault: killSwitch2, r3, killSwitch3 - ) <- makeRaft().provideSomeLayer(zio.Runtime.removeDefaultLoggers >>> zio.test.ZTestLogger.default) + ) <- makeRaft().provideSomeLayer(zio.Runtime.removeDefaultLoggers >>> zio.Runtime.addLogger(testLogger)) _ <- r1.sendCommand(Increase) @@ -221,9 +240,9 @@ object RaftIntegrationSpec extends ZIOSpecDefault: readResult <- r1.readState // verify read waits for heartbeat and not a write/noop command - output <- ZTestLogger.logOutput - pendingHeartbeatLogCount = output.count(_.message().contains("memberId=MemberId(peer1) read pending heartbeat")) - pendingCommandLogCount = output.count(_.message().contains("memberId=MemberId(peer1) read pending command")) + messages = testLogger.getMessages + pendingHeartbeatLogCount = messages.count(_.contains("memberId=MemberId(peer1) read pending heartbeat")) + pendingCommandLogCount = messages.count(_.contains("memberId=MemberId(peer1) read pending command")) yield assertTrue(readResult == 1) && assertTrue(pendingHeartbeatLogCount == 1) && assertTrue(pendingCommandLogCount == 0) },