Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,19 @@ def load_customized(
):
init_params["quant_config"] = nunchaku_config

# Load the model using FSDP loader
# Load the model using FSDP loader.
# When dit_layerwise_offload is enabled, dit_cpu_offload is forced False at runtime,
# but we still need cpu_offload at load time so weights are loaded to CPU first and
# do not OOM the GPU (layerwise offload moves them to GPU per layer during inference).
model = maybe_load_fsdp_model(
model_cls=model_cls,
init_params=init_params,
weight_dir_list=safetensors_list,
device=get_local_torch_device(),
hsdp_replicate_dim=server_args.hsdp_replicate_dim,
hsdp_shard_dim=server_args.hsdp_shard_dim,
cpu_offload=server_args.dit_cpu_offload,
cpu_offload=server_args.dit_cpu_offload
or server_args.dit_layerwise_offload,
pin_cpu_memory=server_args.pin_cpu_memory,
fsdp_inference=server_args.use_fsdp_inference,
# TODO(will): make these configurable
Expand Down
24 changes: 22 additions & 2 deletions python/sglang/multimodal_gen/runtime/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from sglang.multimodal_gen.runtime.layers.quantization.configs.nunchaku_config import (
NunchakuConfig,
)
from sglang.multimodal_gen.runtime.loader.utils import BYTES_PER_GB
from sglang.multimodal_gen.runtime.platforms import (
AttentionBackendEnum,
current_platform,
Expand Down Expand Up @@ -411,7 +412,18 @@ def _adjust_quant_config(self):
)

def _adjust_offload(self):
if self.pipeline_config.task_type.is_image_gen():
# TODO: to be handled by each platform
if current_platform.get_device_total_memory() / BYTES_PER_GB < 30:
logger.info("Enabling all offloading for GPU with low device memory")
if self.dit_cpu_offload is None:
self.dit_cpu_offload = True
if self.text_encoder_cpu_offload is None:
self.text_encoder_cpu_offload = True
if self.image_encoder_cpu_offload is None:
self.image_encoder_cpu_offload = True
if self.vae_cpu_offload is None:
self.vae_cpu_offload = True
Comment on lines +416 to +425
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The added logic for low-memory GPUs is nearly identical to the else block on lines 438-447, which introduces code duplication. This can make future modifications more error-prone.

Additionally, the value 30 is a magic number. It would be better to define it as a constant, for example LOW_GPU_MEMORY_THRESHOLD_GB = 30, to improve readability and make it easier to change.

Please consider refactoring the _adjust_offload method to eliminate this duplication.

Comment on lines +416 to +425
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The hardcoded value 30 for the memory threshold can be extracted into a named constant to improve readability. Additionally, the series of if statements to enable offloading for different components is repetitive. This block can be refactored into a loop to make the code more concise and easier to maintain.

Suggested change
if current_platform.get_device_total_memory() / BYTES_PER_GB < 30:
logger.info("Enabling all offloading for GPU with low device memory")
if self.dit_cpu_offload is None:
self.dit_cpu_offload = True
if self.text_encoder_cpu_offload is None:
self.text_encoder_cpu_offload = True
if self.image_encoder_cpu_offload is None:
self.image_encoder_cpu_offload = True
if self.vae_cpu_offload is None:
self.vae_cpu_offload = True
if current_platform.get_device_total_memory() / BYTES_PER_GB < 30:
logger.info("Enabling all offloading for GPU with low device memory")
offload_attrs = [
"dit_cpu_offload",
"text_encoder_cpu_offload",
"image_encoder_cpu_offload",
"vae_cpu_offload",
]
for attr in offload_attrs:
if getattr(self, attr) is None:
setattr(self, attr, True)

elif self.pipeline_config.task_type.is_image_gen():
logger.info(
"Disabling some offloading (except dit, text_encoder) for image generation model"
)
Expand Down Expand Up @@ -495,6 +507,14 @@ def _adjust_parallelism(self):
# adjust sp_degree: allocate all remaining GPUs after TP and DP
if self.sp_degree is None:
num_gpus_per_group = self.dp_size * self.tp_size
if (
self.enable_cfg_parallel is None
and (self.num_gpus // num_gpus_per_group) >= 2
):
logger.info(
f"Automatically set enable_cfg_parallel for best performance"
)
self.enable_cfg_parallel = True
if self.enable_cfg_parallel:
num_gpus_per_group *= 2
if self.num_gpus % num_gpus_per_group == 0:
Expand Down Expand Up @@ -1086,7 +1106,7 @@ def _validate_offload(self):
)
self.use_fsdp_inference = False

if self.dit_cpu_offload:
if self.dit_cpu_offload is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change from if self.dit_cpu_offload: to if self.dit_cpu_offload is None: appears to introduce a bug. When dit_layerwise_offload is enabled, dit_cpu_offload must be disabled to prevent conflicts. However, with this change, if _adjust_offload sets dit_cpu_offload=True on a low-memory GPU, this condition will be false, and dit_cpu_offload will incorrectly remain True, leading to a configuration conflict.

The previous logic correctly handled this by disabling dit_cpu_offload whenever it was enabled. Please revert this change to ensure the conflict is always resolved.

Suggested change
if self.dit_cpu_offload is None:
if self.dit_cpu_offload:

logger.warning(
"dit_layerwise_offload is enabled, automatically disabling dit_cpu_offload."
)
Expand Down
7 changes: 6 additions & 1 deletion python/sglang/multimodal_gen/runtime/utils/perf_logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
import dataclasses
import json
import logging
import os
import subprocess
import sys
Expand All @@ -15,6 +16,7 @@

import sglang
import sglang.multimodal_gen.envs as envs
from sglang.multimodal_gen.runtime.platforms import current_platform
from sglang.multimodal_gen.runtime.utils.logging_utils import (
_SGLDiffusionLogger,
get_is_main_process,
Expand Down Expand Up @@ -199,7 +201,10 @@ def __init__(

def __enter__(self):
if self.log_stage_start_end:
self.logger.info(f"[{self.stage_name}] started...")
msg = f"[{self.stage_name}] started..."
if self.logger.isEnabledFor(logging.DEBUG):
msg += f" ({round(current_platform.get_available_gpu_memory(), 2)} GB left)"
self.logger.info(msg)

if (self.log_timing and self.timings) or self.log_stage_start_end:
if (
Expand Down
Loading