Skip to content
3 changes: 2 additions & 1 deletion data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from data_juicer.utils.constant import RAY_JOB_ENV_VAR
from data_juicer.utils.logger_utils import setup_logger
from data_juicer.utils.mm_utils import SpecialTokens
from data_juicer.utils.ray_utils import is_ray_mode

global_cfg = None
global_parser = None
Expand Down Expand Up @@ -749,7 +750,7 @@ def init_setup_from_cfg(cfg: Namespace, load_configs_only=False):
"audio_key": cfg.get("audio_key", "audios"),
"video_key": cfg.get("video_key", "videos"),
"image_bytes_key": cfg.get("image_bytes_key", "image_bytes"),
"num_proc": cfg.get("np", None),
"num_proc": cfg.get("np", None) if not is_ray_mode() else None,
"turbo": cfg.get("turbo", False),
"skip_op_error": cfg.get("skip_op_error", True),
"work_dir": cfg.work_dir,
Expand Down
293 changes: 265 additions & 28 deletions data_juicer/core/data/ray_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import itertools
import math
import os
import sys
from functools import partial
from typing import Any, Dict, List, Literal, Optional, Union

Expand All @@ -16,11 +17,11 @@
from data_juicer.utils.constant import Fields
from data_juicer.utils.file_utils import is_remote_path
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.process_utils import calculate_np
from data_juicer.utils.resource_utils import cuda_device_count
from data_juicer.utils.webdataset_utils import _custom_default_decoder

ray = LazyLoader("ray")
_OPS_MEMORY_LIMIT_FRACTION = 0.7


def get_abs_path(path, dataset_dir):
Expand Down Expand Up @@ -90,6 +91,73 @@ def filter_batch(batch, filter_func):
return batch.filter(mask)


def find_optimal_concurrency(resource_ratios, total_resource):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have perf test results to go with the optimization?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will conduct detailed perf testing and provide reports afterward.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
Search for the optimal concurrency allocation to achieve the
highest total resource utilization and the most balanced processing capacity.

Args:
resource_ratios (list[float]): List of single-process resource ratios for each operator
total_resource (float): Total resource

Return:
tuple: (list of optimal concurrency, total resource usage, standard deviation of processing capacity)
If there is no valid combination, return (None, 0, 0)
"""
n = len(resource_ratios)
if n == 0:
return (None, 0, 0)

sum_r_squared = sum(r * r for r in resource_ratios)
if sum_r_squared == 0:
return (None, 0, 0)

c_floats = [(total_resource * r) / sum_r_squared for r in resource_ratios]

# generate candidate concurrency
candidates = []
for cf in c_floats:
floor_cf = math.floor(cf)
ceil_cf = math.ceil(cf)
possible = set()
if floor_cf >= 1:
possible.add(floor_cf)
possible.add(ceil_cf)
possible = [max(1, v) for v in possible]
candidates.append(sorted(list(set(possible))))

# traverse all combinations
best_combination = None
max_resource_usage = 0
min_std = float("inf")

for combo in itertools.product(*candidates):
total_used = sum(c * r for c, r in zip(combo, resource_ratios))
if total_used > total_resource:
continue

# calculate the standard deviation of processing capacity
processing_powers = [c / r for c, r in zip(combo, resource_ratios)]
mean = sum(processing_powers) / n
variance = sum((x - mean) ** 2 for x in processing_powers) / n
std = math.sqrt(variance)

# update the optimal solution (priority resource utilization, suboptimal standard deviation)
if total_used > max_resource_usage:
max_resource_usage = total_used
best_combination = combo
min_std = std
elif total_used == max_resource_usage and std < min_std:
best_combination = combo
min_std = std

return (
list(best_combination) if best_combination else None,
max_resource_usage,
min_std if best_combination else 0,
)


class RayDataset(DJDataset):
def __init__(self, dataset: ray.data.Dataset, dataset_path: str = None, cfg: Optional[Namespace] = None) -> None:
self.data = preprocess_dataset(dataset, dataset_path, cfg)
Expand Down Expand Up @@ -143,32 +211,205 @@ def get_column(self, column: str, k: Optional[int] = None) -> List[Any]:

return [row[column] for row in self.data.take()]

@staticmethod
def set_resource_for_ops(operators):
"""
Automatically calculates optimal concurrency for Ray Data operator.
This function handles both task and actor based operators, considering
resource requirements and user specifications. The computation follows Ray Data's
concurrency semantics while optimizing resource utilization.

Key Concepts:
- Resource Ratio: Individual operator's resource requirement (GPU/CPU/memory)
compared to total cluster resources, using max(cpu_ratio, gpu_ratio, adjusted_mem_ratio)
- Fixed Allocation: Portion of resources reserved by operators with user-specified num_proc
- Dynamic Allocation: Remaining resources distributed among auto-scaling operators

