Skip to content

Commit cce8ef6

Browse files
Flux Autoencoder (#2098)
1 parent 38bf427 commit cce8ef6

File tree

6 files changed

+627
-0
lines changed

6 files changed

+627
-0
lines changed
+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
import torch
9+
10+
from torchtune.models.flux import flux_1_autoencoder
11+
from torchtune.training.seed import set_seed
12+
13+
BSZ = 32
14+
CH_IN = 3
15+
RESOLUTION = 16
16+
CH_MULTS = [1, 2]
17+
CH_Z = 4
18+
RES_Z = RESOLUTION // len(CH_MULTS)
19+
20+
21+
@pytest.fixture(autouse=True)
22+
def random():
23+
set_seed(0)
24+
25+
26+
class TestFluxAutoencoder:
27+
@pytest.fixture
28+
def model(self):
29+
model = flux_1_autoencoder(
30+
resolution=RESOLUTION,
31+
ch_in=CH_IN,
32+
ch_out=3,
33+
ch_base=32,
34+
ch_mults=CH_MULTS,
35+
ch_z=CH_Z,
36+
n_layers_per_resample_block=2,
37+
scale_factor=1.0,
38+
shift_factor=0.0,
39+
)
40+
41+
for param in model.parameters():
42+
param.data.uniform_(0, 0.1)
43+
44+
return model
45+
46+
@pytest.fixture
47+
def img(self):
48+
return torch.randn(BSZ, CH_IN, RESOLUTION, RESOLUTION)
49+
50+
@pytest.fixture
51+
def z(self):
52+
return torch.randn(BSZ, CH_Z, RES_Z, RES_Z)
53+
54+
def test_forward(self, model, img):
55+
actual = model(img)
56+
assert actual.shape == (BSZ, CH_IN, RESOLUTION, RESOLUTION)
57+
58+
actual = torch.mean(actual, dim=(0, 2, 3))
59+
expected = torch.tensor([0.4286, 0.4276, 0.4054])
60+
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)
61+
62+
def test_backward(self, model, img):
63+
y = model(img)
64+
loss = y.mean()
65+
loss.backward()
66+
67+
def test_encode(self, model, img):
68+
actual = model.encode(img)
69+
assert actual.shape == (BSZ, CH_Z, RES_Z, RES_Z)
70+
71+
actual = torch.mean(actual, dim=(0, 2, 3))
72+
expected = torch.tensor([0.6150, 0.7959, 0.7178, 0.7011])
73+
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)
74+
75+
def test_decode(self, model, z):
76+
actual = model.decode(z)
77+
assert actual.shape == (BSZ, CH_IN, RESOLUTION, RESOLUTION)
78+
79+
actual = torch.mean(actual, dim=(0, 2, 3))
80+
expected = torch.tensor([0.4246, 0.4241, 0.4014])
81+
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)

torchtune/models/flux/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from ._model_builders import flux_1_autoencoder
7+
8+
__all__ = [
9+
"flux_1_autoencoder",
10+
]

0 commit comments

Comments
 (0)