Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
f552f9a
Add CDC-FM (Carré du Champ Flow Matching) support
rockerBOO Oct 9, 2025
e03200b
Optimize: Cache CDC shapes in memory to eliminate I/O bottleneck
rockerBOO Oct 9, 2025
0d822b2
Refactor: Extract CDC noise transformation to separate function
rockerBOO Oct 9, 2025
88af208
Fix: Enable gradient flow through CDC noise transformation
rockerBOO Oct 9, 2025
ce17007
Add warning throttling for CDC shape mismatches
rockerBOO Oct 9, 2025
ee8ceee
Add device consistency validation for CDC transformation
rockerBOO Oct 9, 2025
4bea582
Fix: Prevent false device mismatch warnings for cuda vs cuda:0
rockerBOO Oct 9, 2025
1d4c4d4
Fix: Replace CDC integer index lookup with image_key strings
rockerBOO Oct 9, 2025
7a7110c
Use logger instead of print for CDC loading messages
rockerBOO Oct 9, 2025
c8a4e99
Add --cdc_debug flag and tqdm progress for CDC preprocessing
rockerBOO Oct 9, 2025
f128f5a
Formatting cleanup
rockerBOO Oct 9, 2025
20c6ae5
Add faiss to github action
rockerBOO Oct 9, 2025
f450443
Add CDC-FM parameters to model metadata
rockerBOO Oct 10, 2025
7ca799c
Add adaptive k_neighbors support for CDC-FM
rockerBOO Oct 10, 2025
8458a56
Add graceful fallback when FAISS is not installed
rockerBOO Oct 10, 2025
aa3a216
Slight cleanup
rockerBOO Oct 11, 2025
8089cb6
Improve dimension mismatch warning for CDC Flow Matching
rockerBOO Oct 11, 2025
1f79115
Consolidate and simplify CDC test files
rockerBOO Oct 11, 2025
83c17de
Remove faiss, save per image cdc file
rockerBOO Oct 18, 2025
c820ace
Fix CDC tests to new format and deprecate old tests
rockerBOO Oct 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
- name: Install dependencies
run: |
# Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch)
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4 faiss-cpu==1.12.0
pip install -r requirements.txt
- name: Test with pytest
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ GEMINI.md
.claude
.gemini
MagicMock
benchmark_*.py
90 changes: 85 additions & 5 deletions flux_train_network.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import argparse
import copy
import math
import random
from typing import Any, Optional, Union

import torch
Expand Down Expand Up @@ -36,6 +34,7 @@ def __init__(self):
self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False
self.model_type: Optional[str] = None
self.gamma_b_dataset = None # CDC-FM Γ_b dataset

def assert_extra_args(
self,
Expand Down Expand Up @@ -327,9 +326,15 @@ def get_noise_pred_and_target(
noise = torch.randn_like(latents)
bsz = latents.shape[0]

# get noisy model input and timesteps
# Get CDC parameters if enabled
gamma_b_dataset = self.gamma_b_dataset if (self.gamma_b_dataset is not None and "latents_npz" in batch) else None
latents_npz_paths = batch.get("latents_npz") if gamma_b_dataset is not None else None

# Get noisy model input and timesteps
# If CDC is enabled, this will transform the noise with geometry-aware covariance
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype,
gamma_b_dataset=gamma_b_dataset, latents_npz_paths=latents_npz_paths
)

# pack latents and get img_ids
Expand Down Expand Up @@ -456,6 +461,15 @@ def update_metadata(self, metadata, args):
metadata["ss_model_prediction_type"] = args.model_prediction_type
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift

# CDC-FM metadata
metadata["ss_use_cdc_fm"] = getattr(args, "use_cdc_fm", False)
metadata["ss_cdc_k_neighbors"] = getattr(args, "cdc_k_neighbors", None)
metadata["ss_cdc_k_bandwidth"] = getattr(args, "cdc_k_bandwidth", None)
metadata["ss_cdc_d_cdc"] = getattr(args, "cdc_d_cdc", None)
metadata["ss_cdc_gamma"] = getattr(args, "cdc_gamma", None)
metadata["ss_cdc_adaptive_k"] = getattr(args, "cdc_adaptive_k", None)
metadata["ss_cdc_min_bucket_size"] = getattr(args, "cdc_min_bucket_size", None)

def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)

Expand Down Expand Up @@ -494,7 +508,7 @@ def forward(hidden_states):
module.forward = forward_hook(module)

if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
logger.info(f"T5XXL already prepared for fp8")
logger.info("T5XXL already prepared for fp8")
else:
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
text_encoder.to(te_weight_dtype) # fp8
Expand Down Expand Up @@ -533,6 +547,72 @@ def setup_parser() -> argparse.ArgumentParser:
help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
" / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
)

# CDC-FM arguments
parser.add_argument(
"--use_cdc_fm",
action="store_true",
help="Enable CDC-FM (Carré du Champ Flow Matching) for geometry-aware noise during training"
" / CDC-FM(Carré du Champ Flow Matching)を有効にして幾何学的ノイズを使用",
)
parser.add_argument(
"--cdc_k_neighbors",
type=int,
default=256,
help="Number of neighbors for k-NN graph in CDC-FM (default: 256)"
" / CDC-FMのk-NNグラフの近傍数(デフォルト: 256)",
)
parser.add_argument(
"--cdc_k_bandwidth",
type=int,
default=8,
help="Number of neighbors for bandwidth estimation in CDC-FM (default: 8)"
" / CDC-FMの帯域幅推定の近傍数(デフォルト: 8)",
)
parser.add_argument(
"--cdc_d_cdc",
type=int,
default=8,
help="Dimension of CDC subspace (default: 8)"
" / CDCサブ空間の次元(デフォルト: 8)",
)
parser.add_argument(
"--cdc_gamma",
type=float,
default=1.0,
help="CDC strength parameter (default: 1.0)"
" / CDC強度パラメータ(デフォルト: 1.0)",
)
parser.add_argument(
"--force_recache_cdc",
action="store_true",
help="Force recompute CDC cache even if valid cache exists"
" / 有効なCDCキャッシュが存在してもCDCキャッシュを再計算",
)
parser.add_argument(
"--cdc_debug",
action="store_true",
help="Enable verbose CDC debug output showing bucket details"
" / CDCの詳細デバッグ出力を有効化(バケット詳細表示)",
)
parser.add_argument(
"--cdc_adaptive_k",
action="store_true",
help="Use adaptive k_neighbors based on bucket size. If enabled, buckets smaller than k_neighbors will use "
"k=bucket_size-1 instead of skipping CDC entirely. Buckets smaller than cdc_min_bucket_size are still skipped."
" / バケットサイズに基づいてk_neighborsを適応的に調整。有効にすると、k_neighbors未満のバケットは"
"CDCをスキップせずk=バケットサイズ-1を使用。cdc_min_bucket_size未満のバケットは引き続きスキップ。",
)
parser.add_argument(
"--cdc_min_bucket_size",
type=int,
default=16,
help="Minimum bucket size for CDC computation. Buckets with fewer samples will use standard Gaussian noise. "
"Only relevant when --cdc_adaptive_k is enabled (default: 16)"
" / CDC計算の最小バケットサイズ。これより少ないサンプルのバケットは標準ガウスノイズを使用。"
"--cdc_adaptive_k有効時のみ関連(デフォルト: 16)",
)

return parser


Expand Down
Loading