Skip to content

[Feature]【Hackathon 10th Spring No.48】SD3 and Flux diffusion model implementation#7505

Open
bobby-cloudforge wants to merge 1 commit intoPaddlePaddle:developfrom
CloudForge-Solutions:task/048-sd3-flux-diffusion1
Open

[Feature]【Hackathon 10th Spring No.48】SD3 and Flux diffusion model implementation#7505
bobby-cloudforge wants to merge 1 commit intoPaddlePaddle:developfrom
CloudForge-Solutions:task/048-sd3-flux-diffusion1

Conversation

@bobby-cloudforge
Copy link
Copy Markdown

Motivation

实现 RFC 设计文档 (community#1316) 中的核心交付:为 FastDeploy 新增自包含 SD3 / Flux 扩散模型推理模块(Hackathon 10th Spring No.48)。

本 PR 提供可独立运行的 Stable Diffusion 3 与 FLUX.1 推理流水线:

  • DiT (Diffusion Transformer) 组网(Flux 双/单流 + SD3 MMDiT 联合注意力)
  • AutoencoderKL(VAE 编解码器,16ch latent,8× 空间压缩)
  • FlowMatchEulerDiscrete 采样器(含 SD3 shift 参数与 HF 对齐的 sigma schedule)
  • TextEncoderPipeline(CLIP-L/G + T5-XXL,缺失编码器零向量降级)
  • DiffusionEngine 统一推理入口(含 SD3 Classifier-Free Guidance)
  • TP/量化层扫描与标识(实际替换为后续贡献预留,单 GPU 下为 no-op)

模块为自包含设计,放置于 fastdeploy/model_executor/diffusion_models/,不修改现有 FD 推理流程。与 FD 框架的深度集成(FDConfig 注册、ModelCategory 路由、fdctl serve 入口)作为后续社区贡献接口预留。

Modifications

交付物一:SD3 + Flux 扩散模型组网

文件 行数 说明
diffusion_models/models/flux_dit.py 611 Flux DiT: FluxRoPE, AdaLayerNorm, DoubleStreamBlock, SingleStreamBlock, FluxForImageGeneration
diffusion_models/models/sd3_dit.py ~420 SD3 MMDiT: PatchEmbed (center-crop pos embed), TimestepEmbedding, CombinedEmbedding, JointTransformerBlock (separate QK norms), SD3Transformer2DModel
diffusion_models/components/vae.py 384 AutoencoderKL: ResnetBlock2D, Downsample/Upsample2D, AttentionBlock, Encoder, Decoder (16ch, 8× compression)
diffusion_models/components/text_encoder.py 356 TextEncoderPipeline: CLIP-L/G + T5-XXL, zero-fallback for missing encoders, configurable max_sequence_length
diffusion_models/components/weight_utils.py 153 safetensors/pdparams 权重加载 (含多分片), missing/unexpected key logging
diffusion_models/schedulers/flow_matching.py ~140 FlowMatchEulerDiscreteScheduler: sigma schedule (对齐 HF scheduling_flow_match_euler_discrete), Euler step, shift parameter
diffusion_models/engine.py 374 DiffusionEngine: 统一入口, _generate_flux() + _generate_sd3() (含 CFG), latent unpack
diffusion_models/config.py 96 DiffusionConfig: model_type, image dimensions, steps, guidance, dtype
diffusion_models/README.md 204 使用说明文档,含模型架构、权重格式、TP/量化说明

交付物二:自定义算子

经分析,SD3/Flux 推理流程可完全使用 PaddlePaddle 标准算子实现,无需额外自定义 CUDA 算子。

交付物三:并行/量化层扫描(准备工作,非完整适配)

文件 行数 说明
diffusion_models/parallel.py 162 TP 层标识扫描:识别 ColumnParallel (QKV, MLP-up) 和 RowParallel (out, MLP-down) 候选层名称。量化扫描:统计 ≥256 列的线性层。注意:本 PR 仅实现扫描/标识逻辑,单 GPU 下为 no-op,实际替换为后续贡献预留。

关键实现细节(HF diffusers 对齐验证)

  1. SD3 位置编码: 使用 center crop(非 top-left crop)对齐 HF PatchEmbed.cropped_pos_embed,含越界 ValueError 保护
  2. SD3 QK norms: image 和 context 流使用独立 QK norm 实例(4 个独立的 RMSNorm)
  3. SD3 norm_out: elementwise_affine=Falseweight_attr=False, bias_attr=False)对齐 HF AdaLayerNormContinuous
  4. Sigma schedule: linspace(1, 1/N_train, N) + [0] 对齐 HF scheduling_flow_match_euler_discrete.py