Design Logic:
1. User Specification Priority:
- If user provides concurrency setting, directly return it
- Applies to both task and actor based operators
2. Task Operators (equivalent to a cpu operator in dj):
a. When unspecified: Return None to let Ray determine implicitly
b. Auto-calculation: Returns maximum concurrency based on available
resources and operator requirements
3. Actor Operators (equivalent to a gpu operator in dj):
a. Mandatory concurrency - set required gpus to 1 if unspecified, and then refer to the following `b`
to calculate automatically based on this setting
b. Auto-calculation returns tuple (min_concurrency, max_concurrency):
i. Minimum: Ensures baseline resource allocation in remaining resources
when all operators are active simultaneously (proportionally)
ii. Maximum: Allows full utilization of remaining resources by single
operator when others are idle
"""
from data_juicer.utils.ray_utils import (
ray_available_gpu_memories,
ray_available_memories,
ray_cpu_count,
ray_gpu_count,
)
from data_juicer.utils.resource_utils import is_cuda_available

# TODO: split to cpu resources and gpu resources
cuda_available = is_cuda_available()
total_cpu = ray_cpu_count()
total_gpu = ray_gpu_count()
available_mem = sum(ray_available_memories()) * _OPS_MEMORY_LIMIT_FRACTION / 1024 # Convert MB to GB
available_gpu_mem = sum(ray_available_gpu_memories()) * _OPS_MEMORY_LIMIT_FRACTION / 1024 # Convert MB to GB
resource_configs = {}

for op in operators:
cpu_req = op.cpu_required
mem_req = op.mem_required
gpu_req = 0
gpu_mem_req = 0
base_resource_frac = 0.0

if op.gpu_required:
if not op.use_cuda():
raise ValueError(
f"Op[{op._name}] attempted to request GPU resources (gpu_required={op.gpu_required}), "
"but appears to lack GPU support. If you have verified this operator support GPU acceleration, "
'please explicitly set its property: `_accelerator = "cuda"`.'
)
if not cuda_available:
raise ValueError(
f"Op[{op._name}] attempted to request GPU resources (gpu_required={op.gpu_required}), "
"but the gpu is unavailable. Please check whether your environment is installed correctly"
" and whether there is a gpu in the resource pool."
)
# if it is a cuda operator, mem_required will be calculated as gpu memory;
# if it is a cpu, it will be calculated as memory.
auto_proc = False if op.num_proc else True

# GPU operator calculations
if op.use_cuda():
gpu_req = op.gpu_required
gpu_mem_req = op.mem_required
if not gpu_req and not gpu_mem_req:
logger.warning(
f"The required cuda memory and gpu of Op[{op._name}] "
f"has not been specified. "
f"Please specify the `mem_required` field or `gpu_required` field in the "
f"config file. You can reference the `config_all.yaml` file."
f"Set the `gpu_required` to 1 now."
)
gpu_req = 1

base_resource_frac = max(
cpu_req / total_cpu if cpu_req else 0,
gpu_req / total_gpu if gpu_req else 0,
gpu_mem_req / available_gpu_mem if gpu_mem_req else 0,
)

if not gpu_req:
gpu_req = math.ceil(base_resource_frac * total_gpu * 100) / 100
# CPU operator calculations
else:
if cpu_req or mem_req:
base_resource_frac = max(
cpu_req / total_cpu if cpu_req else 0, mem_req / available_mem if mem_req else 0
)
else:
logger.warning(
f"The required memory and cpu of Op[{op._name}] "
f"has not been specified. "
f"We recommend specifying the `mem_required` field or `cpu_required` field in the "
f"config file. You can reference the `config_all.yaml` file."
)
# Default to single CPU if no requirements specified
base_resource_frac = 1 / total_cpu

resource_configs[op._name] = {
"cpu_required": cpu_req,
"gpu_required": gpu_req,
"mem_required": mem_req,
"gpu_mem_required": gpu_mem_req,
"base_resource_frac": base_resource_frac,
"num_proc": tuple(op.num_proc) if isinstance(op.num_proc, list) else op.num_proc,
"auto_proc": auto_proc,
}

fixed_min_resources = 0
fixed_max_resources = 0
auto_resource_frac_map = {}
for op_name, cfg in resource_configs.items():
if cfg["auto_proc"]:
auto_resource_frac_map[op_name] = cfg["base_resource_frac"]
else:
num_proc = cfg["num_proc"]
min_proc = num_proc[0] if isinstance(num_proc, (tuple, list)) else num_proc
max_proc = num_proc[1] if isinstance(num_proc, (tuple, list)) else num_proc
fixed_min_resources += cfg["base_resource_frac"] * min_proc
fixed_max_resources += cfg["base_resource_frac"] * max_proc

# Validate resource availability
total_auto_base_resource = sum(list(auto_resource_frac_map.values()))
total_required_min = fixed_min_resources + total_auto_base_resource
if total_required_min > 1:
raise ValueError(
f"Insufficient cluster resources: "
f"At least {total_required_min:.2f}x the current resource is required. "
f"Add resources or reduce operator requirements."
)
if len(auto_resource_frac_map) > 0:
remaining_min_frac = 1 - fixed_max_resources
remaining_max_frac = 1 - fixed_min_resources

op_names, op_resources = [], []
for k, v in auto_resource_frac_map.items():
op_names.append(k)
op_resources.append(v)
best_combination, _, _ = find_optimal_concurrency(op_resources, remaining_min_frac)
best_combination = dict(zip(op_names, best_combination))

for op_name, cfg in resource_configs.items():
if cfg["auto_proc"]:
# TODO:
min_proc = best_combination[op_name]
# issue: https://github.com/ray-project/ray/issues/55307
# or min_proc = 1 ?
max_proc = int(max(1, remaining_max_frac / cfg["base_resource_frac"]))
# or max_proc = int(max(1, 1 / cfg["base_resource_frac"])) ? use all resources
cfg["num_proc"] = min_proc if min_proc == max_proc else (min_proc, max_proc)

for op in operators:
cfg = resource_configs[op._name]
auto_proc, num_proc = cfg["auto_proc"], cfg["num_proc"]
if op.use_cuda():
op.cpu_required = cfg["cpu_required"]
op.gpu_required = cfg["gpu_required"]
op.num_proc = num_proc
else:
# * If ``fn`` is a function and ``concurrency`` is an int ``n``, Ray Data
# launches *at most* ``n`` concurrent tasks.
op.cpu_required = cfg["cpu_required"]
op.gpu_required = None
# if concurrency left to None, the automatic concurrency of ray may be slightly higher, which could lead to OOM
op.num_proc = num_proc[1] if (auto_proc and isinstance(num_proc, (tuple, list))) else num_proc
# op.num_proc = None if auto_proc else num_proc

logger.info(
f"Op[{op._name}] will be executed with the following resources: "
f"num_cpus: {op.cpu_required}, "
f"num_gpus: {op.gpu_required}, "
f"concurrency: {op.num_proc}, "
)
return operators

def process(self, operators, *, exporter=None, checkpointer=None, tracer=None) -> DJDataset:
if operators is None:
return self
if not isinstance(operators, list):
operators = [operators]

RayDataset.set_resource_for_ops(operators)

for op in operators:
self._run_single_op(op)
return self

def _run_single_op(self, op):
# TODO: optimize auto proc
auto_parallel = False
if op.num_proc:
op_proc = op.num_proc
else:
auto_parallel = True
op_proc = sys.maxsize
auto_op_proc = calculate_np(op._name, op.mem_required, op.cpu_required, op.use_cuda(), op.gpu_required)
op_proc = min(op_proc, auto_op_proc)

# use ray default parallelism in cpu mode if op.num_proc is not specified
if op.use_cuda() or not auto_parallel:
logger.info(f"Op [{op._name}] running with number of procs:{op_proc}")

num_gpus = op.gpu_required if op.gpu_required else get_num_gpus(op, op_proc)

if op._name in TAGGING_OPS.modules and Fields.meta not in self.data.columns():

def process_batch_arrow(table: pyarrow.Table):
Expand All @@ -193,8 +434,8 @@ def process_batch_arrow(table: pyarrow.Table):
fn_constructor_kwargs=op_kwargs,
batch_size=batch_size,
num_cpus=op.cpu_required,
num_gpus=num_gpus,
concurrency=op_proc,
num_gpus=op.gpu_required,
concurrency=op.num_proc,
batch_format="pyarrow",
)
else:
Expand All @@ -203,9 +444,7 @@ def process_batch_arrow(table: pyarrow.Table):
batch_size=batch_size,
batch_format="pyarrow",
num_cpus=op.cpu_required,
concurrency=(
None if auto_parallel else op_proc
), # use ray default parallelism in cpu mode if num_proc is not specified
concurrency=op.num_proc,
)
elif isinstance(op, Filter):
columns = self.data.columns()
Expand All @@ -229,8 +468,8 @@ def process_batch_arrow(table: pyarrow.Table):
fn_constructor_kwargs=op_kwargs,
batch_size=batch_size,
num_cpus=op.cpu_required,
num_gpus=num_gpus,
concurrency=op_proc,
num_gpus=op.gpu_required,
concurrency=op.num_proc,
batch_format="pyarrow",
)
else:
Expand All @@ -239,9 +478,7 @@ def process_batch_arrow(table: pyarrow.Table):
batch_size=batch_size,
batch_format="pyarrow",
num_cpus=op.cpu_required,
concurrency=(
None if auto_parallel else op_proc
), # use ray default parallelism in cpu mode if num_proc is not specified
concurrency=op.num_proc,
)
if op.stats_export_path is not None:
self.data.write_json(op.stats_export_path, force_ascii=False)
Expand Down
Loading
Loading