Skip to content

Commit dee49a4

Browse files
authored
Merge pull request #17556 from karelbilek/kb/doh-fix-doh
dnsdist: TCP multiplexer: update should reset ttd
2 parents 88f8de7 + 2e53b18 commit dee49a4

3 files changed

Lines changed: 100 additions & 31 deletions

File tree

pdns/dnsdistdist/tcpiohandler-mplexer.hh

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -74,37 +74,6 @@ public:
7474
return result;
7575
}
7676

77-
void add(IOState iostate, FDMultiplexer::callbackfunc_t callback, FDMultiplexer::funcparam_t callbackData, std::optional<struct timeval> ttd)
78-
{
79-
DEBUGLOG("in " << __PRETTY_FUNCTION__ << " for fd " << d_fd << ", last state was " << getState() << ", adding " << (int)iostate);
80-
if (iostate == IOState::NeedRead) {
81-
if (isWaitingForRead()) {
82-
if (ttd) {
83-
/* let's update the TTD ! */
84-
d_mplexer.setReadTTD(d_fd, *ttd, /* we pass 0 here because we already have a TTD */ 0);
85-
}
86-
return;
87-
}
88-
89-
d_mplexer.addReadFD(d_fd, callback, callbackData, ttd ? &*ttd : nullptr);
90-
DEBUGLOG(__PRETTY_FUNCTION__ << ": add read FD " << d_fd);
91-
d_isWaitingForRead = true;
92-
}
93-
else if (iostate == IOState::NeedWrite) {
94-
if (isWaitingForWrite()) {
95-
if (ttd) {
96-
/* let's update the TTD ! */
97-
d_mplexer.setWriteTTD(d_fd, *ttd, /* we pass 0 here because we already have a TTD */ 0);
98-
}
99-
return;
100-
}
101-
102-
d_mplexer.addWriteFD(d_fd, callback, callbackData, ttd ? &*ttd : nullptr);
103-
DEBUGLOG(__PRETTY_FUNCTION__ << ": add write FD " << d_fd);
104-
d_isWaitingForWrite = true;
105-
}
106-
}
107-
10877
void update(IOState iostate, FDMultiplexer::callbackfunc_t callback = FDMultiplexer::callbackfunc_t(), FDMultiplexer::funcparam_t callbackData = boost::any(), std::optional<struct timeval> ttd = std::nullopt)
10978
{
11079
DEBUGLOG("in " << __PRETTY_FUNCTION__ << " for fd " << d_fd << ", last state was " << getState() << " , new state is " << (int)iostate);
@@ -125,6 +94,9 @@ public:
12594
/* let's update the TTD ! */
12695
d_mplexer.setReadTTD(d_fd, *ttd, /* we pass 0 here because we already have a TTD */ 0);
12796
}
97+
else {
98+
d_mplexer.resetReadTTD(d_fd);
99+
}
128100
return;
129101
}
130102

@@ -146,6 +118,9 @@ public:
146118
/* let's update the TTD ! */
147119
d_mplexer.setWriteTTD(d_fd, *ttd, /* we pass 0 here because we already have a TTD */ 0);
148120
}
121+
else {
122+
d_mplexer.resetWriteTTD(d_fd);
123+
}
149124
return;
150125
}
151126

pdns/mplexer.hh

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,18 @@ public:
185185
d_readCallbacks.replace(it, newEntry);
186186
}
187187

188+
void resetReadTTD(int fd)
189+
{
190+
const auto& it = d_readCallbacks.find(fd);
191+
if (it == d_readCallbacks.end()) {
192+
throw FDMultiplexerException("attempt to timestamp fd not in the multiplexer");
193+
}
194+
195+
auto newEntry = *it;
196+
memset(&newEntry.d_ttd, 0, sizeof(newEntry.d_ttd));
197+
d_readCallbacks.replace(it, newEntry);
198+
}
199+
188200
void setWriteTTD(int fd, struct timeval tv, int timeout)
189201
{
190202
const auto& it = d_writeCallbacks.find(fd);
@@ -198,6 +210,18 @@ public:
198210
d_writeCallbacks.replace(it, newEntry);
199211
}
200212

