Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions pipelines/pipeline_infu_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
from insightface.app import FaceAnalysis
from insightface.utils import face_align
from PIL import Image
from optimum.quanto import freeze, qint8, quantize
from optimum.quanto import freeze, qfloat8, qint4, qint8, quantize
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Linear4bit
BNB_AVAILABLE = True
except ImportError:
BNB_AVAILABLE = False
from transformers import T5EncoderModel

from .pipeline_flux_infusenet import FluxInfuseNetPipeline
Expand Down Expand Up @@ -133,6 +139,7 @@ def __init__(
infu_flux_version='v1.0',
model_version='aes_stage2',
quantize_8bit=False,
quantize_infusenet=None,
cpu_offload=False,
):

Expand All @@ -150,7 +157,43 @@ def __init__(
infusenet_path = os.path.join(infu_model_path, 'InfuseNetModel')
self.infusenet = FluxControlNetModel.from_pretrained(infusenet_path, torch_dtype=torch.bfloat16)
insightface_root_path = './models/InfiniteYou/supports/insightface'
if quantize_8bit:
# Quantize InfuseNet (independent of FLUX quantization)
# quantize_infusenet options:
# 'nf4' : bitsandbytes NF4 4-bit (true peak VRAM reduction, requires pip install bitsandbytes)
# 'fp8' : optimum.quanto FP8 weight-only
# 'int4' : optimum.quanto INT4 weight-only
# 'int8' : optimum.quanto INT8 weight-only
if quantize_infusenet == 'nf4':
if not BNB_AVAILABLE:
raise ImportError(
'--infusenet_quant nf4 requires bitsandbytes. '
'Install with: pip install bitsandbytes'
)
print('[InfuseNet] Applying NF4 quantization via bitsandbytes...')
from diffusers import BitsAndBytesConfig as DiffusersBnBConfig
bnb_config = DiffusersBnBConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# Reload InfuseNet with BnB quantization config
infusenet_path = os.path.join(infu_model_path, 'InfuseNetModel')
self.infusenet = FluxControlNetModel.from_pretrained(
infusenet_path,
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
)
elif quantize_infusenet == 'fp8':
print('[InfuseNet] Applying FP8 quantization via optimum.quanto...')
quantize(self.infusenet, weights=qfloat8)
freeze(self.infusenet)
elif quantize_infusenet == 'int4':
print('[InfuseNet] Applying INT4 quantization via optimum.quanto...')
quantize(self.infusenet, weights=qint4)
freeze(self.infusenet)
elif quantize_infusenet == 'int8' or quantize_8bit:
print('[InfuseNet] Applying INT8 quantization via optimum.quanto...')
quantize(self.infusenet, weights=qint8)
freeze(self.infusenet)
try:
Expand Down
9 changes: 9 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ def main():
# Memory reduction options
parser.add_argument('--quantize_8bit', action='store_true')
parser.add_argument('--cpu_offload', action='store_true')
parser.add_argument('--infusenet_quant', default=None, choices=['nf4', 'fp8', 'int4', 'int8'],
help="""Quantize InfuseNet independently of FLUX: nf4 | fp8 | int4 | int8.
Approximate peak VRAM savings vs bf16 full (~43GB):
int8 : ~32GB (same as --quantize_8bit but InfuseNet only)
fp8 : ~32GB (slightly lower precision loss than int8)
int4 : ~26GB (most aggressive, may affect ID similarity slightly)
Combine with --quantize_8bit and --cpu_offload for maximum savings.
Requires: pip install optimum-quanto""")
args = parser.parse_args()

# Check arguments
Expand All @@ -63,6 +71,7 @@ def main():
infu_flux_version=args.infu_flux_version,
model_version=args.model_version,
quantize_8bit=args.quantize_8bit,
quantize_infusenet=args.infusenet_quant,
cpu_offload=args.cpu_offload,
)
# Load LoRAs (optional)
Expand Down