Skip to content

Deepspeed dual-machine single-card streaming inference has an error #7734

@Haneeeeeeee

Description

@Haneeeeeeee

Hi team,

I recently ran inference tests on a two-node cluster (each node has a single GPU) using DeepSpeed. I encountered two issues and would appreciate your guidance.

Issue 1: Slower Token Generation Speed and Failure to Load Larger Models
Each of my servers has enough GPU memory to run the Qwen-14B model individually. However, when running in a two-node setup, I first tried the Qwen-7B model. Although it loaded successfully, the token generation speed dropped significantly:

Single-node: ~10 tokens/sec
Two-node: ~5 tokens/sec
Then I attempted to load the Qwen-14B model in the same two-node configuration, but the process hung indefinitely (the terminal became unresponsive). My system specs per node are:

RAM: 16 GB
Swap space: 8 GB
Could this hang be related to insufficient system RAM or swap during model loading? I’m using device_map="cpu" during initial loading, so the model is first placed on CPU before DeepSpeed distributes it.

Issue 2: Streaming Output Behavior Differs Between Single- and Multi-Node
In single-node mode, streaming works as expected: tokens are output one by one in real time via TextIteratorStreamer.

However, in two-node mode, the output is chunked by full sentences—specifically, each chunk ends with a period (.), rather than being truly token-by-token.

The nodes are connected via a low-latency internal network, so I don’t suspect network bandwidth or latency as the root cause. Both nodes use identical code and environment (including the same DeepSpeed and Transformers versions).

And streaming is handled via a background thread calling model.generate(..., streamer=streamer).

If needed, I can share the full script, logs, or system diagnostics.

Thank you very much for your help!

[root@localhost deepspeed]# ds_report
[2025-12-16 07:30:06,298] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-devel package with yum
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  NVIDIA Inference is only supported on Ampere and newer architectures
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.1
 [WARNING]  using untested triton version (2.1.0), only 1.0.0 is known to be compatible
WARNING:tensorflow:Deprecation warnings have been disabled. Set TF_ENABLE_DEPRECATION_WARNINGS=1 to re-enable them.
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-devel package with yum
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [YES] ...... [NO]
fused_adam ............. [YES] ...... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [YES] ...... [OKAY]
cpu_lion ............... [YES] ...... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
 [WARNING]  NVIDIA Inference is only supported on Ampere and newer architectures
