Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changes/next-release/enhancement-s3-85557.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "enhancement",
"category": "``s3``",
"description": "Skip the HEAD request during S3 downloads when the client is configured with ``response_checksum_validation='when_required'``, reducing latency for small-object transfers. The HEAD request remains in place by default to enable full-object checksum validation."
}
282 changes: 267 additions & 15 deletions s3transfer/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,26 @@ def _submit(
:param bandwidth_limiter: The bandwidth limiter to use when
downloading streams
"""
download_output_manager = self._get_download_output_manager_cls(
transfer_future, osutil
)(osutil, self._transfer_coordinator, io_executor)

# Skip the HEAD request only when the caller has explicitly opted out
# of response checksum validation. Otherwise we need the HEAD to
# obtain the full-object ETag/size for checksum validation.
if client.meta.config.response_checksum_validation == "when_required":
self._submit_first_chunk_request(
client,
config,
osutil,
request_executor,
io_executor,
download_output_manager,
transfer_future,
bandwidth_limiter,
)
return

if (
transfer_future.meta.size is None
or transfer_future.meta.etag is None
Expand All @@ -370,10 +390,6 @@ def _submit(
# during a multipart download.
transfer_future.meta.provide_object_etag(response.get('ETag'))

download_output_manager = self._get_download_output_manager_cls(
transfer_future, osutil
)(osutil, self._transfer_coordinator, io_executor)

# If it is greater than threshold do a ranged download, otherwise
# do a regular GetObject download.
if transfer_future.meta.size < config.multipart_threshold:
Expand Down Expand Up @@ -541,6 +557,214 @@ def _calculate_range_param(self, part_size, part_index, num_parts):
range_param = f'bytes={start_range}-{end_range}'
return range_param

def _submit_first_chunk_request(
self,
client,
config,
osutil,
request_executor,
io_executor,
download_output_manager,
transfer_future,
bandwidth_limiter,
):
call_args = transfer_future.meta.call_args

# Get a handle to the file that will be used for writing downloaded
# contents
fileobj = download_output_manager.get_fileobj_for_io_writes(
transfer_future
)

# Get the needed callbacks for the task
progress_callbacks = get_callbacks(transfer_future, 'progress')

# Get any associated tags for the get object task.
get_object_tag = download_output_manager.get_download_task_tag()

# Request first chunk to get object metadata from response headers
chunk_size = config.multipart_chunksize
extra_args = dict(call_args.extra_args)
extra_args['Range'] = f'bytes=0-{chunk_size - 1}'
Comment thread
crowecawcaw marked this conversation as resolved.

if transfer_future.meta.etag is not None:
extra_args['IfMatch'] = transfer_future.meta.etag

# Callback will determine if additional chunks are needed based on
# the Content-Range header in the response
on_done_callback = GetObjectFirstChunkOnDoneCallback(
transfer_future,
download_output_manager,
io_executor,
self._transfer_coordinator,
client,
config,
request_executor,
bandwidth_limiter,
fileobj,
progress_callbacks,
get_object_tag,
)

task = GetObjectTask(
transfer_coordinator=self._transfer_coordinator,
main_kwargs={
'client': client,
'bucket': call_args.bucket,
'key': call_args.key,
'fileobj': fileobj,
'extra_args': extra_args,
'callbacks': progress_callbacks,
'max_attempts': config.num_download_attempts,
'start_index': 0,
'download_output_manager': download_output_manager,
'io_chunksize': config.io_chunksize,
'bandwidth_limiter': bandwidth_limiter,
},
done_callbacks=[on_done_callback],
)
on_done_callback.set_task(task)

self._transfer_coordinator.submit(
request_executor,
task,
tag=get_object_tag,
)


class GetObjectFirstChunkOnDoneCallback:
def __init__(
self,
transfer_future,
download_output_manager,
io_executor,
transfer_coordinator,
client,
config,
request_executor,
bandwidth_limiter,
fileobj,
progress_callbacks,
get_object_tag,
):
self._transfer_future = transfer_future
self._download_output_manager = download_output_manager
self._io_executor = io_executor
self._transfer_coordinator = transfer_coordinator
self._client = client
self._config = config
self._request_executor = request_executor
self._bandwidth_limiter = bandwidth_limiter
self._fileobj = fileobj
self._progress_callbacks = progress_callbacks
self._get_object_tag = get_object_tag
self._task = None

def __call__(self):
# Always check if we have a task and response first
assert self._task is not None, (
"set_task() must be called before the task is submitted"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we replace this assertion with raising a RuntimeError with a similar message wrapped? It would be more idiomatic with our codebase.

)

response = self._task.get_response()
if not response:
# No response means the GET failed or was cancelled
# Still need to submit final task to signal completion
final_task = self._download_output_manager.get_final_io_task()
self._transfer_coordinator.submit(self._io_executor, final_task)
return

# If transfer is already done (cancelled/failed), don't schedule more work
# but still submit the final task
if self._transfer_coordinator.done():
final_task = self._download_output_manager.get_final_io_task()
self._transfer_coordinator.submit(self._io_executor, final_task)
return

size, etag = self._extract_metadata(response)
self._transfer_future.meta.provide_transfer_size(size)
self._transfer_future.meta.provide_object_etag(etag)

if size == 0:
# Force-open the DeferredOpenFile so the temp file exists
# on disk for IORenameFileTask. Without this, the deferred
# file is never opened since no bytes are written.
self._fileobj.write(b'')

chunk_size = self._config.multipart_chunksize
if size > chunk_size:
self._schedule_remaining_chunks(size, etag)
else:
final_task = self._download_output_manager.get_final_io_task()
self._transfer_coordinator.submit(self._io_executor, final_task)

def set_task(self, task):
self._task = task

def _extract_metadata(self, response):
content_range = response.get('ContentRange')
if content_range:
# Content-Range format: 'bytes 0-8388607/39542919'
# Extract total size from the part after the slash
size = int(content_range.split('/')[-1])
else:
size = response['ContentLength']
etag = response.get('ETag')
return size, etag

def _schedule_remaining_chunks(self, size, etag):
call_args = self._transfer_future.meta.call_args
part_size = self._config.multipart_chunksize
num_parts = calculate_num_parts(size, part_size)

# Callback invoker to submit the final io task once all downloads
# are complete.
final_task = self._download_output_manager.get_final_io_task()
finalize_download_invoker = CountCallbackInvoker(
FunctionContainer(
self._transfer_coordinator.submit,
self._io_executor,
final_task,
)
)

# Start from 1 since chunk 0 was already requested
for i in range(1, num_parts):
range_parameter = calculate_range_parameter(
part_size, i, num_parts
)
extra_args = {
'Range': range_parameter,
}
# Use IfMatch to ensure object hasn't changed during download
if etag is not None:
extra_args['IfMatch'] = etag
Comment thread
crowecawcaw marked this conversation as resolved.
extra_args.update(call_args.extra_args)
finalize_download_invoker.increment()

self._transfer_coordinator.submit(
self._request_executor,
GetObjectTask(
transfer_coordinator=self._transfer_coordinator,
main_kwargs={
'client': self._client,
'bucket': call_args.bucket,
'key': call_args.key,
'fileobj': self._fileobj,
'extra_args': extra_args,
'callbacks': self._progress_callbacks,
'max_attempts': self._config.num_download_attempts,
'start_index': i * part_size,
'download_output_manager': self._download_output_manager,
'io_chunksize': self._config.io_chunksize,
'bandwidth_limiter': self._bandwidth_limiter,
},
done_callbacks=[finalize_download_invoker.decrement],
),
tag=self._get_object_tag,
)
finalize_download_invoker.finalize()


class GetObjectTask(Task):
def _main(
Expand Down Expand Up @@ -582,6 +806,9 @@ def _main(
response = client.get_object(
Bucket=bucket, Key=key, **extra_args
)
# Store response so callback can extract metadata
self._response = response

self._validate_content_range(
extra_args.get('Range'),
response.get('ContentRange'),
Expand Down Expand Up @@ -619,6 +846,13 @@ def _main(
f'Contents of stored object "{key}" in bucket '
f'"{bucket}" did not match expected ETag.'
)
elif error_code == "InvalidRange":
self._response = {
'ContentLength': 0,
'ContentRange': None,
'ETag': None,
}
return
else:
raise
except S3_RETRYABLE_DOWNLOAD_ERRORS as e:
Expand All @@ -643,25 +877,43 @@ def _main(
def _handle_io(self, download_output_manager, fileobj, chunk, index):
download_output_manager.queue_file_io_task(fileobj, chunk, index)

def get_response(self):
return getattr(self, '_response', None)

def _validate_content_range(self, requested_range, content_range):
if not requested_range or not content_range:
return
# Unparsed `ContentRange` looks like `bytes 0-8388607/39542919`,
# where `0-8388607` is the fetched range and `39542919` is
# the total object size.
response_range, total_size = content_range.split('/')
# Subtract `1` because range is 0-indexed.
final_byte = str(int(total_size) - 1)
# If it's the last part, the requested range will not include
# the final byte, eg `bytes=33554432-`.
if requested_range.endswith('-'):
requested_range += final_byte
# Request looks like `bytes=0-8388607`.
# Parsed response looks like `bytes 0-8388607`.
if requested_range[6:] != response_range[6:]:
# Parse requested range: `bytes=0-8388607` -> start=0, end=8388607
req_range_part = requested_range[6:] # Remove 'bytes='
if '-' not in req_range_part:
return
req_start, req_end = req_range_part.split('-', 1)
req_start = int(req_start)
# req_end might be empty for open-ended ranges
req_end = int(req_end) if req_end else int(total_size) - 1

# Parse response range: `bytes 0-8388607` -> start=0, end=8388607
resp_range_part = response_range[6:] # Remove 'bytes '
resp_start, resp_end = resp_range_part.split('-', 1)
resp_start = int(resp_start)
resp_end = int(resp_end)

# Validate that response starts where we requested
if resp_start != req_start:
raise S3ValidationError(
f"Response range start `{resp_start}` does not match "
f"requested start `{req_start}`"
)

# Validate that response doesn't exceed what we requested
if resp_end > req_end:
raise S3ValidationError(
f"Requested range: `{requested_range[6:]}` does not match "
f"content range in response: `{response_range[6:]}`"
f"Response range end `{resp_end}` exceeds "
f"requested end `{req_end}`"
)


Expand Down
Loading
Loading