|
26 | 26 | from sglang.multimodal_gen.runtime.layers.quantization.configs.nunchaku_config import ( |
27 | 27 | NunchakuConfig, |
28 | 28 | ) |
| 29 | +from sglang.multimodal_gen.runtime.loader.utils import BYTES_PER_GB |
29 | 30 | from sglang.multimodal_gen.runtime.platforms import ( |
30 | 31 | AttentionBackendEnum, |
31 | 32 | current_platform, |
@@ -411,7 +412,18 @@ def _adjust_quant_config(self): |
411 | 412 | ) |
412 | 413 |
|
413 | 414 | def _adjust_offload(self): |
414 | | - if self.pipeline_config.task_type.is_image_gen(): |
| 415 | + # TODO: to be handled by each platform |
| 416 | + if current_platform.get_device_total_memory() / BYTES_PER_GB < 30: |
| 417 | + logger.info("Enabling all offloading for GPU with low device memory") |
| 418 | + if self.dit_cpu_offload is None: |
| 419 | + self.dit_cpu_offload = True |
| 420 | + if self.text_encoder_cpu_offload is None: |
| 421 | + self.text_encoder_cpu_offload = True |
| 422 | + if self.image_encoder_cpu_offload is None: |
| 423 | + self.image_encoder_cpu_offload = True |
| 424 | + if self.vae_cpu_offload is None: |
| 425 | + self.vae_cpu_offload = True |
| 426 | + elif self.pipeline_config.task_type.is_image_gen(): |
415 | 427 | logger.info( |
416 | 428 | "Disabling some offloading (except dit, text_encoder) for image generation model" |
417 | 429 | ) |
@@ -1086,7 +1098,7 @@ def _validate_offload(self): |
1086 | 1098 | ) |
1087 | 1099 | self.use_fsdp_inference = False |
1088 | 1100 |
|
1089 | | - if self.dit_cpu_offload: |
| 1101 | + if self.dit_cpu_offload is None: |
1090 | 1102 | logger.warning( |
1091 | 1103 | "dit_layerwise_offload is enabled, automatically disabling dit_cpu_offload." |
1092 | 1104 | ) |
|
0 commit comments