Skip to content

Commit 3cbc68e

Browse files
committed
gemini
1 parent f48250d commit 3cbc68e

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2062,10 +2062,7 @@ def get_chunked_inputs(self, inputs, start_idx, end_idx):
20622062
# for LLM, slice the inputs
20632063
for key, val in inputs.items():
20642064
if isinstance(val, torch.Tensor):
2065-
if val.ndim == 0:
2066-
chunk_inputs[key] = val
2067-
else:
2068-
chunk_inputs[key] = val[start_idx:end_idx]
2065+
chunk_inputs[key] = val if val.ndim == 0 else val[start_idx:end_idx]
20692066
else:
20702067
chunk_inputs[key] = val
20712068
if self.is_multimodal:

0 commit comments

Comments
 (0)