Skip to content

[WIP] Adds gpu minhash support for RayBTSMinhashDeduplicator #644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: feat/ayushdg/gpu-minhash-poc
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 123 additions & 13 deletions data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from loguru import logger
from pydantic import Field, PositiveInt
from typing_extensions import Annotated

import time
import pickle
from data_juicer.utils.constant import HashKeys
from data_juicer.utils.model_utils import prepare_sentencepiece_model

Expand Down Expand Up @@ -80,6 +81,9 @@ def __init__(
self.max_pending_edge_buffer_task = max_pending_edge_buffer_task
self.num_edge_buffer_task_returns = num_edge_buffer_task_returns

def get_hash_table(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method can be removed

return self.hash_table

def add_key_value_pairs(self, pairs):
for key, value in pairs:
if key not in self.hash_table:
Expand Down Expand Up @@ -231,6 +235,45 @@ def dup_idx(self, queries):

OP_NAME = 'ray_bts_minhash_deduplicator'

#@ray.remote
class GPUMinHashActor:
def __init__(self, width: int = 5, perm_a: np.ndarray = None, perm_b: np.ndarray = None, lowercase: bool = True):
import cudf
import rmm
rmm.reinitialize(pool_allocator=True)
self.width = width
gen = np.random.RandomState(seed=42)
if perm_a is None or perm_b is None:
perm_a, perm_b = np.array(
[(
gen.randint(1, MERSENNE_PRIME, dtype=np.uint64),
gen.randint(0, MERSENNE_PRIME, dtype=np.uint64),
) for _ in range(256)],
dtype=np.uint32,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar issue

).T
self.perm_a = cudf.Series(perm_a).astype("uint32")
self.perm_b = cudf.Series(perm_b).astype("uint32")
self.lowercase = lowercase

def compute_minhash(self, text_arr: pa.Array) -> pa.Array:
"""
Compute MinHash signatures for texts in a table
"""
import cudf
text_df = cudf.Series.from_arrow(text_arr)
if self.lowercase:
text_df = text_df.str.lower()
minhashes = text_df.str.minhash(seed=0, a=self.perm_a, b=self.perm_b, width=self.width)
del text_df
arrow_minhashes = minhashes.to_arrow()
del(minhashes)
return arrow_minhashes

def __call__(self, table: pa.Table, text_key: str = "text") -> pa.Table:
minhashes = self.compute_minhash(table[text_key])
new_table = table.append_column("_minhash", minhashes)
return new_table


@OPERATORS.register_module(OP_NAME)
class RayBTSMinhashDeduplicator(Deduplicator):
Expand All @@ -245,6 +288,7 @@ class RayBTSMinhashDeduplicator(Deduplicator):
def __init__(
self,
tokenization: str = 'space',
use_gpu: bool = False,
window_size: PositiveInt = 5,
lowercase: bool = True,
ignore_pattern: Optional[str] = None,
Expand All @@ -260,6 +304,7 @@ def __init__(
max_pending_filter_tasks: Optional[int] = 20,
num_filter_task_returns: Optional[int] = 10,
merge_batch_size: Optional[int] = 1000,
minhash_batch_size: Optional[int] = "auto",
*args,
**kwargs,
):
Expand Down Expand Up @@ -311,13 +356,23 @@ def __init__(
:param tmp_file_name: the temporary folder name for deduplication.
"""
super().__init__(*args, **kwargs)
# about minhash computation

self.tokenization = tokenization
self.window_size = window_size
self.lowercase = lowercase
self.ignore_pattern = ignore_pattern
self.use_gpu = use_gpu
if minhash_batch_size == "auto":
if self.use_gpu:
self.minhash_batch_size = 200_000
else:
self.minhash_batch_size = 1024
else:
self.minhash_batch_size = minhash_batch_size
if self.ignore_pattern:
self.ignore_pattern = regex.compile(self.ignore_pattern)
if self.use_gpu and self.tokenization != 'character':
raise ValueError("GPU MinHash computation is only supported for character tokenization")

# check parameters
if self.ignore_pattern and self.tokenization == 'punctuation':
Expand Down Expand Up @@ -396,7 +451,7 @@ def tokenization_func(text):
gen.randint(1, MERSENNE_PRIME, dtype=np.uint64),
gen.randint(0, MERSENNE_PRIME, dtype=np.uint64),
) for _ in range(self.num_permutation)],
dtype=np.uint64,
dtype=np.uint32,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may break some constraints, it's better to keep uint64

).T

if union_find_parallel_num == 'auto':
Expand Down Expand Up @@ -435,9 +490,37 @@ def tokenization_func(text):
+ empty_hash_value.tobytes()
self.empty_hash_table_id = int(MAX_HASH % self.union_find_parallel_num)

def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table:

def band_minhash(self, minhash_list, uid_list):
"""
Logic for creating and pusing LSH bands to the union find list
"""
pairs = {}
minhash_list = minhash_list.to_numpy(zero_copy_only=False)
for minhash, uid in zip(minhash_list, uid_list):
for i, (start, end) in enumerate(self.hash_ranges):
hash_value = i.to_bytes(4, 'big') \
+ minhash[start:end].tobytes()
hash_table_id = minhash[start] \
% self.union_find_parallel_num
if hash_table_id not in pairs:
pairs[hash_table_id] = []
pairs[hash_table_id].append((hash_value, uid))
result_refs = []
for i, p in pairs.items():
if len(result_refs) > self.max_pending_filter_tasks:
ready_refs, result_refs = ray.wait(
result_refs, num_returns=self.num_filter_task_returns)
ray.get(ready_refs)
result_refs.append(
self.union_find_list[i].add_key_value_pairs.remote(p))
ray.get(result_refs)

def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table:
"""
Logic for computing minhash values for each text in the input table
"""
pairs = {}
for text, uid in zip(text_list, uid_list):
text = text.as_py()
if self.lowercase:
Expand All @@ -446,7 +529,6 @@ def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table:
text = self.ignore_pattern.sub('', text)

tokens = self.tokenization_func(text)

if len(tokens) > 0:
hv = np.array([sha1_hash32(token) for token in tokens],
dtype=np.uint64)
Expand Down Expand Up @@ -538,6 +620,16 @@ def run(self, dataset, **kwargs):
start_time = time.time()
id_generator = IdGenerator.remote()


def band_with_uid(table: pa.Table) -> pa.Table:
num_rows = len(table)
min_id, max_id = ray.get(id_generator.get_next_id.remote(num_rows))
uid_list = range(min_id, max_id)
self.band_minhash(table["_minhash"], uid_list)
new_table = table.append_column(HashKeys.uid, pa.array(list(uid_list)))
new_table = new_table.drop_columns(["_minhash"])
return new_table

def minhash_with_uid(table: pa.Table) -> pa.Table:
num_rows = len(table)
min_id, max_id = ray.get(id_generator.get_next_id.remote(num_rows))
Expand All @@ -549,22 +641,40 @@ def minhash_with_uid(table: pa.Table) -> pa.Table:

tmp_dir = os.path.join(self.work_dir, '.tmp',
ray.get_runtime_context().get_job_id())
dataset.map_batches(
minhash_with_uid,
batch_format='pyarrow',
zero_copy_batch=True,
).write_parquet(tmp_dir)
dataset = ray.data.read_parquet(tmp_dir)
if self.use_gpu:
dataset = dataset.map_batches(
GPUMinHashActor,
batch_format='pyarrow',
zero_copy_batch=True,
num_gpus=1,
concurrency=3,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I artificially set this during testing but we would want this to be configurable. I'm not sure what the best approach/config is for this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a gpu_actor_concurrency parameter to __init__ method is okay.

batch_size=self.minhash_batch_size,
)
dataset.map_batches(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way to merge these two map_batches into one? Adding additional map_batches may increase network overhead.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to combine the 2 map batches into a single call by moving all the banding logic into the GPU minhash actor, but since the number of GPUs/concurrency level for this stage might be much lesser than the total CPUs available it might reduce the concurrency for banding. It might be a tradeoff between networking overhead (via the object store) vs fewer actors doing the banding. I'm not sure which is more optimal.

band_with_uid,
batch_format='pyarrow',
zero_copy_batch=True,
).write_parquet(tmp_dir)
del dataset
else:
dataset.map_batches(
minhash_with_uid,
batch_format='pyarrow',
zero_copy_batch=True,
).write_parquet(tmp_dir)
end_time = time.time()
logger.info(f'MinHash time = {end_time - start_time}')

new_dataset = ray.data.read_parquet(tmp_dir)
start_time = time.time()
self.merge()
end_time = time.time()
logger.info(f'merge time = {end_time - start_time}')
result = dataset.map_batches(
start_time = time.time()
result = new_dataset.map_batches(
self.filter_with_union_find,
batch_format='pyarrow',
zero_copy_batch=True,
)
end_time = time.time()
logger.info(f'filter time = {end_time - start_time}')
return result
Loading