-
Notifications
You must be signed in to change notification settings - Fork 188
parallel merge index #590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
parallel merge index #590
Changes from 2 commits
70d8e8f
feee52d
a0605a2
e7edd52
8112858
a583329
d6206f0
21c591a
168d3dd
f82d47e
6add8ea
18a2f97
eb5a16f
1a0c458
9c05860
6004e81
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| import tempfile | ||
| import urllib.parse | ||
| from collections import OrderedDict | ||
| from multiprocessing import Pool | ||
| from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory | ||
| from pathlib import Path | ||
| from time import sleep, time | ||
|
|
@@ -253,6 +254,50 @@ def merge_index(*args: Any, **kwargs: Any): | |
| raise ValueError(f'Invalid arguments to merge_index: {args}, {kwargs}') | ||
|
|
||
|
|
||
| def _download_url(url_info): | ||
| """Download a file given URL information.""" | ||
| from streaming.base.storage.download import download_file | ||
| src, dest, download_timeout = url_info | ||
| try: | ||
| download_file(src, dest, download_timeout) | ||
| except Exception as ex: | ||
| return f'Failed to download index.json: {src} to {dest}: {str(ex)}', ex | ||
| return dest, None | ||
|
|
||
|
|
||
| def _merge_partition_indices(partition_indices): | ||
XiaohanZhangCMU marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Function to be executed by each process to merge a subset of partition indices.""" | ||
| shards = [] | ||
| for partition_index in partition_indices: | ||
| p = Path(partition_index) | ||
| with open(partition_index, 'r') as f: | ||
| obj = json.load(f) | ||
| for shard in obj['shards']: | ||
| for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we really ought to make this a Shard method, which is subject to inheritance and so on this code won't work for parquet shards :/
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any specific suggestion how to deal with this?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this work for json/xsv or just for mds index files? Could you test that as well?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do json/xsv index files have the same file format? @knighton |
||
| if shard.get(key): | ||
| basename = shard[key]['basename'] | ||
| shard[key]['basename'] = os.path.join(os.path.basename(p.parent), basename) | ||
| shards.extend(obj['shards']) | ||
| return shards | ||
|
|
||
|
|
||
| def _parallel_merge_partitions(partitions, n_processes=4): | ||
| """Divide the list of partitions among multiple processes and merge them in parallel.""" | ||
| with Pool(processes=n_processes) as pool: | ||
| # Split the list of partitions into N chunks where N is the number of processes | ||
| chunk_size = len(partitions) // n_processes + (len(partitions) % n_processes > 0) | ||
XiaohanZhangCMU marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| partition_chunks = [ | ||
| partitions[i:i + chunk_size] for i in range(0, len(partitions), chunk_size) | ||
| ] | ||
|
|
||
| # Process each chunk in parallel | ||
| results = pool.map(_merge_partition_indices, partition_chunks) | ||
XiaohanZhangCMU marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Combine the results from all processes | ||
| final_shards = [shard for result in results for shard in result] | ||
| return final_shards | ||
|
|
||
|
|
||
| def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]]], | ||
| out: Union[str, Tuple[str, str]], | ||
| keep_local: bool = True, | ||
|
|
@@ -273,7 +318,6 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] | |
| keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` | ||
| download_timeout (int): The allowed time for downloading each json file. Defaults to 60. | ||
| """ | ||
| from streaming.base.storage.download import download_file | ||
| from streaming.base.storage.upload import CloudUploader | ||
|
|
||
| if not index_file_urls or not out: | ||
|
|
@@ -297,10 +341,10 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] | |
|
|
||
| # Prepare a temp folder to download index.json from remote if necessary. Removed in the end. | ||
| with tempfile.TemporaryDirectory() as temp_root: | ||
| logging.warning(f'A temporary folder {temp_root} is created to store index files') | ||
| logging.info(f'A temporary folder {temp_root} is created to store index files') | ||
XiaohanZhangCMU marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Copy files to a temporary directory. Download if necessary | ||
| partitions = [] | ||
| download_tasks = [] | ||
| for url in urls: | ||
| if isinstance(url, tuple): | ||
| src = url[0] if os.path.exists(url[0]) else url[1] | ||
|
|
@@ -314,30 +358,18 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] | |
| f'Check data availability! local index {url[0]} is not accessible.' + | ||
| f'remote index {url[1]} does not have a valid url format') | ||
| dest = os.path.join(temp_root, path.lstrip('/')) | ||
| download_tasks.append((src, dest, download_timeout)) | ||
XiaohanZhangCMU marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| try: | ||
| download_file(src, dest, download_timeout) | ||
| except Exception as ex: | ||
| raise RuntimeError(f'Failed to download index.json: {src} to {dest}') from ex | ||
|
|
||
| if not os.path.exists(dest): | ||
| raise FileNotFoundError(f'Index file {dest} does not exist or not accessible.') | ||
|
|
||
| partitions.append(dest) | ||
|
|
||
| # merge shards from all index files | ||
| shards = [] | ||
| for partition_index in partitions: | ||
| p = Path(partition_index) | ||
| obj = json.load(open(partition_index)) | ||
| for i in range(len(obj['shards'])): | ||
| shard = obj['shards'][i] | ||
| for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'): | ||
| if shard.get(key): | ||
| basename = shard[key]['basename'] | ||
| obj['shards'][i][key]['basename'] = os.path.join( | ||
| os.path.basename(p.parent), basename) | ||
| shards += obj['shards'] | ||
| with Pool(processes=os.cpu_count()) as pool: | ||
| results = pool.map(_download_url, download_tasks) | ||
XiaohanZhangCMU marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| partitions = [] | ||
| for partition_index, error in results: | ||
| if error: | ||
| raise RuntimeError(partition_index) | ||
| partitions.append(partition_index) | ||
|
|
||
| shards = _parallel_merge_partitions(partitions) | ||
|
|
||
| # Save merged index locally | ||
| obj = { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,7 @@ | |
| import time | ||
| import urllib.parse | ||
| from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory | ||
| from typing import List, Optional, Tuple, Union | ||
| from typing import List, Optional, Sequence, Tuple, Union | ||
|
|
||
| import pytest | ||
|
|
||
|
|
@@ -194,9 +194,9 @@ def test_format_remote_index_files(scheme: str): | |
| assert obj.scheme == scheme | ||
|
|
||
|
|
||
| @pytest.mark.parametrize('index_file_urls_pattern', [1, 2, 3]) | ||
| @pytest.mark.parametrize('keep_local', [True, False]) | ||
| @pytest.mark.parametrize('scheme', ['gs://', 's3://', 'oci://']) | ||
| @pytest.mark.parametrize('index_file_urls_pattern', [1]) # , 2, 3]) | ||
XiaohanZhangCMU marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| @pytest.mark.parametrize('keep_local', [True]) # , False]) | ||
| @pytest.mark.parametrize('scheme', ['gs://']) # , 's3://', 'oci://']) | ||
| def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_local: bool, | ||
| index_file_urls_pattern: int, scheme: str): | ||
| """Validate the final merge index json for following patterns of index_file_urls: | ||
|
|
@@ -206,10 +206,10 @@ def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_loc | |
| 4. All URLs are tuple (local, remote). At least one url is not accessible locally -> download all | ||
| 5. All URLs are str (remote) -> download all | ||
| """ | ||
| from decimal import Decimal | ||
| import random | ||
| import string | ||
|
|
||
| from pyspark.sql import SparkSession | ||
| from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType | ||
|
|
||
| from streaming.base.converters import dataframeToMDS | ||
|
|
||
|
|
@@ -223,15 +223,18 @@ def not_merged_index(index_file_path: str, out: str): | |
| mds_out = out = local | ||
|
|
||
| spark = SparkSession.builder.getOrCreate() # pyright: ignore | ||
| schema = StructType([ | ||
| StructField('id', IntegerType(), nullable=False), | ||
| StructField('name', StringType(), nullable=False), | ||
| StructField('amount', DecimalType(10, 2), nullable=False) | ||
| ]) | ||
| data = [(1, 'Alice', Decimal('123.45')), (2, 'Bob', Decimal('67.89')), | ||
| (3, 'Charlie', Decimal('987.65'))] | ||
| df = spark.createDataFrame(data=data, schema=schema).repartition(3) | ||
| mds_kwargs = {'out': mds_out, 'columns': {'id': 'int', 'name': 'str'}, 'keep_local': True} | ||
|
|
||
| def random_string(length=1000): | ||
| """Generate a random string of fixed length.""" | ||
| letters = string.ascii_letters + string.digits + string.punctuation + ' ' | ||
| return ''.join(random.choice(letters) for _ in range(length)) | ||
XiaohanZhangCMU marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Generate a DataFrame with 10000 rows of random text | ||
| num_rows = 100 | ||
| data = [(i, random_string(), random_string()) for i in range(num_rows)] | ||
| df = spark.createDataFrame(data, ['id', 'name', 'amount']) | ||
|
|
||
| mds_kwargs = {'out': mds_out, 'columns': {'id': 'int64', 'name': 'str'}, 'keep_local': True} | ||
| dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) | ||
|
|
||
| local_cu = CloudUploader.get(local, exist_ok=True, keep_local=True) | ||
|
|
@@ -241,6 +244,16 @@ def not_merged_index(index_file_path: str, out: str): | |
|
|
||
| if index_file_urls_pattern == 1: | ||
| merge_index(local_index_files, out, keep_local=keep_local) | ||
| d1 = json.load(open(os.path.join(out, 'index.json'))) | ||
|
|
||
| _merge_index_from_list_serial(local_index_files, out, keep_local=keep_local) | ||
| d2 = json.load(open(os.path.join(out, 'index.json'))) | ||
|
|
||
| print('d1 = ', d1) | ||
| print('d2 = ', d2) | ||
|
|
||
| assert len(d1['shards']) == len(d2['shards']), 'parallel and serial results different' | ||
| assert d1['shards'] == d2['shards'], 'parallel and serial results different' | ||
|
|
||
| if index_file_urls_pattern == 2: | ||
| with tempfile.TemporaryDirectory() as a_temporary_folder: | ||
|
|
@@ -323,3 +336,98 @@ def flaky_function(): | |
| return "Third time's a charm" | ||
|
|
||
| assert flaky_function() == "Third time's a charm" | ||
|
|
||
|
|
||
| def _merge_index_from_list_serial(index_file_urls: Sequence[Union[str, Tuple[str, str]]], | ||
XiaohanZhangCMU marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| out: Union[str, Tuple[str, str]], | ||
| keep_local: bool = True, | ||
| download_timeout: int = 60) -> None: | ||
| import logging | ||
| import shutil | ||
| import urllib.parse | ||
| from collections import OrderedDict | ||
| from pathlib import Path | ||
|
|
||
| from streaming.base.format.index import get_index_basename | ||
| from streaming.base.storage.download import download_file | ||
| from streaming.base.storage.upload import CloudUploader | ||
|
|
||
| if not index_file_urls or not out: | ||
| return | ||
|
|
||
| # This is the index json file name, e.g., it is index.json as of 0.6.0 | ||
| index_basename = get_index_basename() | ||
|
|
||
| cu = CloudUploader.get(out, keep_local=True, exist_ok=True) | ||
|
|
||
| # Remove duplicates, and strip '/' from right if any | ||
| index_file_urls = list(OrderedDict.fromkeys(index_file_urls)) | ||
| urls = [] | ||
| for url in index_file_urls: | ||
| if isinstance(url, str): | ||
| urls.append(url.rstrip('/').strip()) | ||
| else: | ||
| urls.append((url[0].rstrip('/').strip(), url[1].rstrip('/').strip())) | ||
|
|
||
| # Prepare a temp folder to download index.json from remote if necessary. Removed in the end. | ||
| with tempfile.TemporaryDirectory() as temp_root: | ||
| logging.warning(f'A temporary folder {temp_root} is created to store index files') | ||
|
|
||
| # Copy files to a temporary directory. Download if necessary | ||
| partitions = [] | ||
| for url in urls: | ||
| if isinstance(url, tuple): | ||
| src = url[0] if os.path.exists(url[0]) else url[1] | ||
| else: | ||
| src = url | ||
|
|
||
| obj = urllib.parse.urlparse(src) | ||
| scheme, bucket, path = obj.scheme, obj.netloc, obj.path | ||
| if scheme == '' and bucket == '' and path == '': | ||
| raise FileNotFoundError( | ||
| f'Check data availability! local index {url[0]} is not accessible.' + | ||
| f'remote index {url[1]} does not have a valid url format') | ||
| dest = os.path.join(temp_root, path.lstrip('/')) | ||
|
|
||
| try: | ||
| download_file(src, dest, download_timeout) | ||
| except Exception as ex: | ||
| raise RuntimeError(f'Failed to download index.json: {src} to {dest}') from ex | ||
|
|
||
| if not os.path.exists(dest): | ||
| raise FileNotFoundError(f'Index file {dest} does not exist or not accessible.') | ||
|
|
||
| partitions.append(dest) | ||
|
|
||
| # merge shards from all index files | ||
| shards = [] | ||
| for partition_index in partitions: | ||
| p = Path(partition_index) | ||
| obj = json.load(open(partition_index)) | ||
| for i in range(len(obj['shards'])): | ||
| shard = obj['shards'][i] | ||
| for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'): | ||
| if shard.get(key): | ||
| basename = shard[key]['basename'] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wait why are you just taking the basename of the child file here? and to be clear, why the basename of the parent as well, what if the dir to merge is >1 hops deep?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if anyone takes
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @XiaohanZhangCMU was this resolved? |
||
| obj['shards'][i][key]['basename'] = os.path.join( | ||
| os.path.basename(p.parent), basename) | ||
| shards += obj['shards'] | ||
|
|
||
| # Save merged index locally | ||
| obj = { | ||
| 'version': 2, | ||
| 'shards': shards, | ||
| } | ||
| merged_index_path = os.path.join(temp_root, index_basename) | ||
| with open(merged_index_path, 'w') as outfile: | ||
| json.dump(obj, outfile) | ||
|
|
||
| # Move merged index from temp path to local part in out | ||
| # Upload merged index to remote if out has remote part | ||
| shutil.move(merged_index_path, cu.local) | ||
| if cu.remote is not None: | ||
| cu.upload_file(index_basename) | ||
|
|
||
| # Clean up | ||
| if not keep_local: | ||
| shutil.rmtree(cu.local, ignore_errors=True) | ||
Uh oh!
There was an error while loading. Please reload this page.