Skip to content

Commit c57539f

Browse files
authored
Feature: Add Mock HuggingFace Dataset Support for TorchTitan (#260)
1 parent ee0dec9 commit c57539f

File tree

7 files changed

+133
-1
lines changed

7 files changed

+133
-1
lines changed

primus/configs/modules/torchtitan/pre_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ profiling:
3333
save_traces_folder: profile_traces
3434

3535
training:
36+
mock_data: true
3637
dataset: c4
3738
dataset_path: null
3839
deterministic: false
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
###############################################################################
2+
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
#
4+
# See LICENSE for license information.
5+
###############################################################################
6+
7+
import numpy as np
8+
from datasets import Dataset
9+
10+
11+
def _create_mock_text_dataset(num_samples: int = 128) -> Dataset:
12+
"""Create a lightweight text dataset for validation mock."""
13+
texts = [f"validation sample {i}" for i in range(num_samples)]
14+
return Dataset.from_dict({"text": texts})
15+
16+
17+
def _create_mock_token_dataset(
18+
seq_len: int = 2048,
19+
vocab_size: int = 32000,
20+
num_samples: int = 256,
21+
) -> Dataset:
22+
"""
23+
Create fake tokenized text dataset (Titan-compatible).
24+
25+
Each "text" field is a string of roughly `seq_len // 8` space-separated integers.
26+
Titan's tokenizer.encode() will parse these into tokens and reconstruct
27+
proper seq_len-sized sequences from multiple samples if needed.
28+
29+
This lightweight mock simulates a streaming dataset and avoids heavy memory usage.
30+
"""
31+
rng = np.random.default_rng(42)
32+
token_per_sample = seq_len # shorter text, Titan will concatenate internally
33+
34+
samples = []
35+
for _ in range(num_samples):
36+
token_ids = rng.integers(0, vocab_size, size=token_per_sample, dtype=np.int32)
37+
text = " ".join(map(str, token_ids))
38+
samples.append({"text": text})
39+
40+
return Dataset.from_list(samples)
41+
42+
43+
def patch_mock_hf_dataset() -> None:
44+
from primus.core.utils import logger
45+
46+
try:
47+
import datasets
48+
49+
logger.warning("[Primus Mock] Enabling mock HuggingFace dataset mode.")
50+
51+
def mock_load_dataset(path: str, *args, **kwargs) -> Dataset:
52+
"""
53+
Replacement for datasets.load_dataset().
54+
Intercepts Titan calls like load_dataset('allenai/c4', ...).
55+
Returns a fake Dataset of text samples.
56+
"""
57+
logger.warning(f"[Primus Mock] load_dataset('{path}') is mocked.")
58+
# Shorter dataset for validation split
59+
if "validation" in path.lower():
60+
return _create_mock_text_dataset(num_samples=32)
61+
else:
62+
return _create_mock_token_dataset(seq_len=8192, vocab_size=32000, num_samples=256)
63+
64+
datasets.load_dataset = mock_load_dataset
65+
logger.warning("[PrimusPath][Dataset] Patched datasets.load_dataset successfully.")
66+
67+
except Exception as e:
68+
logger.error(f"[PrimusPath][Dataset] Failed to patch datasets.load_dataset: {e}")

primus/modules/trainer/torchtitan/pre_trainer.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@ def __init__(self, *args, **kwargs):
2626
pre_trainer_cfg = self.primus_cfg.get_module_config("pre_trainer")
2727
cfg_dict = nested_namespace_to_dict(pre_trainer_cfg)
2828

29+
patch_mock = getattr(pre_trainer_cfg.training, "mock_data", False)
30+
if patch_mock:
31+
from primus.modules.trainer.torchtitan.patch_utils import (
32+
patch_mock_hf_dataset,
33+
)
34+
35+
patch_mock_hf_dataset()
36+
2937
self.patch_torchtitan_embedding_amp(cfg_dict["primus_turbo"]["enable_embedding_autocast"])
3038
self.patch_titan_train_spec(pre_trainer_cfg.model.name, pre_trainer_cfg.model.flavor, extra_args)
3139

@@ -460,15 +468,28 @@ def _dict_to_dataclass(self, cls, data: dict[str, Any]) -> Any:
460468
if not is_dataclass(cls):
461469
return data
462470

471+
# collect valid field names
472+
field_names = {f.name for f in fields(cls)}
463473
init_values = {}
474+
475+
# only use known fields for constructor
464476
for f in fields(cls):
465477
if f.name in data:
466478
val = data[f.name]
467479
if is_dataclass(f.type) and isinstance(val, dict):
468480
init_values[f.name] = self._dict_to_dataclass(f.type, val)
469481
else:
470482
init_values[f.name] = val
471-
return cls(**init_values)
483+
484+
# instantiate dataclass
485+
obj = cls(**init_values)
486+
487+
# attach unknown fields dynamically
488+
for k, v in data.items():
489+
if k not in field_names:
490+
setattr(obj, k, v)
491+
492+
return obj
472493

473494
def patch_torchtitan_embedding_amp(self, enable_patch: bool):
474495
"""

tests/modules/__init__.py

Whitespace-only changes.

tests/modules/trainer/__init__.py

Whitespace-only changes.

tests/modules/trainer/torchtitan/__init__.py

Whitespace-only changes.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
###############################################################################
2+
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
#
4+
# See LICENSE for license information.
5+
###############################################################################
6+
7+
8+
from primus.modules.trainer.torchtitan.patch_utils import patch_mock_hf_dataset
9+
from tests.utils import PrimusUT
10+
11+
12+
class TestTorchtitanPatch(PrimusUT):
13+
def __init__(self, *args, **kwargs):
14+
super().__init__(*args, **kwargs)
15+
16+
def setUp(self):
17+
pass
18+
19+
def tearDown(self):
20+
pass
21+
22+
def test_mock_hf_dataset_patch(self):
23+
"""
24+
Test that enable_mock_hf_dataset() successfully patches datasets.load_dataset
25+
and returns a fake HuggingFace Dataset.
26+
"""
27+
# from primus.utils import mock_hf_dataset
28+
29+
patch_mock_hf_dataset()
30+
31+
# Reimport datasets and call load_dataset
32+
import datasets
33+
34+
ds = datasets.load_dataset("allenai/c4", split="train")
35+
36+
# Verify that this is an in-memory Dataset with expected content
37+
assert isinstance(ds, datasets.Dataset)
38+
assert "text" in ds.column_names
39+
assert len(ds) > 0
40+
sample = ds[0]
41+
assert isinstance(sample["text"], str)
42+
assert len(sample["text"].split()) > 0

0 commit comments

Comments
 (0)