Skip to content

Add a progress callback to the wait operation function #204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
23 changes: 22 additions & 1 deletion src/ansys/hps/data_transfer/client/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"""

import builtins
from collections.abc import Callable
import logging
import textwrap
import time
Expand Down Expand Up @@ -281,8 +282,25 @@ def wait_for(
interval: float = 0.1,
cap: float = 2.0,
raise_on_error: bool = False,
progress_handler: Callable[[str, float], None] = None,
):
"""Wait for operations to complete."""
"""Wait for operations to complete.

Parameters
----------
operation_ids: List[str | Operation | OpIdResponse]
List of operation ids.
timeout: float | None
Timeout in seconds. Default is None.
interval: float
Interval in seconds. Default is 0.1.
cap: float
The maximum backoff value used to calculate the next wait time. Default is 2.0.
raise_on_error: bool
Raise an exception if an error occurs. Default is False.
progress_handler: Callable[[str, float], None]
A function to handle progress updates. Default is None.
"""
if not isinstance(operation_ids, list):
operation_ids = [operation_ids]
operation_ids = [op.id if isinstance(op, Operation | OpIdResponse) else op for op in operation_ids]
Expand All @@ -307,6 +325,9 @@ def wait_for(
if op.progress > 0:
fields.append(f"progress={op.progress:.3f}")
log.debug(f"- Operation '{op.description}' {' '.join(fields)}")
if progress_handler is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conditions on self.client.binary_config.debug and progress_handler are not related. We should be able to print debug logs even if there's no progress_handler. I'd just do the loop twice.

for op in ops:
progress_handler(op.id, op.progress)
if all(op.state in [OperationState.Succeeded, OperationState.Failed] for op in ops):
break
except Exception as e:
Expand Down
43 changes: 32 additions & 11 deletions src/ansys/hps/data_transfer/client/api/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import asyncio
import builtins
from collections.abc import Callable
import logging
import textwrap
import time
Expand Down Expand Up @@ -204,8 +205,26 @@ async def wait_for(
interval: float = 0.1,
cap: float = 2.0,
raise_on_error: bool = False,
progress_handler: Callable[[str, float], None] = None,
):
"""Async interface to wait for a list of operations to complete."""
"""Async interface to wait for a list of operations to complete.

Parameters
----------
operation_ids: list[str | Operation]
The list of operation ids to wait for.
timeout: float | None
The maximum time to wait for the operations to complete.
interval: float
The interval between checks for the operations to complete.
cap: float
The maximum backoff value used to calculate the next wait time. Default is 2.0.
raise_on_error: bool
Raise an exception if an error occurs. Default is False.
progress_handler: Callable[[str, float], None]
A function to handle progress updates. Default is None.

"""
if not isinstance(operation_ids, list):
operation_ids = [operation_ids]
operation_ids = [op.id if isinstance(op, Operation | OpIdResponse) else op for op in operation_ids]
Expand All @@ -219,17 +238,19 @@ async def wait_for(
ops = await self._operations(operation_ids)
so_far = hf.format_timespan(time.time() - start)
log.debug(f"Waiting for {len(operation_ids)} operations to complete, {so_far} so far")
if self.client.binary_config.debug:
if progress_handler is not None:
for op in ops:
fields = [
f"id={op.id}",
f"state={op.state}",
f"start={op.started_at}",
f"succeeded_on={op.succeeded_on}",
]
if op.progress > 0:
fields.append(f"progress={op.progress:.3f}")
log.debug(f"- Operation '{op.description}' {' '.join(fields)}")
progress_handler(op.id, op.progress)
if self.client.binary_config.debug:
fields = [
f"id={op.id}",
f"state={op.state}",
f"start={op.started_at}",
f"succeeded_on={op.succeeded_on}",
]
if op.progress > 0:
fields.append(f"progress={op.progress:.3f}")
log.debug(f"- Operation '{op.description}' {' '.join(fields)}")
if all(op.state in [OperationState.Succeeded, OperationState.Failed] for op in ops):
break
except Exception as e:
Expand Down
154 changes: 144 additions & 10 deletions tests/large_file_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
log = logging.getLogger(__name__)

num_files = 2
file_size = 5 # GB


def write_file(file_name, size):
Expand All @@ -52,9 +51,17 @@ def write_file(file_name, size):
return 0


def test_large_batch(storage_path, client):
"""Test copying a large file to a remote storage."""
api = DataTransferApi(client)
def sync_copy(storage_path, api, file_size=5):
"""Copying a large file to a remote storage.

Parameters:
storage_path: str
The path to the remote storage.
api: DataTransferApi
The DataTransferApi object.
file_size: int
The size of the file to be copied in GB.
"""
api.status(wait=True)

with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
Expand All @@ -69,14 +76,20 @@ def test_large_batch(storage_path, client):

log.info("Starting copy ...")
op = api.copy(dsts)
assert op.id is not None
op = api.wait_for(op.id)
assert op[0].state == OperationState.Succeeded, op[0].messages
return op


async def test_async_large_batch(storage_path, async_client):
"""Test copying a large file to a remote storage using the AsyncDataTransferApi."""
api = AsyncDataTransferApi(async_client)
async def async_copy(storage_path, api, file_size=5):
"""Copying a large file to a remote storage using the AsyncDataTransferApi.

