Skip to content

Commit c7dc52e

Browse files
authored
Tiny add Sample.group_index (#475)
1 parent 83a21c2 commit c7dc52e

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

slime/ray/rollout_data_source.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(self, args):
1616
self.args = args
1717

1818
self.epoch_id = 0
19+
self.sample_group_index = 0
1920
self.sample_index = 0
2021
self.sample_offset = 0
2122
# TODO remove this
@@ -66,9 +67,11 @@ def get_samples(self, num_samples):
6667
group = []
6768
for _ in range(self.args.n_samples_per_prompt):
6869
sample = copy.deepcopy(prompt_sample)
70+
sample.group_index = self.sample_group_index
6971
sample.index = self.sample_index
7072
self.sample_index += 1
7173
group.append(sample)
74+
self.sample_group_index += 1
7275
samples.append(group)
7376
return samples
7477

@@ -82,6 +85,7 @@ def save(self, rollout_id):
8285
state_dict = {
8386
"sample_offset": self.sample_offset,
8487
"epoch_id": self.epoch_id,
88+
"sample_group_index": self.sample_group_index,
8589
"sample_index": self.sample_index,
8690
"metadata": self.metadata,
8791
}
@@ -106,6 +110,7 @@ def load(self, rollout_id=None):
106110
state_dict = torch.load(path)
107111
self.sample_offset = state_dict.get("sample_offset", 0)
108112
self.epoch_id = state_dict.get("epoch_id", 0)
113+
self.sample_group_index = state_dict.get("sample_group_index", 0)
109114
self.sample_index = state_dict.get("sample_index", 0)
110115
self.metadata = state_dict.get("metadata", {})
111116

slime/utils/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
class Sample:
1010
"""The sample generated"""
1111

12+
group_index: Optional[int] = None
1213
index: Optional[int] = None
1314
# prompt
1415
prompt: Union[str, list[dict[str, str]]] = ""

0 commit comments

Comments
 (0)