Skip to content

Add a progress callback to the wait operation function #204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/ansys/hps/data_transfer/client/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"""

import builtins
from collections.abc import Callable
import logging
import textwrap
import time
Expand Down Expand Up @@ -281,8 +282,25 @@ def wait_for(
interval: float = 0.1,
cap: float = 2.0,
raise_on_error: bool = False,
progress_handler: Callable[[str, float], None] = None,
):
"""Wait for operations to complete."""
"""Wait for operations to complete.

Parameters
----------
operation_ids: List[str | Operation | OpIdResponse]
List of operation ids.
timeout: float | None
Timeout in seconds. Default is None.
interval: float
Interval in seconds. Default is 0.1.
cap: float
The maximum backoff value used to calculate the next wait time. Default is 2.0.
raise_on_error: bool
Raise an exception if an error occurs. Default is False.
progress_handler: Callable[[str, float], None]
A function to handle progress updates. Default is None.
"""
if not isinstance(operation_ids, list):
operation_ids = [operation_ids]
operation_ids = [op.id if isinstance(op, Operation | OpIdResponse) else op for op in operation_ids]
Expand All @@ -307,6 +325,9 @@ def wait_for(
if op.progress > 0:
fields.append(f"progress={op.progress:.3f}")
log.debug(f"- Operation '{op.description}' {' '.join(fields)}")
if progress_handler is not None:
for op in ops:
progress_handler(op.id, op.progress)
if all(op.state in [OperationState.Succeeded, OperationState.Failed] for op in ops):
break
except Exception as e:
Expand Down
24 changes: 23 additions & 1 deletion src/ansys/hps/data_transfer/client/api/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import asyncio
import builtins
from collections.abc import Awaitable, Callable
import logging
import textwrap
import time
Expand Down Expand Up @@ -204,8 +205,26 @@ async def wait_for(
interval: float = 0.1,
cap: float = 2.0,
raise_on_error: bool = False,
progress_handler: Callable[[str, float], Awaitable[None]] = None,
):
"""Provides an async interface to wait for a list of operations to complete."""
"""Provides an async interface to wait for a list of operations to complete.

Parameters
----------
operation_ids: list[str | Operation]
The list of operation ids to wait for.
timeout: float | None
The maximum time to wait for the operations to complete.
interval: float
The interval between checks for the operations to complete.
cap: float
The maximum backoff value used to calculate the next wait time. Default is 2.0.
raise_on_error: bool
Raise an exception if an error occurs. Default is False.
progress_handler: Callable[[str, float], None]
A async function to handle progress updates. Default is None.

"""
if not isinstance(operation_ids, list):
operation_ids = [operation_ids]
operation_ids = [op.id if isinstance(op, Operation | OpIdResponse) else op for op in operation_ids]
Expand All @@ -230,6 +249,9 @@ async def wait_for(
if op.progress > 0:
fields.append(f"progress={op.progress:.3f}")
log.debug(f"- Operation '{op.description}' {' '.join(fields)}")
if progress_handler is not None:
for op in ops:
await progress_handler(op.id, op.progress)
if all(op.state in [OperationState.Succeeded, OperationState.Failed] for op in ops):
break
except Exception as e:
Expand Down
210 changes: 183 additions & 27 deletions tests/large_file_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,29 +36,49 @@

log = logging.getLogger(__name__)

num_files = 2
file_size = 5 # GB


def write_file(file_name, size):
"""Write a file with random data."""
start_time = time.time()
log.info(f"Generating file {file_name} with size {size} GB")
gb1 = 1024 * 1024 * 1024 # 1GB
with open(file_name, "wb") as fout:
for _i in range(size):
fout.write(os.urandom(gb1))
log.info(f"File {file_name} has been generated after {(time.time() - start_time):.2f} seconds")
return 0


def test_large_batch(storage_path, client):
"""Test copying a large file to a remote storage."""
api = DataTransferApi(client)
class TempFileManager:
def __init__(self, file_name, file_size):
self.file_name = file_name
self.file_size = file_size

def write_file(self):
"""Write a file with random data."""
start_time = time.time()
log.info(f"Generating file {self.file_name} with size {self.file_size} GB")
gb1 = 1024 * 1024 * 1024 # 1GB
with open(self.file_name, "wb") as fout:
for _i in range(self.file_size):
fout.write(os.urandom(gb1))
log.info(f"File {self.file_name} has been generated after {(time.time() - start_time):.2f} seconds")
return 0

def delete_file(self):
"""Delete a file."""
log.info(f"Deleting file {self.file_name}")
try:
os.remove(self.file_name)
log.info(f"Temporary file {self.file_name} has been deleted.")
except Exception as ex:
log.warning(f"Failed to delete file {self.file_name}: {ex}")
return 0


def sync_copy(storage_path, api, file_size=5, num_files=2):
"""Copying a large file to a remote storage.

Parameters:
storage_path: str
The path to the remote storage.
api: DataTransferApi
The DataTransferApi object.
file_size: int
The size of the file to be copied in GB.
"""
api.status(wait=True)

with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
write_file(temp_file.name, file_size)
manager = TempFileManager(temp_file.name, file_size)
manager.write_file()
temp_file_name = os.path.basename(temp_file.name)

src = StoragePath(path=temp_file.name, remote="local")
Expand All @@ -69,18 +89,25 @@ def test_large_batch(storage_path, client):

log.info("Starting copy ...")
op = api.copy(dsts)
assert op.id is not None
op = api.wait_for(op.id)
assert op[0].state == OperationState.Succeeded, op[0].messages
return op, manager


async def test_async_large_batch(storage_path, async_client):
"""Test copying a large file to a remote storage using the AsyncDataTransferApi."""
api = AsyncDataTransferApi(async_client)
async def async_copy(storage_path, api, file_size=5, num_files=2):
"""Copying a large file to a remote storage using the AsyncDataTransferApi.

Parameters:
storage_path: str
The path to the remote storage.
api: DataTransferApi
The DataTransferApi object.
file_size: int
The size of the file to be copied in GB.
"""
api.status(wait=True)

with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
write_file(temp_file.name, file_size)
manager = TempFileManager(temp_file.name, file_size)
manager.write_file()
temp_file_name = os.path.basename(temp_file.name)

src = StoragePath(path=temp_file.name, remote="local")
Expand All @@ -91,6 +118,135 @@ async def test_async_large_batch(storage_path, async_client):

log.info("Starting copy ...")
op = await api.copy(dsts)
return op, manager


def test_large_batch(storage_path, client):
"""Test copying a large file to a remote storage."""
api = DataTransferApi(client)
op, manager = sync_copy(storage_path, api)
assert op.id is not None
op = api.wait_for(op.id)
assert op[0].state == OperationState.Succeeded, op[0].messages
manager.delete_file()


def test_batch_with_wait_parameters(storage_path, client):
"""Test copying a large file to a remote storage with wait parameter progress_handler."""
api = DataTransferApi(client)
log.info("Copy with progress handler")
op, manager = sync_copy(storage_path, api, 1, 1)
assert op.id is not None

# List to store progress data
progress_data = []

# test progress handler
def handler(id, current_progress):
progress_data.append(current_progress)
log.info(f"{current_progress * 100.0}% completed for operation id: {id}")

# Wait for the operation to complete with progress handler
op = api.wait_for(op.id, progress_handler=handler)
assert op[0].state == OperationState.Succeeded, op[0].messages
# Check if progress data is collected
assert len(progress_data) > 0, "No progress data collected"
# Check if the last progress is 100%
assert progress_data[-1] == 1.0, "Last progress is not 100%"
manager.delete_file()


def test_batch_with_multiple_operations_to_wait(storage_path, client):
"""Test copying a large file to a remote storage with wait parameter progress_handler."""
api = DataTransferApi(client)
log.info("Copy with progress handler")
op1, manager1 = sync_copy(storage_path, api, 1, 1)
op2, manager2 = sync_copy(storage_path, api, 1, 1)
assert op1.id is not None
assert op2.id is not None

# List to store progress data
progress_data = []

# test progress handler
def handler(id, current_progress):
progress_data.append(current_progress)
log.info(f"{current_progress * 100.0}% completed for operation id: {id}")

# Wait for the operation to complete with progress handler
op = api.wait_for([op1.id, op2.id], progress_handler=handler)
assert op[0].state == OperationState.Succeeded, op[0].messages
assert op[1].state == OperationState.Succeeded, op[1].messages
# Check if progress data is collected at least twice
assert len(progress_data) > 2, "No progress data collected"
# Check if the last progress is 100%
assert progress_data[-1] == 1.0, "Last progress is not 100%"
assert progress_data[-2] == 1.0, "Last progress is not 100%"
manager1.delete_file()
manager2.delete_file()


async def test_async_large_batch(storage_path, async_client):
"""Test copying a large file to a remote storage using the AsyncDataTransferApi."""
api = AsyncDataTransferApi(async_client)
op, manager = await async_copy(storage_path, api)
assert op.id is not None
op = await api.wait_for(op.id)
assert op[0].state == OperationState.Succeeded, op[0].messages
manager.delete_file()


async def test_async_batch_with_wait_parameters(storage_path, async_client):
"""Test copying a large file to a remote storage using the AsyncDataTransferApi
with wait parameter progress_handler."""
api = AsyncDataTransferApi(async_client)
log.info("Copy with progress handler")
op, manager = await async_copy(storage_path, api, 1, 1)
assert op.id is not None

# List to store progress data
progress_data = []

# test progress handler
async def handler(id, current_progress):
progress_data.append(current_progress)
log.info(f"{current_progress * 100.0}% completed for operation id: {id}")

# Wait for the operation to complete with progress handler
op = await api.wait_for(op.id, progress_handler=handler)
assert op[0].state == OperationState.Succeeded, op[0].messages
# Check if progress data is collected
assert len(progress_data) > 0, "No progress data collected"
# Check if the last progress is 100%
assert progress_data[-1] == 1.0, "Last progress is not 100%"
manager.delete_file()


async def test_async_batch_with_multiple_operations_to_wait(storage_path, async_client):
"""Test copying a large file to a remote storage using the AsyncDataTransferApi
with wait parameter progress_handler."""
api = AsyncDataTransferApi(async_client)
log.info("Copy with progress handler")
op1, manager1 = await async_copy(storage_path, api, 1, 1)
op2, manager2 = await async_copy(storage_path, api, 1, 1)
assert op1.id is not None
assert op2.id is not None

# List to store progress data
progress_data = []

# test progress handler
async def handler(id, current_progress):
progress_data.append(current_progress)
log.info(f"{current_progress * 100.0}% completed for operation id: {id}")

# Wait for the operation to complete with progress handler
op = await api.wait_for([op1.id, op2.id], progress_handler=handler)
assert op[0].state == OperationState.Succeeded, op[0].messages
assert op[1].state == OperationState.Succeeded, op[1].messages
# Check if progress data is collected at least twice
assert len(progress_data) > 2, "No progress data collected"
# Check if the last progress is 100%
assert progress_data[-1] == 1.0, "Last progress is not 100%"
manager1.delete_file()
manager2.delete_file()