diff --git a/checkpoint_engine/distributed/vllm_nccl.py b/checkpoint_engine/distributed/vllm_nccl.py index 2ffe253..b2eb1aa 100644 --- a/checkpoint_engine/distributed/vllm_nccl.py +++ b/checkpoint_engine/distributed/vllm_nccl.py @@ -12,11 +12,21 @@ ncclResult_t, ) from vllm.distributed.utils import StatelessProcessGroup -from vllm.utils import current_stream from checkpoint_engine.distributed.base import CommGroup, Distributed, _common_all_gather_object +try: + from vllm.utils.torch_utils import current_stream +except ImportError: + try: + from vllm.utils import current_stream + except ImportError: + raise ImportError( + "Could not find 'current_stream' in vllm. Please check your vllm version." + ) from None + + class NcclConfigT(ctypes.Structure): _fields_: ClassVar[list[tuple[str, Any]]] = [ ("size", ctypes.c_size_t),