diff --git a/CMakeLists.txt b/CMakeLists.txt index ff16ac3..3113340 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,6 +31,7 @@ set(EXTENSION_SOURCES src/httpfs_httplib_client.cpp src/multi_curl_manager.cpp src/s3fs.cpp + src/string_utils.cpp src/tcp_connection_fetcher.cpp src/tcp_connection_query_function.cpp src/tcp_ip_recorder.cpp diff --git a/src/curl_httpfs_extension.cpp b/src/curl_httpfs_extension.cpp index 5abca1f..e67d16c 100644 --- a/src/curl_httpfs_extension.cpp +++ b/src/curl_httpfs_extension.cpp @@ -3,17 +3,18 @@ #include "httpfs_client.hpp" #include "create_secret_functions.hpp" #include "duckdb.hpp" -#include "extension_loader_helper.hpp" #include "s3fs.hpp" #include "hffs.hpp" #ifdef OVERRIDE_ENCRYPTION_UTILS #include "crypto.hpp" #endif // OVERRIDE_ENCRYPTION_UTILS -namespace duckdb { -namespace { +// Extension-specific header. +#include "extension_loader_helper.hpp" + +namespace duckdb { -void LoadInternal(ExtensionLoader &loader) { +static void LoadInternal(ExtensionLoader &loader) { auto &instance = loader.GetDatabaseInstance(); auto &fs = instance.GetFileSystem(); @@ -49,7 +50,7 @@ void LoadInternal(ExtensionLoader &loader) { config.AddExtensionOption("ca_cert_file", "Path to a custom certificate file for self-signed certificates.", LogicalType::VARCHAR, Value("")); // Global S3 config - config.AddExtensionOption("s3_region", "S3 Region", LogicalType::VARCHAR, Value("us-east-1")); + config.AddExtensionOption("s3_region", "S3 Region", LogicalType::VARCHAR); config.AddExtensionOption("s3_access_key_id", "S3 Access Key ID", LogicalType::VARCHAR); config.AddExtensionOption("s3_secret_access_key", "S3 Access Key", LogicalType::VARCHAR); config.AddExtensionOption("s3_session_token", "S3 Session Token", LogicalType::VARCHAR); @@ -75,6 +76,9 @@ void LoadInternal(ExtensionLoader &loader) { config.AddExtensionOption("hf_max_per_page", "Debug option to limit number of items returned in list requests", LogicalType::UBIGINT, Value::UBIGINT(0)); + config.AddExtensionOption("merge_http_secret_into_s3_request", "Merges http secret params into S3 requests", + LogicalType::BOOLEAN, Value(true)); + auto callback_httpfs_client_implementation = [](ClientContext &context, SetScope scope, Value ¶meter) { auto &config = DBConfig::GetConfig(context); string value = StringValue::Get(parameter); @@ -85,6 +89,7 @@ void LoadInternal(ExtensionLoader &loader) { throw InvalidInputException("Unsupported option for httpfs_client_implementation, only `wasm` and " "`default` are currently supported for duckdb-wasm"); } +#ifndef EMSCRIPTEN if (value == "curl" || value == "default") { if (!config.http_util || config.http_util->GetName() != "HTTPFSUtil-Curl") { config.http_util = make_shared_ptr(); @@ -97,16 +102,19 @@ void LoadInternal(ExtensionLoader &loader) { } return; } +#endif throw InvalidInputException("Unsupported option for httpfs_client_implementation, only `curl`, `httplib` and " "`default` are currently supported"); }; config.AddExtensionOption("httpfs_client_implementation", "Select which is the HTTPUtil implementation to be used", LogicalType::VARCHAR, "default", callback_httpfs_client_implementation); + config.AddExtensionOption("enable_global_s3_configuration", + "Automatically fetch AWS credentials from environment variables.", LogicalType::BOOLEAN, + Value::BOOLEAN(true)); if (config.http_util && config.http_util->GetName() == "WasmHTTPUtils") { // Already handled, do not override } else { - // By default to use curl utils. config.http_util = make_shared_ptr(); } @@ -116,22 +124,22 @@ void LoadInternal(ExtensionLoader &loader) { CreateS3SecretFunctions::Register(loader); CreateBearerTokenFunctions::Register(loader); + // Extension-specific setup. + LoadExtensionInternal(loader); + #ifdef OVERRIDE_ENCRYPTION_UTILS // set pointer to OpenSSL encryption state config.encryption_util = make_shared_ptr(); #endif // OVERRIDE_ENCRYPTION_UTILS } - -} // namespace - -void CurlHttpfsExtension::Load(ExtensionLoader &loader) { +void HttpfsExtension::Load(ExtensionLoader &loader) { LoadInternal(loader); } -std::string CurlHttpfsExtension::Name() { - return "curl_httpfs"; +std::string HttpfsExtension::Name() { + return "httpfs"; } -std::string CurlHttpfsExtension::Version() const { +std::string HttpfsExtension::Version() const { #ifdef EXT_VERSION_HTTPFS return EXT_VERSION_HTTPFS; #else @@ -143,7 +151,7 @@ std::string CurlHttpfsExtension::Version() const { extern "C" { -DUCKDB_CPP_EXTENSION_ENTRY(curl_httpfs, loader) { +DUCKDB_CPP_EXTENSION_ENTRY(httpfs, loader) { duckdb::LoadInternal(loader); } } diff --git a/src/curl_request.cpp b/src/curl_request.cpp index 023871f..f541ba5 100644 --- a/src/curl_request.cpp +++ b/src/curl_request.cpp @@ -2,6 +2,7 @@ #include "duckdb/common/assert.hpp" #include "extension_config.hpp" +#include "string_utils.hpp" namespace duckdb { @@ -20,7 +21,8 @@ CurlRequest::CurlRequest(CURL *easy_curl_p) : info(make_uniq()), ea CurlRequest::~CurlRequest() = default; void CurlRequest::SetUrl(string url) { - curl_easy_setopt(easy_curl, CURLOPT_URL, url.c_str()); + auto encoded_url = EncodeURL(url); + curl_easy_setopt(easy_curl, CURLOPT_URL, encoded_url.c_str()); info->url = std::move(url); } void CurlRequest::SetHeaders(curl_slist *headers) { diff --git a/src/httpfs.cpp b/src/httpfs.cpp index 0855c09..0164409 100644 --- a/src/httpfs.cpp +++ b/src/httpfs.cpp @@ -20,6 +20,8 @@ #include #include +#include "s3fs.hpp" + namespace duckdb { shared_ptr HTTPFSUtil::GetHTTPUtil(optional_ptr opener) { @@ -33,8 +35,9 @@ unique_ptr HTTPFSUtil::InitializeParameters(optional_ptr optional_ptr info) { auto result = make_uniq(*this); result->Initialize(opener); + result->state = HTTPState::TryGetState(opener); - // No point in continueing without an opener + // No point in continuing without an opener if (!opener) { return std::move(result); } @@ -58,23 +61,48 @@ unique_ptr HTTPFSUtil::InitializeParameters(optional_ptr FileOpener::TryGetCurrentSetting(opener, "hf_max_per_page", result->hf_max_per_page, info); FileOpener::TryGetCurrentSetting(opener, "unsafe_disable_etag_checks", result->unsafe_disable_etag_checks, info); + { + auto db = FileOpener::TryGetDatabase(opener); + if (db) { + result->user_agent = StringUtil::Format("%s %s", db->config.UserAgent(), DuckDB::SourceID()); + } + } + + unique_ptr settings_reader; + if (info && !S3FileSystem::TryGetPrefix(info->file_path).empty()) { + // This is an S3-type url, we should + const char *s3_secret_types[] = {"s3", "r2", "gcs", "aws", "http"}; + + idx_t secret_type_count = 5; + Value merge_http_secret_into_s3_request; + FileOpener::TryGetCurrentSetting(opener, "merge_http_secret_into_s3_request", + merge_http_secret_into_s3_request); + + if (!merge_http_secret_into_s3_request.IsNull() && !merge_http_secret_into_s3_request.GetValue()) { + // Drop the http secret from the lookup + secret_type_count = 4; + } + settings_reader = make_uniq(*opener, info, s3_secret_types, secret_type_count); + } else { + settings_reader = make_uniq(*opener, info, "http"); + } + // HTTP Secret lookups - KeyValueSecretReader settings_reader(*opener, info, "http"); string proxy_setting; - if (settings_reader.TryGetSecretKey("http_proxy", proxy_setting) && !proxy_setting.empty()) { + if (settings_reader->TryGetSecretKey("http_proxy", proxy_setting) && !proxy_setting.empty()) { idx_t port; string host; HTTPUtil::ParseHTTPProxyHost(proxy_setting, host, port); result->http_proxy = host; result->http_proxy_port = port; } - settings_reader.TryGetSecretKey("http_proxy_username", result->http_proxy_username); - settings_reader.TryGetSecretKey("http_proxy_password", result->http_proxy_password); - settings_reader.TryGetSecretKey("bearer_token", result->bearer_token); + settings_reader->TryGetSecretKey("http_proxy_username", result->http_proxy_username); + settings_reader->TryGetSecretKey("http_proxy_password", result->http_proxy_password); + settings_reader->TryGetSecretKey("bearer_token", result->bearer_token); Value extra_headers; - if (settings_reader.TryGetSecretKey("extra_http_headers", extra_headers)) { + if (settings_reader->TryGetSecretKey("extra_http_headers", extra_headers)) { auto children = MapValue::GetChildren(extra_headers); for (const auto &child : children) { auto kv = StructValue::GetChildren(child); @@ -102,11 +130,29 @@ void HTTPClientCache::StoreClient(unique_ptr client) { clients.push_back(std::move(client)); } +static void AddUserAgentIfAvailable(HTTPFSParams &http_params, HTTPHeaders &header_map) { + if (!http_params.user_agent.empty()) { + header_map.Insert("User-Agent", http_params.user_agent); + } +} + +static void AddHandleHeaders(HTTPFileHandle &handle, HTTPHeaders &header_map) { + // Inject headers from the http param extra_headers into the request + for (auto &header : handle.http_params.extra_headers) { + header_map[header.first] = header.second; + } + handle.http_params.pre_merged_headers = true; +} + unique_ptr HTTPFileSystem::PostRequest(FileHandle &handle, string url, HTTPHeaders header_map, string &buffer_out, char *buffer_in, idx_t buffer_in_len, string params) { auto &hfh = handle.Cast(); auto &http_util = hfh.http_params.http_util; + + AddUserAgentIfAvailable(hfh.http_params, header_map); + AddHandleHeaders(hfh, header_map); + PostRequestInfo post_request(url, header_map, hfh.http_params, const_data_ptr_cast(buffer_in), buffer_in_len); auto result = http_util.Request(post_request); buffer_out = std::move(post_request.buffer_out); @@ -117,6 +163,10 @@ unique_ptr HTTPFileSystem::PutRequest(FileHandle &handle, string u char *buffer_in, idx_t buffer_in_len, string params) { auto &hfh = handle.Cast(); auto &http_util = hfh.http_params.http_util; + + AddUserAgentIfAvailable(hfh.http_params, header_map); + AddHandleHeaders(hfh, header_map); + string content_type = "application/octet-stream"; PutRequestInfo put_request(url, header_map, hfh.http_params, (const_data_ptr_t)buffer_in, buffer_in_len, content_type); @@ -126,6 +176,10 @@ unique_ptr HTTPFileSystem::PutRequest(FileHandle &handle, string u unique_ptr HTTPFileSystem::HeadRequest(FileHandle &handle, string url, HTTPHeaders header_map) { auto &hfh = handle.Cast(); auto &http_util = hfh.http_params.http_util; + + AddUserAgentIfAvailable(hfh.http_params, header_map); + AddHandleHeaders(hfh, header_map); + auto http_client = hfh.GetClient(); HeadRequestInfo head_request(url, header_map, hfh.http_params); @@ -138,6 +192,10 @@ unique_ptr HTTPFileSystem::HeadRequest(FileHandle &handle, string unique_ptr HTTPFileSystem::DeleteRequest(FileHandle &handle, string url, HTTPHeaders header_map) { auto &hfh = handle.Cast(); auto &http_util = hfh.http_params.http_util; + + AddUserAgentIfAvailable(hfh.http_params, header_map); + AddHandleHeaders(hfh, header_map); + auto http_client = hfh.GetClient(); DeleteRequestInfo delete_request(url, header_map, hfh.http_params); auto response = http_util.Request(delete_request, http_client); @@ -161,6 +219,9 @@ unique_ptr HTTPFileSystem::GetRequest(FileHandle &handle, string u auto &hfh = handle.Cast(); auto &http_util = hfh.http_params.http_util; + AddUserAgentIfAvailable(hfh.http_params, header_map); + AddHandleHeaders(hfh, header_map); + D_ASSERT(hfh.cached_file_handle); auto http_client = hfh.GetClient(); @@ -210,6 +271,9 @@ unique_ptr HTTPFileSystem::GetRangeRequest(FileHandle &handle, str auto &hfh = handle.Cast(); auto &http_util = hfh.http_params.http_util; + AddUserAgentIfAvailable(hfh.http_params, header_map); + AddHandleHeaders(hfh, header_map); + // send the Range header to read only subset of file string range_expr = "bytes=" + to_string(file_offset) + "-" + to_string(file_offset + buffer_out_len - 1); header_map.Insert("Range", range_expr); @@ -222,13 +286,7 @@ unique_ptr HTTPFileSystem::GetRangeRequest(FileHandle &handle, str url, header_map, hfh.http_params, [&](const HTTPResponse &response) { if (static_cast(response.status) >= 400) { - string error = - "HTTP GET error on '" + url + "' (HTTP " + to_string(static_cast(response.status)) + ")"; - if (response.status == HTTPStatusCode::RangeNotSatisfiable_416) { - error += " This could mean the file was changed. Try disabling the duckdb http metadata cache " - "if enabled, and confirm the server supports range requests."; - } - throw HTTPException(response, error); + throw GetHTTPError(handle, response, url); } if (static_cast(response.status) < 300) { // done redirecting out_offset = 0; @@ -237,13 +295,19 @@ unique_ptr HTTPFileSystem::GetRangeRequest(FileHandle &handle, str string responseEtag = response.GetHeaderValue("ETag"); if (!responseEtag.empty() && responseEtag != hfh.etag) { + if (global_metadata_cache) { + global_metadata_cache->Erase(handle.path); + } throw HTTPException( response, - "ETag was initially %s and now it returned %s, this likely means the remote file has " - "changed.\nTry to restart the read or close the file-handle and read the file again (e.g. " - "`DETACH` in the file is a database file).\nYou can disable checking etags via `SET " + "ETag on reading file \"%s\" was initially %s and now it returned %s, this likely means " + "the " + "remote file has " + "changed.\nFor parquet or similar single table sources, consider retrying the query, for " + "persistent FileHandles such as databases consider `DETACH` and re-`ATTACH` " + "\nYou can disable checking etags via `SET " "unsafe_disable_etag_checks = true;`", - hfh.etag, response.GetHeaderValue("ETag")); + handle.path, hfh.etag, response.GetHeaderValue("ETag")); } } @@ -364,14 +428,49 @@ unique_ptr HTTPFileSystem::OpenFileExtended(const OpenFileInfo &file return std::move(handle); } +void HTTPFileHandle::AddStatistics(idx_t read_offset, idx_t read_length, idx_t read_duration) { + range_request_statistics.push_back({read_offset, read_length, read_duration}); +} + +void HTTPFileHandle::AdaptReadBufferSize(idx_t next_read_offset) { + D_ASSERT(!SkipBuffer()); + if (range_request_statistics.empty()) { + return; // No requests yet - nothing to do + } + + const auto &last_read = range_request_statistics.back(); + if (last_read.offset + last_read.length != next_read_offset) { + return; // Not reading sequentially + } + + if (read_buffer.GetSize() >= MAXIMUM_READ_BUFFER_LEN) { + return; // Already at maximum size + } + + // Grow the buffer + // TODO: can use statistics to estimate per-byte and round-trip cost using least squares, and do something smarter + read_buffer = read_buffer.GetAllocator()->Allocate(read_buffer.GetSize() * 2); +} + bool HTTPFileSystem::TryRangeRequest(FileHandle &handle, string url, HTTPHeaders header_map, idx_t file_offset, char *buffer_out, idx_t buffer_out_len) { + auto &hfh = handle.Cast(); + + const auto timestamp_before = Timestamp::GetCurrentTimestamp(); auto res = GetRangeRequest(handle, url, header_map, file_offset, buffer_out, buffer_out_len); if (res) { // Request succeeded TODO: fix upstream that 206 is not considered success if (res->Success() || res->status == HTTPStatusCode::PartialContent_206 || res->status == HTTPStatusCode::Accepted_202) { + + if (!hfh.flags.RequireParallelAccess()) { + // Update range request statistics + const auto duration = + NumericCast(Timestamp::GetCurrentTimestamp().value - timestamp_before.value); + hfh.AddStatistics(file_offset, buffer_out_len, duration); + } + return true; } @@ -390,7 +489,7 @@ bool HTTPFileSystem::TryRangeRequest(FileHandle &handle, string url, HTTPHeaders error.Throw(); } throw HTTPException(*res, "Request returned HTTP %d for HTTP %s to '%s'", static_cast(res->status), - EnumUtil::ToString(RequestType::GET_REQUEST), res->url); + EnumUtil::ToString(RequestType::GET_REQUEST), url); } throw IOException("Unknown error for HTTP %s to '%s'", EnumUtil::ToString(RequestType::GET_REQUEST), url); } @@ -403,6 +502,9 @@ bool HTTPFileSystem::ReadInternal(FileHandle &handle, void *buffer, int64_t nr_b if (!hfh.cached_file_handle->Initialized()) { throw InternalException("Cached file not initialized properly"); } + if (hfh.cached_file_handle->GetSize() < location + nr_bytes) { + throw InternalException("Cached file length can't satisfy the requested Read"); + } memcpy(buffer, hfh.cached_file_handle->GetData() + location, nr_bytes); DUCKDB_LOG_FILE_SYSTEM_READ(handle, nr_bytes, location); hfh.file_offset = location + nr_bytes; @@ -413,8 +515,7 @@ bool HTTPFileSystem::ReadInternal(FileHandle &handle, void *buffer, int64_t nr_b idx_t buffer_offset = 0; // Don't buffer when DirectIO is set or when we are doing parallel reads - bool skip_buffer = hfh.flags.DirectIO() || hfh.flags.RequireParallelAccess(); - if (skip_buffer && to_read > 0) { + if (hfh.SkipBuffer() && to_read > 0) { if (!TryRangeRequest(hfh, hfh.path, {}, location, (char *)buffer, to_read)) { return false; } @@ -459,7 +560,7 @@ bool HTTPFileSystem::ReadInternal(FileHandle &handle, void *buffer, int64_t nr_b } if (to_read > 0 && hfh.buffer_available == 0) { - auto new_buffer_available = MinValue(hfh.READ_BUFFER_LEN, hfh.length - start_offset); + auto new_buffer_available = MinValue(hfh.read_buffer.GetSize(), hfh.length - start_offset); // Bypass buffer if we read more than buffer size if (to_read > new_buffer_available) { @@ -472,6 +573,8 @@ bool HTTPFileSystem::ReadInternal(FileHandle &handle, void *buffer, int64_t nr_b start_offset += to_read; break; } else { + hfh.AdaptReadBufferSize(start_offset); + new_buffer_available = MinValue(hfh.read_buffer.GetSize(), hfh.length - start_offset); if (!TryRangeRequest(hfh, hfh.path, {}, start_offset, (char *)hfh.read_buffer.get(), new_buffer_available)) { return false; @@ -500,10 +603,10 @@ void HTTPFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, id // attempt to download the full file and retry. if (handle.logger) { - DUCKDB_LOG_WARN(handle.logger, - "Falling back to full file download for file '%s': the server does not support HTTP range " - "requests. Performance and memory usage are potentially degraded.", - handle.path); + DUCKDB_LOG_WARNING(handle.logger, + "Falling back to full file download for file '%s': the server does not support HTTP range " + "requests. Performance and memory usage are potentially degraded.", + handle.path); } auto &hfh = handle.Cast(); @@ -596,7 +699,13 @@ static optional_ptr TryGetMetadataCache(optional_ptrconfig.options.http_metadata_cache_enable; + Value use_shared_cache_val; + bool use_shared_cache = false; + FileOpener::TryGetCurrentSetting(opener, "enable_http_metadata_cache", use_shared_cache_val); + if (!use_shared_cache_val.IsNull()) { + use_shared_cache = use_shared_cache_val.GetValue(); + } + if (use_shared_cache) { return httpfs.GetGlobalCache(); } else if (client_context) { @@ -694,7 +803,8 @@ void HTTPFileHandle::LoadFileInfo() { return; } else { // HEAD request fail, use Range request for another try (read only one byte) - if (flags.OpenForReading() && res->status != HTTPStatusCode::NotFound_404) { + if (flags.OpenForReading() && res->status != HTTPStatusCode::NotFound_404 && + res->status != HTTPStatusCode::MovedPermanently_301) { auto range_res = hfs.GetRangeRequest(*this, path, {}, 0, nullptr, 2); if (range_res->status != HTTPStatusCode::PartialContent_206 && range_res->status != HTTPStatusCode::Accepted_202 && range_res->status != HTTPStatusCode::OK_200) { @@ -704,7 +814,7 @@ void HTTPFileHandle::LoadFileInfo() { } res = std::move(range_res); } else { - throw HTTPException(*res, "Unable to connect to URL \"%s\": %d (%s).", res->url, + throw HTTPException(*res, "Unable to connect to URL \"%s\": %d (%s).", path, static_cast(res->status), res->GetError()); } } @@ -739,6 +849,28 @@ void HTTPFileHandle::TryAddLogger(FileOpener &opener) { } } +void HTTPFileHandle::AllocateReadBuffer(optional_ptr opener) { + D_ASSERT(!SkipBuffer()); + D_ASSERT(!read_buffer.IsSet()); + auto &allocator = opener && opener->TryGetClientContext() ? BufferAllocator::Get(*opener->TryGetClientContext()) + : Allocator::DefaultAllocator(); + read_buffer = allocator.Allocate(INITIAL_READ_BUFFER_LEN); +} + +void HTTPFileHandle::InitializeFromCacheEntry(const HTTPMetadataCacheEntry &cache_entry) { + last_modified = cache_entry.last_modified; + length = cache_entry.length; + etag = cache_entry.etag; +} + +HTTPMetadataCacheEntry HTTPFileHandle::GetCacheEntry() const { + HTTPMetadataCacheEntry result; + result.length = length; + result.last_modified = last_modified; + result.etag = etag; + return result; +} + void HTTPFileHandle::Initialize(optional_ptr opener) { auto &hfs = file_system.Cast(); http_params.state = HTTPState::TryGetState(opener); @@ -764,12 +896,10 @@ void HTTPFileHandle::Initialize(optional_ptr opener) { bool found = current_cache->Find(path, value); if (found) { - last_modified = value.last_modified; - length = value.length; - etag = value.etag; + InitializeFromCacheEntry(value); - if (flags.OpenForReading()) { - read_buffer = duckdb::unique_ptr(new data_t[READ_BUFFER_LEN]); + if (flags.OpenForReading() && !SkipBuffer()) { + AllocateReadBuffer(opener); } return; } @@ -784,11 +914,13 @@ void HTTPFileHandle::Initialize(optional_ptr opener) { FullDownload(hfs, should_write_cache); } if (should_write_cache) { - current_cache->Insert(path, {length, last_modified, etag}); + current_cache->Insert(path, GetCacheEntry()); } - // Initialize the read buffer now that we know the file exists - read_buffer = duckdb::unique_ptr(new data_t[READ_BUFFER_LEN]); + if (!SkipBuffer()) { + // Initialize the read buffer now that we know the file exists + AllocateReadBuffer(opener); + } } // If we're writing to a file, we might as well remove it from the cache diff --git a/src/httpfs_curl_client.cpp b/src/httpfs_curl_client.cpp index 15fcf19..2a52438 100644 --- a/src/httpfs_curl_client.cpp +++ b/src/httpfs_curl_client.cpp @@ -10,6 +10,7 @@ #include "curl_request.hpp" #include "duckdb/common/exception/http_exception.hpp" #include "multi_curl_manager.hpp" +#include "string_utils.hpp" namespace duckdb { @@ -169,7 +170,7 @@ class HTTPFSCurlClient : public HTTPClient { if (!http_params.http_proxy.empty()) { curl_easy_setopt(*curl, CURLOPT_PROXY, - StringUtil::Format("%s:%s", http_params.http_proxy, http_params.http_proxy_port).c_str()); + StringUtil::Format("%s:%d", http_params.http_proxy, http_params.http_proxy_port).c_str()); if (!http_params.http_proxy_username.empty()) { curl_easy_setopt(*curl, CURLOPT_PROXYUSERNAME, http_params.http_proxy_username.c_str()); @@ -227,7 +228,8 @@ class HTTPFSCurlClient : public HTTPClient { CURLcode res; { - curl_easy_setopt(*curl, CURLOPT_URL, request_info->url.c_str()); + const auto encoded_url = EncodeURL(request_info->url); + curl_easy_setopt(*curl, CURLOPT_URL, encoded_url.c_str()); // Perform PUT curl_easy_setopt(*curl, CURLOPT_CUSTOMREQUEST, "PUT"); // Include PUT body @@ -282,7 +284,8 @@ class HTTPFSCurlClient : public HTTPClient { CURLcode res; { // Set URL - curl_easy_setopt(*curl, CURLOPT_URL, request_info->url.c_str()); + const auto encoded_url = EncodeURL(request_info->url); + curl_easy_setopt(*curl, CURLOPT_URL, encoded_url.c_str()); // Set DELETE request method curl_easy_setopt(*curl, CURLOPT_CUSTOMREQUEST, "DELETE"); @@ -320,7 +323,8 @@ class HTTPFSCurlClient : public HTTPClient { CURLcode res; { - curl_easy_setopt(*curl, CURLOPT_URL, request_info->url.c_str()); + const auto encoded_url = EncodeURL(request_info->url); + curl_easy_setopt(*curl, CURLOPT_URL, encoded_url.c_str()); curl_easy_setopt(*curl, CURLOPT_POST, 1L); // Set POST body @@ -396,6 +400,10 @@ class HTTPFSCurlClient : public HTTPClient { response->url = request_info->url; if (!request_info->header_collection.empty()) { for (auto &header : request_info->header_collection.back()) { + // We should not return __RESPONSE_STATUS__ to the user. It's only there for debugging. + if (header.first == "__RESPONSE_STATUS__") { + continue; + } response->headers.Insert(header.first, header.second); } } diff --git a/src/httpfs_httplib_client.cpp b/src/httpfs_httplib_client.cpp index fd79311..4de064f 100644 --- a/src/httpfs_httplib_client.cpp +++ b/src/httpfs_httplib_client.cpp @@ -11,8 +11,8 @@ class HTTPFSClient : public HTTPClient { client = make_uniq(proto_host_port); Initialize(http_params); } - void Initialize(HTTPParams &http_p) { - HTTPFSParams &http_params = reinterpret_cast(http_p); + void Initialize(HTTPParams &http_p) override { + HTTPFSParams &http_params = (HTTPFSParams &)http_p; client->set_follow_location(http_params.follow_location); client->set_keep_alive(http_params.keep_alive); if (!http_params.ca_cert_file.empty()) { @@ -92,7 +92,11 @@ class HTTPFSClient : public HTTPClient { } // We use a custom Request method here, because there is no Post call with a contentreceiver in httplib duckdb_httplib_openssl::Request req; - req.method = "POST"; + if (info.send_post_as_get_request) { + req.method = "GET"; + } else { + req.method = "POST"; + } req.path = info.path; req.headers = TransformHeaders(info.headers, info.params); if (req.headers.find("Content-Type") == req.headers.end()) { @@ -106,18 +110,26 @@ class HTTPFSClient : public HTTPClient { info.buffer_out += string(data, data_length); return true; }; + // First assign body, this is the body that will be uploaded req.body.assign(const_char_ptr_cast(info.buffer_in), info.buffer_in_len); - return TransformResult(client->send(req)); + auto transformed_req = TransformResult(client->send(req)); + // Then, after actual re-quest, re-assign body to the response value of the POST request + transformed_req->body.assign(const_char_ptr_cast(info.buffer_in), info.buffer_in_len); + return std::move(transformed_req); } private: duckdb_httplib_openssl::Headers TransformHeaders(const HTTPHeaders &header_map, const HTTPParams ¶ms) { + auto &httpfs_params = params.Cast(); + duckdb_httplib_openssl::Headers headers; for (auto &entry : header_map) { headers.insert(entry); } - for (auto &entry : params.extra_headers) { - headers.insert(entry); + if (!httpfs_params.pre_merged_headers) { + for (auto &entry : params.extra_headers) { + headers.insert(entry); + } } return headers; } diff --git a/src/include/httpfs.hpp b/src/include/httpfs.hpp index 804968a..7e6734b 100644 --- a/src/include/httpfs.hpp +++ b/src/include/httpfs.hpp @@ -85,8 +85,13 @@ class HTTPFileHandle : public FileHandle { std::mutex mu; // Read buffer - duckdb::unique_ptr read_buffer; - constexpr static idx_t READ_BUFFER_LEN = 1000000; + AllocatedData read_buffer; + constexpr static idx_t INITIAL_READ_BUFFER_LEN = 1048576; + constexpr static idx_t MAXIMUM_READ_BUFFER_LEN = 33554432; + + // Adaptively resizes read_buffer based on range_request_statistics + void AddStatistics(idx_t read_offset, idx_t read_length, idx_t read_duration); + void AdaptReadBufferSize(idx_t next_read_offset); void AddHeaders(HTTPHeaders &map); @@ -95,6 +100,22 @@ class HTTPFileHandle : public FileHandle { // Return the client for re-use void StoreClient(unique_ptr client); + // Whether to bypass the read buffer + bool SkipBuffer() const { + return flags.DirectIO() || flags.RequireParallelAccess(); + } + +private: + void AllocateReadBuffer(optional_ptr opener); + + // Statistics that are used to adaptively grow the read_buffer + struct RangeRequestStatistics { + idx_t offset; + idx_t length; + idx_t duration; + }; + vector range_request_statistics; + public: void Close() override { } @@ -108,6 +129,9 @@ class HTTPFileHandle : public FileHandle { //! TODO: make base function virtual? void TryAddLogger(FileOpener &opener); + virtual void InitializeFromCacheEntry(const HTTPMetadataCacheEntry &cache_entry); + virtual HTTPMetadataCacheEntry GetCacheEntry() const; + public: //! Fully downloads a file void FullDownload(HTTPFileSystem &hfs, bool &should_write_cache); diff --git a/src/include/s3fs.hpp b/src/include/s3fs.hpp index b848d2c..d4de5b2 100644 --- a/src/include/s3fs.hpp +++ b/src/include/s3fs.hpp @@ -20,6 +20,33 @@ namespace duckdb { +class S3KeyValueReader { +public: + S3KeyValueReader(FileOpener &opener_p, optional_ptr info, const char **secret_types, + idx_t secret_types_len); + + template + SettingLookupResult TryGetSecretKeyOrSetting(const string &secret_key, const string &setting_name, TYPE &result) { + Value temp_result; + auto setting_scope = reader.TryGetSecretKeyOrSetting(secret_key, setting_name, temp_result); + if (!temp_result.IsNull() && + !(setting_scope.GetScope() == SettingScope::GLOBAL && !use_env_variables_for_secret_settings)) { + result = temp_result.GetValue(); + } + return setting_scope; + } + + template + SettingLookupResult TryGetSecretKey(const string &secret_key, TYPE &value_out) { + // TryGetSecretKey never returns anything from global scope, so we don't need to check + return reader.TryGetSecretKey(secret_key, value_out); + } + +private: + bool use_env_variables_for_secret_settings; + KeyValueSecretReader reader; +}; + struct S3AuthParams { string region; string access_key_id; @@ -34,6 +61,11 @@ struct S3AuthParams { string oauth2_bearer_token; // OAuth2 bearer token for GCS static S3AuthParams ReadFrom(optional_ptr opener, FileOpenerInfo &info); + static S3AuthParams ReadFrom(S3KeyValueReader &secret_reader, const std::string &file_path); + void SetRegion(string region_p); + +private: + void InitializeEndpoint(); }; struct AWSEnvironmentCredentialsProvider { @@ -57,14 +89,14 @@ struct AWSEnvironmentCredentialsProvider { }; struct ParsedS3Url { - const string http_proto; - const string prefix; - const string host; - const string bucket; - const string key; - const string path; - const string query_param; - const string trimmed_s3_url; + string http_proto; + string prefix; + string host; + string bucket; + string key; + string path; + string query_param; + string trimmed_s3_url; string GetHTTPUrl(S3AuthParams &auth_params, const string &http_query_string = ""); }; @@ -112,21 +144,12 @@ class S3FileHandle : public HTTPFileHandle { public: S3FileHandle(FileSystem &fs, const OpenFileInfo &file, FileOpenFlags flags, unique_ptr http_params_p, - const S3AuthParams &auth_params_p, const S3ConfigParams &config_params_p) - : HTTPFileHandle(fs, file, flags, std::move(http_params_p)), auth_params(auth_params_p), - config_params(config_params_p), uploads_in_progress(0), parts_uploaded(0), upload_finalized(false), - uploader_has_error(false), upload_exception(nullptr) { - auto_fallback_to_full_file_download = false; - if (flags.OpenForReading() && flags.OpenForWriting()) { - throw NotImplementedException("Cannot open an HTTP file for both reading and writing"); - } else if (flags.OpenForAppending()) { - throw NotImplementedException("Cannot open an HTTP file for appending"); - } - } + const S3AuthParams &auth_params_p, const S3ConfigParams &config_params_p); ~S3FileHandle() override; S3AuthParams auth_params; const S3ConfigParams config_params; + bool initialized_multipart_upload {false}; public: void Close() override; @@ -134,6 +157,10 @@ class S3FileHandle : public HTTPFileHandle { shared_ptr GetBuffer(uint16_t write_buffer_idx); +protected: + void InitializeFromCacheEntry(const HTTPMetadataCacheEntry &cache_entry) override; + HTTPMetadataCacheEntry GetCacheEntry() const override; + protected: string multipart_upload_id; size_t part_size; @@ -206,18 +233,19 @@ class S3FileSystem : public HTTPFileSystem { void FlushAllBuffers(S3FileHandle &handle); void ReadQueryParams(const string &url_query_param, S3AuthParams ¶ms); - static ParsedS3Url S3UrlParse(string url, S3AuthParams ¶ms); + static ParsedS3Url S3UrlParse(string url, const S3AuthParams ¶ms); static string UrlEncode(const string &input, bool encode_slash = false); static string UrlDecode(string input); + static string TryGetPrefix(const string &url); + // Uploads the contents of write_buffer to S3. // Note: caller is responsible to not call this method twice on the same buffer static void UploadBuffer(S3FileHandle &file_handle, shared_ptr write_buffer); - - vector Glob(const string &glob_pattern, FileOpener *opener = nullptr) override; - bool ListFiles(const string &directory, const std::function &callback, - FileOpener *opener = nullptr) override; + static void UploadSingleBuffer(S3FileHandle &file_handle, shared_ptr write_buffer); + static void UploadBufferImplementation(S3FileHandle &file_handle, shared_ptr write_buffer, + string query_param, bool direct_throw); //! Wrapper around BufferManager::Allocate to limit the number of buffers BufferHandle Allocate(idx_t part_size, uint16_t max_threads); @@ -227,13 +255,28 @@ class S3FileSystem : public HTTPFileSystem { return true; } - static string GetS3BadRequestError(S3AuthParams &s3_auth_params); - static string GetS3AuthError(S3AuthParams &s3_auth_params); - static string GetGCSAuthError(S3AuthParams &s3_auth_params); - static HTTPException GetS3Error(S3AuthParams &s3_auth_params, const HTTPResponse &response, const string &url); + static string GetS3BadRequestError(const S3AuthParams &s3_auth_params, string correct_region = ""); + static string ParseS3Error(const string &error); + static string GetS3AuthError(const S3AuthParams &s3_auth_params); + static string GetGCSAuthError(const S3AuthParams &s3_auth_params); + static HTTPException GetS3Error(const S3AuthParams &s3_auth_params, const HTTPResponse &response, + const string &url); + +protected: + bool ListFilesExtended(const string &directory, const std::function &callback, + optional_ptr opener) override; + bool SupportsListFilesExtended() const override { + return true; + } + unique_ptr GlobFilesExtended(const string &path, const FileGlobInput &input, + optional_ptr opener) override; + bool SupportsGlobExtended() const override { + return true; + } protected: static void NotifyUploadsInProgress(S3FileHandle &file_handle); + static string GetPrefix(const string &url); duckdb::unique_ptr CreateHandle(const OpenFileInfo &file, FileOpenFlags flags, optional_ptr opener) override; @@ -245,10 +288,14 @@ class S3FileSystem : public HTTPFileSystem { // Helper class to do s3 ListObjectV2 api call https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html struct AWSListObjectV2 { - static string Request(string &path, HTTPParams &http_params, S3AuthParams &s3_auth_params, - string &continuation_token, optional_ptr state, bool use_delimiter = false); + static string Request(const string &path, HTTPParams &http_params, S3AuthParams &s3_auth_params, + string &continuation_token, optional_idx max_keys = optional_idx()); static void ParseFileList(string &aws_response, vector &result); static vector ParseCommonPrefix(string &aws_response); static string ParseContinuationToken(string &aws_response); }; + +HTTPHeaders CreateS3Header(string url, string query, string host, string service, string method, + const S3AuthParams &auth_params, string date_now = "", string datetime_now = "", + string payload_hash = "", string content_type = ""); } // namespace duckdb diff --git a/src/include/string_utils.hpp b/src/include/string_utils.hpp new file mode 100644 index 0000000..7d64365 --- /dev/null +++ b/src/include/string_utils.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "duckdb/common/string.hpp" + +namespace duckdb { + +// Encode URL string. +string EncodeURL(const string &url); + +} // namespace duckdb diff --git a/src/s3fs.cpp b/src/s3fs.cpp index 5e093cb..5c618c0 100644 --- a/src/s3fs.cpp +++ b/src/s3fs.cpp @@ -2,7 +2,6 @@ #include "crypto.hpp" #include "duckdb.hpp" -#ifndef DUCKDB_AMALGAMATION #include "duckdb/common/exception/http_exception.hpp" #include "duckdb/logging/log_type.hpp" #include "duckdb/logging/file_system_logger.hpp" @@ -11,12 +10,12 @@ #include "duckdb/common/types/timestamp.hpp" #include "duckdb/function/scalar/strftime_format.hpp" #include "http_state.hpp" -#endif #include "duckdb/common/string_util.hpp" #include "duckdb/function/scalar/string_common.hpp" #include "duckdb/main/secret/secret_manager.hpp" #include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/common/multi_file/multi_file_list.hpp" #include "create_secret_functions.hpp" @@ -28,9 +27,9 @@ namespace duckdb { -static HTTPHeaders create_s3_header(string url, string query, string host, string service, string method, - const S3AuthParams &auth_params, string date_now = "", string datetime_now = "", - string payload_hash = "", string content_type = "") { +HTTPHeaders CreateS3Header(string url, string query, string host, string service, string method, + const S3AuthParams &auth_params, string date_now, string datetime_now, string payload_hash, + string content_type) { HTTPHeaders res; res["Host"] = host; @@ -184,15 +183,47 @@ S3AuthParams AWSEnvironmentCredentialsProvider::CreateParams() { } S3AuthParams S3AuthParams::ReadFrom(optional_ptr opener, FileOpenerInfo &info) { - auto result = S3AuthParams(); // Without a FileOpener we can not access settings nor secrets: return empty auth params if (!opener) { - return result; + return {}; } const char *secret_types[] = {"s3", "r2", "gcs", "aws"}; - KeyValueSecretReader secret_reader(*opener, info, secret_types, 3); + S3KeyValueReader secret_reader(*opener, info, secret_types, 3); + + return ReadFrom(secret_reader, info.file_path); +} + +bool EndpointIsAWS(const string &endpoint) { + if (endpoint.empty()) { + // default (empty) endpoint is AWS + return true; + } + if (StringUtil::StartsWith(endpoint, "s3.") && StringUtil::EndsWith(endpoint, ".amazonaws.com")) { + return true; + } + return false; +} + +void S3AuthParams::InitializeEndpoint() { + if (!EndpointIsAWS(endpoint)) { + return; + } + if (region.empty()) { + if (access_key_id.empty()) { + // no access key and no region - use legacy global endpoint + endpoint = "s3.amazonaws.com"; + return; + } + // access key but no region - default to us-east-1 + region = "us-east-1"; + } + endpoint = StringUtil::Format("s3.%s.amazonaws.com", region); +} + +S3AuthParams S3AuthParams::ReadFrom(S3KeyValueReader &secret_reader, const string &file_path) { + auto result = S3AuthParams(); // These settings we just set or leave to their S3AuthParams default value secret_reader.TryGetSecretKeyOrSetting("region", "s3_region", result.region); @@ -205,12 +236,11 @@ S3AuthParams S3AuthParams::ReadFrom(optional_ptr opener, FileOpenerI secret_reader.TryGetSecretKeyOrSetting("s3_url_compatibility_mode", "s3_url_compatibility_mode", result.s3_url_compatibility_mode); secret_reader.TryGetSecretKeyOrSetting("requester_pays", "s3_requester_pays", result.requester_pays); - // Endpoint and url style are slightly more complex and require special handling for gcs and r2 auto endpoint_result = secret_reader.TryGetSecretKeyOrSetting("endpoint", "s3_endpoint", result.endpoint); auto url_style_result = secret_reader.TryGetSecretKeyOrSetting("url_style", "s3_url_style", result.url_style); - if (StringUtil::StartsWith(info.file_path, "gcs://") || StringUtil::StartsWith(info.file_path, "gs://")) { + if (StringUtil::StartsWith(file_path, "gcs://") || StringUtil::StartsWith(file_path, "gs://")) { // For GCS urls we force the endpoint and vhost path style, allowing only to be overridden by secrets if (result.endpoint.empty() || endpoint_result.GetScope() != SettingScope::SECRET) { result.endpoint = "storage.googleapis.com"; @@ -221,16 +251,16 @@ S3AuthParams S3AuthParams::ReadFrom(optional_ptr opener, FileOpenerI // Read bearer token for GCS secret_reader.TryGetSecretKey("bearer_token", result.oauth2_bearer_token); } - - if (!result.region.empty() && (result.endpoint.empty() || result.endpoint == "s3.amazonaws.com")) { - result.endpoint = StringUtil::Format("s3.%s.amazonaws.com", result.region); - } else if (result.endpoint.empty()) { - result.endpoint = "s3.amazonaws.com"; - } + result.InitializeEndpoint(); return result; } +void S3AuthParams::SetRegion(string new_region) { + region = std::move(new_region); + InitializeEndpoint(); +} + unique_ptr CreateSecret(vector &prefix_paths_p, string &type, string &provider, string &name, S3AuthParams ¶ms) { auto return_value = make_uniq(prefix_paths_p, type, provider, name); @@ -257,6 +287,26 @@ unique_ptr CreateSecret(vector &prefix_paths_p, string & return return_value; } +S3FileHandle::S3FileHandle(FileSystem &fs, const OpenFileInfo &file, FileOpenFlags flags, + unique_ptr http_params_p, const S3AuthParams &auth_params_p, + const S3ConfigParams &config_params_p) + : HTTPFileHandle(fs, file, flags, std::move(http_params_p)), auth_params(auth_params_p), + config_params(config_params_p), uploads_in_progress(0), parts_uploaded(0), upload_finalized(false), + uploader_has_error(false), upload_exception(nullptr) { + auto_fallback_to_full_file_download = false; + if (flags.OpenForReading() && flags.OpenForWriting()) { + throw NotImplementedException("Cannot open an HTTP file for both reading and writing"); + } else if (flags.OpenForAppending()) { + throw NotImplementedException("Cannot open an HTTP file for appending"); + } + if (file.extended_info) { + auto entry = file.extended_info->options.find("s3_region"); + if (entry != file.extended_info->options.end()) { + auth_params.SetRegion(entry->second.ToString()); + } + } +} + S3FileHandle::~S3FileHandle() { if (Exception::UncaughtException()) { // We are in an exception, don't do anything @@ -323,7 +373,7 @@ string S3FileSystem::InitializeMultipartUpload(S3FileHandle &file_handle) { auto res = s3fs.PostRequest(file_handle, file_handle.path, {}, result, nullptr, 0, query_param); if (res->status != HTTPStatusCode::OK_200) { - throw HTTPException(*res, "Unable to connect to URL %s: %s (HTTP code %d)", res->url, res->GetError(), + throw HTTPException(*res, "Unable to connect to URL %s: %s (HTTP code %d)", file_handle.path, res->GetError(), static_cast(res->status)); } @@ -336,6 +386,8 @@ string S3FileSystem::InitializeMultipartUpload(S3FileHandle &file_handle) { open_tag_pos += 10; // Skip open tag + file_handle.initialized_multipart_upload = true; + return result.substr(open_tag_pos, close_tag_pos - open_tag_pos); } @@ -353,10 +405,22 @@ void S3FileSystem::NotifyUploadsInProgress(S3FileHandle &file_handle) { } void S3FileSystem::UploadBuffer(S3FileHandle &file_handle, shared_ptr write_buffer) { - auto &s3fs = (S3FileSystem &)file_handle.file_system; - string query_param = "partNumber=" + to_string(write_buffer->part_no + 1) + "&" + "uploadId=" + S3FileSystem::UrlEncode(file_handle.multipart_upload_id, true); + + UploadBufferImplementation(file_handle, write_buffer, query_param, false); + + NotifyUploadsInProgress(file_handle); +} + +void S3FileSystem::UploadSingleBuffer(S3FileHandle &file_handle, shared_ptr write_buffer) { + UploadBufferImplementation(file_handle, write_buffer, "", true); +} + +void S3FileSystem::UploadBufferImplementation(S3FileHandle &file_handle, shared_ptr write_buffer, + string query_param, bool single_upload) { + auto &s3fs = (S3FileSystem &)file_handle.file_system; + unique_ptr res; string etag; @@ -365,8 +429,8 @@ void S3FileSystem::UploadBuffer(S3FileHandle &file_handle, shared_ptrstatus != HTTPStatusCode::OK_200) { - throw HTTPException(*res, "Unable to connect to URL %s: %s (HTTP code %d)", res->url, res->GetError(), - static_cast(res->status)); + throw HTTPException(*res, "Unable to connect to URL %s: %s (HTTP code %d)", file_handle.path, + res->GetError(), static_cast(res->status)); } if (!res->headers.HasHeader("ETag")) { @@ -374,6 +438,9 @@ void S3FileSystem::UploadBuffer(S3FileHandle &file_handle, shared_ptrheaders.GetHeaderValue("ETag"); } catch (std::exception &ex) { + if (single_upload) { + throw; + } ErrorData error(ex); if (error.Type() != ExceptionType::IO && error.Type() != ExceptionType::HTTP) { throw; @@ -385,6 +452,7 @@ void S3FileSystem::UploadBuffer(S3FileHandle &file_handle, shared_ptr write_buffer) { @@ -437,6 +503,9 @@ void S3FileSystem::FlushBuffer(S3FileHandle &file_handle, shared_ptruploading) { @@ -475,6 +555,10 @@ void S3FileSystem::FlushAllBuffers(S3FileHandle &file_handle) { void S3FileSystem::FinalizeMultipartUpload(S3FileHandle &file_handle) { auto &s3fs = (S3FileSystem &)file_handle.file_system; + if (file_handle.upload_finalized) { + return; + } + file_handle.upload_finalized = true; std::stringstream ss; @@ -591,18 +675,25 @@ void S3FileSystem::ReadQueryParams(const string &url_query_param, S3AuthParams & } } -static string GetPrefix(string url) { +string S3FileSystem::TryGetPrefix(const string &url) { const string prefixes[] = {"s3://", "s3a://", "s3n://", "gcs://", "gs://", "r2://"}; for (auto &prefix : prefixes) { - if (StringUtil::StartsWith(url, prefix)) { + if (StringUtil::StartsWith(StringUtil::Lower(url), prefix)) { return prefix; } } - throw IOException("URL needs to start with s3://, gcs:// or r2://"); - return string(); + return {}; +} + +string S3FileSystem::GetPrefix(const string &url) { + auto prefix = TryGetPrefix(url); + if (prefix.empty()) { + throw IOException("URL needs to start with s3://, gcs:// or r2://"); + } + return prefix; } -ParsedS3Url S3FileSystem::S3UrlParse(string url, S3AuthParams ¶ms) { +ParsedS3Url S3FileSystem::S3UrlParse(string url, const S3AuthParams ¶ms) { string http_proto, prefix, host, bucket, key, path, query_param, trimmed_s3_url; prefix = GetPrefix(url); @@ -708,8 +799,8 @@ unique_ptr S3FileSystem::PostRequest(FileHandle &handle, string ur } else { // Use existing S3 authentication auto payload_hash = GetPayloadHash(buffer_in, buffer_in_len); - headers = create_s3_header(parsed_s3_url.path, http_params, parsed_s3_url.host, "s3", "POST", auth_params, "", - "", payload_hash, "application/octet-stream"); + headers = CreateS3Header(parsed_s3_url.path, http_params, parsed_s3_url.host, "s3", "POST", auth_params, "", "", + payload_hash, "application/octet-stream"); } return HTTPFileSystem::PostRequest(handle, http_url, headers, result, buffer_in, buffer_in_len); @@ -731,8 +822,8 @@ unique_ptr S3FileSystem::PutRequest(FileHandle &handle, string url } else { // Use existing S3 authentication auto payload_hash = GetPayloadHash(buffer_in, buffer_in_len); - headers = create_s3_header(parsed_s3_url.path, http_params, parsed_s3_url.host, "s3", "PUT", auth_params, "", - "", payload_hash, content_type); + headers = CreateS3Header(parsed_s3_url.path, http_params, parsed_s3_url.host, "s3", "PUT", auth_params, "", "", + payload_hash, content_type); } return HTTPFileSystem::PutRequest(handle, http_url, headers, buffer_in, buffer_in_len); @@ -750,8 +841,7 @@ unique_ptr S3FileSystem::HeadRequest(FileHandle &handle, string s3 headers["Host"] = parsed_s3_url.host; } else { // Use existing S3 authentication - headers = - create_s3_header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "HEAD", auth_params, "", "", "", ""); + headers = CreateS3Header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "HEAD", auth_params, "", "", "", ""); } return HTTPFileSystem::HeadRequest(handle, http_url, headers); @@ -769,8 +859,7 @@ unique_ptr S3FileSystem::GetRequest(FileHandle &handle, string s3_ headers["Host"] = parsed_s3_url.host; } else { // Use existing S3 authentication - headers = - create_s3_header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "GET", auth_params, "", "", "", ""); + headers = CreateS3Header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "GET", auth_params, "", "", "", ""); } return HTTPFileSystem::GetRequest(handle, http_url, headers); @@ -789,8 +878,7 @@ unique_ptr S3FileSystem::GetRangeRequest(FileHandle &handle, strin headers["Host"] = parsed_s3_url.host; } else { // Use existing S3 authentication - headers = - create_s3_header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "GET", auth_params, "", "", "", ""); + headers = CreateS3Header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "GET", auth_params, "", "", "", ""); } return HTTPFileSystem::GetRangeRequest(handle, http_url, headers, file_offset, buffer_out, buffer_out_len); @@ -809,7 +897,7 @@ unique_ptr S3FileSystem::DeleteRequest(FileHandle &handle, string } else { // Use existing S3 authentication headers = - create_s3_header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "DELETE", auth_params, "", "", "", ""); + CreateS3Header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "DELETE", auth_params, "", "", "", ""); } return HTTPFileSystem::DeleteRequest(handle, http_url, headers); @@ -831,6 +919,22 @@ unique_ptr S3FileSystem::CreateHandle(const OpenFileInfo &file, S3ConfigParams::ReadFrom(opener)); } +void S3FileHandle::InitializeFromCacheEntry(const HTTPMetadataCacheEntry &cache_entry) { + HTTPFileHandle::InitializeFromCacheEntry(cache_entry); + auto entry = cache_entry.properties.find("s3_region"); + if (entry != cache_entry.properties.end()) { + auth_params.SetRegion(entry->second); + } +} + +HTTPMetadataCacheEntry S3FileHandle::GetCacheEntry() const { + auto result = HTTPFileHandle::GetCacheEntry(); + if (!auth_params.region.empty()) { + result.properties["s3_region"] = auth_params.region; + } + return result; +} + void S3FileHandle::Initialize(optional_ptr opener) { try { HTTPFileHandle::Initialize(opener); @@ -838,6 +942,7 @@ void S3FileHandle::Initialize(optional_ptr opener) { ErrorData error(ex); bool refreshed_secret = false; if (error.Type() == ExceptionType::IO || error.Type() == ExceptionType::HTTP) { + // legacy endpoint (no region) returns 400 auto context = opener->TryGetClientContext(); if (context) { auto transaction = CatalogTransaction::GetSystemCatalogTransaction(*context); @@ -849,14 +954,16 @@ void S3FileHandle::Initialize(optional_ptr opener) { } } } + string correct_region; if (!refreshed_secret) { auto &extra_info = error.ExtraInfo(); auto entry = extra_info.find("status_code"); if (entry != extra_info.end()) { - if (entry->second == "400") { - // 400: BAD REQUEST - auto extra_text = S3FileSystem::GetS3BadRequestError(auth_params); - throw Exception(error.Type(), error.RawMessage() + extra_text, extra_info); + if (entry->second == "301" || entry->second == "400") { + auto new_region = extra_info.find("header_x-amz-bucket-region"); + if (new_region != extra_info.end()) { + correct_region = new_region->second; + } } if (entry->second == "403") { // 403: FORBIDDEN @@ -866,19 +973,27 @@ void S3FileHandle::Initialize(optional_ptr opener) { } else { extra_text = S3FileSystem::GetS3AuthError(auth_params); } - throw Exception(error.Type(), error.RawMessage() + extra_text, extra_info); + throw Exception(extra_info, error.Type(), error.RawMessage() + extra_text); } } - throw; + if (correct_region.empty()) { + throw; + } } // We have succesfully refreshed a secret: retry initializing with new credentials FileOpenerInfo info = {path}; auth_params = S3AuthParams::ReadFrom(opener, info); + if (!correct_region.empty()) { + DUCKDB_LOG_WARNING( + logger, + "Read S3 file \"%s\" from incorrect region \"%s\" - retrying with updated region \"%s\".\n" + "Consider setting the S3 region to this explicitly to avoid extra round-trips.", + path, auth_params.region, correct_region); + auth_params.SetRegion(std::move(correct_region)); + } HTTPFileHandle::Initialize(opener); } - auto &s3fs = file_system.Cast(); - if (flags.OpenForWriting()) { auto aws_minimum_part_size = 5242880; // 5 MiB https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html auto max_part_count = config_params.max_parts_per_file; @@ -889,8 +1004,6 @@ void S3FileHandle::Initialize(optional_ptr opener) { part_size = ((minimum_part_size + Storage::DEFAULT_BLOCK_SIZE - 1) / Storage::DEFAULT_BLOCK_SIZE) * Storage::DEFAULT_BLOCK_SIZE; D_ASSERT(part_size * max_part_count >= config_params.max_file_size); - - multipart_upload_id = s3fs.InitializeMultipartUpload(*this); } } @@ -904,13 +1017,13 @@ bool S3FileSystem::CanHandleFile(const string &fpath) { void S3FileSystem::RemoveFile(const string &path, optional_ptr opener) { auto handle = OpenFile(path, FileFlags::FILE_FLAGS_NULL_IF_NOT_EXISTS, opener); if (!handle) { - throw IOException("Could not remove file \"%s\": %s", {{"errno", "404"}}, path, "No such file or directory"); + throw IOException({{"errno", "404"}}, "Could not remove file \"%s\": %s", path, "No such file or directory"); } auto &s3fh = handle->Cast(); auto res = DeleteRequest(*handle, s3fh.path, {}); if (res->status != HTTPStatusCode::OK_200 && res->status != HTTPStatusCode::NoContent_204) { - throw IOException("Could not remove file \"%s\": %s", {{"errno", to_string(static_cast(res->status))}}, + throw IOException({{"errno", to_string(static_cast(res->status))}}, "Could not remove file \"%s\": %s", path, res->GetError()); } } @@ -971,6 +1084,7 @@ void S3FileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx FlushBuffer(s3fh, write_buffer); } s3fh.file_offset += bytes_to_write; + s3fh.length += bytes_to_write; bytes_written += bytes_to_write; } @@ -1002,71 +1116,116 @@ static bool Match(vector::const_iterator key, vector::const_iter return key == key_end && pattern == pattern_end; } -vector S3FileSystem::Glob(const string &glob_pattern, FileOpener *opener) { - if (opener == nullptr) { +struct S3GlobResult : public LazyMultiFileList { +public: + S3GlobResult(S3FileSystem &fs, const string &path, optional_ptr opener); + +protected: + bool ExpandNextPath() const override; + +private: + string glob_pattern; + optional_ptr opener; + mutable bool finished = false; + mutable S3AuthParams s3_auth_params; + string shared_path; + ParsedS3Url parsed_s3_url; + mutable string main_continuation_token; + mutable string current_common_prefix; + mutable string common_prefix_continuation_token; + mutable vector common_prefixes; +}; + +S3GlobResult::S3GlobResult(S3FileSystem &fs, const string &glob_pattern_p, optional_ptr opener) + : glob_pattern(glob_pattern_p), opener(opener) { + if (!opener) { throw InternalException("Cannot S3 Glob without FileOpener"); } - FileOpenerInfo info = {glob_pattern}; // Trim any query parameters from the string - S3AuthParams s3_auth_params = S3AuthParams::ReadFrom(opener, info); + s3_auth_params = S3AuthParams::ReadFrom(opener, info); // In url compatibility mode, we ignore globs allowing users to query files with the glob chars if (s3_auth_params.s3_url_compatibility_mode) { - return {glob_pattern}; + expanded_files.emplace_back(glob_pattern); + finished = true; + return; } - auto parsed_s3_url = S3UrlParse(glob_pattern, s3_auth_params); + parsed_s3_url = fs.S3UrlParse(glob_pattern, s3_auth_params); auto parsed_glob_url = parsed_s3_url.trimmed_s3_url; // AWS matches on prefix, not glob pattern, so we take a substring until the first wildcard char for the aws calls auto first_wildcard_pos = parsed_glob_url.find_first_of("*[\\"); if (first_wildcard_pos == string::npos) { - return {glob_pattern}; + expanded_files.emplace_back(glob_pattern); + finished = true; + return; } - string shared_path = parsed_glob_url.substr(0, first_wildcard_pos); + shared_path = parsed_glob_url.substr(0, first_wildcard_pos); + + fs.ReadQueryParams(parsed_s3_url.query_param, s3_auth_params); +} + +bool S3GlobResult::ExpandNextPath() const { + if (finished) { + return false; + } + + FileOpenerInfo info = {glob_pattern}; auto http_util = HTTPFSUtil::GetHTTPUtil(opener); auto http_params = http_util->InitializeParameters(opener, info); - ReadQueryParams(parsed_s3_url.query_param, s3_auth_params); - - // Do main listobjectsv2 request vector s3_keys; - string main_continuation_token; - - // Main paging loop - do { - // main listobject call, may - string response_str = AWSListObjectV2::Request(shared_path, *http_params, s3_auth_params, - main_continuation_token, HTTPState::TryGetState(opener).get()); + if (!current_common_prefix.empty()) { + // we have common prefixes left to scan - perform the request + auto prefix_path = parsed_s3_url.prefix + parsed_s3_url.bucket + '/' + current_common_prefix; + + auto prefix_res = + AWSListObjectV2::Request(prefix_path, *http_params, s3_auth_params, common_prefix_continuation_token); + AWSListObjectV2::ParseFileList(prefix_res, s3_keys); + auto more_prefixes = AWSListObjectV2::ParseCommonPrefix(prefix_res); + common_prefixes.insert(common_prefixes.end(), more_prefixes.begin(), more_prefixes.end()); + common_prefix_continuation_token = AWSListObjectV2::ParseContinuationToken(prefix_res); + if (common_prefix_continuation_token.empty()) { + // we are done with the current common prefix + // either move on to the next one, or finish up + if (common_prefixes.empty()) { + // done - we need to do a top-level request again next + current_common_prefix = string(); + } else { + // process the next prefix + current_common_prefix = common_prefixes.back(); + common_prefixes.pop_back(); + } + } + } else { + if (!common_prefixes.empty()) { + throw InternalException("We have common prefixes but we are doing a top-level request"); + } + // issue the main request + string response_str = + AWSListObjectV2::Request(shared_path, *http_params, s3_auth_params, main_continuation_token); main_continuation_token = AWSListObjectV2::ParseContinuationToken(response_str); AWSListObjectV2::ParseFileList(response_str, s3_keys); - // Repeat requests until the keys of all common prefixes are parsed. - auto common_prefixes = AWSListObjectV2::ParseCommonPrefix(response_str); - while (!common_prefixes.empty()) { - auto prefix_path = parsed_s3_url.prefix + parsed_s3_url.bucket + '/' + common_prefixes.back(); + // parse the list of common prefixes + common_prefixes = AWSListObjectV2::ParseCommonPrefix(response_str); + if (!common_prefixes.empty()) { + // we have common prefixes - set one up for the next request + current_common_prefix = common_prefixes.back(); common_prefixes.pop_back(); - - // TODO we could optimize here by doing a match on the prefix, if it doesn't match we can skip this prefix - // Paging loop for common prefix requests - string common_prefix_continuation_token; - do { - auto prefix_res = - AWSListObjectV2::Request(prefix_path, *http_params, s3_auth_params, - common_prefix_continuation_token, HTTPState::TryGetState(opener).get()); - AWSListObjectV2::ParseFileList(prefix_res, s3_keys); - auto more_prefixes = AWSListObjectV2::ParseCommonPrefix(prefix_res); - common_prefixes.insert(common_prefixes.end(), more_prefixes.begin(), more_prefixes.end()); - common_prefix_continuation_token = AWSListObjectV2::ParseContinuationToken(prefix_res); - } while (!common_prefix_continuation_token.empty()); } - } while (!main_continuation_token.empty()); + } + + if (main_continuation_token.empty() && current_common_prefix.empty()) { + // we are done + finished = true; + } vector pattern_splits = StringUtil::Split(parsed_s3_url.key, "/"); - vector result; for (auto &s3_key : s3_keys) { vector key_splits = StringUtil::Split(s3_key.path, "/"); @@ -1079,44 +1238,80 @@ vector S3FileSystem::Glob(const string &glob_pattern, FileOpener * result_full_url += '?' + parsed_s3_url.query_param; } s3_key.path = std::move(result_full_url); - result.push_back(std::move(s3_key)); + if (!s3_auth_params.region.empty()) { + s3_key.extended_info->options["s3_region"] = s3_auth_params.region; + } + expanded_files.push_back(std::move(s3_key)); } } - return result; + return true; +} + +unique_ptr S3FileSystem::GlobFilesExtended(const string &path, const FileGlobInput &input, + optional_ptr opener) { + return make_uniq(*this, path, opener); } string S3FileSystem::GetName() const { return "S3FileSystem"; } -bool S3FileSystem::ListFiles(const string &directory, const std::function &callback, - FileOpener *opener) { +bool S3FileSystem::ListFilesExtended(const string &directory, const std::function &callback, + optional_ptr opener) { string trimmed_dir = directory; - StringUtil::RTrim(trimmed_dir, PathSeparator(trimmed_dir)); - auto glob_res = Glob(JoinPath(trimmed_dir, "**"), opener); + auto sep = PathSeparator(trimmed_dir); + StringUtil::RTrim(trimmed_dir, sep); + auto glob_res = GlobFilesExtended(JoinPath(trimmed_dir, "**"), FileGlobOptions::ALLOW_EMPTY, opener); - if (glob_res.empty()) { + if (!glob_res || glob_res->GetExpandResult() == FileExpandResult::NO_FILES) { return false; } + auto base_path = trimmed_dir + sep; - for (const auto &file : glob_res) { - callback(file.path, false); + for (auto file : glob_res->Files()) { + if (!StringUtil::StartsWith(file.path, base_path)) { + throw InvalidInputException( + "Globbed directory \"%s\", but found file \"%s\" that does not start with base path \"%s\"", directory, + file.path, base_path); + } + file.path = file.path.substr(base_path.size()); + callback(file); } return true; } -string S3FileSystem::GetS3BadRequestError(S3AuthParams &s3_auth_params) { +optional_idx FindTagContents(const string &response, const string &tag, idx_t cur_pos, string &result) { + string open_tag = "<" + tag + ">"; + string close_tag = ""; + auto open_tag_pos = response.find(open_tag, cur_pos); + if (open_tag_pos == string::npos) { + // tag not found + return optional_idx(); + } + auto close_tag_pos = response.find(close_tag, open_tag_pos + open_tag.size()); + if (close_tag_pos == string::npos) { + throw InternalException("Failed to parse S3 result: found open tag for %s but did not find matching close tag", + tag); + } + result = response.substr(open_tag_pos + open_tag.size(), close_tag_pos - open_tag_pos - open_tag.size()); + return close_tag_pos + close_tag.size(); +} + +string S3FileSystem::GetS3BadRequestError(const S3AuthParams &s3_auth_params, string correct_region) { string extra_text = "\n\nBad Request - this can be caused by the S3 region being set incorrectly."; if (s3_auth_params.region.empty()) { extra_text += "\n* No region is provided."; } else { - extra_text += "\n* Provided region is \"" + s3_auth_params.region + "\""; + extra_text += "\n* Provided region is: \"" + s3_auth_params.region + "\""; + } + if (!correct_region.empty()) { + extra_text += "\n* Correct region is: \"" + correct_region + "\""; } return extra_text; } -string S3FileSystem::GetS3AuthError(S3AuthParams &s3_auth_params) { +string S3FileSystem::GetS3AuthError(const S3AuthParams &s3_auth_params) { string extra_text = "\n\nAuthentication Failure - this is usually caused by invalid or missing credentials."; if (s3_auth_params.secret_access_key.empty() && s3_auth_params.access_key_id.empty()) { extra_text += "\n* No credentials are provided."; @@ -1127,7 +1322,7 @@ string S3FileSystem::GetS3AuthError(S3AuthParams &s3_auth_params) { return extra_text; } -string S3FileSystem::GetGCSAuthError(S3AuthParams &s3_auth_params) { +string S3FileSystem::GetGCSAuthError(const S3AuthParams &s3_auth_params) { string extra_text = "\n\nAuthentication Failure - GCS authentication failed."; if (s3_auth_params.oauth2_bearer_token.empty() && s3_auth_params.secret_access_key.empty() && s3_auth_params.access_key_id.empty()) { @@ -1144,17 +1339,59 @@ string S3FileSystem::GetGCSAuthError(S3AuthParams &s3_auth_params) { return extra_text; } -HTTPException S3FileSystem::GetS3Error(S3AuthParams &s3_auth_params, const HTTPResponse &response, const string &url) { - string extra_text; +string S3FileSystem::ParseS3Error(const string &error) { + // S3 errors look like this: + // + // NoSuchKey + // The resource you requested does not exist + // /mybucket/myfoto.jpg + // 4442587FB7D0A2F9 + // + if (error.empty()) { + return string(); + } + // find tag + string error_xml; + idx_t err_pos = 0; + auto next_pos = FindTagContents(error, "Error", err_pos, error_xml); + if (!next_pos.IsValid()) { + return string(); + } + // find and + string error_code, error_message, extra_error_data; + idx_t cur_pos = 0; + next_pos = FindTagContents(error_xml, "Code", cur_pos, error_code); + if (!next_pos.IsValid()) { + return string(); + } + cur_pos = 0; + next_pos = FindTagContents(error_xml, "Message", cur_pos, error_message); + if (!next_pos.IsValid()) { + return string(); + } + // depending on Code, find other info + if (error_code == "InvalidAccessKeyId") { + cur_pos = 0; + next_pos = FindTagContents(error_xml, "AWSAccessKeyId", cur_pos, extra_error_data); + if (next_pos.IsValid()) { + extra_error_data = "\nInvalid Access Key: \"" + extra_error_data + "\""; + } + } + return StringUtil::Format("\n\n%s: %s%s", error_code, error_message, extra_error_data); +} + +HTTPException S3FileSystem::GetS3Error(const S3AuthParams &s3_auth_params, const HTTPResponse &response, + const string &url) { + string extra_text = ParseS3Error(response.body); if (response.status == HTTPStatusCode::BadRequest_400) { - extra_text = GetS3BadRequestError(s3_auth_params); + extra_text += GetS3BadRequestError(s3_auth_params); } if (response.status == HTTPStatusCode::Forbidden_403) { - extra_text = GetS3AuthError(s3_auth_params); + extra_text += GetS3AuthError(s3_auth_params); } auto status_message = HTTPFSUtil::GetStatusMessage(response.status); - throw HTTPException(response, "HTTP GET error reading '%s' in region '%s' (HTTP %d %s)%s", url, - s3_auth_params.region, response.status, status_message, extra_text); + return HTTPException(response, "HTTP GET error reading '%s' in region '%s' (HTTP %d %s)%s", url, + s3_auth_params.region, response.status, status_message, extra_text); } HTTPException S3FileSystem::GetHTTPError(FileHandle &handle, const HTTPResponse &response, const string &url) { @@ -1170,71 +1407,99 @@ HTTPException S3FileSystem::GetHTTPError(FileHandle &handle, const HTTPResponse return GetS3Error(s3_handle.auth_params, response, url); } -string AWSListObjectV2::Request(string &path, HTTPParams &http_params, S3AuthParams &s3_auth_params, - string &continuation_token, optional_ptr state, bool use_delimiter) { - auto parsed_url = S3FileSystem::S3UrlParse(path, s3_auth_params); - - // Construct the ListObjectsV2 call - string req_path = parsed_url.path.substr(0, parsed_url.path.length() - parsed_url.key.length()); - string req_params; - if (!continuation_token.empty()) { - req_params += "continuation-token=" + S3FileSystem::UrlEncode(continuation_token, true); - req_params += "&"; - } - req_params += "encoding-type=url&list-type=2"; - req_params += "&prefix=" + S3FileSystem::UrlEncode(parsed_url.key, true); - - if (use_delimiter) { - req_params += "&delimiter=%2F"; - } - - string listobjectv2_url = req_path + "?" + req_params; - - auto header_map = - create_s3_header(req_path, req_params, parsed_url.host, "s3", "GET", s3_auth_params, "", "", "", ""); +string AWSListObjectV2::Request(const string &path, HTTPParams &http_params, S3AuthParams &s3_auth_params, + string &continuation_token, optional_idx max_keys) { + const idx_t MAX_RETRIES = 1; + for (idx_t it = 0; it <= MAX_RETRIES; it++) { + auto parsed_url = S3FileSystem::S3UrlParse(path, s3_auth_params); - // Get requests use fresh connection - string full_host = parsed_url.http_proto + parsed_url.host; - std::stringstream response; - GetRequestInfo get_request( - full_host, listobjectv2_url, header_map, http_params, - [&](const HTTPResponse &response) { - if (static_cast(response.status) >= 400) { - string trimmed_path = path; - StringUtil::RTrim(trimmed_path, "/"); - trimmed_path += listobjectv2_url; - throw S3FileSystem::GetS3Error(s3_auth_params, response, trimmed_path); - } - return true; - }, - [&](const_data_ptr_t data, idx_t data_length) { - response << string(const_char_ptr_cast(data), data_length); - return true; - }); - auto result = http_params.http_util.Request(get_request); - if (result->HasRequestError()) { - throw IOException("%s error for HTTP GET to '%s'", result->GetRequestError(), listobjectv2_url); - } + // Construct the ListObjectsV2 call + string req_path = parsed_url.path.substr(0, parsed_url.path.length() - parsed_url.key.length()); - return response.str(); -} + string req_params; + if (!continuation_token.empty()) { + req_params += "continuation-token=" + S3FileSystem::UrlEncode(continuation_token, true); + req_params += "&"; + } + req_params += "encoding-type=url&list-type=2"; + req_params += "&prefix=" + S3FileSystem::UrlEncode(parsed_url.key, true); + if (max_keys.IsValid()) { + req_params += "&max-keys=" + to_string(max_keys.GetIndex()); + } -optional_idx FindTagContents(const string &response, const string &tag, idx_t cur_pos, string &result) { - string open_tag = "<" + tag + ">"; - string close_tag = ""; - auto open_tag_pos = response.find(open_tag, cur_pos); - if (open_tag_pos == string::npos) { - // tag not found - return optional_idx(); - } - auto close_tag_pos = response.find(close_tag, open_tag_pos + open_tag.size()); - if (close_tag_pos == string::npos) { - throw InternalException("Failed to parse S3 result: found open tag for %s but did not find matching close tag", - tag); + auto header_map = + CreateS3Header(req_path, req_params, parsed_url.host, "s3", "GET", s3_auth_params, "", "", "", ""); + + // Get requests use fresh connection + string full_host = parsed_url.http_proto + parsed_url.host; + string listobjectv2_url = full_host + req_path + "?" + req_params; + std::stringstream response; + ErrorData error; + GetRequestInfo get_request( + full_host, listobjectv2_url, header_map, http_params, + [&](const HTTPResponse &response) { + if (static_cast(response.status) >= 400) { + string trimmed_path = path; + StringUtil::RTrim(trimmed_path, "/"); + error = ErrorData(S3FileSystem::GetS3Error(s3_auth_params, response, trimmed_path)); + } + return true; + }, + [&](const_data_ptr_t data, idx_t data_length) { + response << string(const_char_ptr_cast(data), data_length); + return true; + }); + auto result = http_params.http_util.Request(get_request); + if (result->HasRequestError()) { + throw IOException("%s error for HTTP GET to '%s'", result->GetRequestError(), listobjectv2_url); + } + // check + string updated_bucket_region; + if (result->status == HTTPStatusCode::MovedPermanently_301) { + string moved_error; + if (it == 0 && result->HasHeader("x-amz-bucket-region")) { + auto response_region = result->GetHeaderValue("x-amz-bucket-region"); + if (response_region == s3_auth_params.region) { + moved_error = "suggested region \"" + response_region + + "\" is the same as the region we used to make the request"; + } else { + updated_bucket_region = response_region; + } + } else { + moved_error = "HTTP response did not contain header_x-amz-bucket-region"; + } + if (!moved_error.empty()) { + throw HTTPException(*result, "HTTP 301 response when running glob \"%s\" but %s", path, moved_error); + } + } + if (error.HasError()) { + if (it == 0 && result->HasHeader("x-amz-bucket-region")) { + auto response_region = result->GetHeaderValue("x-amz-bucket-region"); + if (response_region != s3_auth_params.region) { + updated_bucket_region = response_region; + } + } + if (updated_bucket_region.empty()) { + // no updated region found + error.Throw(); + } + } + if (!updated_bucket_region.empty()) { + DUCKDB_LOG_WARNING( + http_params.logger, + "Ran S3 glob \"%s\" from incorrect region \"%s\" - retrying with updated region \"%s\".\n" + "Consider setting the S3 region to this explicitly to avoid extra round-trips.", + path, s3_auth_params.region, updated_bucket_region); + + // bucket region was updated - update and re-run the request against the correct endpoint + s3_auth_params.SetRegion(std::move(updated_bucket_region)); + continue; + } + return response.str(); } - result = response.substr(open_tag_pos + open_tag.size(), close_tag_pos - open_tag_pos - open_tag.size()); - return close_tag_pos + close_tag.size(); + throw InvalidInputException( + "Exceeded retry count in AWSListObjectV2::Request - this means we got multiple redirects to different regions"); } void AWSListObjectV2::ParseFileList(string &aws_response, vector &result) { @@ -1331,4 +1596,13 @@ vector AWSListObjectV2::ParseCommonPrefix(string &aws_response) { return s3_prefixes; } +S3KeyValueReader::S3KeyValueReader(FileOpener &opener_p, optional_ptr info, const char **secret_types, + idx_t secret_types_len) + : reader(opener_p, info, secret_types, secret_types_len) { + Value use_env_vars_for_secret_info_setting; + reader.TryGetSecretKeyOrSetting("enable_global_s3_configuration", "enable_global_s3_configuration", + use_env_vars_for_secret_info_setting); + use_env_variables_for_secret_settings = use_env_vars_for_secret_info_setting.GetValue(); +} + } // namespace duckdb diff --git a/src/string_utils.cpp b/src/string_utils.cpp new file mode 100644 index 0000000..0c3b316 --- /dev/null +++ b/src/string_utils.cpp @@ -0,0 +1,14 @@ +#include "string_utils.hpp" + +namespace duckdb { + +string EncodeSpaces(const string &url) { + string out; + out.reserve(url.size()); + for (char c : url) { + out += (c == ' ') ? "%20" : string(1, c); + } + return out; +} + +} // namespace duckdb