-
Notifications
You must be signed in to change notification settings - Fork 234
[WIP] Adds gpu minhash support for RayBTSMinhashDeduplicator #644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: feat/ayushdg/gpu-minhash-poc
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,8 @@ | |
from loguru import logger | ||
from pydantic import Field, PositiveInt | ||
from typing_extensions import Annotated | ||
|
||
import time | ||
import pickle | ||
from data_juicer.utils.constant import HashKeys | ||
from data_juicer.utils.model_utils import prepare_sentencepiece_model | ||
|
||
|
@@ -80,6 +81,9 @@ def __init__( | |
self.max_pending_edge_buffer_task = max_pending_edge_buffer_task | ||
self.num_edge_buffer_task_returns = num_edge_buffer_task_returns | ||
|
||
def get_hash_table(self): | ||
return self.hash_table | ||
|
||
def add_key_value_pairs(self, pairs): | ||
for key, value in pairs: | ||
if key not in self.hash_table: | ||
|
@@ -231,6 +235,45 @@ def dup_idx(self, queries): | |
|
||
OP_NAME = 'ray_bts_minhash_deduplicator' | ||
|
||
#@ray.remote | ||
class GPUMinHashActor: | ||
def __init__(self, width: int = 5, perm_a: np.ndarray = None, perm_b: np.ndarray = None, lowercase: bool = True): | ||
import cudf | ||
import rmm | ||
rmm.reinitialize(pool_allocator=True) | ||
self.width = width | ||
gen = np.random.RandomState(seed=42) | ||
if perm_a is None or perm_b is None: | ||
perm_a, perm_b = np.array( | ||
[( | ||
gen.randint(1, MERSENNE_PRIME, dtype=np.uint64), | ||
gen.randint(0, MERSENNE_PRIME, dtype=np.uint64), | ||
) for _ in range(256)], | ||
dtype=np.uint32, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar issue |
||
).T | ||
self.perm_a = cudf.Series(perm_a).astype("uint32") | ||
self.perm_b = cudf.Series(perm_b).astype("uint32") | ||
self.lowercase = lowercase | ||
|
||
def compute_minhash(self, text_arr: pa.Array) -> pa.Array: | ||
""" | ||
Compute MinHash signatures for texts in a table | ||
""" | ||
import cudf | ||
text_df = cudf.Series.from_arrow(text_arr) | ||
if self.lowercase: | ||
text_df = text_df.str.lower() | ||
minhashes = text_df.str.minhash(seed=0, a=self.perm_a, b=self.perm_b, width=self.width) | ||
del text_df | ||
arrow_minhashes = minhashes.to_arrow() | ||
del(minhashes) | ||
return arrow_minhashes | ||
|
||
def __call__(self, table: pa.Table, text_key: str = "text") -> pa.Table: | ||
minhashes = self.compute_minhash(table[text_key]) | ||
new_table = table.append_column("_minhash", minhashes) | ||
return new_table | ||
|
||
|
||
@OPERATORS.register_module(OP_NAME) | ||
class RayBTSMinhashDeduplicator(Deduplicator): | ||
|
@@ -245,6 +288,7 @@ class RayBTSMinhashDeduplicator(Deduplicator): | |
def __init__( | ||
self, | ||
tokenization: str = 'space', | ||
use_gpu: bool = False, | ||
window_size: PositiveInt = 5, | ||
lowercase: bool = True, | ||
ignore_pattern: Optional[str] = None, | ||
|
@@ -260,6 +304,7 @@ def __init__( | |
max_pending_filter_tasks: Optional[int] = 20, | ||
num_filter_task_returns: Optional[int] = 10, | ||
merge_batch_size: Optional[int] = 1000, | ||
minhash_batch_size: Optional[int] = "auto", | ||
*args, | ||
**kwargs, | ||
): | ||
|
@@ -311,13 +356,23 @@ def __init__( | |
:param tmp_file_name: the temporary folder name for deduplication. | ||
""" | ||
super().__init__(*args, **kwargs) | ||
# about minhash computation | ||
|
||
self.tokenization = tokenization | ||
self.window_size = window_size | ||
self.lowercase = lowercase | ||
self.ignore_pattern = ignore_pattern | ||
self.use_gpu = use_gpu | ||
if minhash_batch_size == "auto": | ||
if self.use_gpu: | ||
self.minhash_batch_size = 200_000 | ||
else: | ||
self.minhash_batch_size = 1024 | ||
else: | ||
self.minhash_batch_size = minhash_batch_size | ||
if self.ignore_pattern: | ||
self.ignore_pattern = regex.compile(self.ignore_pattern) | ||
if self.use_gpu and self.tokenization != 'character': | ||
raise ValueError("GPU MinHash computation is only supported for character tokenization") | ||
|
||
# check parameters | ||
if self.ignore_pattern and self.tokenization == 'punctuation': | ||
|
@@ -396,7 +451,7 @@ def tokenization_func(text): | |
gen.randint(1, MERSENNE_PRIME, dtype=np.uint64), | ||
gen.randint(0, MERSENNE_PRIME, dtype=np.uint64), | ||
) for _ in range(self.num_permutation)], | ||
dtype=np.uint64, | ||
dtype=np.uint32, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This may break some constraints, it's better to keep |
||
).T | ||
|
||
if union_find_parallel_num == 'auto': | ||
|
@@ -435,9 +490,37 @@ def tokenization_func(text): | |
+ empty_hash_value.tobytes() | ||
self.empty_hash_table_id = int(MAX_HASH % self.union_find_parallel_num) | ||
|
||
def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table: | ||
|
||
def band_minhash(self, minhash_list, uid_list): | ||
""" | ||
Logic for creating and pusing LSH bands to the union find list | ||
""" | ||
pairs = {} | ||
minhash_list = minhash_list.to_numpy(zero_copy_only=False) | ||
for minhash, uid in zip(minhash_list, uid_list): | ||
for i, (start, end) in enumerate(self.hash_ranges): | ||
hash_value = i.to_bytes(4, 'big') \ | ||
+ minhash[start:end].tobytes() | ||
hash_table_id = minhash[start] \ | ||
% self.union_find_parallel_num | ||
if hash_table_id not in pairs: | ||
pairs[hash_table_id] = [] | ||
pairs[hash_table_id].append((hash_value, uid)) | ||
result_refs = [] | ||
for i, p in pairs.items(): | ||
if len(result_refs) > self.max_pending_filter_tasks: | ||
ready_refs, result_refs = ray.wait( | ||
result_refs, num_returns=self.num_filter_task_returns) | ||
ray.get(ready_refs) | ||
result_refs.append( | ||
self.union_find_list[i].add_key_value_pairs.remote(p)) | ||
ray.get(result_refs) | ||
|
||
def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table: | ||
""" | ||
Logic for computing minhash values for each text in the input table | ||
""" | ||
pairs = {} | ||
for text, uid in zip(text_list, uid_list): | ||
text = text.as_py() | ||
if self.lowercase: | ||
|
@@ -446,7 +529,6 @@ def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table: | |
text = self.ignore_pattern.sub('', text) | ||
|
||
tokens = self.tokenization_func(text) | ||
|
||
if len(tokens) > 0: | ||
hv = np.array([sha1_hash32(token) for token in tokens], | ||
dtype=np.uint64) | ||
|
@@ -538,6 +620,16 @@ def run(self, dataset, **kwargs): | |
start_time = time.time() | ||
id_generator = IdGenerator.remote() | ||
|
||
|
||
def band_with_uid(table: pa.Table) -> pa.Table: | ||
num_rows = len(table) | ||
min_id, max_id = ray.get(id_generator.get_next_id.remote(num_rows)) | ||
uid_list = range(min_id, max_id) | ||
self.band_minhash(table["_minhash"], uid_list) | ||
new_table = table.append_column(HashKeys.uid, pa.array(list(uid_list))) | ||
new_table = new_table.drop_columns(["_minhash"]) | ||
return new_table | ||
|
||
def minhash_with_uid(table: pa.Table) -> pa.Table: | ||
num_rows = len(table) | ||
min_id, max_id = ray.get(id_generator.get_next_id.remote(num_rows)) | ||
|
@@ -549,22 +641,40 @@ def minhash_with_uid(table: pa.Table) -> pa.Table: | |
|
||
tmp_dir = os.path.join(self.work_dir, '.tmp', | ||
ray.get_runtime_context().get_job_id()) | ||
dataset.map_batches( | ||
minhash_with_uid, | ||
batch_format='pyarrow', | ||
zero_copy_batch=True, | ||
).write_parquet(tmp_dir) | ||
dataset = ray.data.read_parquet(tmp_dir) | ||
if self.use_gpu: | ||
dataset = dataset.map_batches( | ||
GPUMinHashActor, | ||
batch_format='pyarrow', | ||
zero_copy_batch=True, | ||
num_gpus=1, | ||
concurrency=3, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I artificially set this during testing but we would want this to be configurable. I'm not sure what the best approach/config is for this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a |
||
batch_size=self.minhash_batch_size, | ||
) | ||
dataset.map_batches( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any way to merge these two There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should be able to combine the 2 map batches into a single call by moving all the banding logic into the GPU minhash actor, but since the number of GPUs/concurrency level for this stage might be much lesser than the total CPUs available it might reduce the concurrency for banding. It might be a tradeoff between networking overhead (via the object store) vs fewer actors doing the banding. I'm not sure which is more optimal. |
||
band_with_uid, | ||
batch_format='pyarrow', | ||
zero_copy_batch=True, | ||
).write_parquet(tmp_dir) | ||
del dataset | ||
else: | ||
dataset.map_batches( | ||
minhash_with_uid, | ||
batch_format='pyarrow', | ||
zero_copy_batch=True, | ||
).write_parquet(tmp_dir) | ||
end_time = time.time() | ||
logger.info(f'MinHash time = {end_time - start_time}') | ||
|
||
new_dataset = ray.data.read_parquet(tmp_dir) | ||
start_time = time.time() | ||
self.merge() | ||
end_time = time.time() | ||
logger.info(f'merge time = {end_time - start_time}') | ||
result = dataset.map_batches( | ||
start_time = time.time() | ||
result = new_dataset.map_batches( | ||
self.filter_with_union_find, | ||
batch_format='pyarrow', | ||
zero_copy_batch=True, | ||
) | ||
end_time = time.time() | ||
logger.info(f'filter time = {end_time - start_time}') | ||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method can be removed