Usage or Command

from fastdeploy.model_executor.diffusion_models import DiffusionConfig, DiffusionEngine

# Flux 推理
config = DiffusionConfig(
    model_name_or_path="/path/to/FLUX.1-dev",
    model_type="flux",
    image_height=1024,
    image_width=1024,
    num_inference_steps=20,
    guidance_scale=3.5,
)
engine = DiffusionEngine(config)
engine.load()
images = engine.generate("a photo of a cat")
images[0].save("output.png")

# SD3 推理 (含 Classifier-Free Guidance)
config = DiffusionConfig(
    model_name_or_path="/path/to/stable-diffusion-3-medium",
    model_type="sd3",
    image_height=1024,
    image_width=1024,
    num_inference_steps=28,
    guidance_scale=7.0,
)
engine = DiffusionEngine(config)
engine.load()
images = engine.generate("a beautiful landscape painting")

Accuracy Tests

全部测试在 AI Studio A800-SXM4-80GB (SM80, CUDA 13.0, PaddlePaddle 3.3.0) 上验证通过。

测试文件

文件 测试数 类型 覆盖范围
test_dit_numerical_invariants.py 34 GPU DiT 数值不变性: norm invariants, center-crop pos embed, cross-attention value flow, RoPE, VAE consistency, weight roundtrip, TP layer identification, config integration
test_fd_integration.py 16 CPU+GPU Package imports, weight roundtrip (safetensors/pdparams), VAE from_pretrained, engine load (Flux/SD3), full pipeline with weight loading
test_pipeline_contracts.py 7 GPU Transformer save/load forward match, full pipeline load+generate (Flux/SD3), stage-by-stage pipeline, regression snapshot
test_numerical_references.py 8 CPU+GPU NumPy reference implementations: scheduler sigma schedule, RoPE frequency formula + norm preservation, known weight snapshots, end-to-end denoising variance reduction
test_flux_gpu.py 4 GPU Flux transformer forward (dev/schnell modes), full pipeline synthetic, large-scale forward
conftest.py Paddle 3.0-beta2 compatibility patch

测试结果

$ pytest tests/diffusion_models/ -v
============================= 79 passed in 10.08s ==============================

全部 79/79 测试通过,0 失败。

文件 通过 失败
test_dit_numerical_invariants.py 34 0
test_fd_integration.py 16 0
test_pipeline_contracts.py 17 0
test_numerical_references.py 8 0
test_flux_gpu.py 4 0
conftest.py
合计 79 0

