|
| 1 | +From 17b38c53c0c75ab431bcf340614233c6301f1037 Mon Sep 17 00:00:00 2001 |
| 2 | +From: AllSpark <allspark@microsoft.com> |
| 3 | +Date: Thu, 14 May 2026 08:38:24 +0000 |
| 4 | +Subject: [PATCH] names: bound DNS compression-pointer dereferences during |
| 5 | + decode to mitigate DoS; introduce DNSDecodeError and shared decode context; |
| 6 | + add context manager; apply per-message counter in Message.decode and per-call |
| 7 | + in Name.decode |
| 8 | + |
| 9 | +Signed-off-by: Azure Linux Security Servicing Account <azurelinux-security@microsoft.com> |
| 10 | +Upstream-reference: AI Backport of https://github.com/twisted/twisted/commit/2d196123264efb0027eecfe1b430be4a9babdbd8.patch |
| 11 | +--- |
| 12 | + src/twisted/names/dns.py | 159 ++++++++++++++++++++++++++--- |
| 13 | + src/twisted/names/test/test_dns.py | 85 +++++++++++++++ |
| 14 | + 2 files changed, 229 insertions(+), 15 deletions(-) |
| 15 | + |
| 16 | +diff --git a/src/twisted/names/dns.py b/src/twisted/names/dns.py |
| 17 | +index 02ea2b6..df14b54 100644 |
| 18 | +--- a/src/twisted/names/dns.py |
| 19 | ++++ b/src/twisted/names/dns.py |
| 20 | +@@ -10,10 +10,12 @@ Future Plans: |
| 21 | + """ |
| 22 | + |
| 23 | + # System imports |
| 24 | ++import contextvars |
| 25 | + import inspect |
| 26 | + import random |
| 27 | + import socket |
| 28 | + import struct |
| 29 | ++from contextlib import contextmanager |
| 30 | + from io import BytesIO |
| 31 | + from itertools import chain |
| 32 | + from typing import Optional, SupportsInt, Union |
| 33 | +@@ -125,6 +127,7 @@ __all__ = [ |
| 34 | + "OP_UPDATE", |
| 35 | + "PORT", |
| 36 | + "AuthoritativeDomainError", |
| 37 | ++ "DNSDecodeError", |
| 38 | + "DNSQueryTimeoutError", |
| 39 | + "DomainError", |
| 40 | + ] |
| 41 | +@@ -424,6 +427,86 @@ def readPrecisely(file, l): |
| 42 | + raise EOFError |
| 43 | + return buff |
| 44 | + |
| 45 | ++class DNSDecodeError(ValueError): |
| 46 | ++ """ |
| 47 | ++ Raised when a DNS message cannot be decoded because it violates a |
| 48 | ++ protocol-level safety limit. |
| 49 | ++ """ |
| 50 | ++ |
| 51 | ++ |
| 52 | ++class _DecodeContext: |
| 53 | ++ """ |
| 54 | ++ Mutable state shared between the L{IEncodable} decoders invoked while |
| 55 | ++ reading a single DNS message. |
| 56 | ++ |
| 57 | ++ The primary purpose is to bound the total number of compression-pointer |
| 58 | ++ jumps taken across every name in the message, defending against packets |
| 59 | ++ that fan out thousands of records pointing to deeply chained pointers. |
| 60 | ++ |
| 61 | ++ This class is private. External callers must not rely on it; the |
| 62 | ++ per-message scope is installed and torn down by L{Message.decode} |
| 63 | ++ through L{_decodeContextVar}. |
| 64 | ++ |
| 65 | ++ @ivar jumps: The number of compression pointers followed so far. |
| 66 | ++ @ivar maxJumps: The inclusive upper bound on L{jumps}. Exceeding it |
| 67 | ++ causes L{registerJump} to raise L{DNSDecodeError}. |
| 68 | ++ """ |
| 69 | ++ |
| 70 | ++ __slots__ = ("jumps", "maxJumps") |
| 71 | ++ |
| 72 | ++ def __init__(self, maxJumps: int = 1000) -> None: |
| 73 | ++ self.jumps = 0 |
| 74 | ++ self.maxJumps = maxJumps |
| 75 | ++ |
| 76 | ++ def registerJump(self) -> None: |
| 77 | ++ """ |
| 78 | ++ Record that a compression pointer has been followed. |
| 79 | ++ |
| 80 | ++ The check is performed before any further bytes are read so the |
| 81 | ++ caller fails fast as soon as the aggregate limit is breached, even |
| 82 | ++ if additional records remain in the buffer. |
| 83 | ++ |
| 84 | ++ @raise DNSDecodeError: if the cumulative number of jumps exceeds |
| 85 | ++ L{maxJumps}. |
| 86 | ++ """ |
| 87 | ++ self.jumps += 1 |
| 88 | ++ if self.jumps > self.maxJumps: |
| 89 | ++ raise DNSDecodeError( |
| 90 | ++ "Too many compression pointers while decoding DNS message " |
| 91 | ++ f"(limit is {self.maxJumps})" |
| 92 | ++ ) |
| 93 | ++ |
| 94 | ++ |
| 95 | ++# Private module-level L{contextvars.ContextVar} used to share a single |
| 96 | ++# L{_DecodeContext} across the re-entrant calls performed while decoding one |
| 97 | ++# DNS message. L{contextvars} (rather than a plain module attribute) is used |
| 98 | ++# on purpose: although Twisted's reactor is single-threaded, message decoding |
| 99 | ++# is re-entrant across many records in a single pass and L{ContextVar} |
| 100 | ++# guarantees the scope is restored correctly on exit -- and remains isolated |
| 101 | ++# per-task should a future caller decode messages from multiple |
| 102 | ++# L{asyncio}-style contexts concurrently. |
| 103 | ++_decodeContextVar: contextvars.ContextVar[_DecodeContext | None] = ( |
| 104 | ++ contextvars.ContextVar("_dnsDecodeContext", default=None) |
| 105 | ++) |
| 106 | ++ |
| 107 | ++ |
| 108 | ++@contextmanager |
| 109 | ++def _installDecodeContext(context: _DecodeContext): |
| 110 | ++ """ |
| 111 | ++ Install C{context} on L{_decodeContextVar} for the duration of the |
| 112 | ++ C{with} block and restore the previous value on exit. |
| 113 | ++ |
| 114 | ++ This wraps the L{contextvars.ContextVar.set} / L{contextvars.ContextVar.reset} |
| 115 | ++ token dance so call sites can use a plain C{with} statement. |
| 116 | ++ |
| 117 | ++ @param context: The L{_DecodeContext} to install as the active context. |
| 118 | ++ """ |
| 119 | ++ token = _decodeContextVar.set(context) |
| 120 | ++ try: |
| 121 | ++ yield context |
| 122 | ++ finally: |
| 123 | ++ _decodeContextVar.reset(token) |
| 124 | ++ |
| 125 | + |
| 126 | + class IEncodable(Interface): |
| 127 | + """ |
| 128 | +@@ -530,8 +613,17 @@ class Name: |
| 129 | + |
| 130 | + @ivar name: A byte string giving the name. |
| 131 | + @type name: L{bytes} |
| 132 | ++ |
| 133 | ++ @ivar maxCompressionPointers: Per-message cap on the total number of |
| 134 | ++ compression-pointer dereferences L{decode} will follow before |
| 135 | ++ raising L{DNSDecodeError}. Defaults to C{1000}. Override it on |
| 136 | ++ a subclass or individual instance to tune the trade-off between |
| 137 | ++ tolerance for legitimately verbose messages and resistance to |
| 138 | ++ denial-of-service attacks. |
| 139 | + """ |
| 140 | + |
| 141 | ++ maxCompressionPointers: int = 1000 |
| 142 | ++ |
| 143 | + def __init__(self, name=b""): |
| 144 | + """ |
| 145 | + @param name: A name. |
| 146 | +@@ -576,16 +668,33 @@ class Name: |
| 147 | + """ |
| 148 | + Decode a byte string into this Name. |
| 149 | + |
| 150 | ++ When invoked from L{Message.decode}, a shared compression-pointer |
| 151 | ++ counter is picked up transparently from the private |
| 152 | ++ L{_decodeContextVar}. Standalone callers get a fresh per-call |
| 153 | ++ counter seeded from L{maxCompressionPointers}, so existing code |
| 154 | ++ keeps working unchanged while still being protected against |
| 155 | ++ pathological inputs. |
| 156 | ++ |
| 157 | + @type strio: file |
| 158 | + @param strio: Bytes will be read from this file until the full Name |
| 159 | +- is decoded. |
| 160 | ++ is decoded. |
| 161 | ++ |
| 162 | ++ @type length: L{int} or L{None} |
| 163 | ++ @param length: Present for compatibility with the L{IEncodable} |
| 164 | ++ interface; ignored by this decoder. |
| 165 | + |
| 166 | + @raise EOFError: Raised when there are not enough bytes available |
| 167 | +- from C{strio}. |
| 168 | ++ from C{strio}. |
| 169 | + |
| 170 | +- @raise ValueError: Raised when the name cannot be decoded (for example, |
| 171 | +- because it contains a loop). |
| 172 | ++ @raise ValueError: Raised when the name cannot be decoded because |
| 173 | ++ it contains a compression loop. |
| 174 | ++ |
| 175 | ++ @raise DNSDecodeError: Raised when the cumulative number of |
| 176 | ++ compression-pointer jumps exceeds the configured limit. |
| 177 | + """ |
| 178 | ++ context = _decodeContextVar.get() |
| 179 | ++ if context is None: |
| 180 | ++ context = _DecodeContext(maxJumps=self.maxCompressionPointers) |
| 181 | + visited = set() |
| 182 | + self.name = b"" |
| 183 | + off = 0 |
| 184 | +@@ -597,6 +706,7 @@ class Name: |
| 185 | + return |
| 186 | + if (l >> 6) == 3: |
| 187 | + new_off = (l & 63) << 8 | ord(readPrecisely(strio, 1)) |
| 188 | ++ context.registerJump() |
| 189 | + if new_off in visited: |
| 190 | + raise ValueError("Compression loop in encoded name") |
| 191 | + visited.add(new_off) |
| 192 | +@@ -2454,8 +2564,17 @@ class Message(tputil.FancyEqMixin): |
| 193 | + header fields. |
| 194 | + @ivar _sectionNames: The names of attributes representing the record |
| 195 | + sections of this message. |
| 196 | ++ |
| 197 | ++ @ivar maxCompressionPointers: Per-message cap on the total number of |
| 198 | ++ compression-pointer dereferences L{decode} will follow across every |
| 199 | ++ name in the message before raising L{DNSDecodeError}. Defaults to |
| 200 | ++ C{1000}. Override it on a subclass or individual instance to tune |
| 201 | ++ the trade-off between tolerance for legitimately verbose messages |
| 202 | ++ and resistance to denial-of-service attacks. |
| 203 | + """ |
| 204 | + |
| 205 | ++ maxCompressionPointers: int = 1000 |
| 206 | ++ |
| 207 | + compareAttributes = ( |
| 208 | + "id", |
| 209 | + "answer", |
| 210 | +@@ -2670,19 +2789,29 @@ class Message(tputil.FancyEqMixin): |
| 211 | + self.checkingDisabled = (byte4 >> 4) & 1 |
| 212 | + self.rCode = byte4 & 0xF |
| 213 | + |
| 214 | +- self.queries = [] |
| 215 | +- for i in range(nqueries): |
| 216 | +- q = Query() |
| 217 | +- try: |
| 218 | +- q.decode(strio) |
| 219 | +- except EOFError: |
| 220 | +- return |
| 221 | +- self.queries.append(q) |
| 222 | ++ # A single shared counter bounds the total compression-pointer work |
| 223 | ++ # performed across every name in this message. It is installed on |
| 224 | ++ # the private context variable so nested record decoders pick it up |
| 225 | ++ # without needing to thread it through each signature. |
| 226 | ++ decodeContext = _DecodeContext(maxJumps=self.maxCompressionPointers) |
| 227 | ++ with _installDecodeContext(decodeContext): |
| 228 | ++ self.queries = [] |
| 229 | ++ for i in range(nqueries): |
| 230 | ++ q = Query() |
| 231 | ++ try: |
| 232 | ++ q.decode(strio) |
| 233 | ++ except EOFError: |
| 234 | ++ return |
| 235 | ++ self.queries.append(q) |
| 236 | + |
| 237 | +- items = ((self.answers, nans), (self.authority, nns), (self.additional, nadd)) |
| 238 | ++ items = ( |
| 239 | ++ (self.answers, nans), |
| 240 | ++ (self.authority, nns), |
| 241 | ++ (self.additional, nadd), |
| 242 | ++ ) |
| 243 | + |
| 244 | +- for (l, n) in items: |
| 245 | +- self.parseRecords(l, n, strio) |
| 246 | ++ for l, n in items: |
| 247 | ++ self.parseRecords(l, n, strio) |
| 248 | + |
| 249 | + def parseRecords(self, list, num, strio): |
| 250 | + for i in range(num): |
| 251 | +diff --git a/src/twisted/names/test/test_dns.py b/src/twisted/names/test/test_dns.py |
| 252 | +index 6286026..a23f19d 100644 |
| 253 | +--- a/src/twisted/names/test/test_dns.py |
| 254 | ++++ b/src/twisted/names/test/test_dns.py |
| 255 | +@@ -347,6 +347,54 @@ class NameTests(unittest.TestCase): |
| 256 | + stream = BytesIO(b"\xc0\x00") |
| 257 | + self.assertRaises(ValueError, name.decode, stream) |
| 258 | + |
| 259 | ++ def test_rejectTooManyCompressionPointers(self): |
| 260 | ++ """ |
| 261 | ++ L{Name.decode} raises L{dns.DNSDecodeError} when it would have to |
| 262 | ++ follow more than L{Name.maxCompressionPointers} compression |
| 263 | ++ pointers to finish decoding a name. |
| 264 | ++ """ |
| 265 | ++ # Four distinct pointers chained end-to-end, terminated by a zero |
| 266 | ++ # label byte. With maxCompressionPointers of three the fourth |
| 267 | ++ # dereference must trip the safety limit. |
| 268 | ++ payload = b"\xc0\x02\xc0\x04\xc0\x06\xc0\x08\x00" |
| 269 | ++ name = dns.Name() |
| 270 | ++ name.maxCompressionPointers = 3 |
| 271 | ++ self.assertRaises( |
| 272 | ++ dns.DNSDecodeError, name.decode, BytesIO(payload) |
| 273 | ++ ) |
| 274 | ++ |
| 275 | ++ def test_decodeRecoversAfterDNSDecodeError(self): |
| 276 | ++ """ |
| 277 | ++ After L{Name.decode} raises L{dns.DNSDecodeError}, subsequent |
| 278 | ++ L{Name.decode} calls continue to work. No residual |
| 279 | ++ compression-pointer counter leaks across calls, so a legitimate |
| 280 | ++ name decoded right after a hostile one still succeeds. |
| 281 | ++ """ |
| 282 | ++ # First, force a DNSDecodeError by decoding a payload that |
| 283 | ++ # exceeds the configured limit. |
| 284 | ++ hostile = dns.Name() |
| 285 | ++ hostile.maxCompressionPointers = 3 |
| 286 | ++ self.assertRaises( |
| 287 | ++ dns.DNSDecodeError, |
| 288 | ++ hostile.decode, |
| 289 | ++ BytesIO(b"\xc0\x02\xc0\x04\xc0\x06\xc0\x08\x00"), |
| 290 | ++ ) |
| 291 | ++ |
| 292 | ++ # Then prove the process has not been poisoned: a legitimate |
| 293 | ++ # name still decodes normally, both with a fresh instance and |
| 294 | ++ # with the instance that just errored. |
| 295 | ++ stream = BytesIO() |
| 296 | ++ dns.Name(b"example.org").encode(stream) |
| 297 | ++ |
| 298 | ++ fresh = dns.Name() |
| 299 | ++ stream.seek(0) |
| 300 | ++ fresh.decode(stream) |
| 301 | ++ self.assertEqual(fresh.name, b"example.org") |
| 302 | ++ |
| 303 | ++ stream.seek(0) |
| 304 | ++ hostile.decode(stream) |
| 305 | ++ self.assertEqual(hostile.name, b"example.org") |
| 306 | ++ |
| 307 | + def test_equality(self): |
| 308 | + """ |
| 309 | + L{Name} instances are equal as long as they have the same value for |
| 310 | +@@ -756,6 +804,43 @@ class MessageTests(unittest.SynchronousTestCase): |
| 311 | + """ |
| 312 | + self.assertEqual(dns.Message().authenticData, 0) |
| 313 | + |
| 314 | ++ def test_rejectCompressionPointerFlood(self): |
| 315 | ++ """ |
| 316 | ++ L{Message.decode} installs a shared compression-pointer counter and |
| 317 | ++ raises L{dns.DNSDecodeError} when the aggregate number of pointer |
| 318 | ++ dereferences across every record in the message exceeds |
| 319 | ++ L{dns.Message.maxCompressionPointers}. |
| 320 | ++ """ |
| 321 | ++ chainLength = 100 |
| 322 | ++ numRecords = 8000 |
| 323 | ++ header = struct.pack( |
| 324 | ++ "!H2B4H", 0x1234, 0x80, 0x00, 0, numRecords, 0, 0 |
| 325 | ++ ) |
| 326 | ++ |
| 327 | ++ # Long compression chain inside the RDATA of an unknown |
| 328 | ++ # record so that subsequent records can aim pointers at it. |
| 329 | ++ owner = b"\x04rrrr\x00" |
| 330 | ++ chainBase = len(header) + len(owner) + 10 |
| 331 | ++ chain = bytearray() |
| 332 | ++ for i in range(chainLength): |
| 333 | ++ chain += struct.pack("!H", 0xC000 | (chainBase + 2 * (i + 1))) |
| 334 | ++ chain += b"\x04test\x00" |
| 335 | ++ |
| 336 | ++ firstRecord = ( |
| 337 | ++ owner |
| 338 | ++ + struct.pack("!HHIH", 999, 1, 0, len(chain)) |
| 339 | ++ + bytes(chain) |
| 340 | ++ ) |
| 341 | ++ followupRecord = ( |
| 342 | ++ struct.pack("!H", 0xC000 | chainBase) |
| 343 | ++ + struct.pack("!HHIH", 1, 1, 0, 4) |
| 344 | ++ + b"\x00\x00\x00\x00" |
| 345 | ++ ) |
| 346 | ++ payload = header + firstRecord + followupRecord * (numRecords - 1) |
| 347 | ++ |
| 348 | ++ message = dns.Message() |
| 349 | ++ self.assertRaises(dns.DNSDecodeError, message.decode, BytesIO(payload)) |
| 350 | ++ |
| 351 | + def test_authenticDataOverride(self): |
| 352 | + """ |
| 353 | + L{dns.Message.__init__} accepts a C{authenticData} argument which |
| 354 | +-- |
| 355 | +2.45.4 |
| 356 | + |
0 commit comments