Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,17 @@ process:
use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese
words_aug_group_sizes: [2] # the group size of words to augment
words_aug_join_char: "" # the join char between words to augment
- general_field_filter: # Filter to keep samples based on a general field filter condition.
filter_condition: "" # The filter condition as a string. It can include logical operators (and/or) and chain comparisons.
- general_field_filter: # Filter to keep samples based on a general field filter condition.
filter_condition: "" # The filter condition as a string. It can include logical operators (and/or) and chain comparisons.
- group_diversity_filter: # filter samples based on their semantic diversity within a group.
api_or_hf_model: "text-embedding-v3" # API or huggingface embedding model name.
is_hf_model: false # indicates if the model is from HuggingFace.
api_endpoint: "/embeddings" # embedding URL endpoint for the API.
response_path: "data.0.embedding" # path to extract content from the API response.
ebd_dim: 512 # the embedding's dimension via API.
min_score: 0.0 # the min score of filter range
max_score: 1.0 # the max score of filter range
norm_ratio: 0.5 # ratio to normalize the score.
- image_aesthetics_filter: # filter samples according to the aesthetics score of images.
hf_scorer_model: shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE # Huggingface model name for the aesthetics predictor
min_score: 0.3 # the min aesthetics score of filter range
Expand Down
4 changes: 3 additions & 1 deletion data_juicer/ops/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .character_repetition_filter import CharacterRepetitionFilter
from .flagged_words_filter import FlaggedWordFilter
from .general_field_filter import GeneralFieldFilter
from .group_diversity_filter import GroupDiversityFilter
from .image_aesthetics_filter import ImageAestheticsFilter
from .image_aspect_ratio_filter import ImageAspectRatioFilter
from .image_face_count_filter import ImageFaceCountFilter
Expand Down Expand Up @@ -63,6 +64,8 @@
"AverageLineLengthFilter",
"CharacterRepetitionFilter",
"FlaggedWordFilter",
"GeneralFieldFilter",
"GroupDiversityFilter",
"ImageAestheticsFilter",
"ImageAspectRatioFilter",
"ImageFaceCountFilter",
Expand Down Expand Up @@ -109,7 +112,6 @@
"VideoWatermarkFilter",
"WordRepetitionFilter",
"WordsNumFilter",
"GeneralFieldFilter",
]