213+
void resetWriteTTD(int fd)
214+
{
215+
const auto& it = d_writeCallbacks.find(fd);
216+
if (it == d_writeCallbacks.end()) {
217+
throw FDMultiplexerException("attempt to timestamp fd not in the multiplexer");
218+
}
219+
220+
auto newEntry = *it;
221+
memset(&newEntry.d_ttd, 0, sizeof(newEntry.d_ttd));
222+
d_writeCallbacks.replace(it, newEntry);
223+
}
224+
201225
void alterFDToRead(int fd, callbackfunc_t toDo, const funcparam_t& parameter = funcparam_t(), const struct timeval* ttd = nullptr)
202226
{
203227
accountingRemoveFD(d_writeCallbacks, fd);

regression-tests.dnsdist/test_DOH.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2374,3 +2374,73 @@ def testDOHWithPaddingWithECS(self):
23742374

23752375
class TestDOHEDNSPadding(DOHEDNSPadding, DNSDistDOHTest):
23762376
_dohLibrary = "nghttp2"
2377+
2378+
2379+
class TestDOHNoIdleTimeoutKeepsConnection(DNSDistDOHTest, DNSDistTest):
2380+
_serverKey = "server.key"
2381+
_serverCert = "server.chain"
2382+
_serverName = "tls.tests.dnsdist.org"
2383+
_caCert = "ca.pem"
2384+
_dohServerPort = pickAvailablePort()
2385+
_dohBaseURL = "https://%s:%d/PowerDNS" % (_serverName, _dohServerPort)
2386+
2387+
_config_template = """
2388+
newServer{address="127.0.0.1:%d"}
2389+
addDOHLocal("127.0.0.1:%d", "%s", "%s", { "/PowerDNS" }, {idleTimeout = 0})
2390+
"""
2391+
_config_params = [
2392+
"_testServerPort",
2393+
"_dohServerPort",
2394+
"_serverCert",
2395+
"_serverKey",
2396+
]
2397+
_verboseMode = True
2398+
2399+
def testKeepsConnection(self):
2400+
"""
2401+
DOH: Keeps connection with idleTimeout
2402+
"""
2403+
name = "simple.doh.tests.powerdns.com."
2404+
query = dns.message.make_query(name, "A", "IN")
2405+
expectedQuery = dns.message.make_query(name, "A", "IN")
2406+
response = dns.message.make_response(query)
2407+
rrset = dns.rrset.from_text(name, 3600, dns.rdataclass.IN, dns.rdatatype.A, "127.0.0.1")
2408+
response.answer.append(rrset)
2409+
2410+
conn = self.openDOHConnection(self._dohServerPort, caFile=self._caCert, timeout=2.0)
2411+
conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
2412+
conn.setopt(pycurl.SSL_VERIFYPEER, 1)
2413+
conn.setopt(pycurl.SSL_VERIFYHOST, 2)
2414+
conn.setopt(pycurl.CAINFO, self._caCert)
2415+
2416+
(receivedQuery, receivedResponse) = self.sendDOHQuery(
2417+
self._dohServerPort,
2418+
self._serverName,
2419+
self._dohBaseURL,
2420+
query,
2421+
response=response,
2422+
caFile=self._caCert,
2423+
conn=conn,
2424+
)
2425+
self.assertTrue(receivedQuery)
2426+
self.assertTrue(receivedResponse)
2427+
receivedQuery.id = expectedQuery.id
2428+
self.assertEqual(expectedQuery, receivedQuery)
2429+
2430+
time.sleep(3)
2431+
2432+
(receivedQuery, receivedResponse) = self.sendDOHQuery(
2433+
self._dohServerPort,
2434+
self._serverName,
2435+
self._dohBaseURL,
2436+
query,
2437+
response=response,
2438+
caFile=self._caCert,
2439+
conn=conn,
2440+
)
2441+
self.assertTrue(receivedQuery)
2442+
self.assertTrue(receivedResponse)
2443+
receivedQuery.id = expectedQuery.id
2444+
self.assertEqual(expectedQuery, receivedQuery)
2445+
2446+
self.assertEqual(conn.getinfo(pycurl.NUM_CONNECTS), 0)

0 commit comments

Comments
 (0)