Skip to content

Commit fd13dcd

Browse files
Make DataSource implement __len__ to standardize the data source contract (#1518)
Co-authored-by: Nan Jiang <59716405+nanjiangwill@users.noreply.github.com>
1 parent 5073e32 commit fd13dcd

File tree

4 files changed

+18
-1
lines changed

4 files changed

+18
-1
lines changed

docs/en/get_started/customization.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,9 @@ class CustomDataSource(DataSource):
350350

351351
def load(self, rollout_id=None):
352352
"""Load state from checkpoint"""
353+
354+
def __len__(self):
355+
"""Length of the data source. May change when samples are added/fetched."""
353356
```
354357

355358
---

docs/zh/get_started/customization.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,11 @@ class CustomDataSource(DataSource):
350350

351351
def load(self, rollout_id=None):
352352
"""从 ckpt 加载状态"""
353+
354+
def __len__(self) -> int:
355+
"""
356+
返回当前数据源中可用样本的数量。该数量可能会随着样本的获取或添加而变化。
357+
"""
353358
```
354359

355360
---

slime/ray/rollout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def get_rollout_engines_and_lock(self):
125125

126126
def get_num_rollout_per_epoch(self):
127127
assert self.args.rollout_global_dataset
128-
return len(self.data_source.dataset) // self.args.rollout_batch_size
128+
return len(self.data_source) // self.args.rollout_batch_size
129129

130130
def generate(self, rollout_id):
131131
start_time = time.time()

slime/rollout/data_source.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ def load(self, rollout_id=None):
3939
Load the state of the data source
4040
"""
4141

42+
@abc.abstractmethod
43+
def __len__(self) -> int:
44+
"""
45+
Length of the data source. May change when samples are added/fetched.
46+
"""
47+
4248

4349
# TODO may further refactor data-loading part later
4450
class RolloutDataSource(DataSource):
@@ -153,6 +159,9 @@ def load(self, rollout_id=None):
153159
if self.args.rollout_global_dataset and self.args.rollout_shuffle:
154160
self.dataset.shuffle(self.epoch_id)
155161

162+
def __len__(self) -> int:
163+
return len(self.dataset)
164+
156165

157166
class RolloutDataSourceWithBuffer(RolloutDataSource):
158167
def __init__(self, args):

0 commit comments

Comments
 (0)