Skip to content

Commit e17d1cd

Browse files
fix: Multiple potential issues with the tcp DNS client (#32636)
* fix: improve TCP DNS client interactions with TCP actor * ensure that TCP connection is closed when TcpDnsClient fails * Fix for potential re-order of response bytes during high load * Deathwatch the connection to make sure we are never stuck in a broken state --------- Co-authored-by: Levi Ramsey <levi.ramsey@alum.cs.umass.edu>
1 parent d21bf75 commit e17d1cd

File tree

4 files changed

+302
-49
lines changed

4 files changed

+302
-49
lines changed

akka-actor-tests/src/test/scala/akka/io/dns/internal/DnsClientSpec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@ package akka.io.dns.internal
66

77
import java.net.InetSocketAddress
88
import java.util.concurrent.atomic.AtomicBoolean
9-
109
import scala.collection.immutable.Seq
1110
import scala.concurrent.duration._
12-
1311
import akka.actor.Props
1412
import akka.io.Udp
1513
import akka.io.dns.{ RecordClass, RecordType }
1614
import akka.io.dns.internal.DnsClient.{ Answer, Question4 }
15+
import akka.testkit.WithLogCapturing
1716
import akka.testkit.{ AkkaSpec, ImplicitSender, TestProbe }
1817

19-
class DnsClientSpec extends AkkaSpec with ImplicitSender {
18+
class DnsClientSpec extends AkkaSpec("""akka.loglevel = DEBUG
19+
akka.loggers = ["akka.testkit.SilenceAllTestEventListener"]""") with ImplicitSender with WithLogCapturing {
2020
"The async DNS client" should {
2121
val exampleRequest = Question4(42, "akka.io")
2222
val exampleRequestMessage =

akka-actor-tests/src/test/scala/akka/io/dns/internal/TcpDnsClientSpec.scala

Lines changed: 153 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,19 @@
55
package akka.io.dns.internal
66

77
import java.net.InetSocketAddress
8-
9-
import scala.collection.immutable.Seq
10-
8+
import akka.actor.ActorRef
9+
import akka.actor.PoisonPill
1110
import akka.actor.Props
1211
import akka.io.Tcp
1312
import akka.io.Tcp.{ Connected, PeerClosed, Register }
1413
import akka.io.dns.{ RecordClass, RecordType }
1514
import akka.io.dns.internal.DnsClient.Answer
15+
import akka.testkit.EventFilter
16+
import akka.testkit.WithLogCapturing
1617
import akka.testkit.{ AkkaSpec, ImplicitSender, TestProbe }
1718

18-
class TcpDnsClientSpec extends AkkaSpec with ImplicitSender {
19+
class TcpDnsClientSpec extends AkkaSpec("""akka.loglevel = DEBUG
20+
akka.loggers = ["akka.testkit.SilenceAllTestEventListener"]""") with ImplicitSender with WithLogCapturing {
1921
import TcpDnsClient._
2022

2123
"The async TCP DNS client" should {
@@ -51,6 +53,27 @@ class TcpDnsClientSpec extends AkkaSpec with ImplicitSender {
5153
tcpExtensionProbe.expectMsg(Tcp.Connect(dnsServerAddress))
5254
}
5355

56+
"terminated if the connection terminates unexpectedly" in {
57+
val tcpExtensionProbe = TestProbe()
58+
val answerProbe = TestProbe()
59+
60+
val client = system.actorOf(Props(new TcpDnsClient(tcpExtensionProbe.ref, dnsServerAddress, answerProbe.ref)))
61+
62+
client ! exampleRequestMessage
63+
64+
tcpExtensionProbe.expectMsg(Tcp.Connect(dnsServerAddress))
65+
tcpExtensionProbe.lastSender ! Connected(dnsServerAddress, localAddress)
66+
expectMsgType[Register]
67+
val registered = tcpExtensionProbe.lastSender
68+
69+
expectMsgType[Tcp.Write]
70+
71+
answerProbe.watch(client)
72+
73+
registered ! PoisonPill
74+
answerProbe.expectTerminated(client)
75+
}
76+
5477
"accept a fragmented TCP response" in {
5578
val tcpExtensionProbe = TestProbe()
5679
val answerProbe = TestProbe()
@@ -72,22 +95,80 @@ class TcpDnsClientSpec extends AkkaSpec with ImplicitSender {
7295
answerProbe.expectMsg(Answer(42, Nil))
7396
}
7497

75-
"accept merged TCP responses" in {
98+
"accept multiple fragmented TCP responses" in {
7699
val tcpExtensionProbe = TestProbe()
77100
val answerProbe = TestProbe()
78101

79102
val client = system.actorOf(Props(new TcpDnsClient(tcpExtensionProbe.ref, dnsServerAddress, answerProbe.ref)))
80103

81104
client ! exampleRequestMessage
82-
client ! exampleRequestMessage.copy(id = 43)
83105

84106
tcpExtensionProbe.expectMsg(Tcp.Connect(dnsServerAddress))
85107
tcpExtensionProbe.lastSender ! Connected(dnsServerAddress, localAddress)
86108
expectMsgType[Register]
87109
val registered = tcpExtensionProbe.lastSender
88110

111+
// pretend write+ack+write happened, so three requests written, now both coming back.
112+
// (we need to make sure buffer is not reordered by sandwitched responses)
89113
expectMsgType[Tcp.Write]
114+
val fullResponse1 = encodeLength(exampleResponseMessage.write().length) ++ exampleResponseMessage.write()
115+
val exampleResponseMessage2 = exampleResponseMessage.copy(id = 43)
116+
val fullResponse2 = encodeLength(exampleResponseMessage2.write().length) ++ exampleResponseMessage2.write()
117+
val exampleResponseMessage3 = exampleResponseMessage.copy(id = 44)
118+
val fullResponse3 = encodeLength(exampleResponseMessage3.write().length) ++ exampleResponseMessage3.write()
119+
registered ! Tcp.Received(fullResponse1.take(8))
120+
Thread.sleep(30) // give things some time to go wrong
121+
registered ! Tcp.Received(fullResponse1.drop(8) ++ fullResponse2.take(8))
122+
Thread.sleep(30)
123+
registered ! Tcp.Received(fullResponse2.drop(8) ++ fullResponse3.take(8))
124+
Thread.sleep(30)
125+
registered ! Tcp.Received(fullResponse3.drop(8))
126+
127+
answerProbe.expectMsg(Answer(42, Nil))
128+
answerProbe.expectMsg(Answer(43, Nil))
129+
answerProbe.expectMsg(Answer(44, Nil))
130+
}
131+
132+
"respect backpressure from the TCP actor" in {
133+
val tcpExtensionProbe = TestProbe()
134+
135+
val client = system.actorOf(Props(new TcpDnsClient(tcpExtensionProbe.ref, dnsServerAddress, ActorRef.noSender)))
136+
137+
client ! exampleRequestMessage
138+
client ! exampleRequestMessage.copy(id = 43)
139+
140+
tcpExtensionProbe.expectMsg(Tcp.Connect(dnsServerAddress))
141+
tcpExtensionProbe.lastSender ! Connected(dnsServerAddress, localAddress)
142+
expectMsgType[Register]
143+
val registered = tcpExtensionProbe.lastSender
144+
145+
val ack = expectMsgType[Tcp.Write].ack
146+
expectNoMessage()
147+
registered ! ack
148+
90149
expectMsgType[Tcp.Write]
150+
}
151+
152+
"accept merged TCP responses" in {
153+
val tcpExtensionProbe = TestProbe()
154+
val answerProbe = TestProbe()
155+
156+
val client = system.actorOf(Props(new TcpDnsClient(tcpExtensionProbe.ref, dnsServerAddress, answerProbe.ref)))
157+
158+
client ! exampleRequestMessage
159+
client ! exampleRequestMessage.copy(id = 43)
160+
161+
tcpExtensionProbe.expectMsg(Tcp.Connect(dnsServerAddress))
162+
tcpExtensionProbe.lastSender ! Connected(dnsServerAddress, localAddress)
163+
expectMsgType[Register]
164+
val registered = tcpExtensionProbe.lastSender
165+
166+
var ack = expectMsgType[Tcp.Write].ack
167+
168+
registered ! ack
169+
170+
ack = expectMsgType[Tcp.Write].ack
171+
91172
val fullResponse =
92173
encodeLength(exampleResponseMessage.write().length) ++ exampleResponseMessage.write() ++
93174
encodeLength(exampleResponseMessage.write().length) ++ exampleResponseMessage.copy(id = 43).write()
@@ -97,5 +178,71 @@ class TcpDnsClientSpec extends AkkaSpec with ImplicitSender {
97178
answerProbe.expectMsg(Answer(42, Nil))
98179
answerProbe.expectMsg(Answer(43, Nil))
99180
}
181+
182+
"report its failure to the outer client" in {
183+
val tcpExtensionProbe = TestProbe()
184+
val answerProbe = TestProbe()
185+
186+
val failToConnectClient =
187+
system.actorOf(Props(new TcpDnsClient(tcpExtensionProbe.ref, dnsServerAddress, answerProbe.ref)))
188+
189+
failToConnectClient ! exampleRequestMessage
190+
191+
val connect = tcpExtensionProbe.expectMsg(Tcp.Connect(dnsServerAddress))
192+
193+
failToConnectClient ! Tcp.CommandFailed(connect)
194+
195+
answerProbe.expectMsg(DnsClient.TcpDropped)
196+
197+
val closesWithErrorClient =
198+
system.actorOf(Props(new TcpDnsClient(tcpExtensionProbe.ref, dnsServerAddress, answerProbe.ref)))
199+
200+
closesWithErrorClient ! exampleRequestMessage
201+
202+
tcpExtensionProbe.expectMsg(connect)
203+
tcpExtensionProbe.lastSender ! Connected(dnsServerAddress, localAddress)
204+
expectMsgType[Register]
205+
val registered = tcpExtensionProbe.lastSender
206+
207+
registered ! Tcp.ErrorClosed("BOOM!")
208+
209+
answerProbe.expectMsg(DnsClient.TcpDropped)
210+
}
211+
212+
"should drop older requests when TCP connection backpressures" in {
213+
val tcpExtensionProbe = TestProbe()
214+
val connectionProbe = TestProbe()
215+
216+
val client = system.actorOf(Props(new TcpDnsClient(tcpExtensionProbe.ref, dnsServerAddress, ActorRef.noSender)))
217+
218+
// initial request
219+
val initialRequest = exampleRequestMessage.copy(id = 1)
220+
client ! initialRequest
221+
222+
tcpExtensionProbe.expectMsg(Tcp.Connect(dnsServerAddress))
223+
tcpExtensionProbe.lastSender.tell(Connected(dnsServerAddress, localAddress), connectionProbe.ref)
224+
connectionProbe.expectMsgType[Register]
225+
val registered = connectionProbe.lastSender
226+
227+
var write = connectionProbe.expectMsgType[Tcp.Write]
228+
write.data.drop(2) shouldBe initialRequest.write()
229+
230+
// 1 in flight, 14 more, buffer fits 10, should drop 5 oldest (id 2 - 6)
231+
EventFilter.warning(occurrences = 5, pattern = "Dropping oldest buffered DNS request").intercept {
232+
(2 to 16).foreach { i =>
233+
client ! exampleRequestMessage.copy(id = i.toShort)
234+
}
235+
}
236+
237+
// initial write is acked
238+
registered ! write.ack
239+
240+
// the rest of the buffered should be handled
241+
(7 to 16).foreach { i =>
242+
write = connectionProbe.expectMsgType[Tcp.Write]
243+
write.data.drop(2) shouldBe exampleRequestMessage.copy(id = i.toShort).write()
244+
registered ! write.ack // next ack
245+
}
246+
}
100247
}
101248
}

akka-actor/src/main/scala/akka/io/dns/internal/DnsClient.scala

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ package akka.io.dns.internal
66

77
import java.net.{ InetAddress, InetSocketAddress }
88

9-
import scala.annotation.nowarn
109
import scala.collection.{ immutable => im }
1110
import scala.concurrent.duration._
1211
import scala.util.Try
@@ -35,6 +34,8 @@ import akka.pattern.{ BackoffOpts, BackoffSupervisor }
3534

3635
final case class DropRequest(question: DnsQuestion)
3736

37+
case object TcpDropped
38+
3839
// sent as an indication that as of the time of sending, `id` is not being used
3940
// by an active question. Useful for a questioner which tracks which ids it can use
4041
final case class Dropped(id: Short) extends NoSerializationVerificationNeeded
@@ -47,6 +48,8 @@ import akka.pattern.{ BackoffOpts, BackoffSupervisor }
4748
if (question.name.last == '.') Seq(unchangedPair, question.name.dropRight(1) -> question.qType)
4849
else Seq(unchangedPair, (question.name + '.') -> question.qType)
4950
}
51+
52+
private final case class InFlight(replyTo: ActorRef, message: Message, tcpRequest: Boolean = false)
5053
}
5154

5255
/**
@@ -62,7 +65,7 @@ import akka.pattern.{ BackoffOpts, BackoffSupervisor }
6265
val udp = IO(Udp)
6366
val tcp = IO(Tcp)
6467

65-
private[internal] var inflightRequests: Map[Short, (ActorRef, Message)] = Map.empty
68+
private var inflightRequests: Map[Short, InFlight] = Map.empty
6669

6770
lazy val tcpDnsClient: ActorRef = createTcpClient()
6871

@@ -90,12 +93,11 @@ import akka.pattern.{ BackoffOpts, BackoffSupervisor }
9093
/**
9194
* Silent to allow map update syntax
9295
*/
93-
@nowarn()
9496
def ready(socket: ActorRef): Receive = {
9597
case DropRequest(question) =>
9698
val id = question.id
9799
inflightRequests.get(id) match {
98-
case Some((_, sentMsg)) =>
100+
case Some(InFlight(_, sentMsg, _)) =>
99101
val sentQs = sentMsg.questions.map { question =>
100102
question.name -> question.qType
101103
}
@@ -129,7 +131,7 @@ import akka.pattern.{ BackoffOpts, BackoffSupervisor }
129131
log.debug("Resolving [{}] (A)", name)
130132

131133
val msg = message(name, id, RecordType.A)
132-
inflightRequests += (id -> (sender() -> msg))
134+
inflightRequests += (id -> InFlight(sender(), msg))
133135
log.debug("Message [{}] to [{}]: [{}]", id, ns, msg)
134136
socket ! Udp.Send(msg.write(), ns)
135137
}
@@ -143,7 +145,7 @@ import akka.pattern.{ BackoffOpts, BackoffSupervisor }
143145
log.debug("Resolving [{}] (AAAA)", name)
144146

145147
val msg = message(name, id, RecordType.AAAA)
146-
inflightRequests += (id -> (sender() -> msg))
148+
inflightRequests += (id -> InFlight(sender(), msg))
147149
log.debug("Message to [{}]: [{}]", ns, msg)
148150
socket ! Udp.Send(msg.write(), ns)
149151
}
@@ -156,7 +158,7 @@ import akka.pattern.{ BackoffOpts, BackoffSupervisor }
156158
} else {
157159
log.debug("Resolving [{}] (SRV)", name)
158160
val msg = message(name, id, RecordType.SRV)
159-
inflightRequests += (id -> (sender() -> msg))
161+
inflightRequests += (id -> InFlight(sender(), msg))
160162
log.debug("Message to [{}]: [{}]", ns, msg)
161163
socket ! Udp.Send(msg.write(), ns)
162164
}
@@ -169,7 +171,7 @@ import akka.pattern.{ BackoffOpts, BackoffSupervisor }
169171
Try {
170172
val msg = Message.parse(send.payload)
171173
inflightRequests.get(msg.id).foreach {
172-
case (s, _) =>
174+
case InFlight(s, _, _) =>
173175
s ! Failure(new RuntimeException("Send failed to nameserver"))
174176
inflightRequests -= msg.id
175177
}
@@ -185,8 +187,10 @@ import akka.pattern.{ BackoffOpts, BackoffSupervisor }
185187
if (msg.flags.isTruncated) {
186188
log.debug("DNS response truncated, falling back to TCP")
187189
inflightRequests.get(msg.id) match {
188-
case Some((_, msg)) =>
189-
tcpDnsClient ! msg
190+
case Some(inFlight) =>
191+
inflightRequests = inflightRequests.updated(msg.id, inFlight.copy(tcpRequest = true))
192+
tcpDnsClient ! inFlight.message
193+
190194
case _ =>
191195
log.debug("Client for id {} not found. Discarding unsuccessful response.", msg.id)
192196
}
@@ -197,7 +201,7 @@ import akka.pattern.{ BackoffOpts, BackoffSupervisor }
197201
}
198202
case UdpAnswer(questions, response) =>
199203
inflightRequests.get(response.id) match {
200-
case Some((reply, sentMsg)) =>
204+
case Some(InFlight(reply, sentMsg, _)) =>
201205
val sentQs = sentMsg.questions.flatMap(withAndWithoutTrailingDots).toSet
202206
val answeredQs = questions.flatMap(withAndWithoutTrailingDots).toSet
203207

@@ -220,19 +224,26 @@ import akka.pattern.{ BackoffOpts, BackoffSupervisor }
220224
// for TCP, we don't have to use the question for correlation
221225
case response: Answer =>
222226
inflightRequests.get(response.id) match {
223-
case Some((reply, sentMsg)) =>
224-
reply ! response
227+
case Some(InFlight(replyTo, _, _)) =>
228+
replyTo ! response
225229
inflightRequests -= response.id
226230

227231
case None =>
228232
log.debug("Client for id [{}] not found. Discarding response.", response.id)
229233
}
230234

235+
case TcpDropped =>
236+
log.warning("TCP client failed, clearing inflight resolves which were being resolved by TCP")
237+
238+
inflightRequests = inflightRequests.filterNot {
239+
case (_, inFlight) => inFlight.tcpRequest
240+
}
241+
231242
case Udp.Unbind => socket ! Udp.Unbind
232243
case Udp.Unbound => context.stop(self)
233244
}
234245

235-
def createTcpClient() = {
246+
def createTcpClient(): ActorRef = {
236247
context.actorOf(
237248
BackoffSupervisor.props(
238249
BackoffOpts.onFailure(

0 commit comments

Comments
 (0)