fp_quantizer ........... [NO] ....... [NO]
fused_lamb ............. [YES] ...... [OKAY]
fused_layernorm ........ [YES] ...... [OKAY]
fused_lion ............. [YES] ...... [OKAY]
fused_rope ............. [YES] ...... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
transformer_inference .. [YES] ...... [OKAY]
quantizer .............. [YES] ...... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [YES] ...... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.1
 [WARNING]  using untested triton version (2.1.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [YES] ...... [OKAY]
transformer ............ [YES] ...... [OKAY]
stochastic_transformer . [YES] ...... [OKAY]
swiglu ................. [YES] ...... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/usr/local/corex-4.1.3/lib64/python3/dist-packages/torch']
torch version .................... 2.1.1
deepspeed install path ........... ['/usr/local/corex-4.1.3/lib64/python3/dist-packages/deepspeed']
deepspeed info ................... 0.14.3+corex.4.1.3, unknown, unknown
torch cuda version ............... 10.2
torch hip version ................ None
nvcc version ..................... 10.2
deepspeed wheel compiled w. ...... torch 2.1, cuda 10.2
shared memory (/dev/shm) size .... 32.00 GB
import os
import time
import torch
import deepspeed
from threading import Thread

from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer

MODEL_DIR = "/mnt/data/models/Qwen2.5-7B-Instruct"
PROMPT = "Who are you?"
MAX_NEW_TOKENS = 256

def maybe_init_torch_dist():
    # Usually deepspeed launcher already initialized torch.distributed.
    if not torch.distributed.is_initialized():
        backend = "nccl" if torch.cuda.is_available() else "gloo"
        torch.distributed.init_process_group(backend=backend)

def main():
    # env info (deepspeed launcher should set these)
    local_rank = int(os.getenv("LOCAL_RANK", "0"))
    world_size = int(os.getenv("WORLD_SIZE", "1"))
    rank = int(os.getenv("RANK", "0"))

    # Ensure distributed backend initialized (safe-guard)
    if world_size > 1:
        maybe_init_torch_dist()

    if rank == 0:
        print(f"Rank0: loading tokenizer & model from {MODEL_DIR} (world_size={world_size})")

    # 1) tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # 2) load model (placed on CPU first, deepspeed will manage device)
    if rank == 0:
        print("Loading model (to CPU) ...")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        device_map="cpu"
    )

    if rank == 0:
        print("Model loaded, initializing DeepSpeed inference engine...")

    # 3) DeepSpeed init
    ds_engine = deepspeed.init_inference(
        model,
        mp_size=world_size,
        dtype=torch.float16,
        replace_with_kernel_inject=False,
        injection_policy={ Qwen2DecoderLayer: ('self_attn.o_proj', 'mlp.down_proj') }
    )
    model = ds_engine.module
    ds_device = model.device  # device where ds-engine lives (cuda:0 typically)

    if rank == 0:
        print(f"DeepSpeed engine init done. device: {ds_device}")

    # 4) Prepare input on rank 0, broadcast to others
    if rank == 0:
        text = tokenizer.apply_chat_template([{"role":"user","content":PROMPT}], tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(text=[text], return_tensors="pt", padding=True)
        input_ids = inputs.input_ids.to(ds_device)
        attention_mask = inputs.attention_mask.to(ds_device)
        length = torch.tensor([input_ids.shape[1]], dtype=torch.long, device=ds_device)
    else:
        # placeholder tensors on the target device (they will be overwritten by broadcast)
        length = torch.tensor([0], dtype=torch.long, device=ds_device)

    # Broadcast length
    if world_size > 1:
        torch.distributed.broadcast(length, src=0)
    L = int(length.item())

    # Prepare input tensors on non-zero ranks
    if rank != 0:
        input_ids = torch.zeros((1, L), dtype=torch.long, device=ds_device)
        attention_mask = torch.zeros((1, L), dtype=torch.long, device=ds_device)

    # Broadcast actual input tensors (all ranks call)
    if world_size > 1:
        torch.distributed.broadcast(input_ids, src=0)
        torch.distributed.broadcast(attention_mask, src=0)

    # Sanity barrier
    if world_size > 1:
        torch.distributed.barrier()

    # 5) Create streamer on all ranks
    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=True,
        skip_special_tokens=True,
        timeout=120  # seconds of inactivity before raising StopIteration
    )

    generation_kwargs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "max_new_tokens": MAX_NEW_TOKENS,
        "do_sample": True,
        "temperature": 0.7,
        "top_p": 0.9,
        "repetition_penalty": 1.1,
        "streamer": streamer,
    }

    # 6) Start generation on all ranks (each rank gets its own streamer object)
    def gen_thread_fn():
        with torch.no_grad():
            try:
                model.generate(**generation_kwargs)
            except Exception as e:
                # Often deep learning libs may raise on exit; print on rank0 for debugging
                if rank == 0:
                    print(f"[rank {rank}] generate() exception: {e}")

    gen_thread = Thread(target=gen_thread_fn, daemon=False)
    gen_thread.start()

    # 7) Handle streamer consumption:
    #    - Rank0: print streaming chunks as they arrive.
    #    - Other ranks: drain/consume streamer immediately in background (to avoid buffering).
    start_time = time.time()
    token_count = 0
    generated_text = ""

    if rank == 0:
        print("\nQuestion:", PROMPT)
        print("Answer (streaming): ", end="", flush=True)
        try:
            for chunk in streamer:
                # chunk is a text piece (may be 1 token or multiple tokens depending on internals)
                print(chunk, end="", flush=True)
                generated_text += chunk
                token_count += 1
        except Exception as e:
            print(f"\n[rank0] streaming error: {e}")
    else:
        # drain in background to avoid blocking others
        def drain():
            try:
                for _ in streamer:
                    # discard; optionally count
                    pass
            except Exception:
                return
        drain_thread = Thread(target=drain, daemon=True)
        drain_thread.start()
        # wait for generation to finish (gen_thread will end when generate returns)
        gen_thread.join()
        # join drain_thread with timeout to avoid hang
        drain_thread.join(timeout=1.0)

    # wait for generation thread on rank0 as well
    if rank == 0:
        gen_thread.join()

        # stats
        gen_time = time.time() - start_time
        toks_per_s = token_count / gen_time if gen_time > 0 else 0.0
        print("\n\n" + "-" * 60)
        print("Generation finished.")
        print(f"Tokens printed (chunks): {token_count}")
        print(f"Generation time: {gen_time:.2f}s")
        print(f"Throughput (chunks/sec): {toks_per_s:.2f}")
        print(f"Generated text length (chars): {len(generated_text)}")
        print("-" * 60)

    # final barrier and exit
    if world_size > 1:
        torch.distributed.barrier()
    if rank == 0:
        print("All done.")

if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions