Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dataset/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class ImageDatasetParams(BaseDatasetParams):
fp_1f_clean_indices: Optional[Sequence[int]] = None
fp_1f_target_index: Optional[int] = None
fp_1f_no_post: Optional[bool] = False
fp_1f_image_embedding_source: Optional[str] = None


@dataclass
Expand Down Expand Up @@ -117,6 +118,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"fp_1f_clean_indices": [int],
"fp_1f_target_index": int,
"fp_1f_no_post": bool,
"fp_1f_image_embedding_source": str,
}
VIDEO_DATASET_DISTINCT_SCHEMA = {
"video_directory": str,
Expand Down Expand Up @@ -302,6 +304,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
fp_1f_clean_indices: {dataset.fp_1f_clean_indices}
fp_1f_target_index: {dataset.fp_1f_target_index}
fp_1f_no_post: {dataset.fp_1f_no_post}
fp_1f_image_embedding_source: "{dataset.fp_1f_image_embedding_source}"
\n"""
),
" ",
Expand Down
9 changes: 9 additions & 0 deletions dataset/dataset_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ The number of control images should match the number of indices specified in `fp

The default values mean that the first image (control image) is at index `0`, and the target image (the changed image) is at index `9`.

`fp_1f_image_embedding_source` sets where to get the embedding of the control image. `control_0` means to get it from the first control image, `control1` means to get it from the second control image (and so on). `target` means to get it from the target image. If not specified, it defaults to `control_0`.

For training with 1f-mc, set `fp_1f_clean_indices` to `[0, 1]` and `fp_1f_target_index` to `9` (or another value). This allows you to use multiple control images to train a single generated image. The control images will be two in this case.

```toml
Expand All @@ -416,11 +418,14 @@ fp_1f_no_post = false

For training with kisekaeichi, set `fp_1f_clean_indices` to `[0, 10]` and `fp_1f_target_index` to `1` (or another value). This allows you to use the starting image (the image just before the generation section) and the image following the generation section (equivalent to `clean_latent_post`) to train the first image of the generated video. The control images will be two in this case. `fp_1f_no_post` should be set to `true`.

It is recommended to set `fp_1f_image_embedding_source` to `target` for training kisekaeichi (however, if the inference environment always gets the image embedding from the first control image, you can set it to `control_0`).

```toml
[[datasets]]
fp_1f_clean_indices = [0, 10]
fp_1f_target_index = 1
fp_1f_no_post = true
fp_1f_image_embedding_source = "target" # recommended for kisekaeichi training
```

With `fp_1f_clean_indices` and `fp_1f_target_index`, you can specify any number of control images and any index of the target image for training.
Expand All @@ -446,10 +451,14 @@ The 2x indices are `1 + fp1_latent_window_size + 1` for two indices (usually `11

デフォルトの1フレーム学習では、開始画像(制御画像)1枚をインデックス`0`、生成対象の画像(変化後の画像)をインデックス`9`に設定しています。

`fp_1f_image_embedding_source`は、制御画像の埋め込みをどこから取得するかを設定します。`control_0`を指定すると、制御画像の最初の画像から取得し、`control_1`を指定すると2枚目の制御画像から取得します(以下同様)。`target`を指定すると生成対象の画像から取得します。未指定時は`control_0`となります。

1f-mcの学習を行う場合は、`fp_1f_clean_indices`に `[0, 1]`を、`fp_1f_target_index`に`9`を設定してください。これにより動画の先頭の2枚の制御画像を使用して、後続の1枚の生成画像を学習します。制御画像は2枚になります。

kisekaeichiの学習を行う場合は、`fp_1f_clean_indices`に `[0, 10]`を、`fp_1f_target_index`に`1`(または他の値)を設定してください。これは、開始画像(生成セクションの直前の画像)(`clean_latent_pre`に相当)と、生成セクションに続く1枚の画像(`clean_latent_post`に相当)を使用して、生成動画の先頭の画像(`target_index=1`)を学習します。制御画像は2枚になります。`f1_1f_no_post`は`true`に設定してください。

kisekaeichiの学習では、`fp_1f_image_embedding_source`は`target`に設定することをお勧めします(ただし推論環境によってはimage embeddingを常に最初の制御画像から取得する場合もあり、それらの環境を前提とする場合は`control_0`に設定してください)。

`fp_1f_clean_indices`と`fp_1f_target_index`を応用することで、任意の枚数の制御画像を、任意のインデックスを指定して学習することが可能です。

`fp_1f_no_post`を`false`に設定すると、`clean_latent_post_index`は `1 + fp1_latent_window_size` になります。
Expand Down
4 changes: 4 additions & 0 deletions dataset/image_video_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def __init__(
self.fp_1f_clean_indices: Optional[list[int]] = None # indices of clean latents for 1f
self.fp_1f_target_index: Optional[int] = None # target index for 1f clean latents
self.fp_1f_no_post: Optional[bool] = None # whether to add zero values as clean latent post
self.fp_1f_image_embedding_source: Optional[str] = None # source of image embedding for 1f

def __str__(self) -> str:
return (
Expand Down Expand Up @@ -1330,6 +1331,7 @@ def __init__(
fp_1f_clean_indices: Optional[list[int]] = None,
fp_1f_target_index: Optional[int] = None,
fp_1f_no_post: Optional[bool] = False,
fp_1f_image_embedding_source: Optional[str] = None,
debug_dataset: bool = False,
architecture: str = "no_default",
):
Expand All @@ -1351,6 +1353,7 @@ def __init__(
self.fp_1f_clean_indices = fp_1f_clean_indices
self.fp_1f_target_index = fp_1f_target_index
self.fp_1f_no_post = fp_1f_no_post
self.fp_1f_image_embedding_source = fp_1f_image_embedding_source

control_count_per_image = 1
if fp_1f_clean_indices is not None:
Expand Down Expand Up @@ -1415,6 +1418,7 @@ def aggregate_future(consume_all: bool = False):
item_info.fp_1f_clean_indices = self.fp_1f_clean_indices
item_info.fp_1f_target_index = self.fp_1f_target_index
item_info.fp_1f_no_post = self.fp_1f_no_post
item_info.fp_1f_image_embedding_source = self.fp_1f_image_embedding_source

if self.architecture == ARCHITECTURE_FRAMEPACK:
# we need to split the bucket with latent window size and optional 1f clean indices, zero post
Expand Down
20 changes: 19 additions & 1 deletion fpack_cache_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,25 @@ def encode_and_save_batch_one_frame(
latents[b : b + 1, :, i : i + 1] *= content_mask

# Vision encoding per‑item (once): use control content because it is the start image
images = [item.control_content[0] for item in batch] # list of [H, W, C]
# images = [item.control_content[0] for item in batch] # list of [H, W, C]
images = []
for item in batch:
# print(f"Processing item {item.item_key} with image embedding source: {item.fp_1f_image_embedding_source}")
if item.fp_1f_image_embedding_source == "target":
images.append(item.content) # Use target image
elif item.fp_1f_image_embedding_source is None or item.fp_1f_image_embedding_source.startswith("control"):
image_embedding_source = item.fp_1f_image_embedding_source or "control_0"
control_index = int(image_embedding_source.split("_")[-1])
# print(f" Using control index: {control_index} for item {item.item_key}")
if control_index < len(item.control_content):
images.append(item.control_content[control_index])
else:
raise ValueError(
f"Control index {control_index} out of range for item {item.item_key}, available: {len(item.control_content)}"
)
else:
raise ValueError(f"Unknown image embedding source: {item.fp_1f_image_embedding_source}")
# print(f" Using image embedding source: {item.fp_1f_image_embedding_source}, image shape: {images[-1].shape}")

# encode image with image encoder
image_embeddings = []
Expand Down