Parameters:
storage_path: str
The path to the remote storage.
api: DataTransferApi
The DataTransferApi object.
file_size: int
The size of the file to be copied in GB.
"""
api.status(wait=True)

with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
Expand All @@ -91,6 +104,127 @@ async def test_async_large_batch(storage_path, async_client):

log.info("Starting copy ...")
op = await api.copy(dsts)
return op


def test_large_batch(storage_path, client):
"""Test copying a large file to a remote storage."""
api = DataTransferApi(client)
op = sync_copy(storage_path, api)
assert op.id is not None
op = api.wait_for(op.id)
assert op[0].state == OperationState.Succeeded, op[0].messages


def test_batch_with_wait_parameters(storage_path, client):
"""Test copying a large file to a remote storage with wait parameter progress_handler."""
api = DataTransferApi(client)
log.info("Copy with progress handler")
op = sync_copy(storage_path, api, 2)
assert op.id is not None

# List to store progress data
progress_data = []

# test progress handler
def handler(id, current_progress):
progress_data.append(current_progress)
log.info(f"{current_progress * 100.0}% completed for operation id: {id}")

# Wait for the operation to complete with progress handler
op = api.wait_for(op.id, progress_handler=handler)
assert op[0].state == OperationState.Succeeded, op[0].messages
# Check if progress data is collected
assert len(progress_data) > 0, "No progress data collected"
# Check if the last progress is 100%
assert progress_data[-1] == 1.0, "Last progress is not 100%"


def test_batch_with_multiple_operations_to_wait(storage_path, client):
"""Test copying a large file to a remote storage with wait parameter progress_handler."""
api = DataTransferApi(client)
log.info("Copy with progress handler")
op1 = sync_copy(storage_path, api, 1)
op2 = sync_copy(storage_path, api, 1)
assert op1.id is not None
assert op2.id is not None

# List to store progress data
progress_data = []

# test progress handler
def handler(id, current_progress):
progress_data.append(current_progress)
log.info(f"{current_progress * 100.0}% completed for operation id: {id}")

# Wait for the operation to complete with progress handler
op = api.wait_for([op1.id, op2.id], progress_handler=handler)
assert op[0].state == OperationState.Succeeded, op[0].messages
assert op[1].state == OperationState.Succeeded, op[1].messages
# Check if progress data is collected at least twice
assert len(progress_data) > 2, "No progress data collected"
# Check if the last progress is 100%
assert progress_data[-1] == 1.0, "Last progress is not 100%"
assert progress_data[-2] == 1.0, "Last progress is not 100%"


async def test_async_large_batch(storage_path, async_client):
"""Test copying a large file to a remote storage using the AsyncDataTransferApi."""
api = AsyncDataTransferApi(async_client)
op = await async_copy(storage_path, api)
assert op.id is not None
op = await api.wait_for(op.id)
assert op[0].state == OperationState.Succeeded, op[0].messages


async def test_async_batch_with_wait_parameters(storage_path, async_client):
"""Test copying a large file to a remote storage using the AsyncDataTransferApi
with wait parameter progress_handler."""
api = AsyncDataTransferApi(async_client)
log.info("Copy with progress handler")
op = await async_copy(storage_path, api, 2)
assert op.id is not None

# List to store progress data
progress_data = []

# test progress handler
def handler(id, current_progress):
progress_data.append(current_progress)
log.info(f"{current_progress * 100.0}% completed for operation id: {id}")

# Wait for the operation to complete with progress handler
op = await api.wait_for(op.id, progress_handler=handler)
assert op[0].state == OperationState.Succeeded, op[0].messages
# Check if progress data is collected
assert len(progress_data) > 0, "No progress data collected"
# Check if the last progress is 100%
assert progress_data[-1] == 1.0, "Last progress is not 100%"


async def test_async_batch_with_multiple_operations_to_wait(storage_path, async_client):
"""Test copying a large file to a remote storage using the AsyncDataTransferApi
with wait parameter progress_handler."""
api = AsyncDataTransferApi(async_client)
log.info("Copy with progress handler")
op1 = await async_copy(storage_path, api, 1)
op2 = await async_copy(storage_path, api, 1)
assert op1.id is not None
assert op2.id is not None

# List to store progress data
progress_data = []

# test progress handler
def handler(id, current_progress):
progress_data.append(current_progress)
log.info(f"{current_progress * 100.0}% completed for operation id: {id}")

# Wait for the operation to complete with progress handler
op = await api.wait_for([op1.id, op2.id], progress_handler=handler)
assert op[0].state == OperationState.Succeeded, op[0].messages
assert op[1].state == OperationState.Succeeded, op[1].messages
# Check if progress data is collected at least twice
assert len(progress_data) > 2, "No progress data collected"
# Check if the last progress is 100%
assert progress_data[-1] == 1.0, "Last progress is not 100%"
Loading