@@ -16,6 +16,7 @@ def __init__(self, args):
1616 self .args = args
1717
1818 self .epoch_id = 0
19+ self .sample_group_index = 0
1920 self .sample_index = 0
2021 self .sample_offset = 0
2122 # TODO remove this
@@ -66,9 +67,11 @@ def get_samples(self, num_samples):
6667 group = []
6768 for _ in range (self .args .n_samples_per_prompt ):
6869 sample = copy .deepcopy (prompt_sample )
70+ sample .group_index = self .sample_group_index
6971 sample .index = self .sample_index
7072 self .sample_index += 1
7173 group .append (sample )
74+ self .sample_group_index += 1
7275 samples .append (group )
7376 return samples
7477
@@ -82,6 +85,7 @@ def save(self, rollout_id):
8285 state_dict = {
8386 "sample_offset" : self .sample_offset ,
8487 "epoch_id" : self .epoch_id ,
88+ "sample_group_index" : self .sample_group_index ,
8589 "sample_index" : self .sample_index ,
8690 "metadata" : self .metadata ,
8791 }
@@ -106,6 +110,7 @@ def load(self, rollout_id=None):
106110 state_dict = torch .load (path )
107111 self .sample_offset = state_dict .get ("sample_offset" , 0 )
108112 self .epoch_id = state_dict .get ("epoch_id" , 0 )
113+ self .sample_group_index = state_dict .get ("sample_group_index" , 0 )
109114 self .sample_index = state_dict .get ("sample_index" , 0 )
110115 self .metadata = state_dict .get ("metadata" , {})
111116
0 commit comments