Skip to content

Commit d73f06f

Browse files
authored
[diffusion] chore: improve memory usage on consumer-level GPU (#18997)
1 parent 963def7 commit d73f06f

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

python/sglang/multimodal_gen/runtime/loader/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,15 @@ def _list_safetensors_files(model_path: str) -> list[str]:
148148
return sorted(glob.glob(os.path.join(str(model_path), "*.safetensors")))
149149

150150

151+
BYTES_PER_GB = 1024**3
152+
153+
151154
def get_memory_usage_of_component(module) -> float | None:
152155
"""
153156
returned value is in GB, rounded to 2 decimal digits
154157
"""
155158
if not isinstance(module, nn.Module):
156159
return None
157-
BYTES_PER_GB = 1024**3
158160
if hasattr(module, "get_memory_footprint"):
159161
usage = module.get_memory_footprint() / BYTES_PER_GB
160162
else:

python/sglang/multimodal_gen/runtime/server_args.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from sglang.multimodal_gen.runtime.layers.quantization.configs.nunchaku_config import (
2727
NunchakuConfig,
2828
)
29+
from sglang.multimodal_gen.runtime.loader.utils import BYTES_PER_GB
2930
from sglang.multimodal_gen.runtime.platforms import (
3031
AttentionBackendEnum,
3132
current_platform,
@@ -411,7 +412,18 @@ def _adjust_quant_config(self):
411412
)
412413

413414
def _adjust_offload(self):
414-
if self.pipeline_config.task_type.is_image_gen():
415+
# TODO: to be handled by each platform
416+
if current_platform.get_device_total_memory() / BYTES_PER_GB < 30:
417+
logger.info("Enabling all offloading for GPU with low device memory")
418+
if self.dit_cpu_offload is None:
419+
self.dit_cpu_offload = True
420+
if self.text_encoder_cpu_offload is None:
421+
self.text_encoder_cpu_offload = True
422+
if self.image_encoder_cpu_offload is None:
423+
self.image_encoder_cpu_offload = True
424+
if self.vae_cpu_offload is None:
425+
self.vae_cpu_offload = True
426+
elif self.pipeline_config.task_type.is_image_gen():
415427
logger.info(
416428
"Disabling some offloading (except dit, text_encoder) for image generation model"
417429
)
@@ -1086,7 +1098,7 @@ def _validate_offload(self):
10861098
)
10871099
self.use_fsdp_inference = False
10881100

1089-
if self.dit_cpu_offload:
1101+
if self.dit_cpu_offload is None:
10901102
logger.warning(
10911103
"dit_layerwise_offload is enabled, automatically disabling dit_cpu_offload."
10921104
)

0 commit comments

Comments
 (0)