diff --git a/config/fuzzy_dedup_config.yaml b/config/fuzzy_dedup_config.yaml index a513a72f8..89ae00c1e 100644 --- a/config/fuzzy_dedup_config.yaml +++ b/config/fuzzy_dedup_config.yaml @@ -1,4 +1,5 @@ cache_dir: "./fuzzy_dedup_cache" + # Optional Params below with default values # profile_dir: null # id_field: "id" diff --git a/docs/user-guide/semdedup.rst b/docs/user-guide/semdedup.rst index 66d0f7b00..4482810c3 100644 --- a/docs/user-guide/semdedup.rst +++ b/docs/user-guide/semdedup.rst @@ -172,7 +172,8 @@ Use Individual Components embedding_creator = EmbeddingCreator( embedding_model_name_or_path="path/to/pretrained/model", embedding_batch_size=128, - embedding_output_dir="path/to/output/embeddings", + cache_dir="path/to/output", + embeddings_save_loc="embeddings", input_column="text", logger="path/to/log/dir", ) @@ -190,7 +191,8 @@ Use Individual Components id_column="doc_id", max_iter=100, n_clusters=50000, - clustering_output_dir="path/to/output/clusters", + cache_dir="path/to/output", + clustering_save_loc="clustering_results", logger="path/to/log/dir" ) clustered_dataset = clustering_model(embeddings_dataset) @@ -204,12 +206,13 @@ Use Individual Components # Step 3: Semantic Deduplication semantic_dedup = SemanticClusterLevelDedup( n_clusters=50000, - emb_by_clust_dir="path/to/embeddings/by/cluster", - sorted_clusters_dir="path/to/sorted/clusters", id_column="doc_id", id_column_type="str", which_to_keep="hard", output_dir="path/to/output/deduped", + # cache_dir and clustering_save_loc should match ClusteringModel + cache_dir="path/to/output", + clustering_save_loc="clustering_results", logger="path/to/log/dir" ) semantic_dedup.compute_semantic_match_dfs() diff --git a/examples/fuzzy_deduplication.py b/examples/fuzzy_deduplication.py index 892c52224..ab7d0a423 100644 --- a/examples/fuzzy_deduplication.py +++ b/examples/fuzzy_deduplication.py @@ -38,7 +38,7 @@ def main(args): filetype = "parquet" - # Fuzzy dup calculation only supports the cuDF/GPU backend + # Fuzzy deduplication only supports the cuDF/GPU backend backend = "cudf" assert args.device == "gpu" @@ -89,12 +89,12 @@ def main(args): if duplicates is None: print("No duplicates found") - print(f"Time taken:{time.time() - t0}s") + print(f"Time taken: {time.time() - t0}s") return result = fuzzy_dup.remove(input_dataset, duplicates) write_to_disk(result, output_dir, output_type=filetype) - print(f"Time taken:{time.time() - t0}s") + print(f"Time taken: {time.time() - t0}s") def attach_args( diff --git a/examples/semdedup_example.py b/examples/semdedup_example.py index 0e1d0cd73..b179ba89b 100644 --- a/examples/semdedup_example.py +++ b/examples/semdedup_example.py @@ -49,13 +49,18 @@ def main(args): log_level=logging.INFO, stdout=True, ) + st = time.time() + input_files = get_all_files_paths_under( root=args.input_data_dir, ) + if semdedup_config.num_files > 0: input_files = input_files[: semdedup_config.num_files] + logger.info(f"Processing {len(input_files)} files") + ddf = read_data( input_files=input_files, file_type=args.input_file_type, @@ -63,9 +68,11 @@ def main(args): backend="cudf", ) dataset = DocumentDataset(ddf) + semdup = SemDedup(semdedup_config, logger=logger) dedup_ids = semdup(dataset) print(dedup_ids.df.head()) + logger.info(f"Time taken: {time.time() - st}") client.cancel(client.futures, force=True) client.close() diff --git a/nemo_curator/cache.py b/nemo_curator/cache.py new file mode 100644 index 000000000..920a90a0f --- /dev/null +++ b/nemo_curator/cache.py @@ -0,0 +1,48 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_curator.utils.file_utils import expand_outdir_and_mkdir + + +class Cache: + _instance = None + _cache_dir = None + + def __new__(cls, cache_dir=None): + if cls._instance is None: + cls._instance = super(Cache, cls).__new__(cls) + if cache_dir is not None: + cls._cache_dir = expand_outdir_and_mkdir(cache_dir) + else: + cls._cache_dir = None + elif cache_dir is not None and cls._cache_dir is None: + cls._cache_dir = expand_outdir_and_mkdir(cache_dir) + return cls._instance + + @classmethod + def get_cache_directory(cls) -> str: + """ + Retrieve the cache directory. + """ + return cls._cache_dir + + @classmethod + def delete_cache_instance(cls): + """ + Reset the Cache singleton. + """ + if cls._cache_dir is not None: + cls._cache_dir = None + + cls._instance = None diff --git a/nemo_curator/modules/config.py b/nemo_curator/modules/config.py index 9852ca6b1..7d183e90d 100644 --- a/nemo_curator/modules/config.py +++ b/nemo_curator/modules/config.py @@ -18,6 +18,8 @@ import yaml +from nemo_curator.cache import Cache + @dataclass class BaseConfig: @@ -31,44 +33,49 @@ def from_yaml(cls, file_path: str): @dataclass class FuzzyDuplicatesConfig(BaseConfig): """ - Configuration for MinHash based fuzzy duplicates detection. + Configuration for MinHash-based fuzzy duplicates detection. + Parameters ---------- - seed: Seed for minhash permutations - char_ngrams: Size of Char ngram shingles used in minhash computation - num_buckets: Number of Bands or buckets to use during Locality Sensitive Hashing - hashes_per_bucket: Number of hashes per bucket/band. + cache_dir: If specified, directory to store deduplication intermediates, such as + minhashes, buckets, etc. If None, we check if a cache_dir has been initialized + with Cache().get_cache_directory(). Default is None. + profile_dir: If specified, directory to write Dask profile. Default is None. + id_field: Column in the dataset denoting document ID. Default is "id". + text_field: Column in the dataset denoting document content. Default is "text". + perform_removal: Boolean value to specify whether calling the module should remove + the duplicates from the original dataset, or return the list of IDs denoating + duplicates. Default is False. + seed: Seed for minhash permutations. Default is 42. + char_ngrams: Size of character n-gram shingles used in minhash computation. + Default is 5. + num_buckets: Number of bands or buckets to use during Locality Sensitive Hashing. + Default is 20. + hashes_per_bucket: Number of hashes per bucket/band. Default is 13. use_64_bit_hash: Whether to use a 32bit or 64bit hash function for minhashing. - buckets_per_shuffle: Number of bands/buckets to shuffle concurrently. - Larger values process larger batches by processing multiple bands - but might lead to memory pressures and related errors. - id_field: Column in the Dataset denoting document ID. - text_field: Column in the Dataset denoting document content. - perform_removal: Boolean value to specify whether calling the module should remove the duplicates from - the original dataset, or return the list of IDs denoting duplicates. - profile_dir: str, Default None - If specified directory to write dask profile - cache_dir: str, Default None - Location to store deduplcation intermediates such as minhashes/buckets etc. - false_positive_check: bool, - Whether to run a check to look for false positives within buckets. - Note: This is a computationally expensive step. - num_anchors: int - Number of documents per bucket to use as reference for computing jaccard - pairs within that bucket to identify false positives. - jaccard_threshold: float - The Jaccard similariy threshold to consider a document a near duplicate - during false positive evaluations. + Default is False. + buckets_per_shuffle: Number of bands/buckets to shuffle concurrently. Larger values + process larger batches by processing multiple bands but might lead to memory + pressures and related errors. Default is 1. + false_positive_check: Whether to run a check to look for false positives within + buckets. Note: This is a computationally expensive step. Default is False. + num_anchors: Number of documents per bucket to use as reference for computing + Jaccard pairs within that bucket to identify false positives. Default is 2. + jaccard_threshold: The Jaccard similariy threshold to consider a document a near + duplicate during false positive evaluations. Default is 0.8. + bucket_mapping_blocksize: Default is 256. + parts_per_worker: Default is 1. + bucket_parts_per_worker: Default is 8. """ # General config - cache_dir: str + cache_dir: Optional[str] = None profile_dir: Optional[str] = None id_field: str = "id" text_field: str = "text" perform_removal: bool = False - # Minhash + LSH Config + # Minhash + LSH config seed: int = 42 char_ngrams: int = 24 num_buckets: int = 20 @@ -86,6 +93,7 @@ class FuzzyDuplicatesConfig(BaseConfig): def __post_init__(self): self.num_hashes = self.num_buckets * self.hashes_per_bucket + false_positive_defaults = { "num_anchors": 2, "jaccard_threshold": 0.8, @@ -93,46 +101,64 @@ def __post_init__(self): "parts_per_worker": 1, "bucket_parts_per_worker": 8, } + if self.false_positive_check: warnings.warn( - "Identifying false positives during the Minhash deduplication is computationally expensive." - " For improved performance consider setting this to False" + "Identifying false positives during Minhash deduplication is " + "computationally expensive. For improved performance consider setting " + "this to False." ) + for arg, default in false_positive_defaults.items(): if getattr(self, arg) is None: setattr(self, arg, default) + if self.num_anchors <= 0: - raise ValueError("Number of anchors must be greater than 0") + raise ValueError("Number of anchors must be greater than 0.") + if self.num_anchors > 2: warnings.warn( - "Using a higher number of anchor docs might lead to higher memory footprint and might impact performance", + "Using a higher number of anchor documents might lead to higher memory " + "footprint and might impact performance.", category=UserWarning, ) + if not 0 <= self.jaccard_threshold <= 1: - raise ValueError("Jaccard Threshold must be between [0,1]") + raise ValueError("Jaccard threshold must be between [0, 1].") + else: if self.char_ngrams < 20: warnings.warn( "Using a small char_ngrams value might lead to a large number (~5%) of false positives during deduplication." " Using a value of at least 20 for char_ngrams is recommended." ) + unused_false_positive_args = [ arg for arg in false_positive_defaults.keys() if getattr(self, arg) is not None ] + if unused_false_positive_args: warnings.warn( - f"False positive check is disabled. Unused arguments {unused_false_positive_args} will be ignored", + f"False positive check is disabled. Unused arguments {unused_false_positive_args} will be ignored.", category=UserWarning, ) - if self.cache_dir is None: - raise ValueError( - "Finding fuzzy duplicates requires a cache directory accessible via all workers to store intermediates" - ) if not 1 <= self.buckets_per_shuffle <= self.num_buckets: - raise ValueError("Buckets per shuffle must be between [1, num_buckets]") + raise ValueError("Buckets per shuffle must be between [1, num_buckets].") + + if self.cache_dir is None: + cache_dir = Cache().get_cache_directory() + if cache_dir is None: + raise ValueError( + "Finding fuzzy duplicates requires a cache directory accessible via " + "all workers to store intermediates. Please use " + "Cache(cache_dir=...) or FuzzyDuplicatesConfig(cache_dir=...) to " + "set the cache directory." + ) + else: + self.cache_dir = cache_dir if not self.perform_removal: warnings.warn( @@ -146,7 +172,9 @@ class SemDedupConfig(BaseConfig): Configuration for Semantic Deduplication. Attributes: - cache_dir (str): Directory to store cache. + cache_dir (Optional[str]): If specified, directory to store cache. + If None, we check if a cache_dir has been initialized with Cache().get_cache_directory(). + Default is None. profile_dir (Optional[str]): If specified, directory to write Dask profile. Default is None. num_files (int): Number of files. Default is -1, meaning all files. @@ -190,7 +218,7 @@ class SemDedupConfig(BaseConfig): Default is 0.01. """ - cache_dir: str + cache_dir: str = None profile_dir: Optional[str] = None num_files: int = -1 @@ -216,17 +244,25 @@ class SemDedupConfig(BaseConfig): kmeans_with_cos_dist: bool = False clustering_input_partition_size: str = "2gb" - # Extract dedup config + # SemDedup eps_thresholds: List[float] = field(default_factory=lambda: [0.01, 0.001]) eps_to_extract: float = 0.01 def __post_init__(self): if self.cache_dir is None: - raise ValueError( - "Finding sem-dedup requires a cache directory accessible via all workers to store intermediates" - ) + cache_dir = Cache().get_cache_directory() + if cache_dir is None: + raise ValueError( + "Finding semantic duplicates requires a cache directory accessible " + "via all workers to store intermediates. Please use " + "Cache(cache_dir=...) or SemDedupConfig(cache_dir=...) to " + "set the cache directory." + ) + else: + self.cache_dir = cache_dir if self.eps_to_extract not in self.eps_thresholds: raise ValueError( - f"Epsilon to extract {self.eps_to_extract} must be in eps_thresholds {self.eps_thresholds}" + f"Epsilon to extract {self.eps_to_extract} must be in eps_thresholds " + f"{self.eps_thresholds}." ) diff --git a/nemo_curator/modules/exact_dedup.py b/nemo_curator/modules/exact_dedup.py index 33b940a95..b98cf3246 100644 --- a/nemo_curator/modules/exact_dedup.py +++ b/nemo_curator/modules/exact_dedup.py @@ -26,6 +26,7 @@ from dask import dataframe as dd from nemo_curator._compat import DASK_P2P_ERROR +from nemo_curator.cache import Cache from nemo_curator.datasets import DocumentDataset from nemo_curator.log import create_logger from nemo_curator.modules.base import BaseModule @@ -53,13 +54,17 @@ def __init__( Parameters ---------- logger: Existing logger to log to, or a path to a log directory. - id_field: Column in the Dataset denoting document ID. - text_field: Column in the Dataset denoting document content. - hash_method: The hashing algorithm used for identifying exact duplicates. Currently supports {"md5"} - profile_dir: str, Default None - If specified directory to write dask profile - cache_dir: str, Default None - If specified, will compute & write duplicate id's to cache directory. + id_field: Column in the dataset denoting document ID. + text_field: Column in the dataset denoting document content. + hash_method: The hashing algorithm used for identifying exact duplicates. + Currently only supports "md5". + perform_removal: Boolean value to specify whether calling the module should + remove the duplicates from the original dataset, or return the list of IDs + denoting duplicates. + profile_dir: If specified, directory to write Dask profile. Default is None. + cache_dir: If specified, will compute and write duplicate IDs to cache directory. + If None, we check if a cache_dir has been initialized with Cache().get_cache_directory(). + Default is None. """ super().__init__(input_backend="any") @@ -71,6 +76,19 @@ def __init__( self.hash_method = hash_method self.id_field = id_field self.text_field = text_field + + if cache_dir is None: + self.cache_dir = Cache().get_cache_directory() + else: + self.cache_dir = cache_dir + + if self.cache_dir is None and profile_dir is not None: + warnings.warn( + "cache_dir for intermediate outputs is required to generate profiles. " + "Please initialize with Cache(cache_dir=...) or ExactDuplicates(cache_dir=...)" + ) + self.profile_dir = profile_dir + self.perform_removal = perform_removal if not self.perform_removal: @@ -78,17 +96,12 @@ def __init__( "In future NeMo Curator releases, the default value for perform_removal will be True." ) - if self.perform_removal and cache_dir is None: - warnings.warn("cache_dir is recommended to remove duplicates.") - - if cache_dir is None and profile_dir is not None: + if self.perform_removal and self.cache_dir is None: warnings.warn( - "cache_dir for intermediate outputs is required to generate profiles" + "cache_dir is recommended to remove duplicates. " + "Please initialize with Cache(cache_dir=...) or ExactDuplicates(cache_dir=...)" ) - self.cache_dir = cache_dir - self.profile_dir = profile_dir - if isinstance(logger, str): self._logger = create_logger( rank=0, @@ -100,7 +113,8 @@ def __init__( def _exact_dup_ids(self, df: dd.DataFrame) -> dd.DataFrame: """ - Get the id's for text/documents that are exact duplicates + Get the IDs for text/documents that are exact duplicates. + Parameters ---------- df: dask.dataframe.DataFrame @@ -130,10 +144,11 @@ def _compute_hashes( df: dd.DataFrame, ) -> dd.DataFrame: """ - Computes the hash of the text_column provided and returns a dataframe - containing the id_column and relevant hashes in the _hashes column. + Computes the hash of the text field provided and returns a DataFrame + containing the ID field and relevant hashes in the _hashes column. """ self._logger.info("Starting lazy hash generation") + res = df[[self.id_field]] res["_hashes"] = df[self.text_field].map_partitions(self.hash_documents) @@ -153,7 +168,7 @@ def hash_documents( return df.hash_values(method=self.hash_method) elif isinstance(df, pd.Series): - # TODO: Generalize ty using self.hash_method + # TODO: Generalize by using self.hash_method return df.apply(lambda x: md5(x.encode()).hexdigest()) else: @@ -161,7 +176,8 @@ def hash_documents( def identify_duplicates(self, dataset: DocumentDataset) -> DocumentDataset: """ - Find document IDs for exact duplicates in a given DocumentDataset + Find document IDs for exact duplicates in a given DocumentDataset. + Parameters ---------- dataset: DocumentDataset @@ -176,7 +192,7 @@ def identify_duplicates(self, dataset: DocumentDataset) -> DocumentDataset: return DocumentDataset(result) t0 = time.time() - self._logger.info("Starting execution for ExactDedup") + self._logger.info("Starting execution for ExactDuplicates") write_path = os.path.join(self.cache_dir, "_exact_duplicates.parquet") if os.path.exists(write_path): @@ -191,7 +207,8 @@ def identify_duplicates(self, dataset: DocumentDataset) -> DocumentDataset: result.to_parquet(write_path, write_index=False, overwrite=True) self._logger.info( - f"Time taken for Exact Dedup Computation = {time.time() - t0}s and output written at {write_path}" + f"Time taken for ExactDuplicates computation = {time.time() - t0}s \n" + f"Output written at {write_path}" ) backend = "cudf" if is_cudf_type(result) else "pandas" diff --git a/nemo_curator/modules/fuzzy_dedup/_mapbuckets.py b/nemo_curator/modules/fuzzy_dedup/_mapbuckets.py index 20a09ed79..b0eb62216 100644 --- a/nemo_curator/modules/fuzzy_dedup/_mapbuckets.py +++ b/nemo_curator/modules/fuzzy_dedup/_mapbuckets.py @@ -32,12 +32,13 @@ class _MapBuckets: """ - buckets to a logical partition by using a modified bin packing algorithm. + Buckets to a logical partition by using a modified bin packing algorithm. Combines buckets generated from LSH (typically high cardinality) to more coarse lower cardinality bucket groups by mapping multiple buckets to a logical partition using document length information and a modified bin packing algorithm. - Only needed if running False Postive check to remove false positives. + + Only needed if running false positive check to remove false positives. """ def __init__( @@ -49,17 +50,18 @@ def __init__( logger: Union[logging.LoggerAdapter, str] = "./", ): """ - id_fields: list or str - id fields of df - text_field: str = "text", - bucket_column: str = "bucket_column", - num_anchors: int = 2, - logger: Union[logging.LoggerAdapter, str] = "./", + id_fields (list or str): ID fields of DataFrame. Default is "id". + text_field (str): text field of DataFrame. Default is "text". + bucket_field (str): Default is "_bucket_id". + num_anchors (int): Default is 2. + logger (Union[logging.LoggerAdapter, str]): Default is "./". """ + self.id_fields = [id_fields] if isinstance(id_fields, str) else id_fields self.text_field = text_field self.num_anchors = num_anchors self.bucket_field = bucket_field + if isinstance(logger, str): self._logger = create_logger( rank=0, @@ -78,12 +80,13 @@ def _get_output_part_ids_with_approx_equal_sum( output_partition_column: str, ) -> cudf.DataFrame: """ - Create a output_series that maps the ser.index into `nparts` + Create an output_series that maps the ser.index into `nparts` so that the total sum of bucket_val_counts_df - for each output id are all most equal and - less than max_text_bytes_per_part - This is used downstream for creating equal output_ids + for each output ID are almost equal and + less than max_text_bytes_per_part. + This is used downstream for creating equal output_ids. """ + sizes = bucket_text_bytes_df[bytes_column].values bucket_output_ar = build_partition( sizes=sizes.get(), max_size=max_text_bytes_per_part @@ -104,8 +107,8 @@ def _get_output_map_from_text_bytes_per_bucket( max_text_bytes_per_part = int(np.iinfo(np.int32).max * 3) self._logger.info(f"max_text_bytes_per_part = {max_text_bytes_per_part}") - # Increasing in an attempt to prevent hitting - # ulimits + + # Increasing in an attempt to prevent hitting ulimits output_map_df_meta = cudf.DataFrame( {self.bucket_field: [0], output_partition_column: [1]} ) @@ -122,9 +125,11 @@ def _get_output_map_from_text_bytes_per_bucket( meta=output_map_df_meta, ) output_map_df = output_map_df.persist() + self._logger.info( - f"Step 1 of output_map_df of len: {len(output_map_df)} computed" + f"Step 1 of output_map_df of length {len(output_map_df)} computed" ) + lower_bounds = ( output_map_df[output_partition_column] .map_partitions(lambda s: (s.max() + 1)) @@ -145,9 +150,11 @@ def update_id(df, lower_bound): updated_parts.append(output_map_df.get_partition(0)) output_map_df = dask_cudf.concat(updated_parts) output_map_df = output_map_df.persist() + self._logger.info( - f"All steps of output_map_df of len: {len(output_map_df)} computed" + f"All steps of output_map_df of length {len(output_map_df)} computed" ) + return output_map_df def _get_output_map_based_on_str_bytes( @@ -156,6 +163,7 @@ def _get_output_map_based_on_str_bytes( """ Add output_partition_id to buckets_ddf """ + documents_df = documents_df.copy() documents_df[bytes_column] = documents_df[self.text_field].map_partitions( lambda s: s.str.byte_count() @@ -168,6 +176,7 @@ def _get_output_map_based_on_str_bytes( npartitions=n_partitions ) del documents_df + ddf_bk_text_bytes, agg_df_len = get_agg_text_bytes_df( df=buckets_df, agg_column=self.bucket_field, @@ -175,7 +184,9 @@ def _get_output_map_based_on_str_bytes( n_partitions=n_partitions, shuffle=True, ) - self._logger.info(f"Agg_df computed of length = {agg_df_len}") + + self._logger.info(f"agg_df of length {agg_df_len} computed") + del buckets_df output_map_df = self._get_output_map_from_text_bytes_per_bucket( ddf_bk_text_bytes=ddf_bk_text_bytes, @@ -185,8 +196,9 @@ def _get_output_map_based_on_str_bytes( def _random_select_anchor(self, buckets_df, n=2): """ - Randomly select `n` anchors from each bucket. + Randomly select n anchors from each bucket. """ + buckets_df = buckets_df.copy() buckets_df["_id_hash"] = buckets_df[self.id_fields].hash_values() buckets_df = buckets_df.sort_values([self.bucket_field, "_id_hash"]) @@ -194,8 +206,10 @@ def _random_select_anchor(self, buckets_df, n=2): self.bucket_field ).cumcount() buckets_df["is_anchor"] = buckets_df["_order_in_bucket"] < n + for i in range(0, n): buckets_df[f"is_anchor_id_{i}"] = buckets_df["_order_in_bucket"] == i + buckets_df = buckets_df.drop(columns=["_id_hash", "_order_in_bucket"], axis=1) buckets_df = buckets_df.reset_index(drop=True) buckets_df = buckets_df[buckets_df.is_anchor] @@ -205,14 +219,17 @@ def _add_anchor_docs(self, buckets_df, num_anchors): """ Get anchor documents for each bucket. """ + df_anchor_bk = self._random_select_anchor(buckets_df=buckets_df, n=num_anchors) df_anchor_docs = None + for i in range(num_anchors): df_anchor_bk_i = df_anchor_bk[df_anchor_bk[f"is_anchor_id_{i}"]][ [self.bucket_field] + self.id_fields ].reset_index(drop=True) column_mapping = {id: f"anchor_{i}_{id}" for id in self.id_fields} df_anchor_bk_i = df_anchor_bk_i.rename(columns=column_mapping) + if i == 0: df_anchor_docs = df_anchor_bk_i else: @@ -232,17 +249,10 @@ def map_buckets_with_anchors( shuffle_type: Union[str, bool, None] = "tasks", ) -> dask_cudf.DataFrame: """ - Get anchor docs with bucket info - Args: - input_data_paths: list of paths to input data - input_bucket_path: path to input buckets - text_ddf_blocksize: blocksize for text ddf - num_files: number of files to read - num_workers: number of workers - shuffle_type: type of shuffle to use - Returns: - ddf_anchor_docs_with_bk + Get anchor documents with bucket information. + """ + output_map_df = self._get_output_map_based_on_str_bytes( buckets_df=buckets_df, documents_df=documents_df ) @@ -253,10 +263,12 @@ def map_buckets_with_anchors( ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.merge( output_map_df, on=self.bucket_field ) + # Bucket is no longer needed ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.drop( columns=[self.bucket_field] ) + # Below removes any duplicates lying around after dropping buckets ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.map_partitions( M.drop_duplicates, @@ -265,6 +277,7 @@ def map_buckets_with_anchors( transform_divisions=False, align_dataframes=False, ) + ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.shuffle( self.id_fields, ignore_index=True, @@ -276,5 +289,6 @@ def map_buckets_with_anchors( transform_divisions=False, align_dataframes=False, ) + del output_map_df return ddf_anchor_docs_with_bk diff --git a/nemo_curator/modules/fuzzy_dedup/_shuffle.py b/nemo_curator/modules/fuzzy_dedup/_shuffle.py index 218bf4a62..471ce357b 100644 --- a/nemo_curator/modules/fuzzy_dedup/_shuffle.py +++ b/nemo_curator/modules/fuzzy_dedup/_shuffle.py @@ -76,21 +76,24 @@ def shuffle_docs_on_buckets( bucket_parts_per_worker: int = 8, partition_on: str = "_output_partition_id", ): - ddf_anchor_docs_with_bk, bk_mapping = aggregated_anchor_docs_with_bk_read( path=bucket_w_anchors_path, blocksize=bucket_mapping_df_blocksize, ) - self._logger.info("Getting ddf_anchor_docs_with_bk completed") + + self._logger.info("Computing ddf_anchor_docs_with_bk completed") self._logger.debug( f"ddf_anchor_docs_with_bk.npartitions = {ddf_anchor_docs_with_bk.npartitions}" ) + st = time.time() + num_workers = get_num_workers(get_current_client()) parts_per_batch = num_workers * parts_per_worker - self._logger.debug(f"parts_per_batch = {parts_per_batch}") + + self._logger.debug(f"parts_per_batch = {parts_per_batch}") parts_per_bucket_batch = num_workers * bucket_parts_per_worker - self._logger.debug(f"parts_per_bucket_batch = {parts_per_bucket_batch}") + self._logger.debug(f"parts_per_bucket_batch = {parts_per_bucket_batch}") dask_profile_name = ( "suffle_docs" @@ -111,8 +114,9 @@ def shuffle_docs_on_buckets( bk_mapping=bk_mapping, num_workers=num_workers, ) + self._logger.info( - f"Time taken for Shuffle = {time.time()-st}s and output written at {output_shuffled_docs_path}" + f"Time taken for _Shuffle: {time.time()-st}s and output written at {output_shuffled_docs_path}" ) def _batched_merge_and_write( @@ -145,10 +149,9 @@ def _batched_merge_and_write( ) # Set end offsets - # NOTE: These end offsets are always set to the end - # of the data. However, we may want to be able to set - # both the start and end offsets from the command line - # in the future. + # NOTE: These end offsets are always set to the end of the data. + # However, we may want to be able to set both the start and end offsets from + # the command line in the future. bucket_part_end_offset = total_bucket_partitions text_part_end_offset = total_text_partitions @@ -158,7 +161,6 @@ def _batched_merge_and_write( assert text_part_end_offset > text_part_start_offset # Initialize "retry" variables - # # - retry_count: The number of successive batches that # we have already performed at a reduced batch size. # - retry_threshold: The number of successive batches @@ -179,11 +181,11 @@ def _batched_merge_and_write( bucket_part_start_offset, bucket_part_end_offset, parts_per_bucket_batch ) ): - # Outer loop over batches of "bucket-map" partitions end_bucket_offset = min( bucket_part_offset + parts_per_bucket_batch, bucket_part_end_offset ) + print( f"\nStarted processing bucket-map partitions {bucket_part_offset} " f"through {end_bucket_offset} of {bucket_part_end_offset}", @@ -207,13 +209,13 @@ def _batched_merge_and_write( text_part_offset = text_part_start_offset while text_part_offset < text_part_end_offset: - # Check if we are "retrying" with a smaller "parts_per_text_batch" if parts_per_text_batch_retry: parts_per_text_batch_use = parts_per_text_batch_retry else: st_text = time.time() parts_per_text_batch_use = parts_per_text_batch + print(f"Using {parts_per_text_batch_use} text partitions.", flush=True) # Select partitions for our text batch @@ -234,7 +236,9 @@ def _batched_merge_and_write( output_df = output_df.map_partitions( int_ids_to_str, id_column=self.int_to_str_id ) + batch_label = f"{end_bucket_offset}_{end_text_offset}" + if output_df is not None: written_files = output_df.map_partitions( write_partitioned_file, @@ -244,13 +248,14 @@ def _batched_merge_and_write( meta=cudf.Series([True]), ) written_files = written_files.compute() + update_restart_offsets(output_path, bucket_part_offset, end_text_offset) del output_df print( - "Text-df partition ", + "text-df partition ", f"{end_text_offset}/{text_part_end_offset} " - f"completed in {time.time()-st_text}", + f"completed in {time.time()-st_text}s", flush=True, ) @@ -268,13 +273,15 @@ def _batched_merge_and_write( # case we fail again parts_per_text_batch_retry = None retry_count, retry_threshold = 0, min(retry_threshold * 2, 16) + text_part_offset += parts_per_text_batch_use update_restart_offsets(output_path, end_bucket_offset, end_text_offset) + print( "Bucket partition ", f"{end_bucket_offset}/{bucket_part_end_offset} " - f"completed in {time.time()-st_bucket}", + f"completed in {time.time()-st_bucket}s", flush=True, ) diff --git a/nemo_curator/modules/fuzzy_dedup/bucketstoedges.py b/nemo_curator/modules/fuzzy_dedup/bucketstoedges.py index 5ff08b4c7..22e27174f 100644 --- a/nemo_curator/modules/fuzzy_dedup/bucketstoedges.py +++ b/nemo_curator/modules/fuzzy_dedup/bucketstoedges.py @@ -27,6 +27,7 @@ import pandas as pd import pyarrow as pa +from nemo_curator.cache import Cache from nemo_curator.datasets import DocumentDataset from nemo_curator.log import create_logger from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix @@ -34,14 +35,13 @@ class BucketsToEdges: """ - Maps buckets generated from LSH into an edgelist that - can be processed further by Connected Components to find duplicate - documents + Maps buckets generated from LSH into an edgelist that can be processed further by + Connected Components to find duplicate documents. """ def __init__( self, - cache_dir: str = None, + cache_dir: Optional[str] = None, id_fields: Union[list, str] = "id", str_id_name: str = "id", bucket_field: str = "_bucket_id", @@ -51,24 +51,29 @@ def __init__( """ Parameters ---------- - cache_dir: str or None - If specified, will compute & write the edgelist to a file - id_fields: list or str - id fields of documents in buckets_df - str_id_name: str - Ignored if there is a single id field. Multiple id fields - will be combined into a single id field with the given name. - bucket_field: str - Column denoting bucket ID - num_buckets: Number of bands/buckets to create from the minhash signature. - Hashes_per_signature = num_hashes / num_buckets + cache_dir: Directory to compute and write edgelist. Can also be set with + Cache(cache_dir=...). Default is None. + id_fields: List or string representing column(s) in buckets_df denoting + document ID. Default is "id". + str_id_name: Ignored if there is a single ID field. Multiple ID fields will be + combined into a single ID field with the given name. Default is "id". + bucket_field: Column denoting bucket ID. Default is "_bucket_id". + logger: Existing logger to log to, or a path to a log directory. + Default is "./". + profile_dir: If specified, directory to write Dask profile. Default is None. """ - self.cache_dir = cache_dir + + if cache_dir is None: + self.cache_dir = Cache().get_cache_directory() + else: + self.cache_dir = cache_dir + self.id_fields = [id_fields] if isinstance(id_fields, str) else id_fields self.str_id_name = str_id_name if len(self.id_fields) > 1 else self.id_fields[0] self.output_ids = [f"{self.str_id_name}_x", f"{self.str_id_name}_y"] self.bucket_field = bucket_field self.profile_dir = profile_dir + if isinstance(logger, str): self._logger = create_logger( rank=0, @@ -84,12 +89,13 @@ def _combine_multiple_ids( ) -> cudf.DataFrame: if output_id_field in input_df.columns: raise ValueError( - f"Input df already contains column named: {output_id_field}" + f"Input DataFrame already contains column named {output_id_field}" ) output_df = input_df.copy()[input_df.columns.difference(input_id_fields)] output_df[output_id_field] = input_df[input_id_fields[0]].astype(str) + for input_field in input_id_fields[1:]: output_df[output_id_field] = output_df[output_id_field] = ( input_df[input_id_fields[0]].astype(str) @@ -109,23 +115,29 @@ def buckets_to_edges( .agg(list) .list.sort_values() ) + bucket_docs = grouped_buckets.to_arrow().to_pylist() edges = [] + # Create pairs of all documents within a bucket since they are near duplicates # Effectively create a edge list of all near duplicate documents for bucket_doc in bucket_docs: edges.extend(pairwise(bucket_doc)) + edges = pd.DataFrame(edges, columns=self.output_ids) edges = pa.Table.from_pandas(edges) result_df = cudf.DataFrame.from_arrow(edges) del edges + result_df = result_df.drop_duplicates(self.output_ids).reset_index(drop=True) result_df["jaccard"] = np.float32(1.0) return result_df def __call__(self, dataset: DocumentDataset) -> DocumentDataset: buckets_df = dataset.df - self._logger.info(f"Starting conversion of LSH Buckets to Graph Edgelist") + + self._logger.info(f"Starting conversion of LSH buckets to graph edgelist") + if len(self.id_fields) > 1: buckets_df = buckets_df.map_partitions( BucketsToEdges._combine_multiple_ids, @@ -145,14 +157,16 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset: warnings.warn( f"Output path {write_path} already exists and will be overwritten" ) + t0 = time.time() with performance_report_if_with_ts_suffix( self.profile_dir, "bucket-to-edges", ): edges_df.to_parquet(write_path, write_index=False, overwrite=True) + self._logger.info( - f"Time taken for Converted Buckets To Edgelist = {time.time() - t0}s and output written at {write_path}" + f"Time taken for converted buckets to edgelist: {time.time() - t0}s and output written at {write_path}" ) return DocumentDataset( diff --git a/nemo_curator/modules/fuzzy_dedup/connectedcomponents.py b/nemo_curator/modules/fuzzy_dedup/connectedcomponents.py index 1394ae9a0..3ce144eff 100644 --- a/nemo_curator/modules/fuzzy_dedup/connectedcomponents.py +++ b/nemo_curator/modules/fuzzy_dedup/connectedcomponents.py @@ -28,6 +28,7 @@ from cugraph import MultiGraph from dask.utils import M +from nemo_curator.cache import Cache from nemo_curator.log import create_logger from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix @@ -35,20 +36,31 @@ class ConnectedComponents: def __init__( self, - cache_dir: str, jaccard_pairs_path: str, + cache_dir: Optional[str] = None, id_column="id", jaccard_threshold: float = 0.8, logger: Union[logging.LoggerAdapter, str] = "./", profile_dir: Optional[str] = None, ): - self.cache_dir = cache_dir self.jaccard_pairs_path = jaccard_pairs_path + + if cache_dir is None: + self.cache_dir = Cache().get_cache_directory() + else: + self.cache_dir = cache_dir + if self.cache_dir is None: + raise ValueError( + "cache_dir is required for Connected Components. Please initialize with " + "Cache(cache_dir=...) or ConnectedComponents(cache_dir=...)" + ) + self.id_column = id_column self.left_id = f"{id_column}_x" self.right_id = f"{id_column}_y" self.jaccard_threshold = jaccard_threshold self.profile_dir = profile_dir + if isinstance(logger, str): self._logger = create_logger( rank=0, @@ -60,12 +72,15 @@ def __init__( def cc_workflow(self, output_path): deduped_parsed_id_path = self._write_dedup_parsed_id() + encoded_jaccard_pair_path = self._write_encoded_jaccard_pair( deduped_parsed_id_path ) + deduped_encoded_jaccard_path = self._write_dedup_encoded_jaccard_pair( encoded_jaccard_pair_path ) + cc_path = self._run_connected_components( deduped_encoded_jaccard_path, deduped_parsed_id_path, output_path ) @@ -81,7 +96,6 @@ def _run_connected_components( with performance_report_if_with_ts_suffix( self.profile_dir, "connected-components-run" ): - Comms.initialize(p2p=False) df = dask_cudf.read_parquet( deduped_encoded_jaccard_path, blocksize="1GB", aggregate_files=True @@ -102,6 +116,7 @@ def _run_connected_components( ) result = dcg.weakly_connected_components(G) del G + max_partitions = min(32, result.npartitions) n_components = len( result[["labels"]].drop_duplicates(split_out=max_partitions) @@ -110,6 +125,7 @@ def _run_connected_components( labels_df = labels_df.merge( result, left_on=["uid"], right_on=["vertex"], how="inner" ) + id_columns = [self.id_column] labels_df = labels_df[id_columns + ["labels"]] labels_df = labels_df.rename(columns={"labels": "group"}) @@ -119,27 +135,31 @@ def _run_connected_components( self._logger.info( "Result of connected compoinents are " - f"# of groups : {n_components}, " - f"# of docs removed : {num_labels - n_components}, " - f"# nodes = {num_nodes}, " - f"# rows in labels_df = {len(labels_df)}" + f"# of groups: {n_components}, " + f"# of documents removed: {num_labels - n_components}, " + f"# nodes: {num_nodes}, " + f"# rows in labels_df: {len(labels_df)}" ) assert num_nodes == len(labels_df) - # Ensure all docs in the same group are in the same partition + + # Ensure all documents in the same group are in the same partition labels_df = labels_df.shuffle(on=["group"], ignore_index=True) labels_df.to_parquet(output_path, write_index=False, overwrite=True) Comms.destroy() + self._logger.info( - f"Time taken for Connected Components Run = {time.time() - t0}s and output written at {output_path}" + f"Time taken for connected components: {time.time() - t0}s and output written at {output_path}" ) @staticmethod def _sort_ids(df, id_columns): x = df[id_columns].values x = cp.sort(x, axis=1) + for i, id_column in enumerate(id_columns): df[id_column] = x[:, i] df[id_column] = df[id_column].astype("uint64") + return df @staticmethod @@ -152,10 +172,10 @@ def thresholding(df, threshold, column_to_threshold): def _write_dedup_encoded_jaccard_pair(self, encoded_jaccard_pair_path): output_path = f"{self.cache_dir}/final_dedup_encoded_jaccard_pair.parquet" t0 = time.time() + with performance_report_if_with_ts_suffix( self.profile_dir, "connected-components-dedup-encoded-jaccard-pair" ): - ddf = dask_cudf.read_parquet( encoded_jaccard_pair_path, blocksize="512MB", aggregate_files=True ) @@ -196,14 +216,17 @@ def _write_dedup_encoded_jaccard_pair(self, encoded_jaccard_pair_path): align_dataframes=False, ) ddf.to_parquet(output_path, write_index=False, overwrite=True) + self._logger.info( - f"Time taken for Dedup Encoding Jaccard Pairs = {time.time() - t0}s and output written at {output_path}" + f"Time taken for dedupe encoding Jaccard pairs: {time.time() - t0}s and output written at {output_path}" ) + return output_path def _write_dedup_parsed_id(self): dedup_parsed_id_path = f"{self.cache_dir}/dedup_parsed_id.parquet" t0 = time.time() + with performance_report_if_with_ts_suffix( self.profile_dir, "connected-components-dedup-parsed-id" ): @@ -221,20 +244,24 @@ def _write_dedup_parsed_id(self): # Dask does not guard against split_out=0 split_out=max(ddf.npartitions // 4, 1) ) + unique_docs["uid"] = np.uint64(1) unique_docs["uid"] = unique_docs["uid"].cumsum() unique_docs["uid"] = unique_docs["uid"] - 1 unique_docs.to_parquet( dedup_parsed_id_path, write_index=False, overwrite=True ) + self._logger.info( - f"Time taken for Dedup Parsed Id = {time.time() - t0}s and output written at {dedup_parsed_id_path}" + f"Time taken for dedupe parsed ID: {time.time() - t0}s and output written at {dedup_parsed_id_path}" ) + return dedup_parsed_id_path def _write_encoded_jaccard_pair(self, dedup_parsed_id_path): output_path = f"{self.cache_dir}/encoded_jaccard_pair/" t0 = time.time() + with performance_report_if_with_ts_suffix( self.profile_dir, "connected-components-encoded-jaccard-pair" ): @@ -252,9 +279,11 @@ def _write_encoded_jaccard_pair(self, dedup_parsed_id_path): output_path=output_path, id_column=self.id_column, ) + self._logger.info( - f"Time taken for Encoding Jaccard Pairs = {time.time() - t0}s and output written at {output_path}" + f"Time taken for encoding Jaccard pairs: {time.time() - t0}s and output written at {output_path}" ) + return output_path def _merge_and_write( @@ -265,11 +294,13 @@ def _merge_and_write( id_column: str, ) -> None: st = time.time() - # Ensure 'id_columns' is a list + + # Ensure id_columns is a list ddf_id = ddf_id.set_index(id_column) + for tag in ["x", "y"]: pair_id = f"{id_column}_{tag}" - # Merge 'ddf' with 'ddf_id' to map ids to uids + # Merge ddf with ddf_id to map IDs to UIDs ddf = ddf.merge( ddf_id, left_on=pair_id, @@ -279,19 +310,22 @@ def _merge_and_write( ) ddf = ddf.drop(columns=pair_id) ddf = ddf.rename(columns={"uid": f"{self.id_column}_{tag}"}) + ddf = ddf[[self.left_id, self.right_id, "jaccard"]] ddf.to_parquet(output_path, write_index=False, overwrite=True) et = time.time() self._logger.info( - f"Time taken for merge and write = {et - st}s and output written at {output_path}" + f"Time taken for merge and write: {et - st}s and output written at {output_path}" ) @staticmethod def _get_unique_ids_per_partition(df, id_columns): unique_df_ls = [] + for tag in ["x", "y"]: cols_to_drop = [] + for id_col in id_columns: cols_to_drop.append(f"{id_col}_{tag}") @@ -300,6 +334,7 @@ def _get_unique_ids_per_partition(df, id_columns): columns={f"{id_col}_{tag}": f"{id_col}" for id_col in id_columns} ) unique_df_ls.append(subset_df) + unique_df = cudf.concat(unique_df_ls, ignore_index=True) unique_df = unique_df.drop_duplicates(ignore_index=True) return unique_df diff --git a/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py b/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py index e037829dd..82e21ec63 100644 --- a/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py +++ b/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py @@ -19,6 +19,7 @@ import time from typing import Optional, Union +from nemo_curator.cache import Cache from nemo_curator.datasets import DocumentDataset from nemo_curator.log import create_logger from nemo_curator.modules.base import BaseModule @@ -44,15 +45,16 @@ def __init__( """ Parameters ---------- - config: FuzzyDuplicatesConfig, - Config options for finding FuzzyDuplicates + config (FuzzyDuplicatesConfig): Config options for finding fuzzy duplicates. logger: Existing logger to log to, or a path to a log directory. + Default is "./". Returns ------- DocumentDataset containing IDs of all documents and the corresponding duplicate group they belong to. Documents in the same group are near duplicates. """ + super().__init__(input_backend="cudf") if isinstance(logger, str): self._logger = create_logger( @@ -65,6 +67,16 @@ def __init__( self.config = config + if self.config.cache_dir is not None: + self.cache_dir = self.config.cache_dir + elif Cache().get_cache_directory() is not None: + self.cache_dir = Cache().get_cache_directory() + else: + raise RuntimeError( + "No cache directory specified. Please initialize with Cache(cache_dir=...) " + "or specify a cache_dir in your YAML file." + ) + self.minhash = MinHash( seed=self.config.seed, num_hashes=self.config.num_hashes, @@ -74,10 +86,11 @@ def __init__( id_field=self.config.id_field, text_field=self.config.text_field, profile_dir=self.config.profile_dir, - cache_dir=self.config.cache_dir, + cache_dir=self.cache_dir, ) + self.lsh = LSH( - cache_dir=self.config.cache_dir, + cache_dir=self.cache_dir, num_hashes=self.config.num_hashes, num_buckets=self.config.num_buckets, buckets_per_shuffle=self.config.buckets_per_shuffle, @@ -109,9 +122,10 @@ def __init__( for i in range(self.config.num_anchors) ], ) + else: self.buckets_to_edges = BucketsToEdges( - cache_dir=self.config.cache_dir, + cache_dir=self.cache_dir, id_fields=self.config.id_field, logger=self._logger, profile_dir=self.config.profile_dir, @@ -123,8 +137,8 @@ def __init__( else "_edges.parquet" ) self.connected_components = ConnectedComponents( - cache_dir=self.config.cache_dir, - jaccard_pairs_path=os.path.join(self.config.cache_dir, jaccard_pairs_fname), + cache_dir=self.cache_dir, + jaccard_pairs_path=os.path.join(self.cache_dir, jaccard_pairs_fname), id_column=self.config.id_field, jaccard_threshold=self.config.jaccard_threshold, logger=self._logger, @@ -137,8 +151,8 @@ def identify_duplicates( """ Parameters ---------- - dataset: DocumentDataset - The input datset to compute FuzzyDuplicates. Must contain a text and unique id field. + dataset (DocumentDataset): The input dataset on which to compute fuzzy deduplication. + Must contain a text field and unique ID field. Returns ------- @@ -152,20 +166,22 @@ def identify_duplicates( minhashLSH = Sequential([self.minhash, self.lsh]) buckets_df = minhashLSH(dataset) print(f"Stage {stage_num}: Minhash + LSH complete!") + if buckets_df is None: print( f"Stage {stage_num}: No potential duplicate documents found during LSH" ) return None - stage_num += 1 + stage_num += 1 if self.config.false_positive_check: # Map buckets to lower cardinality distribution print(f"Stage {stage_num} (False Positive Check): Starting Map_Buckets") t0 = time.time() mapped_buckets_w_anchors_path = os.path.join( - self.config.cache_dir, "anchor_docs_with_bk.parquet" + self.cache_dir, "anchor_docs_with_bk.parquet" ) + with performance_report_if_with_ts_suffix( self.config.profile_dir, "map_buckets", @@ -178,18 +194,17 @@ def identify_duplicates( ddf_mapped_buckets_w_anchors.to_parquet( mapped_buckets_w_anchors_path, write_index=False, overwrite=True ) + self._logger.info( - f"Time taken for Map_buckets : {time.time() - t0}s and output written at {mapped_buckets_w_anchors_path}" + f"Time taken for Map_Buckets: {time.time() - t0}s and output written at {mapped_buckets_w_anchors_path}" ) - print(f"Stage {stage_num} (False Postive Check): Map_Buckets Complete!") + print(f"Stage {stage_num} (False Positive Check): Map_Buckets complete!") stage_num += 1 # Shuffle documents based on mapped buckets - print(f"Stage {stage_num} (False Postive Check): Shuffle docs") - shuffled_docs_path = os.path.join( - self.config.cache_dir, "shuffled_docs.parquet" - ) + print(f"Stage {stage_num} (False Positive Check): Shuffle documents") + shuffled_docs_path = os.path.join(self.cache_dir, "shuffled_docs.parquet") self.jaccard_shuffle.shuffle_docs_on_buckets( documents_df=dataset.df, bucket_w_anchors_path=mapped_buckets_w_anchors_path, @@ -198,15 +213,17 @@ def identify_duplicates( parts_per_worker=self.config.parts_per_worker, bucket_parts_per_worker=self.config.bucket_parts_per_worker, ) - print(f"Stage {stage_num} (False Postive Check): Shuffle docs complete!") + print( + f"Stage {stage_num} (False Positive Check): Shuffle documents complete!" + ) stage_num += 1 - # jaccard comparision within buckets + # Jaccard comparision within buckets print( - f"Stage {stage_num} (False Postive Check): Jaccard Similarity in Buckets" + f"Stage {stage_num} (False Positive Check): Jaccard similarity in buckets" ) jaccard_pairs_path = os.path.join( - self.config.cache_dir, "jaccard_similarity_results.parquet" + self.cache_dir, "jaccard_similarity_results.parquet" ) t0 = time.time() with performance_report_if_with_ts_suffix( @@ -223,11 +240,11 @@ def identify_duplicates( overwrite=True, ) self._logger.info( - f"Time taken for Jaccard Similarity = {time.time()-t0}s and output written at {jaccard_pairs_path}" + f"Time taken for Jaccard similarity: {time.time()-t0}s and output written at {jaccard_pairs_path}" ) print( - f"Stage {stage_num} (False Postive Check): Jaccard Similarity in Buckets Complete!" + f"Stage {stage_num} (False Positive Check): Jaccard similarity in buckets complete!" ) stage_num += 1 @@ -236,13 +253,13 @@ def identify_duplicates( print(f"Stage {stage_num}: Starting LSH Buckets to Graph Edgelist") self.buckets_to_edges(buckets_df) print( - f"Stage {stage_num}: Starting LSH Buckets to Graph Edgelist Complete!" + f"Stage {stage_num}: Starting LSH Buckets to Graph Edgelist complete!" ) stage_num += 1 # Connected components across buckets print(f"Stage {stage_num}: Connected Components across buckets") - cc_path = os.path.join(self.config.cache_dir, "connected_components.parquet") + cc_path = os.path.join(self.cache_dir, "connected_components.parquet") self.connected_components.cc_workflow(cc_path) print(f"Stage {stage_num}: Connected Components across buckets complete!") stage_num += 1 diff --git a/nemo_curator/modules/fuzzy_dedup/jaccardsimilarity.py b/nemo_curator/modules/fuzzy_dedup/jaccardsimilarity.py index 04ac73a4b..39ff044b3 100644 --- a/nemo_curator/modules/fuzzy_dedup/jaccardsimilarity.py +++ b/nemo_curator/modules/fuzzy_dedup/jaccardsimilarity.py @@ -46,6 +46,7 @@ def jaccard_compute(self, shuffled_docs_path): for entry in os.scandir(shuffled_docs_path) if not entry.path.endswith(".txt") ] + meta_df = cudf.DataFrame( { self.left_id: ["x"], @@ -53,9 +54,11 @@ def jaccard_compute(self, shuffled_docs_path): "jaccard": np.float32([0.0]), } ) + result_df = dd.from_map( self._compute_jaccard_on_1_partition, paths, meta=meta_df ).reset_index(drop=True) + return result_df def _compute_jaccard_on_1_partition(self, path): @@ -64,17 +67,22 @@ def _compute_jaccard_on_1_partition(self, path): pair_df = self._compute_jaccard_and_create_pair_df(df) except OverflowError: paths = [entry.path for entry in os.scandir(os.path.join(path))] + anchor_df_str_size_ls = [ self._get_anchor_docs_and_string_size(path) for path in paths ] + anchor_df = cudf.concat( [anchor_doc for anchor_doc, _ in anchor_df_str_size_ls], ignore_index=True, ).drop_duplicates() + df_str_size = [str_size for _, str_size in anchor_df_str_size_ls] + paths = JaccardSimilarity._create_bins( df_str_size, np.iinfo(np.int32).max // 10 ) + pair_dfs = [] for path in paths: print(path) @@ -82,15 +90,19 @@ def _compute_jaccard_on_1_partition(self, path): df = cudf.concat([df, anchor_df], ignore_index=True) pair_df = self._compute_jaccard_and_create_pair_df(df) pair_dfs.append(pair_df) + pair_df = cudf.concat(pair_dfs, ignore_index=True) + return pair_df def _get_anchor_docs_and_string_size(self, path): df = cudf.read_parquet(path) str_bytes = df[self.text_field].str.byte_count().sum() is_anchor_flag = df[self.id_field] == df[self.anchor_id_fields[0]] + for anchor_id in self.anchor_id_fields[1:]: is_anchor_flag = is_anchor_flag | (df[self.id_field] == df[anchor_id]) + anchor_df = df[is_anchor_flag].reset_index(drop=True) return anchor_df, {"path": path, "str_bytes": str_bytes} @@ -98,26 +110,32 @@ def _get_anchor_docs_and_string_size(self, path): def _create_bins(path_dicts, max_size): path_dicts.sort(key=lambda x: x["str_bytes"], reverse=True) bins, bin_sizes = [], [] + for path_d in path_dicts: new_path, new_size = path_d["path"], path_d["str_bytes"] + for i, bin_size in enumerate(bin_sizes): if bin_size + new_size <= max_size: bins[i].append(new_path) bin_sizes[i] += new_size new_size = 0 break + if new_size: bins.append([new_path]) bin_sizes.append(new_size) + return bins def _compute_jaccard_and_create_pair_df(self, df): df = df.drop_duplicates( subset=[self.id_field] + self.anchor_id_fields, ignore_index=True ) + anchor_columns = self.anchor_id_fields id_field = self.id_field result_ls = [] + try: for anchor_col in anchor_columns: doc_df = df[[id_field, self.text_field, anchor_col]] @@ -128,15 +146,17 @@ def _compute_jaccard_and_create_pair_df(self, df): result_ls.append(result_df) return cudf.concat(result_ls) + except OverflowError as e: print( - "Failed with OverflowError in compute_jaccard_and_create_pair_df", + "Failed with OverflowError in compute_jaccard_and_create_pair_df", flush=True, ) print(df, flush=True) print("--" * 30) print("Error") print("---" * 30) + raise e def _get_anchor_df(self, df, anchor_col): @@ -150,22 +170,29 @@ def _compute_jaccard_pair(self, docs_df, anchor_df): nrows_at_once = JaccardSimilarity._get_max_num_rows_to_process_once( df=docs_df, text_field=self.text_field ) + result_ls = [] for i in range(0, docs_df.shape[0], nrows_at_once): pair_df = docs_df[i : i + nrows_at_once] pair_df = pair_df.merge(anchor_df, on=self.anchor_id) + pair_df = pair_df.rename( columns={self.id_field: self.left_id, self.anchor_id: self.right_id} ) + mask = pair_df[self.left_id] != pair_df[self.right_id] pair_df = pair_df[mask].reset_index(drop=True) + if len(pair_df) == 0: result_df = self._create_empty_jaccard_result() else: result_df = self._compute_jaccard_partition(pair_df) + result_ls.append(result_df) + if len(result_ls) == 0: return self._create_empty_jaccard_result() + df_pair = cudf.concat(result_ls) return df_pair @@ -186,10 +213,12 @@ def _compute_jaccard_partition(self, df): @staticmethod def _get_max_num_rows_to_process_once(df, text_field): nbytes = df[text_field].str.byte_count().sum() - # Number of exmploded bytes + + # Number of exploded bytes exploded_bytes = nbytes * 5 * 2 max_chars_allowed = 2_147_483_647 byte_ratio = int(exploded_bytes) // max_chars_allowed + if byte_ratio > 1: nrows_at_once = len(df) // byte_ratio else: diff --git a/nemo_curator/modules/fuzzy_dedup/lsh.py b/nemo_curator/modules/fuzzy_dedup/lsh.py index 4a38b7c61..74c03ba09 100644 --- a/nemo_curator/modules/fuzzy_dedup/lsh.py +++ b/nemo_curator/modules/fuzzy_dedup/lsh.py @@ -25,6 +25,7 @@ import dask_cudf import numpy as np +from nemo_curator.cache import Cache from nemo_curator.datasets import DocumentDataset from nemo_curator.log import create_logger from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix @@ -33,14 +34,14 @@ class LSH: """ - Performs LSH on a MinhashSignatures + Performs LSH on a Minhash signatures """ def __init__( self, - cache_dir: str, num_hashes: int, num_buckets: int, + cache_dir: Optional[str] = None, buckets_per_shuffle: int = 1, false_positive_check: bool = False, logger: Union[logging.LoggerAdapter, str] = "./", @@ -51,37 +52,48 @@ def __init__( """ Parameters ---------- - cache_dir: str - Needs to be specified, will compute & write duplicate id, bucket pairs to cache directory. - num_hashes: Length of minhash signature + num_hashes: Length of minhash signature. num_buckets: Number of bands/buckets to create from the minhash signature. - Hashes_per_signature = num_hashes / num_buckets - buckets_per_shuffle: Number of bands/buckets to shuffle concurrently. - but might lead to memory pressures and related errors. - false_positive_check: bool - If True, writes out buckets in a format compatible with downstream false positive check. + hashes_per_signature = num_hashes / num_buckets. + cache_dir: Directory to compute and write duplicate ID, bucket pairs. + This field is required via LSH(cache_dir=...) or Cache(cache_dir=...). + buckets_per_shuffle: Number of bands/buckets to shuffle concurrently. Larger + values process larger batches by processing multiple bands but might lead + to memory pressures and related errors. Default is 1. + false_positive_check: If True, writes out buckets in a format compatible with + downstream false positive check. Default is False. logger: Existing logger to log to, or a path to a log directory. - id_field: Columns in the Dataset denoting document ID. - minhash_field: Column in the Dataset denoting minhash signature. - profile_dir: str, Default None - If specified directory to write dask profile + Default is "./". + id_fields: List or string representing column(s) in the dataset denoting + document ID. Default is "id". + minhash_field: Column in the dataset denoting minhash signature. + Default is "_minhash_signature". + profile_dir: If specified, directory to write Dask profile. Default is None. """ + self.num_hashes = num_hashes self.num_buckets = num_buckets self.id_fields = [id_fields] if isinstance(id_fields, str) else id_fields self.minhash_field = minhash_field self.buckets_per_shuffle = buckets_per_shuffle + self.bucket_ranges = self._generate_bucket_ranges( self.num_buckets, self.num_hashes ) + self.buckets_as_int = false_positive_check if cache_dir is None: + self.cache_dir = Cache().get_cache_directory() + else: + self.cache_dir = cache_dir + self.profile_dir = profile_dir + + if self.cache_dir is None: raise ValueError( - "cache_dir for intermediate outputs is required for this stage" + "cache_dir for intermediate outputs is required for this stage. " + "Please initialize with Cache(cache_dir=...) or LSH(cache_dir=...)" ) - self.cache_dir = cache_dir - self.profile_dir = profile_dir if isinstance(logger, str): self._logger = create_logger( @@ -96,11 +108,12 @@ def _generate_bucket_ranges( self, num_buckets: int, num_hashes: int ) -> List[List[int]]: """ - Generates a list of indices for the minhash ranges given num_bands & - num_hashes. - eg: num_bands=3, num_hashes=6 + Generates a list of indices for the minhash ranges, given num_bands and num_hashes. + + For example: num_bands=3, num_hashes=6 [[0, 1], [2, 3], [4, 5]] """ + minhashes_per_bucket = num_hashes // num_buckets bucket_ranges = [ @@ -111,6 +124,7 @@ def _generate_bucket_ranges( ) for bucket in range(num_buckets) ] + return bucket_ranges def minhash_to_buckets( @@ -119,11 +133,13 @@ def minhash_to_buckets( bucket_ranges: List[List[int]], ) -> cudf.DataFrame: df2 = df[self.id_fields] + for i, h in enumerate(bucket_ranges): indices = cudf.Series([h]).repeat(len(df2)) df2[f"_bucket_{i}"] = f"b{i}_" + df[self.minhash_field].list.take( indices ).hash_values(method="md5") + return df2 def bucket_id_to_int( @@ -133,23 +149,28 @@ def bucket_id_to_int( start_id: int = 0, ) -> Tuple[dask_cudf.DataFrame, int]: """ - Maps bucket ids to a contigious integer range from starting from start_id. + Maps bucket IDs to a contigious integer range from starting from start_id. """ + unique_bucket_df = ( bucket_ddf[[bucket_col_name]] .map_partitions(lambda x: x.drop_duplicates(ignore_index=True)) .persist() ) + end_bucket_id = len(unique_bucket_df) - 1 + start_id unique_bucket_df["bucket_int_id"] = np.uint64(1) unique_bucket_df["bucket_int_id"] = unique_bucket_df["bucket_int_id"].cumsum() + unique_bucket_df["bucket_int_id"] = ( unique_bucket_df["bucket_int_id"] - 1 + start_id ) + bucket_ddf = bucket_ddf.merge(unique_bucket_df, on=[bucket_col_name]) bucket_ddf = bucket_ddf.drop(columns=[bucket_col_name]) bucket_ddf = bucket_ddf.rename(columns={"bucket_int_id": "_bucket_id"}) bucket_ddf["_bucket_id"] = bucket_ddf["_bucket_id"].astype(np.uint64) + return (bucket_ddf, end_bucket_id) def _minhash_to_bucket_meta( @@ -165,29 +186,33 @@ def lsh( df: dask_cudf.DataFrame, ) -> bool: """ - Computes hash buckets for the DataFrame and writes them as parquet files to the specified path. + Computes hash buckets for the DataFrame and writes them as Parquet files to the specified path. Parameters: - - write_path (str): The directory path to write parquet files. + - write_path (str): The directory path to write Parquet files. - df (dask_cudf.DataFrame): The input DataFrame with minhashes to be bucketed. Returns: are_buckets_empty: True if buckets were empty (no duplicates found), False otherwise. """ + wrote_buckets = False are_buckets_empty = True meta = self._minhash_to_bucket_meta(df) + df = df.map_partitions( self.minhash_to_buckets, bucket_ranges=self.bucket_ranges, meta=meta, ) + bucket_start_id = 0 for i in range(0, self.num_buckets, self.buckets_per_shuffle): bucket_columns = [ f"_bucket_{i}" for i in range(i, min(self.num_buckets, i + self.buckets_per_shuffle)) ] + df2 = df.melt( id_vars=self.id_fields, value_name="_bucket_id", @@ -201,17 +226,20 @@ def lsh( ).map_partitions(lambda x: x[x["_bucket_id"].duplicated(keep=False)]) df2 = df2.reset_index(drop=True) - # Buckets to Int + + # Buckets to int if self.buckets_as_int: df2, end_id = self.bucket_id_to_int( df2, bucket_col_name="_bucket_id", start_id=bucket_start_id ) - # If bucketing return empty dataframe + + # If bucketing returns empty DataFrame if end_id < bucket_start_id: self._logger.info( f"No duplicate documents found for buckets: {bucket_columns}" ) continue + bucket_start_id = end_id + 1 are_buckets_empty = False @@ -241,15 +269,17 @@ def _write_bucket_parquet( buckets_to_write: List[str], ) -> tuple[bool, bool]: """ - Utility function to write the bucketed data to parquet + Utility function to write the bucketed data to Parquet, handling cases of overwriting and appending as needed. """ + if not wrote_buckets: if os.path.exists(write_path): warnings.warn( f"Output path {write_path} already exists and will be overwritten" ) df.to_parquet(write_path, write_index=False, overwrite=True) + else: df.to_parquet( write_path, @@ -258,9 +288,11 @@ def _write_bucket_parquet( append=not are_buckets_empty, ignore_divisions=True, ) + # Only check if buckets written so far are empty if are_buckets_empty: are_buckets_empty = check_empty_buckets(write_path) + wrote_buckets = True if are_buckets_empty: @@ -269,21 +301,24 @@ def _write_bucket_parquet( ) else: self._logger.info(f"Wrote data for buckets: {buckets_to_write}") + return wrote_buckets, are_buckets_empty def __call__(self, dataset: DocumentDataset) -> DocumentDataset: df = dataset.df - write_path = os.path.join(self.cache_dir, "_buckets.parquet") + t0 = time.time() with performance_report_if_with_ts_suffix(self.profile_dir, "lsh-profile"): empty_result = self.lsh(write_path=write_path, df=df) + self._logger.info( - f"Time taken for LSH = {time.time() - t0}s and output written at {write_path}" + f"Time taken for LSH: {time.time() - t0}s and output written at {write_path}" ) if empty_result: return None buckets_df = dask_cudf.read_parquet(write_path, split_row_groups=False) + return DocumentDataset(buckets_df) diff --git a/nemo_curator/modules/fuzzy_dedup/minhash.py b/nemo_curator/modules/fuzzy_dedup/minhash.py index 28fa9aca5..0b85f6072 100644 --- a/nemo_curator/modules/fuzzy_dedup/minhash.py +++ b/nemo_curator/modules/fuzzy_dedup/minhash.py @@ -25,6 +25,7 @@ import numpy as np from nemo_curator._compat import MINHASH_DEPRECATED_API, MINHASH_PERMUTED_AVAILABLE +from nemo_curator.cache import Cache from nemo_curator.datasets import DocumentDataset from nemo_curator.log import create_logger from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix @@ -50,17 +51,19 @@ def __init__( """ Parameters ---------- - seed: Seed for minhash permutations - num_hashes: Length of minhash signature (No. of minhash permutations) + seed: Seed for minhash permutations. Default is 42. + num_hashes: Length of minhash signature (number of minhash permutations). + Default is 260. char_ngrams: Width of text window (in characters) while computing minhashes. - use_64bit_hash: Whether to use a 64 bit hash function. + Default is 5. + use_64bit_hash: Whether to use a 64 bit hash function. Default is False. logger: Existing logger to log to, or a path to a log directory. - id_field: Column in the Dataset denoting document ID. - text_field: Column in the Dataset denoting document content. - profile_dir: str, Default None - If specified directory to write dask profile - cache_dir: str, Default None - If specified, will compute & write id, minhash pairs to directory + Default is "./". + id_field: Column in the dataset denoting document ID. Default is "id". + text_field: Column in the dataset denoting document content. Default is "text". + profile_dir: If specified, directory to write Dask profile. Default is None. + cache_dir: If specified, will compute and write "ID, minhash pairs" to + directory. Can also be set with Cache(cache_dir=...). Default is None. """ self.num_hashes = num_hashes self.char_ngram = char_ngrams @@ -78,12 +81,16 @@ def __init__( self.id_field = id_field self.text_field = text_field - if cache_dir is None and profile_dir is not None: + if cache_dir is None: + self.cache_dir = Cache().get_cache_directory() + else: + self.cache_dir = cache_dir + self.profile_dir = profile_dir + if self.cache_dir is None and profile_dir is not None: warnings.warn( - "cache_dir for intermediate outputs is required to generate profiles" + "cache_dir for intermediate outputs is required to generate profiles. " + "Please initialize with Cache(cache_dir=...) or MinHash(cache_dir=...)" ) - self.cache_dir = cache_dir - self.profile_dir = profile_dir if isinstance(logger, str): self._logger = create_logger( @@ -98,6 +105,7 @@ def generate_seeds(self, n_seeds: int = 260, seed: int = 0) -> np.ndarray: """ Generate seeds for all minhash permutations based on the given seed. """ + gen = np.random.RandomState(seed) return gen.randint(0, 1e6, size=n_seeds) @@ -107,6 +115,7 @@ def generate_hash_permutation_seeds( """ Generate seeds for all minhash permutations based on the given seed. """ + gen = np.random.RandomState(seed) if bit_width == 32: @@ -117,7 +126,7 @@ def generate_hash_permutation_seeds( MERSENNE_PRIME = np.uint64((1 << 61) - 1) dtype = np.uint64 else: - raise ValueError("Unsupported bit width. Use either 32 or 64.") + raise ValueError("Unsupported bit width. Please use either 32 or 64.") return np.array( [ @@ -134,8 +143,9 @@ def minhash32( self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int ) -> cudf.Series: """ - Compute 32bit minhashes based on the MurmurHash3 algorithm + Compute 32-bit minhashes based on the MurmurHash3 algorithm. """ + if not isinstance(ser, cudf.Series): raise TypeError("Expected data of type cudf.Series") @@ -146,8 +156,10 @@ def minhash32( "Install the latest version of cuDF using `pip install curator[cuda12x_nightly]`", category=FutureWarning, ) + seeds = cudf.Series(seeds, dtype="uint32") return ser.str.minhash(seeds=seeds, width=char_ngram) + else: seeds_a = cudf.Series(seeds[:, 0], dtype="uint32") seeds_b = cudf.Series(seeds[:, 1], dtype="uint32") @@ -165,10 +177,12 @@ def minhash64( self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int ) -> cudf.Series: """ - Compute 64bit minhashes based on the MurmurHash3 algorithm + Compute 64-bit minhashes based on the MurmurHash3 algorithm. """ + if not isinstance(ser, cudf.Series): raise TypeError("Expected data of type cudf.Series") + if MINHASH_DEPRECATED_API: warnings.warn( "Using an outdated minhash implementation, please update to cuDF version 24.12 " @@ -176,8 +190,10 @@ def minhash64( "Install the latest version of cuDF using `pip install curator[cuda12x_nightly]`", category=FutureWarning, ) + seeds = cudf.Series(seeds, dtype="uint64") return ser.str.minhash64(seeds=seeds, width=char_ngram) + else: seeds_a = cudf.Series(seeds[:, 0], dtype="uint64") seeds_b = cudf.Series(seeds[:, 1], dtype="uint64") @@ -193,16 +209,19 @@ def minhash64( def __call__(self, dataset: DocumentDataset) -> Union[str, DocumentDataset]: """ - Computes the MinHash Signatures for a given dataset. + Computes the MinHash signatures for a given dataset. + Parameters ---------- - dataset: DocumentDataset - The input datset to compute MinHashes. + dataset (DocumentDataset): The input dataset on which to compute MinHashes. + Returns ------- - DocumentDataset containing IDs of all documents and the corresponding MinHash Signature + DocumentDataset containing IDs of all documents and the corresponding MinHash signature """ + result = dataset.df[[self.id_field]] + result["_minhash_signature"] = dataset.df[self.text_field].map_partitions( self.minhash_method, seeds=self.seeds, @@ -214,16 +233,20 @@ def __call__(self, dataset: DocumentDataset) -> Union[str, DocumentDataset]: t0 = time.time() self._logger.info("Starting execution for Minhashes") + write_path = os.path.join(self.cache_dir, "_minhashes.parquet") if os.path.exists(write_path): warnings.warn( f"Output path {write_path} already exists and will be overwritten" ) + with performance_report_if_with_ts_suffix(self.profile_dir, "minhash-profile"): result.to_parquet(write_path, write_index=False, overwrite=True) + self._logger.info( - f"Time taken for Minhash signature computation = {time.time() - t0}s and output written at {write_path}" + f"Time taken for Minhash signature computation: {time.time() - t0}s and output written at {write_path}" ) + return DocumentDataset( dask_cudf.read_parquet(write_path, blocksize="2GB", aggregate_files=True) ) diff --git a/nemo_curator/modules/semantic_dedup/clusteringmodel.py b/nemo_curator/modules/semantic_dedup/clusteringmodel.py index 440627297..d0182a58b 100644 --- a/nemo_curator/modules/semantic_dedup/clusteringmodel.py +++ b/nemo_curator/modules/semantic_dedup/clusteringmodel.py @@ -25,6 +25,7 @@ import numpy as np from cuml.dask.cluster import KMeans +from nemo_curator.cache import Cache from nemo_curator.datasets import DocumentDataset from nemo_curator.log import create_logger from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix @@ -53,7 +54,8 @@ def __init__( id_column: str = "id", max_iter: int = 100, n_clusters: int = 1000, - clustering_output_dir: str = "./clustering_results", + cache_dir: Optional[str] = None, + clustering_save_loc: str = "clustering_results", embedding_column: str = "embeddings", random_state: int = 1234, sim_metric: str = "cosine", @@ -72,8 +74,11 @@ def __init__( Default is "id". max_iter (int): Maximum iterations for clustering. Default is 100. n_clusters (int): Number of clusters. Default is 1000. - clustering_output_dir (str): Location to save clustering results. - Default is "./clustering_results". + cache_dir (str, optional): Directory path where clustering results will be saved. + If None, we check if a cache_dir has been initialized with Cache().get_cache_directory(). + Default is None. + clustering_save_loc (str): Location within cache_dir to save clustering results. + Default is "clustering_results". embedding_column (str): The column name that stores the embeddings. Default is "embeddings". random_state (int): KMeans random state used for reproducibility. @@ -96,7 +101,6 @@ def __init__( self.id_col = id_column self.max_iter = max_iter self.n_clusters = n_clusters - self.clustering_output_dir = clustering_output_dir self.embedding_column = embedding_column self.random_state = random_state self.sim_metric = sim_metric @@ -107,11 +111,24 @@ def __init__( self.logger = self._setup_logger(logger) self.profile_dir = profile_dir + if cache_dir is not None: + self.clustering_output_dir = os.path.join(cache_dir, clustering_save_loc) + elif Cache().get_cache_directory() is not None: + self.clustering_output_dir = os.path.join( + Cache().get_cache_directory(), clustering_save_loc + ) + else: + raise RuntimeError( + "No cache directory specified. Please initialize with Cache(cache_dir=...) " + "or ClusteringModel(cache_dir=...)" + ) + if not os.path.exists(self.clustering_output_dir): expand_outdir_and_mkdir(self.clustering_output_dir) else: self.logger.warning( - f"Clustering output directory {self.clustering_output_dir} already exists and will be overwritten" + f"Clustering output directory {self.clustering_output_dir} already exists" + " and will be overwritten" ) def _setup_logger(self, logger): @@ -207,7 +224,8 @@ def __call__(self, embeddings_dataset: DocumentDataset): ) if os.path.exists(clustering_output_dir): self.logger.warning( - f"Output directory {clustering_output_dir} already exists and will be overwritten" + f"Output directory {clustering_output_dir} already exists and will" + " be overwritten." ) shutil.rmtree(clustering_output_dir) diff --git a/nemo_curator/modules/semantic_dedup/embeddings.py b/nemo_curator/modules/semantic_dedup/embeddings.py index f95d1bdab..51d4775b4 100644 --- a/nemo_curator/modules/semantic_dedup/embeddings.py +++ b/nemo_curator/modules/semantic_dedup/embeddings.py @@ -135,7 +135,8 @@ def __init__( self, embedding_model_name_or_path: str = "sentence-transformers/all-MiniLM-L6-v2", embedding_batch_size: int = 128, - embedding_output_dir: str = "./embeddings", + cache_dir: Optional[str] = None, + embeddings_save_loc: str = "embeddings", embedding_max_mem_gb: Optional[int] = None, embedding_pooling_strategy: str = "mean_pooling", input_column: str = "text", @@ -153,8 +154,11 @@ def __init__( Default is "sentence-transformers/all-MiniLM-L6-v2". embedding_batch_size (int): Initial batch size for processing embeddings. Default is 128. - embedding_output_dir (str): Location to save embeddings. - Default is "./embeddings". + cache_dir (str, optional): Directory path where embeddings will be saved. + If None, we check if a cache_dir has been initialized with Cache().get_cache_directory(). + Default is None. + embeddings_save_loc (str): Location within cache_dir to save embeddings. + Default is "embeddings". embedding_max_mem_gb (int, optional): Maximum memory usage in GB for the embedding process. If None, it defaults to the available GPU memory minus 4 GB. embedding_pooling_strategy: Strategy for pooling embeddings, either "mean_pooling" or "last_token". @@ -180,7 +184,6 @@ def __init__( ) self.batch_size = embedding_batch_size self.logger = self._setup_logger(logger) - self.embedding_output_dir = embedding_output_dir self.input_column = input_column self.embedding_column = embedding_column self.model = EmbeddingCrossFitModel( @@ -190,6 +193,18 @@ def __init__( self.write_to_filename = write_to_filename self.profile_dir = profile_dir + if cache_dir is not None: + self.embedding_output_dir = os.path.join(cache_dir, embeddings_save_loc) + elif Cache().get_cache_directory() is not None: + self.embedding_output_dir = os.path.join( + Cache().get_cache_directory(), embeddings_save_loc + ) + else: + raise RuntimeError( + "No cache directory specified. Please initialize with Cache(cache_dir=...) " + "or EmbeddingCreator(cache_dir=...)" + ) + def _setup_logger(self, logger): if isinstance(logger, str): return create_logger( diff --git a/nemo_curator/modules/semantic_dedup/semanticclusterleveldedup.py b/nemo_curator/modules/semantic_dedup/semanticclusterleveldedup.py index 2c40a0914..627242cb2 100644 --- a/nemo_curator/modules/semantic_dedup/semanticclusterleveldedup.py +++ b/nemo_curator/modules/semantic_dedup/semanticclusterleveldedup.py @@ -21,6 +21,7 @@ import dask.bag as db +from nemo_curator.cache import Cache from nemo_curator.datasets import DocumentDataset from nemo_curator.log import create_logger from nemo_curator.modules.config import SemDedupConfig @@ -36,13 +37,13 @@ class SemanticClusterLevelDedup: def __init__( self, n_clusters: int = 1000, - emb_by_clust_dir: str = "./clustering_results/embs_by_nearest_center", - sorted_clusters_dir: str = "./clustering_results/sorted", id_column: str = "id", id_column_type: str = "int", which_to_keep: str = "hard", - output_dir: str = "./clustering_results", + output_dir: Optional[str] = None, + cache_dir: Optional[str] = None, embedding_column: str = "embeddings", + clustering_save_loc: str = "clustering_results", logger: Union[logging.Logger, str] = "./", profile_dir: Optional[str] = None, ) -> None: @@ -51,19 +52,18 @@ def __init__( Args: n_clusters (int): Number of clusters. Default is 1000. - emb_by_clust_dir (str): Directory containing embeddings by cluster. - Default is "./clustering_results/embs_by_nearest_center". - sorted_clusters_dir (str): Directory containing sorted clusters. - Default is "./clustering_results/sorted". id_column (str): Column name used as the identifier in the dataset. Default is "id". id_column_type (str): Data type of id_column. Default is "int". which_to_keep (str): Method to determine which duplicates to keep. Default is "hard". - output_dir (str): Directory to save output files. - Default is "./clustering_results". + output_dir (str, optional): Directory to save output files. + If None, it will be saved to cache_dir/clustering_save_loc. + Default is None. + cache_dir (str, optional): Should be the same as specified in ClusteringModel. embedding_column (str): The column name that stores the embeddings. Default is "embeddings". + clustering_save_loc (str): Should be the same as specified in ClusteringModel. logger (Union[logging.Logger, str]): Existing logger to log to, or a path to a log directory. Default is "./". profile_dir (Optional[str]): If specified, directory to write Dask profile. @@ -71,20 +71,37 @@ def __init__( """ self.n_clusters = n_clusters - self.emb_by_clust_dir = emb_by_clust_dir - self.sorted_clusters_dir = sorted_clusters_dir self.id_col = id_column self.id_col_type = id_column_type self.which_to_keep = which_to_keep - self.output_dir = output_dir - self.semdedup_pruning_tables_dir = os.path.join( - output_dir, "semdedup_pruning_tables" - ) self.computed_semantic_match_dfs = False self.embedding_column = embedding_column self.logger = self._setup_logger(logger) self.profile_dir = profile_dir + if cache_dir is None: + if Cache().get_cache_directory() is None: + raise RuntimeError( + "No cache directory specified. Please initialize with Cache(cache_dir=...) " + "or SemanticClusterLevelDedup(cache_dir=...)" + ) + else: + cache_dir = Cache().get_cache_directory() + self.emb_by_clust_dir = os.path.join( + cache_dir, clustering_save_loc, "embs_by_nearest_center" + ) + self.sorted_clusters_dir = os.path.join( + cache_dir, clustering_save_loc, "sorted" + ) + + if output_dir is None: + self.output_dir = os.path.join(cache_dir, clustering_save_loc) + else: + self.output_dir = output_dir + self.semdedup_pruning_tables_dir = os.path.join( + output_dir, "semdedup_pruning_tables" + ) + def _setup_logger(self, logger: Union[logging.Logger, str]) -> logging.Logger: """ Set up the logger. @@ -126,6 +143,7 @@ def compute_semantic_match_dfs( ) shutil.rmtree(self.semdedup_pruning_tables_dir) expand_outdir_and_mkdir(self.semdedup_pruning_tables_dir) + t0 = time.time() with performance_report_if_with_ts_suffix( diff --git a/nemo_curator/modules/semantic_dedup/semdedup.py b/nemo_curator/modules/semantic_dedup/semdedup.py index 8d9713e8a..484bfc953 100644 --- a/nemo_curator/modules/semantic_dedup/semdedup.py +++ b/nemo_curator/modules/semantic_dedup/semdedup.py @@ -17,6 +17,7 @@ import os from typing import Union +from nemo_curator.cache import Cache from nemo_curator.datasets import DocumentDataset from nemo_curator.modules.base import BaseModule from nemo_curator.modules.config import SemDedupConfig @@ -49,14 +50,27 @@ def __init__( logger (Union[logging.Logger, str]): Existing logger to log to, or a path to a log directory. Default is "./". """ + super().__init__(input_backend="cudf") self.config = config self.logger = logger - cache_dir = config.cache_dir + if config.cache_dir is not None: + cache_dir = config.cache_dir + elif Cache().get_cache_directory() is not None: + cache_dir = Cache().get_cache_directory() + else: + raise RuntimeError( + "No cache directory specified. Please initialize with Cache(cache_dir=...) " + "or specify a cache_dir in your YAML file." + ) + profile_dir = self.config.profile_dir + clustering_save_loc = config.clustering_save_loc + self.embedding_creator = EmbeddingCreator( embedding_model_name_or_path=config.embedding_model_name_or_path, embedding_batch_size=config.embedding_batch_size, - embedding_output_dir=os.path.join(cache_dir, config.embeddings_save_loc), + cache_dir=cache_dir, + embeddings_save_loc=config.embeddings_save_loc, embedding_max_mem_gb=config.embedding_max_mem_gb, embedding_pooling_strategy=config.embedding_pooling_strategy, input_column=input_column, @@ -64,13 +78,14 @@ def __init__( write_embeddings_to_disk=config.write_embeddings_to_disk, write_to_filename=config.write_to_filename, logger=logger, - profile_dir=self.config.profile_dir, + profile_dir=profile_dir, ) self.clustering_model = ClusteringModel( id_column=id_column, max_iter=config.max_iter, n_clusters=config.n_clusters, - clustering_output_dir=os.path.join(cache_dir, config.clustering_save_loc), + cache_dir=cache_dir, + clustering_save_loc=clustering_save_loc, embedding_column=config.embedding_column, sim_metric=config.sim_metric, which_to_keep=config.which_to_keep, @@ -78,23 +93,20 @@ def __init__( kmeans_with_cos_dist=config.kmeans_with_cos_dist, clustering_input_partition_size=config.clustering_input_partition_size, logger=logger, - profile_dir=self.config.profile_dir, + profile_dir=profile_dir, ) self.semantic_cluster_dedup = SemanticClusterLevelDedup( n_clusters=config.n_clusters, - emb_by_clust_dir=os.path.join( - cache_dir, config.clustering_save_loc, "embs_by_nearest_center" - ), - sorted_clusters_dir=os.path.join( - cache_dir, config.clustering_save_loc, "sorted" - ), id_column=id_column, id_column_type=id_column_type, which_to_keep=config.which_to_keep, - output_dir=os.path.join(cache_dir, config.clustering_save_loc), + cache_dir=cache_dir, embedding_column=config.embedding_column, + clustering_save_loc=clustering_save_loc, logger=logger, - profile_dir=self.config.profile_dir, + profile_dir=profile_dir, + # Hardcoded path + output_dir=os.path.join(cache_dir, clustering_save_loc), ) self.eps_thresholds = config.eps_thresholds self.eps_to_extract = config.eps_to_extract diff --git a/nemo_curator/scripts/find_exact_duplicates.py b/nemo_curator/scripts/find_exact_duplicates.py index 5f6fc2435..e0588b363 100644 --- a/nemo_curator/scripts/find_exact_duplicates.py +++ b/nemo_curator/scripts/find_exact_duplicates.py @@ -38,8 +38,10 @@ def main(args): logger.info(f"Starting workflow with args:\n {args}") assert args.hash_method == "md5", "Currently only md5 hash is supported" + client = get_client(**ArgumentHelper.parse_client_args(args)) - logger.info(f"Client Created {client}") + logger.info(f"Client created: {client}") + if args.device == "gpu": client.run(pre_imports) logger.info("Pre imports complete") @@ -48,16 +50,21 @@ def main(args): id_field = args.input_json_id_field text_field = args.input_json_text_field num_files = args.num_files + t0 = time.time() + dfs = [] for data_path in data_paths: data_path = strip_trailing_sep(data_path) + if num_files is not None and num_files <= 0: logger.info(f"Processed {num_files}... quitting") break + files = get_all_files_paths_under( root=data_path, recurse_subdirectories=False, keep_extensions="jsonl" ) + df = read_data( files[:num_files] if num_files else files, file_type="jsonl", @@ -65,12 +72,15 @@ def main(args): files_per_partition=args.files_per_partition, add_filename=False, )[[id_field, text_field]] + if num_files is not None: num_files -= len(files) + dfs.append(df) logger.info(f"Lazy read complete for {dfs[-1].npartitions} partitions") input_df = dask_cudf.concat(dfs, ignore_unknown_divisions=True) + exact_dups = ExactDuplicates( logger=logger, id_field=id_field, @@ -80,8 +90,10 @@ def main(args): cache_dir=args.output_dir, ) exact_dups(dataset=DocumentDataset(input_df)) + logger.info( - f"Exact deduplication computation across datasets took {time.time() - t0}s complete at {args.output_dir}" # noqa:E501 + f"Exact deduplication computation across datasets took {time.time() - t0}s \n" + f"Output written at {args.output_dir}" # noqa:E501 ) diff --git a/nemo_curator/scripts/fuzzy_deduplication/README.md b/nemo_curator/scripts/fuzzy_deduplication/README.md index 63dcdb5c8..eb56db85a 100644 --- a/nemo_curator/scripts/fuzzy_deduplication/README.md +++ b/nemo_curator/scripts/fuzzy_deduplication/README.md @@ -1,5 +1,5 @@ ## Fuzzy Deduplication Steps -This directory consists of scripts that can be invoked directly via the command line for finding fuzzy duplicates from a group of Jsonl files consisting of text & unique ID's that are specifically formatted using the `add_id` script included as a part of NeMo-Curator. +This directory consists of scripts that can be invoked directly via the command line for finding fuzzy duplicates from a group of JSONL files consisting of text and unique IDs that are specifically formatted using the `add_id` script included as a part of NeMo Curator. > [!IMPORTANT] > The up to date documentation on running the fuzzy deduplication scripts can be found in the [NeMo Curator User Guide](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html#id4). It is recommended to use the Python API directly rather than the CLI scripts for most cases. diff --git a/nemo_curator/scripts/fuzzy_deduplication/buckets_to_edges.py b/nemo_curator/scripts/fuzzy_deduplication/buckets_to_edges.py index 72103d736..253982587 100644 --- a/nemo_curator/scripts/fuzzy_deduplication/buckets_to_edges.py +++ b/nemo_curator/scripts/fuzzy_deduplication/buckets_to_edges.py @@ -68,6 +68,7 @@ def main(args): OUTPUT_PATH = args.output_dir client = get_client(**ArgumentHelper.parse_client_args(args)) + logger.info(f"Client Created {client}") logger.info(f"Num Workers = {get_num_workers(client)}") logger.info( @@ -81,13 +82,15 @@ def main(args): bucket_field=args.input_bucket_field, logger=logger, ) + st = time.time() buckets_df = DocumentDataset( dask_cudf.read_parquet(input_bucket_path, split_row_groups=False) ) _ = buckets_to_edges(buckets_df) + et = time.time() - logger.info(f"Bucket to Edges conversion took = {et-st} s") + logger.info(f"Bucket to edges conversion took {et-st} seconds") def console_script(): diff --git a/nemo_curator/scripts/fuzzy_deduplication/compute_minhashes.py b/nemo_curator/scripts/fuzzy_deduplication/compute_minhashes.py index f15a70867..3f4969a5e 100644 --- a/nemo_curator/scripts/fuzzy_deduplication/compute_minhashes.py +++ b/nemo_curator/scripts/fuzzy_deduplication/compute_minhashes.py @@ -39,11 +39,11 @@ def main(args): ) logger.info(f"Starting workflow with args:\n {args}") - assert args.hash_bytes in {4, 8}, "Currently only 32bit/64bit hashes are supported" + assert args.hash_bytes in {4, 8}, "Currently only 32bit/64bit hashes are supported." assert args.device == "gpu" client = get_client(**ArgumentHelper.parse_client_args(args)) - logger.info(f"Client Created {client}") + logger.info(f"Client created {client}") client.run(pre_imports) logger.info("Pre imports complete") @@ -65,7 +65,9 @@ def main(args): t0 = time.time() for data_path in data_paths: print(f"Computing minhashes for {data_path}", flush=True) + data_path = strip_trailing_sep(data_path) + if num_files is not None and num_files <= 0: print(f"Processed {args.num_files}... quitting") break @@ -73,6 +75,7 @@ def main(args): files = get_all_files_paths_under( root=data_path, recurse_subdirectories=False, keep_extensions="jsonl" ) + df = read_data( files[:num_files] if num_files else files, file_type="jsonl", @@ -86,10 +89,12 @@ def main(args): num_files -= len(files) res = minhasher(DocumentDataset(df)).df + logger.info( f"Lazy minhash generation complete for {res.npartitions} partitions" ) logger.info(f"Starting execution for {data_path}") + write_path = os.path.join( args.output_minhash_dir, os.path.basename(data_path), "minhashes.parquet" ) @@ -99,8 +104,9 @@ def main(args): args.profile_path, f"{os.path.basename(data_path)}-minhash-profile.html" ): res.to_parquet(write_path, write_index=False) + logger.info( - f"Minhash computation for f{data_path} took {time.time() - t1}s complete at {write_path}" # noqa:E501 + f"Minhash computation for {data_path} took {time.time() - t1}s complete at {write_path}" # noqa:E501 ) logger.info( f"Minhash computation across datasets took {time.time() - t0}s complete at {args.output_minhash_dir}" # noqa:E501 diff --git a/nemo_curator/scripts/fuzzy_deduplication/connected_components.py b/nemo_curator/scripts/fuzzy_deduplication/connected_components.py index f43eec746..c86b43ba2 100644 --- a/nemo_curator/scripts/fuzzy_deduplication/connected_components.py +++ b/nemo_curator/scripts/fuzzy_deduplication/connected_components.py @@ -44,6 +44,7 @@ def main(args): profile_dir=args.profile_path, ) components_stage.cc_workflow(output_path=output_path) + print(f"All done in {time.time()-st:.1f} seconds") print(f"Results written to {output_path}") diff --git a/nemo_curator/scripts/fuzzy_deduplication/jaccard_compute.py b/nemo_curator/scripts/fuzzy_deduplication/jaccard_compute.py index d87be3cb1..a245c7b0d 100644 --- a/nemo_curator/scripts/fuzzy_deduplication/jaccard_compute.py +++ b/nemo_curator/scripts/fuzzy_deduplication/jaccard_compute.py @@ -26,24 +26,28 @@ def main(args): from a partitioned Parquet dataset. Result is a Parquet dataset consiting of document and ID pairs, along with their Jaccard similarity scores. """ + OUTPUT_PATH = args.output_dir shuffled_docs_path = args.shuffled_docs_path output_final_results_path = os.path.join( OUTPUT_PATH, "jaccard_similarity_results.parquet" ) + args.enable_spilling = True client = get_client(**ArgumentHelper.parse_client_args(args)) - print(f"Num Workers = {get_num_workers(client)}", flush=True) - print("Connected to dask cluster", flush=True) - print("Running jaccard compute script", flush=True) + print(f"Number of workers: {get_num_workers(client)}", flush=True) + print("Connected to Dask cluster", flush=True) + print("Running Jaccard compute script", flush=True) st = time.time() + jaccard = JaccardSimilarity( id_field=args.input_json_id_field, text_field=args.input_json_text_field, anchor_id_fields=[f"anchor_{i}_{args.input_json_id_field}" for i in range(2)], ngram_width=args.ngram_size, ) + # Run actual computation result_df = jaccard.jaccard_compute(shuffled_docs_path) @@ -52,7 +56,8 @@ def main(args): write_index=False, write_metadata_file=False, ) - print(f"Jaccard Computing+Writing time: {time.time() - st:.1f} seconds") + + print(f"Jaccard computing and writing time: {time.time() - st:.1f}s") def attach_args(): diff --git a/nemo_curator/scripts/fuzzy_deduplication/jaccard_shuffle.py b/nemo_curator/scripts/fuzzy_deduplication/jaccard_shuffle.py index 24c2243ab..ef52baf8b 100644 --- a/nemo_curator/scripts/fuzzy_deduplication/jaccard_shuffle.py +++ b/nemo_curator/scripts/fuzzy_deduplication/jaccard_shuffle.py @@ -38,11 +38,13 @@ def main(args): client = get_client(**ArgumentHelper.parse_client_args(args)) client.run(func) - print(f"Num Workers = {get_num_workers(client)}", flush=True) - print("Connected to dask cluster", flush=True) - print("Running jaccard shuffle script", flush=True) + + print(f"Number of workers: {get_num_workers(client)}", flush=True) + print("Connected to Dask cluster", flush=True) + print("Running Jaccard shuffle script", flush=True) print(f"Args = {args}") st = time.time() + text_ddf = get_text_ddf_from_json_path_with_blocksize( input_data_paths=input_data_paths, num_files=args.num_files, @@ -51,17 +53,20 @@ def main(args): text_column=args.input_json_text_field, input_meta=args.input_meta, ) + print( "Graph creation for get_text_ddf_from_json_path_with_blocksize complete.", flush=True, ) - print(f"text_ddf.npartitions = {text_ddf.npartitions}", flush=True) + print(f"text_ddf.npartitions = {text_ddf.npartitions}", flush=True) + shuffle = _Shuffle( id_fields=["dataset_id", "doc_id"], text_field=args.input_json_text_field, profile_dir=args.profile_path, int_to_str_id=args.input_json_id_field, ) + shuffle.shuffle_docs_on_buckets( documents_df=text_ddf, bucket_w_anchors_path=input_anchor_docs_with_bk_dir, @@ -71,8 +76,9 @@ def main(args): bucket_parts_per_worker=args.bucket_parts_per_worker, partition_on="_output_partition_id", ) + et = time.time() - print(f"Jaccard Shuffle E2E time taken = {et-st} s") + print(f"Jaccard shuffle E2E time taken: {et-st}s") def attach_args(): diff --git a/nemo_curator/scripts/fuzzy_deduplication/map_buckets.py b/nemo_curator/scripts/fuzzy_deduplication/map_buckets.py index fb825b1b0..b41a55bed 100644 --- a/nemo_curator/scripts/fuzzy_deduplication/map_buckets.py +++ b/nemo_curator/scripts/fuzzy_deduplication/map_buckets.py @@ -38,7 +38,8 @@ def get_anchor_and_output_map_info( input_meta, ): """ - Get anchor docs with bucket info + Get anchor documents with bucket information. + Args: input_data_paths: list of paths to input data input_bucket_path: path to input buckets @@ -49,6 +50,7 @@ def get_anchor_and_output_map_info( Returns: ddf_anchor_docs_with_bk """ + ddf_text = get_text_ddf_from_json_path_with_blocksize( input_data_paths=input_data_paths, num_files=num_files, @@ -57,17 +59,21 @@ def get_anchor_and_output_map_info( text_column=input_text_field, input_meta=input_meta, ) + ddf_bk = get_bucket_ddf_from_parquet_path( input_bucket_path=input_bucket_path, num_workers=num_workers ) + map_buckets = _MapBuckets( id_fields=["dataset_id", "doc_id"], bucket_field=input_bucket_field, text_field=input_text_field, ) + ddf_anchor_docs_with_bk = map_buckets.map_buckets_with_anchors( documents_df=ddf_text, buckets_df=ddf_bk, shuffle_type=shuffle_type ) + return ddf_anchor_docs_with_bk @@ -123,18 +129,21 @@ def jaccard_get_output_map_workflow( input_meta, ): """ - Workflow for jaccard shuffle + Workflow for Jaccard shuffle. + Args: - client: dask client + client: Dask client input_data_paths: list of paths to input data input_bucket_path: path to input buckets - output_anchor_docs_with_bk_path: path to save anchor docs with bucket info + output_anchor_docs_with_bk_path: path to save anchor documents with bucket + information text_ddf_blocksize: blocksize for text ddf num_files: number of files to read - parts_per_worker: number of parts per worker shuffle_type: type of shuffle to use before writing to parquet """ + num_workers = get_num_workers(client) + ddf_anchor_docs_with_bk = get_anchor_and_output_map_info( input_data_paths, input_bucket_path, @@ -147,6 +156,7 @@ def jaccard_get_output_map_workflow( input_text_field, input_meta=input_meta, ) + ddf_anchor_docs_with_bk.to_parquet( output_anchor_docs_with_bk_path, write_index=False, @@ -160,12 +170,15 @@ def main(args): output_anchor_docs_with_bk_path = os.path.join( OUTPUT_PATH, "anchor_docs_with_bk.parquet" ) + client = get_client(**ArgumentHelper.parse_client_args(args)) - print(f"Num Workers = {get_num_workers(client)}", flush=True) - print("Connected to dask cluster", flush=True) - print("Running jaccard map buckets script", flush=True) + + print(f"Number of workers: {get_num_workers(client)}", flush=True) + print("Connected to Dask cluster", flush=True) + print("Running Jaccard map buckets script", flush=True) print(f"Args = {args}") st = time.time() + jaccard_get_output_map_workflow( client, input_data_paths, @@ -179,8 +192,9 @@ def main(args): args.input_json_text_field, args.input_meta, ) + et = time.time() - print(f"Bucket Mapping time taken = {et-st} s") + print(f"Bucket mapping time taken: {et-st}s") def console_script(): diff --git a/nemo_curator/scripts/fuzzy_deduplication/minhash_lsh.py b/nemo_curator/scripts/fuzzy_deduplication/minhash_lsh.py index 3312c4f9d..dbce804f8 100644 --- a/nemo_curator/scripts/fuzzy_deduplication/minhash_lsh.py +++ b/nemo_curator/scripts/fuzzy_deduplication/minhash_lsh.py @@ -40,7 +40,8 @@ def main(args): assert args.device == "gpu" client = get_client(**ArgumentHelper.parse_client_args(args)) - logger.info(f"Client Created {client}") + + logger.info(f"Client created {client}") client.run(pre_imports) logger.info("Pre imports complete") @@ -53,8 +54,10 @@ def main(args): dfs.append( dask_cudf.read_parquet(data_path, blocksize="2GB", aggregate_files=True) ) + df = dask_cudf.concat(dfs, ignore_unknown_divisions=True) df = df[~df[id_field].isna()] + df = df.map_partitions( convert_str_id_to_int, id_column=id_field, @@ -77,7 +80,7 @@ def main(args): t1 = time.time() _ = lsh(DocumentDataset(df)) - logger.info(f"Computing and writing buckets took {time.time() - t1} s") + logger.info(f"Computing and writing buckets took {time.time() - t1}s") def attach_args(): diff --git a/nemo_curator/scripts/semdedup/clustering.py b/nemo_curator/scripts/semdedup/clustering.py index 98968c817..a4a67f244 100644 --- a/nemo_curator/scripts/semdedup/clustering.py +++ b/nemo_curator/scripts/semdedup/clustering.py @@ -52,9 +52,6 @@ def main(args): embedding_fp = os.path.join( semdedup_config.cache_dir, semdedup_config.embeddings_save_loc ) - clustering_output_dir = os.path.join( - semdedup_config.cache_dir, semdedup_config.clustering_save_loc - ) # Switch to https://github.com/NVIDIA/NeMo-Curator/issues/50 # When we fix that @@ -65,7 +62,8 @@ def main(args): id_column=args.id_column, max_iter=semdedup_config.max_iter, n_clusters=semdedup_config.n_clusters, - clustering_output_dir=clustering_output_dir, + cache_dir=semdedup_config.cache_dir, + clustering_save_loc=semdedup_config.clustering_save_loc, embedding_column=semdedup_config.embedding_column, random_state=semdedup_config.random_state, sim_metric=semdedup_config.sim_metric, diff --git a/nemo_curator/scripts/semdedup/compute_embeddings.py b/nemo_curator/scripts/semdedup/compute_embeddings.py index 8008083dd..41965e5fc 100644 --- a/nemo_curator/scripts/semdedup/compute_embeddings.py +++ b/nemo_curator/scripts/semdedup/compute_embeddings.py @@ -76,9 +76,8 @@ def main(args): embedding_creator = EmbeddingCreator( embedding_model_name_or_path=semdedup_config.embedding_model_name_or_path, embedding_batch_size=semdedup_config.embedding_batch_size, - embedding_output_dir=os.path.join( - semdedup_config.cache_dir, semdedup_config.embeddings_save_loc - ), + cache_dir=semdedup_config.cache_dir, + embeddings_save_loc=semdedup_config.embeddings_save_loc, embedding_max_mem_gb=semdedup_config.embedding_max_mem_gb, embedding_pooling_strategy=semdedup_config.embedding_pooling_strategy, input_column=args.input_text_field, diff --git a/nemo_curator/scripts/semdedup/extract_dedup_data.py b/nemo_curator/scripts/semdedup/extract_dedup_data.py old mode 100755 new mode 100644 index 9c286a842..9a4cf9c87 --- a/nemo_curator/scripts/semdedup/extract_dedup_data.py +++ b/nemo_curator/scripts/semdedup/extract_dedup_data.py @@ -45,20 +45,17 @@ def main(args): semantic_dedup = SemanticClusterLevelDedup( n_clusters=semdedup_config.n_clusters, - emb_by_clust_dir=os.path.join( - cache_dir, semdedup_config.clustering_save_loc, "embs_by_nearest_center" - ), - sorted_clusters_dir=os.path.join( - cache_dir, semdedup_config.clustering_save_loc, "sorted" - ), id_column=args.id_column, id_column_type=args.id_column_type, which_to_keep=semdedup_config.which_to_keep, + cache_dir=semdedup_config.cache_dir, + embedding_column=semdedup_config.embedding_column, + clustering_save_loc=semdedup_config.clustering_save_loc, + logger=logger, + # Hardcoded value output_dir=os.path.join( semdedup_config.cache_dir, semdedup_config.clustering_save_loc ), - embedding_column=semdedup_config.embedding_column, - logger=logger, ) semantic_dedup.compute_semantic_match_dfs(semdedup_config.eps_thresholds) diff --git a/tests/test_exact_dedup.py b/tests/test_exact_dedup.py index 71b9a1034..21576a3a1 100644 --- a/tests/test_exact_dedup.py +++ b/tests/test_exact_dedup.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from hashlib import md5 import pandas as pd import pytest from dask import dataframe as dd +from dask.dataframe.utils import assert_eq +from nemo_curator.cache import Cache from nemo_curator.datasets import DocumentDataset from nemo_curator.modules import ExactDuplicates @@ -48,20 +51,58 @@ def test_unsupported_hash(self): with pytest.raises(ValueError): ExactDuplicates(hash_method="sha256") + @pytest.mark.parametrize("cache_method", [None, "Cache", "ExactDuplicates"]) + def test_exact_dedup_cache_method(self, exact_dedup_data, cache_method, tmpdir): + + Cache().delete_cache_instance() # Fresh start for new PyTest + if cache_method == "Cache": + Cache(cache_dir=tmpdir) + cache_dir = None + elif cache_method == "ExactDuplicates": + cache_dir = tmpdir + else: + cache_dir = None + + exact_dups = ExactDuplicates( + id_field="id", + text_field="text", + hash_method="md5", + cache_dir=cache_dir, + ) + + result = exact_dups(exact_dedup_data) + result = result.df.compute() + expected_df = exact_dedup_data.df.compute() + expected_df = expected_df[expected_df.text.duplicated(keep=False)] + + assert_eq(result.id, expected_df.id, check_index=False) + + # Check that the output is written when either: + # (1) Cache(cache_dir=...) is initialized, or + # (2) ExactDuplicates(cache_dir=...) is initialized. + # If there is no Cache and ExactDuplicates(cache_dir=None), + # then there should be no output file. + if cache_method in ["Cache", "ExactDuplicates"]: + assert os.path.exists(str(tmpdir / "_exact_duplicates.parquet")) + else: + assert not os.path.exists(str(tmpdir / "_exact_duplicates.parquet")) + @pytest.mark.parametrize("cache_result", [False, True]) - def test_dup(self, exact_dedup_data, cache_result, tmpdir): + def test_exact_dedup(self, exact_dedup_data, cache_result, tmpdir): exact_dups = ExactDuplicates( id_field="id", text_field="text", hash_method="md5", cache_dir=tmpdir if cache_result else None, ) + duplicates = exact_dups.identify_duplicates(exact_dedup_data) deduplicated_ds = exact_dups.remove(exact_dedup_data, duplicates) deduplicated_ids_series = deduplicated_ds.df.to_backend("pandas").compute()[ "id" ] output_deduplicated_ids = set(deduplicated_ids_series.tolist()) + assert ( len(output_deduplicated_ids) == 3 and 300 in output_deduplicated_ids diff --git a/tests/test_fuzzy_dedup.py b/tests/test_fuzzy_dedup.py index 53fbb1022..c71e96a7f 100644 --- a/tests/test_fuzzy_dedup.py +++ b/tests/test_fuzzy_dedup.py @@ -24,6 +24,7 @@ from dask.dataframe.utils import assert_eq from nemo_curator import LSH, FuzzyDuplicates, FuzzyDuplicatesConfig, MinHash +from nemo_curator.cache import Cache from nemo_curator.datasets import DocumentDataset from nemo_curator.utils.fuzzy_dedup_utils.merge_utils import extract_partitioning_index from nemo_curator.utils.import_utils import gpu_only_import @@ -164,10 +165,13 @@ def test_minhash_approximation( use_64bit_hash=use_64bit_hash, ) minhashes = minhasher(fuzzy_dedup_data) + minhash_signatures = ( minhashes.df["_minhash_signature"].compute().to_pandas().values ) + strings = fuzzy_dedup_data.df["text"].compute().to_pandas().values + for (sig1, str1), (sig2, str2) in generate_all_pairs( tuple(zip(minhash_signatures, strings)) ): @@ -175,9 +179,19 @@ def test_minhash_approximation( minhash_approximation = minhash_overlap(np.array(sig1), np.array(sig2)) assert abs(true_jaccard - minhash_approximation) < THRESHOLD - def test_minhash_cache(self, fuzzy_dedup_data, tmpdir): - minhasher = MinHash(cache_dir=tmpdir) + @pytest.mark.parametrize("cache_method", ["Cache", "MinHash"]) + def test_minhash_cache(self, fuzzy_dedup_data, tmpdir, cache_method): + + Cache().delete_cache_instance() # Fresh start for new PyTest + if cache_method == "Cache": + Cache(cache_dir=tmpdir) + cache_dir = None + else: + cache_dir = tmpdir + + minhasher = MinHash(cache_dir=cache_dir) result = minhasher(fuzzy_dedup_data) + assert len(result) == len(fuzzy_dedup_data) assert "_minhashes.parquet" in os.listdir(tmpdir) assert len(os.listdir(tmpdir / "_minhashes.parquet")) != 0 @@ -204,20 +218,32 @@ def minhash_data(self): self.dataset = DocumentDataset(df) @pytest.mark.parametrize("buckets_per_shuffle", [1, 2, 3]) - def test_lsh(self, tmpdir, buckets_per_shuffle): + @pytest.mark.parametrize("cache_method", ["Cache", "LSH"]) + def test_lsh(self, tmpdir, buckets_per_shuffle, cache_method): + + Cache().delete_cache_instance() # Fresh start for new PyTest + if cache_method == "Cache": + Cache(cache_dir=tmpdir) + cache_dir = None + else: + cache_dir = tmpdir + lsh = LSH( - cache_dir=tmpdir, + cache_dir=cache_dir, num_hashes=6, num_buckets=3, buckets_per_shuffle=buckets_per_shuffle, minhash_field="minhash_sig", id_fields="id", ) + buckets = lsh(self.dataset) buckets_df = buckets.df docs_list = buckets_df.groupby("_bucket_id").id.agg(list) expected_df = cudf.Series([[1, 2], [2, 3], [4, 5]], name="id") + assert_eq(expected_df, docs_list, check_index=False) + assert "_buckets.parquet" in os.listdir(tmpdir) def test_multiple_id_cols(self, tmpdir): lsh = LSH( @@ -229,14 +255,17 @@ def test_multiple_id_cols(self, tmpdir): minhash_field="minhash_sig", ) buckets = lsh(self.dataset) + buckets_df = buckets.df.compute().to_pandas() buckets_df["new_id"] = list( map(list, zip(buckets_df.dataset_id, buckets_df.id)) ) + docs_list = buckets_df.groupby("_bucket_id").new_id.apply(list) expected_df = cudf.Series( [[(1, 1), (1, 2)], [(1, 2), (2, 3)], [(3, 4), (4, 5)]], name="new_id" ) + assert_eq(expected_df, docs_list, check_index=False) @pytest.mark.parametrize("false_positive_check", [True, False]) @@ -265,6 +294,7 @@ def test_no_duplicates(self, tmpdir, false_positive_check): false_positive_check=false_positive_check, ) buckets = lsh(minhash_dataset) + assert buckets is None assert "_buckets.parquet" not in os.listdir(tmpdir) @@ -292,6 +322,7 @@ def test_partial_overlap(self, tmpdir, false_positive_check): false_positive_check=false_positive_check, ) buckets = lsh(minhash_dataset) + assert len(buckets) == 4 assert buckets.df["_bucket_id"].nunique().compute() == 2 assert_eq( @@ -304,13 +335,14 @@ class TestFuzzyDuplicates: @pytest.mark.parametrize("use_64_bit_hash", [False, True]) @pytest.mark.parametrize( "num_buckets,jaccard_threshold,duplicate_docs", - # Duplcated docs estimated from true_jaccard values + # Duplicated docs estimated from true_jaccard values [ (5, 0.5, [[4, -1]]), (10, 0.39, [[4, -1], [1, 2]]), (15, 0.3, [[4, -1], [1, 2, 300]]), ], ) + @pytest.mark.parametrize("cache_method", ["Cache", "FuzzyDuplicatesConfig"]) def test_fuzzy_dedup( self, fuzzy_dedup_data, @@ -319,13 +351,23 @@ def test_fuzzy_dedup( jaccard_threshold, duplicate_docs, tmpdir, + cache_method, gpu_client, ): print(gpu_client) - # Dedup might fail when indices per partition do not start from 0 + + Cache().delete_cache_instance() # Fresh start for new PyTest + if cache_method == "Cache": + Cache(cache_dir=tmpdir) + cache_dir = None + else: + cache_dir = tmpdir + + # Deduplication might fail when indices per partition do not start from 0 fuzzy_dedup_data.df = fuzzy_dedup_data.df.reset_index(drop=True) + config = FuzzyDuplicatesConfig( - cache_dir=tmpdir, + cache_dir=cache_dir, id_field="id", text_field="text", seed=42, @@ -338,25 +380,30 @@ def test_fuzzy_dedup( num_anchors=2, jaccard_threshold=jaccard_threshold, ) + fuzzy_duplicates = FuzzyDuplicates(config=config) result = fuzzy_duplicates.identify_duplicates(fuzzy_dedup_data) result_df = result.df.compute() - # Drop non duplicated docs + + # Drop non-duplicated documents result_df = result_df[result_df.group.duplicated(keep=False)] result_df = result_df.groupby("group").id.agg(list) - # Sort to maintain uniform ordering + # Sort to maintain uniform ordering result_df = result_df.list.sort_values() result_df = result_df.sort_values() + expected_df = cudf.Series(duplicate_docs, name="id") expected_df = expected_df.list.sort_values() expected_df = expected_df.sort_values() + assert_eq(expected_df, result_df, check_index=False) def test_different_fields(self, fuzzy_dedup_data, tmpdir): fuzzy_dedup_data.df = fuzzy_dedup_data.df.reset_index(drop=True).rename( columns={"id": "col0", "text": "col1"} ) + config = FuzzyDuplicatesConfig( cache_dir=tmpdir, id_field="col0", @@ -369,23 +416,27 @@ def test_different_fields(self, fuzzy_dedup_data, tmpdir): jaccard_threshold=0.39, char_ngrams=5, ) + fuzzy_duplicates = FuzzyDuplicates(config=config) + duplicates = fuzzy_duplicates.identify_duplicates(fuzzy_dedup_data) deduplicated_ds = fuzzy_duplicates.remove(fuzzy_dedup_data, duplicates) deduplicated_df = deduplicated_ds.df.compute() output_deduplicated_ids = set(deduplicated_df["col0"].to_arrow().to_pylist()) + assert len(deduplicated_df) == 3 - # From each of our groups we'll have atmost one document that is not duplicated + # From each of our groups we will have at most 1 document that is not duplicated assert ( 300 in output_deduplicated_ids and len({-1, 4}.intersection(output_deduplicated_ids)) == 1 and len({1, 2}.intersection(output_deduplicated_ids)) == 1 ) - # Drop non duplicated docs + # Drop non-duplicated documents duplicates_df = duplicates.df.compute() duplicates_df = duplicates_df[duplicates_df.group.duplicated(keep=False)] duplicates_df = duplicates_df.groupby("group")["col0"].agg(list) + # Sort to maintain uniform ordering duplicates_df = duplicates_df.list.sort_values() duplicates_df = duplicates_df.sort_values() @@ -394,6 +445,7 @@ def test_different_fields(self, fuzzy_dedup_data, tmpdir): expected_df = cudf.Series(duplicate_docs, name="col0") expected_df = expected_df.list.sort_values() expected_df = expected_df.sort_values() + assert_eq(expected_df, duplicates_df, check_index=False) @pytest.mark.xfail @@ -403,7 +455,8 @@ def test_non_uniform_indices( gpu_client, ): print(gpu_client) - # Dedup might fail when indices per partition do not start from 0 + + # Deduplication might fail when indices per partition do not start from 0 df = cudf.DataFrame( { "id": [1, 2, 300, 4, -1], @@ -418,7 +471,9 @@ def test_non_uniform_indices( ) df = dask_cudf.from_cudf(df, 2) data = DocumentDataset(df) + duplicate_docs = [[4, -1], [1, 2, 300]] + config = FuzzyDuplicatesConfig( cache_dir=tmpdir, id_field="id", @@ -434,34 +489,50 @@ def test_non_uniform_indices( jaccard_threshold=0.39, ) fuzzy_duplicates = FuzzyDuplicates(config=config) + duplicates = fuzzy_duplicates.identify_duplicates(data) deduplicated_ds = fuzzy_duplicates.remove(fuzzy_dedup_data, duplicates) deduplicated_df = deduplicated_ds.df.compute() output_deduplicated_ids = set(deduplicated_df["col0"].to_arrow().to_pylist()) + assert len(deduplicated_df) == 2 - # From each of our groups we'll have atmost one document that is not duplicated + # From each of our groups we will have at most 1 document that is not duplicated assert ( len({4, -1}.intersection(output_deduplicated_ids)) == 1 and len({1, 2, 300}.intersection(output_deduplicated_ids)) == 1 ) duplicates_df = duplicates.df.compute() - # Drop non duplicated docs + + # Drop non-duplicated documents duplicates_df = duplicates_df[duplicates_df.group.duplicated(keep=False)] duplicates_df = duplicates_df.groupby("group").id.agg(list) - # Sort to maintain uniform ordering + # Sort to maintain uniform ordering duplicates_df = duplicates_df.list.sort_values() duplicates_df = duplicates_df.sort_values() + expected_df = cudf.Series(duplicate_docs, name="id") expected_df = expected_df.list.sort_values() expected_df = expected_df.sort_values() + assert_eq(expected_df, duplicates_df, check_index=False) @pytest.mark.parametrize("num_anchors", [1, 3, 10]) - def test_num_anchors(self, large_fuzzy_dedup_data, num_anchors, tmpdir): + @pytest.mark.parametrize("cache_method", ["Cache", "FuzzyDuplicatesConfig"]) + def test_num_anchors( + self, large_fuzzy_dedup_data, num_anchors, tmpdir, cache_method + ): + + Cache().delete_cache_instance() # Fresh start for new PyTest + if cache_method == "Cache": + Cache(cache_dir=tmpdir) + cache_dir = None + else: + cache_dir = tmpdir + config = FuzzyDuplicatesConfig( - cache_dir=tmpdir, + cache_dir=cache_dir, id_field="id", text_field="text", seed=42, @@ -474,17 +545,20 @@ def test_num_anchors(self, large_fuzzy_dedup_data, num_anchors, tmpdir): num_anchors=num_anchors, jaccard_threshold=0.39, ) + fuzzy_duplicates = FuzzyDuplicates(config=config) fuzzy_duplicates(large_fuzzy_dedup_data) + anchor_docs_df_cols = dask_cudf.read_parquet( tmpdir / "anchor_docs_with_bk.parquet" ).columns + assert all(f"anchor_{i}_id" in anchor_docs_df_cols for i in range(num_anchors)) @pytest.mark.parametrize("use_64_bit_hash", [False, True]) @pytest.mark.parametrize( "num_buckets,duplicate_docs", - # Duplcated docs estimated from true_jaccard values + # Duplicated docs estimated from true_jaccard values [ (10, [[4, -1], [1, 2, 300]]), (5, [[4, -1], [1, 2, 300]]), @@ -513,19 +587,23 @@ def test_no_fp_check( num_anchors=2, jaccard_threshold=0.39, ) + fuzzy_duplicates = FuzzyDuplicates(config=config) result = fuzzy_duplicates.identify_duplicates(fuzzy_dedup_data) result_df = result.df.compute() - # Drop non duplicated docs + + # Drop non-duplicated documents result_df = result_df[result_df.group.duplicated(keep=False)] result_df = result_df.groupby("group").id.agg(list) - # Sort to maintain uniform ordering + # Sort to maintain uniform ordering result_df = result_df.list.sort_values() result_df = result_df.sort_values() + expected_df = cudf.Series(duplicate_docs, name="id") expected_df = expected_df.list.sort_values() expected_df = expected_df.sort_values() + assert_eq(expected_df, result_df, check_index=False) def test_shuffle_fail_fuzzy_dedup_data( @@ -534,7 +612,7 @@ def test_shuffle_fail_fuzzy_dedup_data( tmpdir, gpu_client, ): - # Dedup might fail when indices per partition do not start from 0 + # Deduplication might fail when indices per partition do not start from 0 shuffle_fail_fuzzy_dedup_data.df = shuffle_fail_fuzzy_dedup_data.df.reset_index( drop=True ) @@ -552,29 +630,34 @@ def test_shuffle_fail_fuzzy_dedup_data( num_anchors=2, jaccard_threshold=0.39, ) + fuzzy_duplicates = FuzzyDuplicates(config=config) result = fuzzy_duplicates.identify_duplicates(shuffle_fail_fuzzy_dedup_data) result_df = result.df.compute() - # Drop non duplicated docs + + # Drop non-duplicated documents result_df = result_df[result_df.group.duplicated(keep=False)] result_df = result_df.groupby("group").id.agg(list) - # Sort to maintain uniform ordering + # Sort to maintain uniform ordering result_df = result_df.list.sort_values() result_df = result_df.sort_values() + expected_df = cudf.Series([[1, 2]], name="id") expected_df = expected_df.list.sort_values() expected_df = expected_df.sort_values() + assert_eq(expected_df, result_df, check_index=False) @pytest.mark.parametrize("false_positive_check", [True, False]) def test_fuzzy_dedup_no_duplicates( self, no_duplicates_fuzzy_dedup_data, tmpdir, false_positive_check, gpu_client ): - # Dedup might fail when indices per partition do not start from 0 + # Deduplication might fail when indices per partition do not start from 0 no_duplicates_fuzzy_dedup_data.df = ( no_duplicates_fuzzy_dedup_data.df.reset_index(drop=True) ) + config = FuzzyDuplicatesConfig( cache_dir=tmpdir, id_field="id", @@ -589,6 +672,7 @@ def test_fuzzy_dedup_no_duplicates( num_anchors=2, jaccard_threshold=0.39, ) + fuzzy_duplicates = FuzzyDuplicates(config=config) result = fuzzy_duplicates.identify_duplicates(no_duplicates_fuzzy_dedup_data) assert result is None @@ -596,6 +680,7 @@ def test_fuzzy_dedup_no_duplicates( class TestFuzzyDuplicatesConfig: def test_bad_inputs(self, tmpdir): + with pytest.raises(ValueError): FuzzyDuplicatesConfig( cache_dir=tmpdir, num_anchors=0, false_positive_check=True @@ -610,23 +695,27 @@ def test_bad_inputs(self, tmpdir): FuzzyDuplicatesConfig( cache_dir=None, num_anchors=0, false_positive_check=True ) + with pytest.warns( - UserWarning, match="Using a higher number of anchor docs might" + UserWarning, match="Using a higher number of anchor documents might" ): FuzzyDuplicatesConfig( cache_dir=tmpdir, num_anchors=3, false_positive_check=True ) + with pytest.warns( UserWarning, match="Using a small char_ngrams value might lead" ): FuzzyDuplicatesConfig( cache_dir=tmpdir, char_ngrams=10, false_positive_check=False ) + with pytest.warns( UserWarning, - match="Identifying false positives during the Minhash deduplication is computationally expensive", + match="Identifying false positives during Minhash deduplication is computationally expensive", ): FuzzyDuplicatesConfig(cache_dir=tmpdir, false_positive_check=True) + with pytest.warns( UserWarning, match="False positive check is disabled. Unused arguments", @@ -638,6 +727,11 @@ def test_bad_inputs(self, tmpdir): jaccard_threshold=0.8, ) + # Need to specify either Cache(cache_dir=...) or FuzzyDuplicatesConfig(cache_dir=...) + with pytest.raises(ValueError): + Cache().delete_cache_instance() # Fresh start for new PyTest + FuzzyDuplicatesConfig(cache_dir=None) + def test_from_yaml(self, tmpdir): yaml_params = { "cache_dir": "./", @@ -647,9 +741,12 @@ def test_from_yaml(self, tmpdir): "buckets_per_shuffle": 1, "char_ngrams": 20, } + with open(tmpdir / "config.yaml", "w") as f: yaml.dump(yaml_params, f) + config = FuzzyDuplicatesConfig.from_yaml(tmpdir / "config.yaml") + for param in yaml_params: assert getattr(config, param) == yaml_params[param] @@ -674,8 +771,7 @@ def add_partition_info(df, partition_info=None): return df with config.set({"dataframe.backend": backend}): - - # Create a random `unshuffled` DataFrame with a + # Create a random unshuffled DataFrame with a # "part_id" column to be used as the shuffle index npartitions_left = 7 unshuffled = dd.from_dict( @@ -683,7 +779,7 @@ def add_partition_info(df, partition_info=None): npartitions=npartitions_left, ) - # Create a `bk_mapping` DataFrame that defines + # Create a bk_mapping DataFrame that defines # the "correct" mapping beween "part_id" and # the destination partition ("file_id") npartitions_right = 5 @@ -697,7 +793,7 @@ def add_partition_info(df, partition_info=None): .compute() ) - # Use `extract_partitioning_index` to calculate + # Use extract_partitioning_index to calculate # the partitioning index and assign it as a new # "_partitions" column result, _ = extract_partitioning_index( diff --git a/tests/test_semdedup.py b/tests/test_semdedup.py index 61f4403b3..1aecea882 100644 --- a/tests/test_semdedup.py +++ b/tests/test_semdedup.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import numpy as np @@ -21,6 +22,7 @@ from transformers import AutoConfig, AutoModel, AutoTokenizer from nemo_curator import SemDedup, SemDedupConfig +from nemo_curator.cache import Cache from nemo_curator.datasets import DocumentDataset from nemo_curator.utils.import_utils import gpu_only_import, gpu_only_import_from @@ -68,17 +70,25 @@ def non_dedup_data(): @pytest.mark.gpu class TestSemDuplicates: + @pytest.mark.parametrize("cache_method", ["Cache", "SemDedupConfig"]) @pytest.mark.parametrize("n_clusters", [3, 10]) def test_sem_dedup( self, dedup_data, tmpdir, + cache_method, n_clusters, gpu_client, ): print("client", gpu_client) - cache_dir = os.path.join(tmpdir, "test_sem_dedup_cache") + Cache().delete_cache_instance() # Fresh start for new PyTest + if cache_method == "Cache": + Cache(cache_dir=os.path.join(tmpdir, "test_sem_dedup_cache")) + cache_dir = None + else: + cache_dir = os.path.join(tmpdir, "test_sem_dedup_cache") + config = SemDedupConfig( cache_dir=cache_dir, n_clusters=n_clusters, @@ -106,6 +116,58 @@ def test_sem_dedup( expected_df = cudf.Series(duplicate_docs, name="id") assert_eq(result_df["id"].sort_values(), expected_df, check_index=False) + # Check that the output is written when either: + # (1) Cache(cache_dir=...) is initialized, or + # (2) SemDedupConfig(cache_dir=...) is initialized. + # Either way, their output files should be identical. + cache_dir = os.path.join(tmpdir, "test_sem_dedup_cache") + + assert os.path.exists(cache_dir) + assert os.path.exists(cache_dir + "/embeddings/part.0.parquet") + assert os.path.exists(cache_dir + "/embeddings/part.1.parquet") + assert os.path.exists( + cache_dir + "/clustering_results/dedup_summary_0.1.csv" + ) + assert os.path.exists( + cache_dir + "/clustering_results/kmeans_centroids.npy" + ) + assert os.path.exists( + cache_dir + "/clustering_results/sorted/cluster_0.npy" + ) + assert os.path.exists( + cache_dir + "/clustering_results/sorted/cluster_1.npy" + ) + assert os.path.exists( + cache_dir + "/clustering_results/sorted/cluster_2.npy" + ) + assert os.path.exists( + cache_dir + + "/clustering_results/embs_by_nearest_center/nearest_cent=0/part.0.parquet" + ) + assert os.path.exists( + cache_dir + + "/clustering_results/embs_by_nearest_center/nearest_cent=1/part.0.parquet" + ) + assert os.path.exists( + cache_dir + + "/clustering_results/embs_by_nearest_center/nearest_cent=2/part.0.parquet" + ) + assert os.path.exists( + cache_dir + + "/clustering_results/semdedup_pruning_tables/cluster_0.parquet" + ) + assert os.path.exists( + cache_dir + + "/clustering_results/semdedup_pruning_tables/cluster_1.parquet" + ) + assert os.path.exists( + cache_dir + + "/clustering_results/semdedup_pruning_tables/cluster_2.parquet" + ) + assert os.path.exists( + cache_dir + "/clustering_results/unique_ids_0.1.parquet" + ) + @pytest.mark.parametrize("n_clusters", [2, 3]) def test_no_sem_dedup( self, @@ -157,9 +219,10 @@ def test_embedding_creator_pooling_strategies(self, tmpdir, pooling_strategy): embedding_creator = EmbeddingCreator( embedding_model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", embedding_batch_size=32, + cache_dir=cache_dir, + embeddings_save_loc="mean_embeddings", embedding_pooling_strategy=pooling_strategy, input_column="text", - embedding_output_dir=os.path.join(cache_dir, "mean_embeddings"), ) embeddings = embedding_creator.create_embeddings(ddf).compute() diff --git a/tutorials/image-curation/image-curation.ipynb b/tutorials/image-curation/image-curation.ipynb index 47bed501b..4f24c0292 100644 --- a/tutorials/image-curation/image-curation.ipynb +++ b/tutorials/image-curation/image-curation.ipynb @@ -641,7 +641,8 @@ " embedding_col=\"image_embedding\",\n", " max_iter=10,\n", " n_clusters=1,\n", - " clustering_output_dir=clustering_output,\n", + " cache_dir=semantic_dedup_outputs,\n", + " clustering_save_loc=\"cluster_output\",\n", ")\n", "clustered_dataset = clustering_model(embeddings_dataset)" ] @@ -669,19 +670,17 @@ ], "source": [ "# Run cluster-level dedup\n", - "emb_by_cluster_output = os.path.join(clustering_output, \"embs_by_nearest_center\")\n", - "sorted_cluster_output = os.path.join(clustering_output, \"sorted\")\n", "duplicate_output = os.path.join(semantic_dedup_outputs, \"duplicates\")\n", "\n", "semantic_dedup = SemanticClusterLevelDedup(\n", " n_clusters=1,\n", - " emb_by_clust_dir=emb_by_cluster_output,\n", - " sorted_clusters_dir=sorted_cluster_output,\n", " id_column=id_col,\n", " id_column_type=\"str\",\n", " embedding_col=\"image_embedding\",\n", " which_to_keep=\"hard\",\n", " output_dir=duplicate_output,\n", + " cache_dir=semantic_dedup_outputs,\n", + " clustering_save_loc=\"cluster_output\",\n", ")\n", "semantic_dedup.compute_semantic_match_dfs([0.01, 0.001])\n", "deduplicated_dataset_ids = semantic_dedup.extract_dedup_data(eps_to_extract=0.01)" diff --git a/tutorials/peft-curation-with-sdg/main.py b/tutorials/peft-curation-with-sdg/main.py index 1ec231f8a..a52a50c0e 100644 --- a/tutorials/peft-curation-with-sdg/main.py +++ b/tutorials/peft-curation-with-sdg/main.py @@ -131,6 +131,7 @@ def semantic_dedupe(dataset): os.path.join(CONFIG_DIR, "sem_dedup_config.yaml") ) expand_outdir_and_mkdir(semdedup_config.cache_dir) + semdup = SemDedup( config=semdedup_config, input_column="text", @@ -138,6 +139,7 @@ def semantic_dedupe(dataset): id_column_type="str", ) dedup_ids = semdup(dataset) + # When there are few duplicates we can compute the results to a list and use `isin`. result = dataset.df[dataset.df["id"].isin(dedup_ids.df["id"].compute())] return DocumentDataset(result) diff --git a/tutorials/pretraining-vietnamese-data-curation/pretraining-vietnamese-data-curation.ipynb b/tutorials/pretraining-vietnamese-data-curation/pretraining-vietnamese-data-curation.ipynb index 6181a8a74..3389fc6ee 100644 --- a/tutorials/pretraining-vietnamese-data-curation/pretraining-vietnamese-data-curation.ipynb +++ b/tutorials/pretraining-vietnamese-data-curation/pretraining-vietnamese-data-curation.ipynb @@ -806,7 +806,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -832,7 +832,7 @@ " id_field=exact_dedup_dataset_id_field,\n", " text_field=exact_dedup_dataset_text_field,\n", " hash_method=\"md5\",\n", - " cache_dir=exact_dedup_output_dir\n", + " cache_dir=exact_dedup_output_dir,\n", ")\n", "duplicates = exact_dup(dataset=input_dataset)\n", "\n", diff --git a/tutorials/single_node_tutorial/single_gpu_tutorial.ipynb b/tutorials/single_node_tutorial/single_gpu_tutorial.ipynb index 510d6d258..0b2581856 100644 --- a/tutorials/single_node_tutorial/single_gpu_tutorial.ipynb +++ b/tutorials/single_node_tutorial/single_gpu_tutorial.ipynb @@ -768,15 +768,15 @@ "id": "1baf027e", "metadata": {}, "source": [ - "## 4.Exact Deduplication\n", + "## 4. Exact Deduplication\n", "\n", - "In exact deduplication, the document text is hashed into unique string using certain hashing algorithm, such as 'md5'. The documents with exact hashed values are having identical text. We will output the `ID` of duplicated documents for removal later. The function used is `ExactDuplicates()`. Arguments for this function include:\n", - "- `id_field`: Key in input file for identifying document ID\n", - "- `text_field`: Key in input file which contains document text.\n", - "- `hash_method`: Hashing algorithm used. Default is `md5`\n", - "- `cache_dir`: If specified, the duplicated document IDs will be output to the `cache_dir`. Otherwise, the IDs will not be saved\n", + "In exact deduplication, the document text is hashed into a unique string by using a hashing algorithm such as md5. The documents with exact hashed values are identified as having identical text. We will output the ID of duplicated documents for removal later. The class used for exact deduplication in NeMo Curator is called `ExactDuplicates`. Fields for this class include:\n", + "- `id_field`: Column in input file which contains a unique ID.\n", + "- `text_field`: Column in input file which contains document text.\n", + "- `hash_method`: Hashing algorithm used. Default is \"md5\".\n", + "- `cache_dir`: If specified via `ExactDuplicates(cache_dir=...)` or `Cache(cache_dir=...)`, the duplicated document IDs will be output to the cache directory. Otherwise, the IDs will not be saved.\n", "\n", - "Also, we are going to use GPU dask cluster to accelerate computation for deduplication (both exact and fuzzy)\n" + "We are going to use a GPU-based Dask cluster to accelerate computation for deduplication (both exact and fuzzy deduplication).\n" ] }, { @@ -901,19 +901,19 @@ "# Read input dataset\n", "input_dataset = DocumentDataset.read_json(exact_dedup_input_dataset_dir, backend='cudf')\n", "\n", - "#Run exact deduplication to the input\n", + "# Run exact deduplication to the input\n", "exact_dup = ExactDuplicates(\n", " logger=exact_dedup_log_dir,\n", " id_field=exact_dedup_dataset_id_field,\n", " text_field=exact_dedup_dataset_text_field,\n", " hash_method=\"md5\",\n", - " cache_dir=exact_dedup_output_dir #Duplicated document ID list is output to the cache_dir\n", + " cache_dir=exact_dedup_output_dir, # Duplicated document ID list is output to the cache_dir\n", ")\n", "duplicates = exact_dup(dataset=input_dataset)\n", "\n", - "print(f\"Number of exact duplicated file:{len(duplicates)}\")\n", + "print(f\"Number of exact duplicated files: {len(duplicates)}\")\n", "\n", - "print(f\"Time taken for exact duplicate:{time.time()-t0}\")" + "print(f\"Time taken for exact deduplication: {time.time()-t0}s\")" ] }, { @@ -1249,23 +1249,21 @@ "id": "5df73743", "metadata": {}, "source": [ - "### 5.2 Minhash\n", + "### 5.2 MinHash\n", "\n", - "Run `MinHash()` for this section. The output of a minhash is a parquet file which contains document ID and hashed value which is an array contains 260 32-bit integer data. To obtain such hashed values we need to go through the following steps:\n", "1. Generate a set of n-gram components of a document. For example, doc = `Nemo Curator is a data curation tool`, a 3-gram set of this document will be `['Nemo Curator is','Curator is a','is a data','a data curation','data curation tool']`\n", - "2. Hashed each n-gram into numerical values\n", + "2. Hash each n-gram into numerical values\n", "3. Generate a random hash function $H_1()$ which will hash each numeric n-gram into a 32-bit integer and take the minimum integer to use as minhash value for $H_1()$\n", - "4. Repeat step 2 and 3 with hash function $H_x()$ until desired minhash length is reached. Minhash value of each iteration will be append together to form the final minhash array. \n", + "4. Repeat step 2 and 3 with hash function $H_x()$ until the desired minhash length is reached. The minhash value of each iteration will be append together to form the final minhash array. \n", "\n", "Arguments include:\n", - "- `seed`:Random seed used for initializing the hash functions used to compute the MinHashes. It's advised to keep this value the same for different experiment for reproducibility\n", - "- `num_hashes`:Length of each minhash array. Default is 260. Longer minhash length will have better estimate of actual Jaccard similarity, but require more computational power\n", - "- `char_ngrams`:n-gram length. Assuming an average of 4.5 chars per word it's recommended to use `char_ngrams>=24` to use ~5 word ngrams or greater.\n", - "- `use_64bit_hash`:Whether to use 64bit or 32bit hash function\n", - "- `id_field`: Key in input file for identifying document ID\n", - "- `text_field`: Key in input file which contains document text.\n", - "- `cache_dir`: If specified, the intermediate result will be output to the `cache_dir`. \n", - "\n" + "- `seed`: Random seed used for initializing the hash functions used to compute the minhashes. It is advised to keep this value the same for different experiments for reproducibility.\n", + "- `num_hashes`: Length of each minhash array. Default is 260. A longer minhash length will have better estimate of actual Jaccard similarity, but require more computational power.\n", + "- `char_ngrams`: n-gram length. Assuming an average of 4.5 characters per word, it is recommended to use `char_ngrams>=24` to use ~5 word n-grams or greater.\n", + "- `use_64bit_hash`: Whether to use 64-bit or 32-bit hash function.\n", + "- `id_field`: Column in input file which contains a unique ID.\n", + "- `text_field`: Column in input file which contains document text.\n", + "- `cache_dir`: If specified via `MinHash(cache_dir=...)` or `Cache(cache_dir=...)`, the intermediate result will be output to the cache directory." ] }, { @@ -1297,18 +1295,12 @@ }, "outputs": [], "source": [ - "#Input\n", + "# Input\n", "minhash_data_path = added_id_output_path\n", - "#Output\n", - "minhash_base_output_path = os.path.join(data_dir,\"fuzzy/minhash\")\n", - "minhash_log_dir = os.path.join(minhash_base_output_path,'log')\n", - "minhash_output_dir = os.path.join(minhash_base_output_path,'data')\n", - "#Specify dataset name\n", - "dataset_name = 'TH_wikipedia'\n", "\n", - "#Relevant parameters\n", - "minhash_id_field = 'id'\n", - "minhash_text_field = 'text'\n", + "# Relevant parameters\n", + "minhash_id_field = \"id\"\n", + "minhash_text_field = \"text\"\n", "seed = 10 # Using the same value as the wrapper above for consistency\n", "minhash_length = 260\n", "char_ngram = 24\n", @@ -1339,7 +1331,8 @@ "t0 = time.time()\n", "print(f\"Computing minhashes for {minhash_data_path}\")\n", "\n", - "# Load data. Only the [minhash_id_field, text_field] columns are needed\n", + "# Load data\n", + "# Only the [minhash_id_field, minhash_text_field] columns are needed\n", "files = get_all_files_paths_under(\n", " root=minhash_data_path, recurse_subdirectories=False, keep_extensions=\"jsonl\"\n", ")\n", @@ -1351,7 +1344,7 @@ " add_filename=False,\n", ")[[minhash_id_field, minhash_text_field]]\n", "\n", - "# Run MinHash() on input data\n", + "# Run MinHash on input data\n", "minhasher = MinHash(\n", " seed=seed,\n", " num_hashes=minhash_length,\n", @@ -1360,11 +1353,11 @@ " logger=minhash_log_dir,\n", " id_field=minhash_id_field,\n", " text_field=minhash_text_field,\n", - " cache_dir=minhash_output_dir\n", + " cache_dir=minhash_output_dir,\n", ")\n", "res = minhasher(DocumentDataset(df)).df\n", "\n", - "print(f\"Time taken for MinHash:{time.time()-t0}\")" + "print(f\"Time taken for MinHash: {time.time()-t0}\")" ] }, { @@ -1394,19 +1387,18 @@ "metadata": {}, "source": [ "### 5.3 LSH\n", - "`LSH()` implements LSH algorithm which includes the following steps:\n", + "`LSH` implements Locality Sensitive Hashing algorithm which includes the following steps:\n", "1. Divide the minhash array into `X` different portions. \n", "2. For each portions, hash the minhash values into buckets. One document will be assigned to `X` buckets.\n", "3. Documents within the same bucket will be deemed similar. Since every document will be assigned `X` buckets and as long as two documents share 1 or more buckets they are deemed similar.\n", "\n", "Arguments include:\n", - "- `minhash_length`:Length of minhash signature. Must be consistent with `MinHash()`\n", - "- `num_buckets`: Number of buckets\n", - "- `buckets_per_shuffle`: Number of buckets to shuffle concurrently\n", - "- `id_field`: Key in input file for identifying document ID\n", - "- `minhash_field`: Key in input file for identifying document MinHash signature \n", - "- `cache_dir`:If specified, the intermediate result will be output to the `cache_dir`.\n", - "\n" + "- `minhash_length`: Length of minhash signature. Must be consistent with `MinHash`.\n", + "- `num_buckets`: Number of buckets.\n", + "- `buckets_per_shuffle`: Number of buckets to shuffle concurrently.\n", + "- `id_field`: Column in input file which contains a unique ID.\n", + "- `minhash_field`: Column in input file for identifying a document's MinHash signature.\n", + "- `cache_dir`: If specified via `LSH(cache_dir=...)` or `Cache(cache_dir=...)`, the intermediate result will be output to the cache directory.\"" ] }, { @@ -1438,20 +1430,20 @@ }, "outputs": [], "source": [ - "#Input\n", + "# Input\n", "lsh_input_data_path = minhash_output_dir\n", "\n", - "#Output\n", - "lsh_base_output_path = os.path.join(data_dir,\"fuzzy/lsh\")\n", - "lsh_log_dir = os.path.join(lsh_base_output_path,'log')\n", - "lsh_output_dir = os.path.join(lsh_base_output_path,'data')\n", + "# Output\n", + "lsh_base_output_path = os.path.join(data_dir, \"fuzzy/lsh\")\n", + "lsh_log_dir = os.path.join(lsh_base_output_path, \"log\")\n", + "lsh_output_dir = os.path.join(lsh_base_output_path, \"data\")\n", "\n", - "#Relevant parameters\n", - "lsh_id_field = 'id'\n", - "minhash_field = '_minhash_signature'\n", - "minhash_length=260\n", - "num_bands=20\n", - "buckets_per_shuffle=1\n", + "# Relevant parameters\n", + "lsh_id_field = \"id\"\n", + "minhash_field = \"_minhash_signature\"\n", + "minhash_length = 260\n", + "num_bands = 20\n", + "buckets_per_shuffle = 1\n", "\n", "!mkdir -p {lsh_log_dir}\n", "!mkdir -p {lsh_output_dir}" @@ -1476,10 +1468,10 @@ "source": [ "t0 = time.time()\n", "\n", - "#Load MinHash output\n", + "# Load MinHash output\n", "df = dask_cudf.read_parquet(lsh_input_data_path, blocksize=\"2GB\", aggregate_files=True, backend = \"cudf\")\n", "\n", - "#Run LSH()\n", + "# Run LSH\n", "lsh = LSH(\n", " cache_dir=lsh_output_dir,\n", " num_hashes=minhash_length,\n", @@ -1492,7 +1484,7 @@ "res = lsh(DocumentDataset(df))\n", "\n", "t1 = time.time()\n", - "print(f\"Time taken for LSH:{time.time()-t0}\")" + "print(f\"Time taken for LSH: {time.time()-t0}s\")" ] }, { @@ -1632,13 +1624,13 @@ "metadata": {}, "source": [ "### 5.5 Connected Components\n", - "This section uses `ConnectedComponents()`.This section takes a dataset consisting of document pairs and their corresponding jaccard similarity to construct a non-directed graph. A edge will be formed between documents whose Jaccard similarity is higher than the threshold. It will then identify the connected components in this graph. Documents within the same connected components are deemed duplicated.\n", + "This section uses the `ConnectedComponents` class. This section takes a dataset consisting of document pairs and their corresponding Jaccard similarity scores to construct a non-directed graph. A edge will be formed between documents whose Jaccard similarity is higher than a given threshold (0.8 in this example). It will then identify the connected components in this graph. Documents within the same connected components are deemed duplicates.\n", "\n", "Arguments include:\n", - "- `cache_dir`: Output path for intermediate results\n", - "- `jaccard_pairs_path`: Input path for `jaccard_similarity_results.parquet`\n", - "- `id_column`: prefix of ID column in `jaccard_similarity_results.parquet`\n", - "- `jaccard_threshold`: Threshold to determine if an edge exists between two documents" + "- `cache_dir`: If specified via `ConnectedComponents(cache_dir=...)` or `Cache(cache_dir=...)`, the intermediate results will be output to the cache directory.\n", + "- `jaccard_pairs_path`: Input path for `jaccard_similarity_results.parquet`.\n", + "- `id_column`: Prefix of ID column in `jaccard_similarity_results.parquet`.\n", + "- `jaccard_threshold`: Threshold to determine if an edge exists between two documents." ] }, { @@ -1670,17 +1662,17 @@ }, "outputs": [], "source": [ - "#Input\n", + "# Input\n", "jaccard_pairs_path = edgelist_output_dir\n", "\n", - "#Output\n", - "connected_component_base_output_path = os.path.join(data_dir,\"fuzzy/cc\")\n", + "# Output\n", + "connected_component_base_output_path = os.path.join(data_dir, \"fuzzy/cc\")\n", "connected_component_output_path = os.path.join(connected_component_base_output_path, \"connected_components.parquet\")\n", "connected_component_cache_dir = os.path.join(connected_component_base_output_path, \"cache\")\n", "connected_component_log_path = os.path.join(connected_component_base_output_path,\"log\")\n", "\n", - "#Relevant parameters\n", - "input_id_field = 'id'\n", + "# Relevant parameters\n", + "input_id_field = \"id\"\n", "\n", "!mkdir -p {connected_component_base_output_path}\n", "!mkdir -p {connected_component_log_path}" @@ -1712,9 +1704,9 @@ " logger=connected_component_log_path,\n", ")\n", "\n", - "#Load and run connected component\n", + "# Load and run connected components\n", "components_stage.cc_workflow(output_path=connected_component_output_path)\n", - "print(f\"Time taken for Connected Component: {time.time()-t0} s\")" + "print(f\"Time taken for Connected Components: {time.time()-t0}s\")" ] }, { diff --git a/tutorials/zyda2-tutorial/1_fuzzy_dedup/2_buckets_to_edges.py b/tutorials/zyda2-tutorial/1_fuzzy_dedup/2_buckets_to_edges.py index 457556734..21783256b 100644 --- a/tutorials/zyda2-tutorial/1_fuzzy_dedup/2_buckets_to_edges.py +++ b/tutorials/zyda2-tutorial/1_fuzzy_dedup/2_buckets_to_edges.py @@ -39,7 +39,6 @@ cache_dir=buckets_to_edges_out, id_fields=["dataset_id", "doc_id"], ) - ddf_b2e = buckets_to_edges(DocumentDataset(ddf_bk)) logging.info(f"Time taken for Buckets to Edges: {time.time() - t0} s")