Skip to content

sdxl controlnet inpaint cost more memory than torch #605

Open
@285220927

Description

@285220927

Describe the bug

A clear and concise description of what the bug is.

测试sdxl controlnet inpaint的时候,onediff显存的占用比pytorch高了接近一半

Your environment

OS

ubuntu 20.04
gpu NVIDIA GeForce RTX 4090
python 3.8
diffusers 0.23.0
onediff 0.12.1.dev202401310124
pytorch 2.0.1
cuda 12.2

OneDiff git commit id

OneFlow version info

Run python -m oneflow --doctor and paste it here.
version: 0.9.1.dev20240125+cu122
git_commit: 6458a12
cmake_build_type: Release
rdma: True
mlir: True
enterprise: False

How To Reproduce

Steps to reproduce the behavior(code or script):

import torch
from diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel

from onediff.infer_compiler import oneflow_compile


device = "cuda:0"
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
        'stablediffusionapi/dreamshaper-xl',
        controlnet=controlnet,
        torch_dtype=torch.float16
)
pipe = pipe.to(device)

pipe.unet = oneflow_compile(pipe.unet)
pipe.controlnet = oneflow_compile(pipe.controlnet)

image = pipe(
    prompt=prompt,
    image=init_image,
    mask_image=mask_image,
    control_image=control_image,
    strength=1.0,
    controlnet_conditioning_scale=0.5,
    num_inference_steps=30
)

The complete result

pytorch

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.146.02             Driver Version: 535.146.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4090        On  | 00000000:82:00.0 Off |                  Off |
|  0%   38C    P2              64W / 450W |  13380MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+

onediff

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.146.02             Driver Version: 535.146.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4090        On  | 00000000:82:00.0 Off |                  Off |
|  0%   36C    P8              22W / 450W |  18382MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+

Additional context

Add any other context about the problem here.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions