-
Notifications
You must be signed in to change notification settings - Fork 330
Expand file tree
/
Copy pathflux_model.py
More file actions
245 lines (219 loc) · 9.33 KB
/
flux_model.py
File metadata and controls
245 lines (219 loc) · 9.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import keras
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.flux.flux_layers import DoubleStreamBlock
from keras_hub.src.models.flux.flux_layers import EmbedND
from keras_hub.src.models.flux.flux_layers import LastLayer
from keras_hub.src.models.flux.flux_layers import MLPEmbedder
from keras_hub.src.models.flux.flux_layers import SingleStreamBlock
from keras_hub.src.models.flux.flux_maths import TimestepEmbedding
@keras_hub_export("keras_hub.models.FluxBackbone")
class FluxBackbone(Backbone):
"""Transformer model for flow matching on sequences.
The model processes image and text data with associated positional and
timestep embeddings, and optionally applies guidance embedding.
Double-stream blocks handle separate image and text streams, while
single-stream blocks combine these streams. Ported from:
https://github.com/black-forest-labs/flux
Args:
input_channels: int. The number of input channels.
hidden_size: int. The hidden size of the transformer, must be divisible
by `num_heads`.
mlp_ratio: float. The ratio of the MLP dimension to the hidden size.
num_heads: int. The number of attention heads.
depth: int. The number of double-stream blocks.
depth_single_blocks: int. The number of single-stream blocks.
axes_dim: list[int]. A list of dimensions for the positional embedding
axes.
theta: int. The base frequency for positional embeddings.
use_bias: bool. Whether to apply bias to the query, key, and value
projections.
guidance_embed: bool. If True, applies guidance embedding in the model.
image_shape: tuple[int]. Shape of the image input tensor. Defaults to
[(None, 768, 3072)].
text_shape: tuple[int]. Shape of the text input tensor. Defaults to
[(None, 768, 3072)].
image_ids_shape: tuple[int]. Shape of the image IDs input tensor.
Defaults to [(None, 768, 3072)].
text_ids_shape: tuple[int]. Shape of the text IDs input tensor.
Defaults to [(None, 768, 3072)].
y_shape: tuple[int]. Shape of the additional vector input tensor.
Defaults to `(None, 128)`.
Call arguments:
image: KerasTensor. Image input tensor of shape (N, L, D) where N is the
batch size, L is the sequence length, and D is the feature
dimension.
image_ids: KerasTensor. Image ID input tensor of shape (N, L, D)
corresponding to the image sequences.
text: KerasTensor. Text input tensor of shape (N, L, D).
text_ids: KerasTensor. Text ID input tensor of shape (N, L, D)
corresponding to the text sequences.
timesteps: KerasTensor. Timestep tensor used to compute positional
embeddings.
y: KerasTensor. Additional vector input, such as target values.
guidance: KerasTensor, optional. Guidance input tensor used
in guidance-embedded models.
Raises:
ValueError: If `hidden_size` is not divisible by `num_heads`, or if
`sum(axes_dim)` is not equal to the positional embedding dimension.
"""
def __init__(
self,
input_channels,
hidden_size,
mlp_ratio,
num_heads,
depth,
depth_single_blocks,
axes_dim,
theta,
use_bias,
guidance_embed=False,
# These will be inferred from the CLIP/T5 encoders later
image_shape=(None, 768, 3072),
text_shape=(None, 768, 3072),
image_ids_shape=(None, 768, 3072),
text_ids_shape=(None, 768, 3072),
y_shape=(None, 128),
**kwargs,
):
# === Layers ===
self.positional_embedder = EmbedND(theta=theta, axes_dim=axes_dim)
self.image_input_embedder = keras.layers.Dense(
hidden_size, use_bias=True
)
self.time_input_embedder = MLPEmbedder(hidden_dim=hidden_size)
self.vector_embedder = MLPEmbedder(hidden_dim=hidden_size)
self.guidance_input_embedder = (
MLPEmbedder(hidden_dim=hidden_size)
if guidance_embed
else keras.layers.Identity()
)
self.text_input_embedder = keras.layers.Dense(hidden_size)
self.double_blocks = [
DoubleStreamBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
use_bias=use_bias,
)
for _ in range(depth)
]
self.single_blocks = [
SingleStreamBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
for _ in range(depth_single_blocks)
]
self.final_layer = LastLayer(hidden_size, 1, input_channels)
self.timestep_embedding = TimestepEmbedding()
self.guidance_embed = guidance_embed
# === Functional Model ===
image_input = keras.Input(shape=image_shape, name="image")
image_ids = keras.Input(shape=image_ids_shape, name="image_ids")
text_input = keras.Input(shape=text_shape, name="text")
text_ids = keras.Input(shape=text_ids_shape, name="text_ids")
y = keras.Input(shape=y_shape, name="y")
timesteps_input = keras.Input(shape=(), name="timesteps")
guidance_input = keras.Input(shape=(), name="guidance")
# running on sequences image
image = self.image_input_embedder(image_input)
modulation_encoding = self.time_input_embedder(
self.timestep_embedding(timesteps_input, dim=256)
)
if self.guidance_embed:
if guidance_input is None:
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
modulation_encoding = (
modulation_encoding
+ self.guidance_input_embedder(
self.timestep_embedding(guidance_input, dim=256)
)
)
modulation_encoding = modulation_encoding + self.vector_embedder(y)
text = self.text_input_embedder(text_input)
ids = keras.ops.concatenate((text_ids, image_ids), axis=1)
positional_encoding = self.positional_embedder(ids)
for block in self.double_blocks:
image, text = block(
image=image,
text=text,
modulation_encoding=modulation_encoding,
positional_encoding=positional_encoding,
)
image = keras.ops.concatenate((text, image), axis=1)
for block in self.single_blocks:
image = block(
image,
modulation_encoding=modulation_encoding,
positional_encoding=positional_encoding,
)
image = image[:, text.shape[1] :, ...]
image = self.final_layer(
image, modulation_encoding
) # (N, T, patch_size ** 2 * output_channels)
super().__init__(
inputs={
"image": image_input,
"image_ids": image_ids,
"text": text_input,
"text_ids": text_ids,
"y": y,
"timesteps": timesteps_input,
"guidance": guidance_input,
},
outputs=image,
**kwargs,
)
# === Config ===
self.input_channels = input_channels
self.output_channels = self.input_channels
self.hidden_size = hidden_size
self.num_heads = num_heads
self.image_shape = image_shape
self.text_shape = text_shape
self.image_ids_shape = image_ids_shape
self.text_ids_shape = text_ids_shape
self.y_shape = y_shape
self.mlp_ratio = mlp_ratio
self.depth = depth
self.depth_single_blocks = depth_single_blocks
self.axes_dim = axes_dim
self.theta = theta
self.use_bias = use_bias
def get_config(self):
config = super().get_config()
config.update(
{
"input_channels": self.input_channels,
"hidden_size": self.hidden_size,
"mlp_ratio": self.mlp_ratio,
"num_heads": self.num_heads,
"depth": self.depth,
"depth_single_blocks": self.depth_single_blocks,
"axes_dim": self.axes_dim,
"theta": self.theta,
"use_bias": self.use_bias,
"guidance_embed": self.guidance_embed,
"image_shape": self.image_shape,
"text_shape": self.text_shape,
"image_ids_shape": self.image_ids_shape,
"text_ids_shape": self.text_ids_shape,
"y_shape": self.y_shape,
}
)
return config
def encode_text_step(self, token_ids, negative_token_ids):
raise NotImplementedError("Not implemented yet")
def encode(token_ids):
raise NotImplementedError("Not implemented yet")
def encode_image_step(self, images):
raise NotImplementedError("Not implemented yet")
def add_noise_step(self, latents, noises, step, num_steps):
raise NotImplementedError("Not implemented yet")
def denoise_step(
self,
):
raise NotImplementedError("Not implemented yet")
def decode_step(self, latents):
raise NotImplementedError("Not implemented yet")