From cbc51680c517ff750a7e99b3c054e17e0f17fdca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20K=C5=82ys?= Date: Wed, 17 Jun 2026 13:03:05 +0200 Subject: [PATCH] Harden Excel image downloads --- OfficeIMO.Excel/Utilities/ImageDownloader.cs | 143 ++++++++++++++---- .../Excel.HeadersFootersAndProperties.cs | 76 ++++++++++ 2 files changed, 188 insertions(+), 31 deletions(-) diff --git a/OfficeIMO.Excel/Utilities/ImageDownloader.cs b/OfficeIMO.Excel/Utilities/ImageDownloader.cs index c7b7e9b7b..e39dd6a45 100644 --- a/OfficeIMO.Excel/Utilities/ImageDownloader.cs +++ b/OfficeIMO.Excel/Utilities/ImageDownloader.cs @@ -19,6 +19,8 @@ public CacheEntry(byte[] bytes, string? contentType, DateTimeOffset expiresAt) { } private const int CacheCapacity = 32; + private const int MaxRedirects = 5; + private const int BufferSize = 81920; private static readonly TimeSpan CacheEntryLifetime = TimeSpan.FromMinutes(10); private static readonly ConcurrentDictionary Cache = new(StringComparer.OrdinalIgnoreCase); private static readonly ConcurrentQueue CacheOrder = new(); @@ -47,58 +49,137 @@ private static void TrimCache() { public static bool TryFetch(string url, int timeoutSeconds, long maxBytes, out byte[]? bytes, out string? contentType) { bytes = null; contentType = null; try { - if (Cache.TryGetValue(url, out var cached)) { + if (maxBytes <= 0 || !TryCreateHttpUri(url, out var uri)) return false; + + var cacheKey = uri.AbsoluteUri; + if (Cache.TryGetValue(cacheKey, out var cached)) { if (DateTimeOffset.UtcNow <= cached.ExpiresAt) { bytes = cached.Bytes; contentType = cached.ContentType; return true; } - Cache.TryRemove(url, out _); + Cache.TryRemove(cacheKey, out _); } #if NETFRAMEWORK - var request = (HttpWebRequest)WebRequest.Create(url); - request.AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate; - request.Timeout = Math.Max(1, timeoutSeconds) * 1000; - using (var response = (HttpWebResponse)request.GetResponse()) + using (var response = SendWithRedirects(uri, timeoutSeconds)) +#else + using (var handler = new HttpClientHandler { AllowAutoRedirect = false, AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate }) + using (var http = new HttpClient(handler) { Timeout = TimeSpan.FromSeconds(Math.Max(1, timeoutSeconds)) }) + using (var response = SendWithRedirects(http, uri)) +#endif { + if (response == null) return false; +#if NETFRAMEWORK if (response.StatusCode != HttpStatusCode.OK) return false; var ct = NormalizeContentType(response.ContentType); - if (ct == null || !ct.StartsWith("image/", StringComparison.OrdinalIgnoreCase)) return false; var len = response.ContentLength; +#else + if (!response.IsSuccessStatusCode) return false; + var ct = NormalizeContentType(response.Content.Headers.ContentType?.MediaType); + var len = response.Content.Headers.ContentLength; +#endif + if (ct == null || !ct.StartsWith("image/", StringComparison.OrdinalIgnoreCase)) return false; +#if NETFRAMEWORK if (len > 0 && len > maxBytes) return false; using var s = response.GetResponseStream(); - if (s == null) return false; - using var ms = new MemoryStream(); s.CopyTo(ms); - if (ms.Length > maxBytes) return false; - var arr = ms.ToArray(); - Cache[url] = new CacheEntry(arr, ct, DateTimeOffset.UtcNow.Add(CacheEntryLifetime)); - CacheOrder.Enqueue(url); - TrimCache(); - bytes = arr; contentType = ct; - return true; - } #else - using (var handler = new HttpClientHandler { AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate }) - using (var http = new HttpClient(handler) { Timeout = TimeSpan.FromSeconds(Math.Max(1, timeoutSeconds)) }) - using (var resp = http.GetAsync(url, HttpCompletionOption.ResponseHeadersRead).GetAwaiter().GetResult()) { - if (!resp.IsSuccessStatusCode) return false; - var ct = NormalizeContentType(resp.Content.Headers.ContentType?.MediaType); - if (ct == null || !ct.StartsWith("image/", StringComparison.OrdinalIgnoreCase)) return false; - var len = resp.Content.Headers.ContentLength; if (len.HasValue && len.Value > maxBytes) return false; - using var s = resp.Content.ReadAsStreamAsync().GetAwaiter().GetResult(); - using var ms = new MemoryStream(); s.CopyTo(ms); - if (ms.Length > maxBytes) return false; - var arr = ms.ToArray(); - Cache[url] = new CacheEntry(arr, ct, DateTimeOffset.UtcNow.Add(CacheEntryLifetime)); - CacheOrder.Enqueue(url); + using var s = response.Content.ReadAsStreamAsync().GetAwaiter().GetResult(); +#endif + if (s == null) return false; + var arr = ReadWithLimit(s, maxBytes); + if (arr == null) return false; + Cache[cacheKey] = new CacheEntry(arr, ct, DateTimeOffset.UtcNow.Add(CacheEntryLifetime)); + CacheOrder.Enqueue(cacheKey); TrimCache(); bytes = arr; contentType = ct; return true; } -#endif } catch { return false; } } + + private static bool TryCreateHttpUri(string url, out Uri uri) { + uri = null!; + if (string.IsNullOrWhiteSpace(url) || !Uri.TryCreate(url, UriKind.Absolute, out var parsed)) return false; + if (!IsHttpUri(parsed)) return false; + + uri = parsed; + return true; + } + + private static bool IsHttpUri(Uri uri) { + return string.Equals(uri.Scheme, Uri.UriSchemeHttps, StringComparison.OrdinalIgnoreCase) + || string.Equals(uri.Scheme, Uri.UriSchemeHttp, StringComparison.OrdinalIgnoreCase); + } + + private static bool IsRedirect(HttpStatusCode statusCode) { + return statusCode == HttpStatusCode.Moved + || statusCode == HttpStatusCode.Redirect + || statusCode == HttpStatusCode.SeeOther + || statusCode == HttpStatusCode.TemporaryRedirect + || (int)statusCode == 308; + } + + private static Uri? ResolveRedirect(Uri currentUri, string? location) { + if (string.IsNullOrWhiteSpace(location)) return null; + if (!Uri.TryCreate(location, UriKind.RelativeOrAbsolute, out var parsed)) return null; + + var resolved = parsed.IsAbsoluteUri ? parsed : new Uri(currentUri, parsed); + return IsHttpUri(resolved) ? resolved : null; + } + +#if NETFRAMEWORK + private static HttpWebResponse? SendWithRedirects(Uri uri, int timeoutSeconds) { + var currentUri = uri; + for (int redirectCount = 0; redirectCount <= MaxRedirects; redirectCount++) { + var request = (HttpWebRequest)WebRequest.Create(currentUri); + request.AllowAutoRedirect = false; + request.AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate; + request.Timeout = Math.Max(1, timeoutSeconds) * 1000; + + var response = (HttpWebResponse)request.GetResponse(); + if (!IsRedirect(response.StatusCode)) return response; + + var nextUri = ResolveRedirect(currentUri, response.Headers[HttpResponseHeader.Location]); + response.Dispose(); + if (nextUri == null || redirectCount == MaxRedirects) return null; + currentUri = nextUri; + } + + return null; + } +#else + private static HttpResponseMessage? SendWithRedirects(HttpClient http, Uri uri) { + var currentUri = uri; + for (int redirectCount = 0; redirectCount <= MaxRedirects; redirectCount++) { + var response = http.GetAsync(currentUri, HttpCompletionOption.ResponseHeadersRead).GetAwaiter().GetResult(); + if (!IsRedirect(response.StatusCode)) return response; + + var nextUri = ResolveRedirect(currentUri, response.Headers.Location?.ToString()); + response.Dispose(); + if (nextUri == null || redirectCount == MaxRedirects) return null; + currentUri = nextUri; + } + + return null; + } +#endif + + private static byte[]? ReadWithLimit(Stream stream, long maxBytes) { + using var ms = new MemoryStream(); + var buffer = new byte[BufferSize]; + long total = 0; + while (true) { + int read = stream.Read(buffer, 0, buffer.Length); + if (read == 0) break; + + total += read; + if (total > maxBytes) return null; + ms.Write(buffer, 0, read); + } + + return ms.ToArray(); + } } } diff --git a/OfficeIMO.Tests/Excel.HeadersFootersAndProperties.cs b/OfficeIMO.Tests/Excel.HeadersFootersAndProperties.cs index 914cc3020..6bb41d205 100644 --- a/OfficeIMO.Tests/Excel.HeadersFootersAndProperties.cs +++ b/OfficeIMO.Tests/Excel.HeadersFootersAndProperties.cs @@ -227,6 +227,55 @@ public async Task ImageDownloader_Reuses_Cache_For_Repeat_Urls() } } + [Fact] + [Trait("Category","ExcelHeaderFooterImages")] + public async Task ImageDownloader_Rejects_Redirect_To_NonHttp_Target() { + OfficeIMO.Excel.ImageDownloader.ClearCache(); + + var listener = new TcpListener(IPAddress.Loopback, 0); + listener.Start(); + var port = ((IPEndPoint)listener.LocalEndpoint).Port; + var url = $"http://127.0.0.1:{port}/redirect.png"; + var response = "HTTP/1.1 302 Found\r\nLocation: file:///C:/Windows/win.ini\r\nConnection: close\r\n\r\n"; + var acceptTask = ServeSingleRawResponseAsync(listener, Encoding.ASCII.GetBytes(response)); + + try { + Assert.False(OfficeIMO.Excel.ImageDownloader.TryFetch(url, 5, 2_000_000, out var bytes, out var contentType)); + Assert.Null(bytes); + Assert.Null(contentType); + } finally { + listener.Stop(); + await acceptTask; + OfficeIMO.Excel.ImageDownloader.ClearCache(); + } + } + + [Fact] + [Trait("Category","ExcelHeaderFooterImages")] + public async Task ImageDownloader_Rejects_Response_When_Stream_Exceeds_Limit() { + OfficeIMO.Excel.ImageDownloader.ClearCache(); + + var payload = Enumerable.Repeat((byte)0x41, 64).ToArray(); + var header = Encoding.ASCII.GetBytes("HTTP/1.1 200 OK\r\nContent-Type: image/png\r\nConnection: close\r\n\r\n"); + var response = header.Concat(payload).ToArray(); + + var listener = new TcpListener(IPAddress.Loopback, 0); + listener.Start(); + var port = ((IPEndPoint)listener.LocalEndpoint).Port; + var url = $"http://127.0.0.1:{port}/large.png"; + var acceptTask = ServeSingleRawResponseAsync(listener, response); + + try { + Assert.False(OfficeIMO.Excel.ImageDownloader.TryFetch(url, 5, 16, out var bytes, out var contentType)); + Assert.Null(bytes); + Assert.Null(contentType); + } finally { + listener.Stop(); + await acceptTask; + OfficeIMO.Excel.ImageDownloader.ClearCache(); + } + } + [Fact] [Trait("Category","ExcelHeaderFooterImages")] public async Task Excel_HeaderImageUrl_Roundtrips_ContentType() { @@ -336,5 +385,32 @@ private static Task ServeSingleImageAsync(TcpListener listener, byte[] payload, } }); } + + private static Task ServeSingleRawResponseAsync(TcpListener listener, byte[] responseBytes) { + return Task.Run(async () => + { + try + { + using var client = await listener.AcceptTcpClientAsync(); + using var stream = client.GetStream(); + using (var reader = new StreamReader(stream, Encoding.ASCII, false, 1024, leaveOpen: true)) + { + string? line; + while (!string.IsNullOrEmpty(line = await reader.ReadLineAsync())) { } + } + + await stream.WriteAsync(responseBytes, 0, responseBytes.Length); + await stream.FlushAsync(); + } + catch (SocketException) + { + // Listener stopped before accepting a connection; ignore for test cleanup. + } + catch (ObjectDisposedException) + { + // Listener disposed before accept completed; ignore for cleanup. + } + }); + } } }