NON_STATS_FILTERS = [
Expand Down
140 changes: 140 additions & 0 deletions data_juicer/ops/filter/group_diversity_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import sys
from typing import Dict, List

import numpy as np
from jsonargparse.typing import NonNegativeFloat, PositiveInt
from tqdm import tqdm

from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Filter

# Lazy load torch to improve startup time
torch = LazyLoader("torch")


@OPERATORS.register_module("group_diversity_filter")
class GroupDiversityFilter(Filter):
"""
Filter samples based on their semantic diversity within a group.
"""

_accelerator = "cuda"
_batched_op = True

def __init__(
self,
api_or_hf_model: str = "text-embedding-v3",
is_hf_model: bool = False,
api_endpoint: str = "/embeddings",
response_path: str = "data.0.embedding",
model_params: Dict = {},
ebd_dim: PositiveInt = 512,
min_score: NonNegativeFloat = 0.0,
max_score: NonNegativeFloat = 1.0,
norm_ratio: NonNegativeFloat = 0.5,
*args,
**kwargs,
):
"""
Initialization method.

:param api_or_hf_model: API or huggingface embedding model name.
:param is_hf_model: Indicates if the model is from HuggingFace.
:param api_endpoint: Embedding URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'data.0.embedding' for embedding model.
:param model_params: Parameters for initializing the API model.
:param ebd_dim: The embedding's dimension via API.
:param min_score: Minimum score for filtering.
:param max_score: Maximum score for filtering.
:param norm_ratio: Ratio to normalize the score.
:param args: extra args
:param kwargs: extra args
"""
kwargs.setdefault("mem_required", "20GB")
super().__init__(*args, **kwargs)

self.min_score = min_score
self.max_score = max_score
self.norm_ratio = norm_ratio
self.is_hf_model = is_hf_model
self.ebd_dim = ebd_dim

if self.is_hf_model:
self.model_key = prepare_model(model_type="embedding", model_path=api_or_hf_model, **model_params)
else:
self.model_key = prepare_model(
model_type="api",
model=api_or_hf_model,
endpoint=api_endpoint,
response_path=response_path,
**model_params,
)

def _embed_texts(self, texts: List[str], rank: int) -> np.ndarray:
# Embed a list of texts using the initialized model
embeddings = []
model = get_model(self.model_key, rank, self.use_cuda())

for text in tqdm(texts, desc="Embedding texts", leave=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This loop processes texts one by one, which is inefficient for a batched operator. Most embedding models, including Hugging Face sentence-transformers, are optimized for batch processing. Given that this operator processes the entire dataset in a single batch (num_proc=1), this loop can become a significant performance bottleneck.

Consider refactoring this to process texts in batches. For Hugging Face models, you can pass the entire list of texts to model.encode() outside the loop. For API models, check if batching is supported by the underlying API wrapper.

if self.is_hf_model:
    try:
        # Use batch encoding for efficiency with Hugging Face models
        embeddings = model.encode(texts, show_progress_bar=False)
        return np.array(embeddings, dtype=np.float32)
    except Exception as e:
        logger.error(f"Failed to embed texts in batch. Error: {e}. Using zero vectors for all.")
        dim = model.get_sentence_embedding_dimension()
        return np.zeros((len(texts), dim), dtype=np.float32)

try:
if self.is_hf_model:
embedding = model.encode(text)
else:
embedding = model(text, dimensions=self.ebd_dim, encoding_format="float")
embeddings.append(np.array(embedding, dtype=np.float32))
except Exception as e:
dim = model.get_sentence_embedding_dimension() if self.is_hf_model else self.ebd_dim
embeddings.append(np.zeros(dim, dtype=np.float32))
print(f"Failed to embed text: '{text}'. Error: {e}. Using zero vector.", file=sys.stderr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code uses print(..., file=sys.stderr) for logging errors. The tests for this OP use loguru.logger. For consistency with the rest of the project, it's better to use loguru.logger.error() here.

logger.error(f"Failed to embed text: '{text}'. Error: {e}. Using zero vector.")


return np.array(embeddings)

def compute_stats_batched(self, samples: Dict, rank: int = 0) -> Dict:
stats_list = samples[Fields.stats]
if stats_list and StatsKeys.text_ebd_diversity_score in stats_list[0]:
return samples

texts_to_embed = samples[self.text_key]
if not texts_to_embed:
for stat in stats_list:
stat[StatsKeys.text_ebd_diversity] = 0.0
stat[StatsKeys.text_ebd_diversity_score] = 0.0
return samples

embeddings_array = self._embed_texts(texts_to_embed, rank=rank)

avg_embedding = np.mean(embeddings_array, axis=0)

cos_sims = (
torch.nn.functional.cosine_similarity(
torch.from_numpy(embeddings_array), torch.from_numpy(avg_embedding).unsqueeze(0), dim=1
)
.cpu()
.numpy()
.tolist()
)

min_sim, max_sim = min(cos_sims), max(cos_sims)
range_sim = max_sim - min_sim

normalized_scores = []
if range_sim < 1e-8:
normalized_scores = [0.0] * len(cos_sims)
else:
for sim in cos_sims:
normalized_sim = self.norm_ratio * (max_sim - sim) / range_sim
normalized_scores.append(normalized_sim)
Comment on lines +124 to +130
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This loop for calculating normalized_scores can be vectorized using numpy for better performance and readability. This avoids iterating through the similarities one by one in Python.

if range_sim < 1e-8:
    normalized_scores = [0.0] * len(cos_sims)
else:
    cos_sims_np = np.array(cos_sims)
    normalized_scores_np = self.norm_ratio * (max_sim - cos_sims_np) / range_sim
    normalized_scores = normalized_scores_np.tolist()


for i, stat in enumerate(stats_list):
stat[StatsKeys.text_ebd_diversity] = cos_sims[i]
stat[StatsKeys.text_ebd_diversity_score] = normalized_scores[i]

return samples

def process_batched(self, samples: Dict) -> List[bool]:
stats_list = samples[Fields.stats]
return [self.min_score <= stat[StatsKeys.text_ebd_diversity_score] <= self.max_score for stat in stats_list]
2 changes: 2 additions & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ class StatsKeysConstant(object):
llm_perplexity = "llm_perplexity"
llm_task_relevance = "llm_task_relevance"
llm_task_relevance_record = "llm_task_relevance_record"
text_ebd_diversity = "text_ebd_diversity"
text_ebd_diversity_score = "text_ebd_diversity_score"

# === image ===
aspect_ratios = "aspect_ratios"
Expand Down
7 changes: 4 additions & 3 deletions docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ All the specific operators are listed below, each featured with several capabili
| character_repetition_filter | 🔤Text 💻CPU 🟢Stable | Filter to keep samples with char-level n-gram repetition ratio within a specific range. 过滤器将具有char级n-gram重复比率的样本保持在特定范围内。 | [code](../data_juicer/ops/filter/character_repetition_filter.py) | [tests](../tests/ops/filter/test_character_repetition_filter.py) |
| flagged_words_filter | 🔤Text 💻CPU 🟢Stable | Filter to keep samples with flagged-word ratio less than a specific max value. 过滤以保持标记词比率小于特定最大值的样本。 | [code](../data_juicer/ops/filter/flagged_words_filter.py) | [tests](../tests/ops/filter/test_flagged_words_filter.py) |
| general_field_filter | 💻CPU 🟡Beta | Filter to keep samples based on a general field filter condition. 根据常规字段筛选条件保留样本。 | [code](../data_juicer/ops/filter/general_field_filter.py) | [tests](../tests/ops/filter/test_general_field_filter.py) |
| group_diversity_filter | 🔤Text 💻CPU 🔗API 🟡Beta | Filter samples based on their semantic diversity within a group. 基于样本在组内的语义多样性来过滤样本。 | [code](../data_juicer/ops/filter/group_diversity_filter.py) | [tests](../tests/ops/filter/test_group_diversity_filter.py) |
| image_aesthetics_filter | 🏞Image 💻CPU 🧩HF 🟢Stable | Filter to keep samples with aesthetics scores within a specific range. 过滤以保持美学分数在特定范围内的样品。 | [code](../data_juicer/ops/filter/image_aesthetics_filter.py) | [tests](../tests/ops/filter/test_image_aesthetics_filter.py) |
| image_aspect_ratio_filter | 🏞Image 💻CPU 🟢Stable | Filter to keep samples with image aspect ratio within a specific range. 过滤器,以保持样本的图像纵横比在特定范围内。 | [code](../data_juicer/ops/filter/image_aspect_ratio_filter.py) | [tests](../tests/ops/filter/test_image_aspect_ratio_filter.py) |
| image_face_count_filter | 🏞Image 💻CPU 🟢Stable | Filter to keep samples with the number of faces within a specific range. 过滤以保持样本的面数在特定范围内。 | [code](../data_juicer/ops/filter/image_face_count_filter.py) | [tests](../tests/ops/filter/test_image_face_count_filter.py) |
Expand Down Expand Up @@ -153,7 +154,7 @@ All the specific operators are listed below, each featured with several capabili
| local_formatter | 🟢Stable | The class is used to load a dataset from local files or local directory. 类用于从本地文件或本地目录加载数据集。 | [code](../data_juicer/format/formatter.py) | [tests](../tests/format/test_unify_format.py) |
| parquet_formatter | 🟢Stable | The class is used to load and format parquet-type files. 该类用于加载和格式化镶木地板类型的文件。 | [code](../data_juicer/format/parquet_formatter.py) | [tests](../tests/format/test_parquet_formatter.py) |
| remote_formatter | 🟢Stable | The class is used to load a dataset from repository of huggingface hub. 该类用于从huggingface hub的存储库加载数据集。 | [code](../data_juicer/format/formatter.py) | [tests](../tests/format/test_unify_format.py) |
| text_formatter | 🔴Alpha | The class is used to load and format text-type files. 类用于加载和格式化文本类型文件。 | [code](../data_juicer/format/text_formatter.py) | - |
| text_formatter | 🔴Alpha | The class is used to load and format text-type files. 类用于加载和格式化文本类型的文件。 | [code](../data_juicer/format/text_formatter.py) | - |
| tsv_formatter | 🟢Stable | The class is used to load and format tsv-type files. 该类用于加载和格式化tsv类型的文件。 | [code](../data_juicer/format/tsv_formatter.py) | [tests](../tests/format/test_tsv_formatter.py) |

## grouper <a name="grouper"/>
Expand All @@ -170,7 +171,7 @@ All the specific operators are listed below, each featured with several capabili
|----------|------|-------------|-------------|------------|
| audio_add_gaussian_noise_mapper | 📣Audio 💻CPU 🟡Beta | Mapper to add gaussian noise to audio. 映射器向音频添加高斯噪声。 | [code](../data_juicer/ops/mapper/audio_add_gaussian_noise_mapper.py) | [tests](../tests/ops/mapper/test_audio_add_gaussian_noise_mapper.py) |
| audio_ffmpeg_wrapped_mapper | 📣Audio 💻CPU 🟢Stable | Simple wrapper for FFmpeg audio filters. FFmpeg音频滤波器的简单包装。 | [code](../data_juicer/ops/mapper/audio_ffmpeg_wrapped_mapper.py) | [tests](../tests/ops/mapper/test_audio_ffmpeg_wrapped_mapper.py) |
| calibrate_qa_mapper | 🔤Text 💻CPU 🔗API 🟢Stable | Mapper to calibrate question-answer pairs based on reference text. 映射器基于参考文本校准问题-答案对。 | [code](../data_juicer/ops/mapper/calibrate_qa_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_qa_mapper.py) |
| calibrate_qa_mapper | 🔤Text 💻CPU 🔗API 🟢Stable | Mapper to calibrate question-answer pairs based on reference text. 映射器基于参考文本校准问答对。 | [code](../data_juicer/ops/mapper/calibrate_qa_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_qa_mapper.py) |
| calibrate_query_mapper | 💻CPU 🟢Stable | Mapper to calibrate query in question-answer pairs based on reference text. 映射器基于参考文本校准问答对中的查询。 | [code](../data_juicer/ops/mapper/calibrate_query_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_query_mapper.py) |
| calibrate_response_mapper | 💻CPU 🟢Stable | Mapper to calibrate response in question-answer pairs based on reference text. 映射器基于参考文本校准问答对中的响应。 | [code](../data_juicer/ops/mapper/calibrate_response_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_response_mapper.py) |
| chinese_convert_mapper | 🔤Text 💻CPU 🟢Stable | Mapper to convert Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji. 映射器在繁体中文,简体中文和日语汉字之间转换中文。 | [code](../data_juicer/ops/mapper/chinese_convert_mapper.py) | [tests](../tests/ops/mapper/test_chinese_convert_mapper.py) |
Expand Down Expand Up @@ -201,7 +202,7 @@ All the specific operators are listed below, each featured with several capabili
| image_diffusion_mapper | 🔮Multimodal 💻CPU 🧩HF 🟢Stable | Generate image by diffusion model. 通过扩散模型生成图像。 | [code](../data_juicer/ops/mapper/image_diffusion_mapper.py) | [tests](../tests/ops/mapper/test_image_diffusion_mapper.py) |
| image_face_blur_mapper | 🏞Image 💻CPU 🟢Stable | Mapper to blur faces detected in images. 映射器模糊图像中检测到的人脸。 | [code](../data_juicer/ops/mapper/image_face_blur_mapper.py) | [tests](../tests/ops/mapper/test_image_face_blur_mapper.py) |
| image_remove_background_mapper | 🏞Image 💻CPU 🟢Stable | Mapper to remove background of images. 映射器删除图像的背景。 | [code](../data_juicer/ops/mapper/image_remove_background_mapper.py) | [tests](../tests/ops/mapper/test_image_remove_background_mapper.py) |
| image_segment_mapper | 🏞Image 💻CPU 🟢Stable | Perform segment-anything on images and return the bounding boxes. 在图像上执行segment-anything并返回边界框。 | [code](../data_juicer/ops/mapper/image_segment_mapper.py) | [tests](../tests/ops/mapper/test_image_segment_mapper.py) |
| image_segment_mapper | 🏞Image 💻CPU 🟢Stable | Perform segment-anything on images and return the bounding boxes. 对图像执行segment-任何操作并返回边界框。 | [code](../data_juicer/ops/mapper/image_segment_mapper.py) | [tests](../tests/ops/mapper/test_image_segment_mapper.py) |
| image_tagging_mapper | 🏞Image 💻CPU 🟢Stable | Mapper to generate image tags. 映射器生成图像标签。 | [code](../data_juicer/ops/mapper/image_tagging_mapper.py) | [tests](../tests/ops/mapper/test_image_tagging_mapper.py) |
| imgdiff_difference_area_generator_mapper | 💻CPU 🟡Beta | A fused operator for OPs that is used to run sequential OPs on the same batch to allow fine-grained control on data processing. OPs的融合操作符,用于在同一批次上运行顺序OPs,以实现对数据处理的细粒度控制。 | [code](../data_juicer/ops/mapper/imgdiff_difference_area_generator_mapper.py) | [tests](../tests/ops/mapper/test_imgdiff_difference_area_generator_mapper.py) |
| imgdiff_difference_caption_generator_mapper | 💻CPU 🟡Beta | A fused operator for OPs that is used to run sequential OPs on the same batch to allow fine-grained control on data processing. OPs的融合操作符,用于在同一批次上运行顺序OPs,以实现对数据处理的细粒度控制。 | [code](../data_juicer/ops/mapper/imgdiff_difference_caption_generator_mapper.py) | [tests](../tests/ops/mapper/test_imgdiff_difference_caption_generator_mapper.py) |
Expand Down
Loading