Skip to content

Commit 016a5d3

Browse files
committed
move chunk_list
1 parent da95963 commit 016a5d3

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

swift/pipelines/infer/rollout.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
UpdateWeightsRequest)
4141
from swift.rlhf_trainers.utils import (FlattenedTensorBucket, FlattenedTensorMetadata, TensorLoRARequest,
4242
UpdateAdapterRequest, UpdateFlattenedAdapterRequest,
43-
UpdateFlattenedParamsRequest, check_vllm_version_ge, patch_vllm_load_adapter,
44-
patch_vllm_moe_model_weight_loader)
43+
UpdateFlattenedParamsRequest, check_vllm_version_ge, chunk_list,
44+
patch_vllm_load_adapter, patch_vllm_moe_model_weight_loader)
4545
from swift.rollout import RolloutScheduler, multi_turns
4646
from swift.utils import get_logger, get_seed, is_vllm_ascend_available
4747
from ..base import SwiftPipeline
@@ -57,8 +57,6 @@
5757
if is_vllm_ascend_available():
5858
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator # noqa
5959

60-
from trl.scripts.vllm_serve import chunk_list
61-
6260
except ImportError:
6361
pass
6462
"""

swift/rlhf_trainers/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,26 @@ def embeddings(self):
5757
return self.lora_embeddings
5858

5959

60+
def chunk_list(lst: list, n: int) -> list[list]:
61+
"""
62+
Split list `lst` into `n` evenly distributed sublists.
63+
64+
Example:
65+
```python
66+
>>> chunk_list([1, 2, 3, 4, 5, 6], 2)
67+
[[1, 2, 3], [4, 5, 6]]
68+
69+
>>> chunk_list([1, 2, 3, 4, 5, 6], 4)
70+
[[1, 2], [3, 4], [5], [6]]
71+
72+
>>> chunk_list([1, 2, 3, 4, 5, 6], 8)
73+
[[1], [2], [3], [4], [5], [6], [], []]
74+
```
75+
"""
76+
k, r = divmod(len(lst), n)
77+
return [lst[i * k + min(i, r):(i + 1) * k + min(i + 1, r)] for i in range(n)]
78+
79+
6080
def is_valid_ipv6_address(address: str) -> bool:
6181
"""Check if the given address is a valid IPv6 address."""
6282
try:

0 commit comments

Comments
 (0)