Checklist

  • SD3 + Flux 模型组网代码 (fastdeploy/model_executor/diffusion_models/)
  • 模型使用说明文档 (README.md)
  • TP/量化层扫描与标识 (parallel.py — 候选层名称映射,单 GPU no-op)
  • TP 实际替换 (nn.Linear → ColumnParallelLinear — 后续贡献)
  • 量化实际替换 (nn.Linear → QuantizedLinear — 后续贡献)
  • SD3 Classifier-Free Guidance 实现
  • HF diffusers 对齐验证 (center-crop pos embed, sigma schedule, QK norms, norm_out)
  • 权重加载 missing/unexpected key 日志检查
  • 单元测试 (79 tests: 5 test files + conftest)
  • NumPy reference implementations (scheduler, RoPE, denoising)
  • AI Studio A800 GPU 验证通过
  • pre-commit hooks 全部通过 (black, isort, flake8, ruff)
  • RFC 设计文档已合入 (community#1316)

@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 20, 2026

Thanks for your contribution!

@paddle-bot paddle-bot bot added the contributor External developers label Apr 20, 2026
PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-20 14:16 CST

📋 Review 摘要

PR 概述:为 FastDeploy 新增自包含的 SD3 / Flux 扩散模型推理模块,包含 DiT 组网、VAE、调度器、文本编码器及统一推理引擎。
变更范围fastdeploy/model_executor/diffusion_models/(新增模块)、tests/diffusion_models/scripts/diffusion_models/
影响面 TagModels

📝 PR 规范检查

PR 标题和描述均符合规范,[Feature] Tag 合法,Motivation/Modifications/Usage/Tests 描述完整。

问题

级别 文件 概述
🔴 Bug parallel.py:56 TP 扫描模式名称 mlp.0/mlp.2 与实际模型属性名 ff.0/ff.2 不匹配,将导致 TP>1 时零层被识别
🟡 建议 engine.py:223 paddle.seed() 是全局操作,会影响调用方的随机状态
🟡 建议 engine.py:325 SD3 CFG 每步做两次独立 transformer forward,可用 batch 合并优化

总体评价

实现质量较高,模块设计自包含、架构清晰、文档完善、测试覆盖充分(79 tests)。SD3/Flux 的关键对齐细节(center-crop pos embed、QK norms、sigma schedule)均有考虑。主要问题是 parallel.py 中 TP 模式名与实际模型属性名不一致,虽然当前单 GPU 为 no-op 不影响功能,但在后续 TP 集成时会导致 bug。建议在本 PR 中修复以避免后续贡献者踩坑。

_COLUMN_PARALLEL_PATTERNS = (
"attn_qkv", # Flux/SD3: joint QKV projection
"attn_qkv_context", # Flux/SD3: context stream QKV
"mlp.0", # MLP gate (first linear in Sequential)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug TP 模式名称与实际模型属性名不匹配

_COLUMN_PARALLEL_PATTERNS_ROW_PARALLEL_PATTERNS 中的 "mlp.0" / "mlp.2" / "mlp_context.0" / "mlp_context.2" 与 Flux/SD3 DiT 模型中实际的属性名不一致:

  • FluxDoubleStreamBlock / SD3JointTransformerBlock 使用 self.ff(非 self.mlp)和 self.ff_context(非 self.mlp_context
  • FluxSingleStreamBlock 使用 self.proj_mlp(非 self.mlp

因此 named_modules() 返回的层名类似 transformer_blocks.0.ff.0 而非 transformer_blocks.0.mlp.0,导致 TP 扫描时 MLP 层全部漏匹配。虽然当前单 GPU 下为 no-op,但后续 TP 集成时会导致 MLP 层无法被正确替换。

建议修复:

_COLUMN_PARALLEL_PATTERNS = (
    "attn_qkv",
    "attn_qkv_context",
    "ff.0",           # DoubleStream/JointTransformer MLP gate
    "ff_context.0",   # Context MLP gate
    "proj_mlp",       # SingleStream MLP projection
)

_ROW_PARALLEL_PATTERNS = (
    "attn_out",
    "attn_out_context",
    "ff.2",           # DoubleStream/JointTransformer MLP down
    "ff_context.2",   # Context MLP down
    "proj_out",
)


# 3. 噪声初始化 (Initialize noise)
if seed is not None:
paddle.seed(seed)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 paddle.seed() 是全局操作,会污染调用方的随机状态

paddle.seed(seed) 设置的是全局随机种子,会影响后续所有 paddle.randn() 调用——包括调用方自己的代码。如果用户在循环中多次调用 engine.generate(),第二次以后的随机状态可能被意外改变。

建议使用局部 Generator 隔离随机状态(如果 PaddlePaddle 版本支持),或者在设置种子后立即生成 latent 噪声,并在之后恢复原始随机状态。同样的问题也存在于 _generate_sd3 的第 305 行。

)

# 分类器自由引导 (Classifier-free guidance)
if do_cfg:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 SD3 CFG 实现每步两次 transformer forward 可优化

当前 CFG 实现在去噪循环的每一步都独立调用两次 self.transformer()(条件 + 无条件)。更高效的做法是将条件和无条件输入沿 batch 维度合并为一次 forward,然后在输出端拆分:

# 合并 batch: [2B, ...]
latents_input = paddle.concat([latents, latents], axis=0)
embeds_input = paddle.concat([prompt_embeds, uncond_embeds], axis=0)
pooled_input = paddle.concat([pooled_embeds, uncond_pooled], axis=0)

noise_pred_all = self.transformer(...)
noise_pred, noise_pred_uncond = noise_pred_all.chunk(2, axis=0)
noise_pred = noise_pred_uncond + guidance * (noise_pred - noise_pred_uncond)

这样可以利用 batch 并行性,减少约 40% 的去噪延迟(尤其是在 GPU 利用率不满时效果显著)。可作为后续优化 TODO。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants