Skip to content

Commit d4b88a8

Browse files
committed
fix branch
1 parent 5267198 commit d4b88a8

File tree

4 files changed

+128
-18
lines changed

4 files changed

+128
-18
lines changed

taming/modules/diffusionmodules/model_mergevq.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -300,13 +300,15 @@ def apply_merge(self, model, merge_ratio=None, merge_num=None, cand_distribution
300300
merge_ratio = tome.check_parse_r(num_layers, merge_num, self.embed_res ** 2, inflect)
301301
if isinstance(cand_distribution, str):
302302
cand_distribution = cand_distribution.split('-')
303-
if len(cand_distribution) == 2:
304-
cand_distribution, r_cand_num = cand_distribution[0], int(cand_distribution[-1])
303+
if len(cand_distribution) == 3:
304+
cand_distribution, r_cand_num, bias = cand_distribution[0], int(cand_distribution[ 1]), int(cand_distribution[-1])
305+
elif len(cand_distribution) == 2:
306+
cand_distribution, r_cand_num, bias = cand_distribution[0], int(cand_distribution[-1]), 1
305307
else:
306-
cand_distribution, r_cand_num = cand_distribution[0], 5
308+
cand_distribution, r_cand_num, bias = cand_distribution[0], 5, 1
307309
if cand_distribution is not None:
308310
# generate candidate list with the center index
309-
bias = 1 if cand_distribution.lower() == "gaussian" else 0
311+
bias = bias if cand_distribution.lower() == "gaussian" else 0
310312
remain_list = [int(((self.embed_res**2 - merge_num) **0.5 + i) **2) \
311313
for i in range(int(-bias), r_cand_num - bias)]
312314
merged_list = [self.embed_res**2 - num for num in remain_list]
@@ -359,27 +361,51 @@ def load_pretrained_vit(self, model_path=None):
359361
weight_selection = {}
360362
for key in student_weights.keys():
361363
if ('block' in key or 'cls_token' in key) and key in teacher_weights.keys():
362-
weight_selection[key] = uniform_element_selection(teacher_weights[key], student_weights[key].shape)
364+
weight_selection[key] = uniform_element_selection(teacher_weights[key], student_weights[key])
363365
# load to attention
364366
print("load pre-trained model for encoder:\n",
365367
self.attn.load_state_dict(weight_selection, strict=False))
366368

367369

368-
def uniform_element_selection(tea_weights, stu_shape):
369-
"""Large Model Initialization (https://arxiv.org/abs/2311.18823)"""
370-
assert tea_weights.dim() == len(stu_shape), "Tensors have different number of dimensions"
370+
def uniform_element_selection(tea_weights, stu_weights):
371+
"""Modified and borrowed from `Large Model Initialization` (https://arxiv.org/abs/2311.18823)"""
372+
assert tea_weights.dim() == stu_weights.dim(), "Tensors have different number of dimensions"
371373
tea_weights = tea_weights.clone()
372-
if tea_weights.shape != stu_shape:
374+
375+
def interpolate_1d(x, dim, size, up_mode='nearest'):
376+
if x.shape[dim] == size:
377+
return x
378+
permute_order = list(range(x.dim()))
379+
permute_order[dim] = x.dim()-1
380+
permute_order[-1] = dim
381+
x = x.permute(permute_order).contiguous()
382+
input_shape = x.shape
383+
x = x.view(-1, input_shape[-1]).unsqueeze(1)
384+
# upsampling
385+
x = torch.nn.functional.interpolate(x, size=size, mode=up_mode) if up_mode == 'nearest' else \
386+
torch.nn.functional.interpolate(x, size=size, mode=up_mode, align_corners=False)
387+
# reshape back
388+
x = x.squeeze(1).view(*input_shape[:-1], size)
389+
inv_order = [0] * len(permute_order)
390+
for i, o in enumerate(permute_order):
391+
inv_order[o] = i
392+
return x.permute(inv_order).contiguous()
393+
394+
if tea_weights.shape != stu_weights.shape:
373395
for dim in range(tea_weights.dim()):
374-
assert tea_weights.shape[dim] >= stu_shape[dim], "Teacher's dimension should not be smaller than students'"
375-
if tea_weights.shape[dim] % stu_shape[dim] == 0:
376-
step = tea_weights.shape[dim] // stu_shape[dim]
377-
indices = torch.arange(stu_shape[dim]) * step
396+
if tea_weights.shape[dim] >= stu_weights.shape[dim]:
397+
# Teacher's dimension >= students' dimensions
398+
if tea_weights.shape[dim] % stu_weights.shape[dim] == 0:
399+
step = tea_weights.shape[dim] // stu_weights.shape[dim]
400+
indices = torch.arange(stu_weights.shape[dim]) * step
401+
else:
402+
indices = torch.round(torch.linspace(0, tea_weights.shape[dim]-1, stu_weights.shape[dim])).long()
403+
tea_weights = torch.index_select(tea_weights, dim, indices)
378404
else:
379-
indices = torch.round(torch.linspace(0, tea_weights.shape[dim]-1, stu_shape[dim])).long()
380-
tea_weights = torch.index_select(tea_weights, dim, indices)
405+
# Teacher's dimension < students' dimensions
406+
tea_weights = interpolate_1d(tea_weights, dim, stu_weights.shape[dim])
381407
else:
382-
assert tea_weights.shape == stu_shape, "Selected weight should be the same as student"
408+
assert tea_weights.shape == stu_weights.shape, "Selected weight should be the same as student"
383409
return tea_weights
384410

385411

@@ -732,7 +758,7 @@ def forward(self, x, quantizer):
732758
from fvcore.nn import FlopCountAnalysis, flop_count_table
733759
resolution = 256
734760
cand_distribution = 'gaussian-6'
735-
cand_sample_times = 0
761+
cand_sample_times = 10
736762

737763
# ch, ch_mult = 64, (1, 2, 4, 8)
738764
# num_att_blocks, num_res_blocks, r, merge_num = 12, 4, None, 768
@@ -773,6 +799,7 @@ def forward(self, x, quantizer):
773799
print('encoder (r={}): {}'.format(model.attn.r, y.shape))
774800
print(flop_count_table(flop, max_depth=4))
775801
print('MACs (G) of Encoder: {:.3f}'.format(flop.total() / 1e9))
802+
# print(model)
776803

777804
if source is not None:
778805
print('encoder source matrix:', source.shape)

taming/modules/losses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from taming.modules.losses.vqperceptual import DummyLoss

taming/modules/losses/diff_aug.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Differentiable Augmentation for Data-Efficient GAN Training
2+
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
3+
# https://arxiv.org/pdf/2006.10738
4+
5+
import torch
6+
import torch.nn.functional as F
7+
8+
from functools import partial
9+
10+
11+
def DiffAugment(x, policy='', prob=0.5, channels_first=True):
12+
if policy:
13+
if not channels_first:
14+
x = x.permute(0, 3, 1, 2)
15+
for p in policy.split(','):
16+
if torch.rand(1) > prob:
17+
continue
18+
19+
for f in AUGMENT_FNS[p]:
20+
x = f(x)
21+
22+
if not channels_first:
23+
x = x.permute(0, 2, 3, 1)
24+
x = x.contiguous()
25+
return x
26+
27+
28+
def rand_brightness(x):
29+
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
30+
return x
31+
32+
33+
def rand_saturation(x):
34+
x_mean = x.mean(dim=1, keepdim=True)
35+
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
36+
return x
37+
38+
39+
def rand_contrast(x):
40+
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
41+
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
42+
return x
43+
44+
45+
def rand_translation(x, ratio=0.125):
46+
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
47+
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
48+
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
49+
grid_batch, grid_x, grid_y = torch.meshgrid(
50+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
51+
torch.arange(x.size(2), dtype=torch.long, device=x.device),
52+
torch.arange(x.size(3), dtype=torch.long, device=x.device),
53+
)
54+
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
55+
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
56+
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
57+
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
58+
return x
59+
60+
61+
def rand_cutout(x, ratio=0.2):
62+
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
63+
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
64+
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
65+
grid_batch, grid_x, grid_y = torch.meshgrid(
66+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
67+
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
68+
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
69+
)
70+
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
71+
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
72+
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
73+
mask[grid_batch, grid_x, grid_y] = 0
74+
x = x * mask.unsqueeze(1)
75+
return x
76+
77+
78+
AUGMENT_FNS = {
79+
'color': [rand_brightness, rand_saturation, rand_contrast],
80+
'translation': [rand_translation],
81+
'cutout_0.2': [partial(rand_cutout, ratio=0.2)],
82+
'cutout_0.5': [partial(rand_cutout, ratio=0.5)],
83+
}

taming/modules/losses/vqperceptual.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def __init__(self, disc_start, disc_loss="hinge", disc_dim=64, disc_type="patchg
9999
super().__init__()
100100
# discriminator loss
101101
assert disc_loss in ["hinge", "vanilla", "non_saturate"]
102-
assert disc_type == "patchgan"
103102
self.discriminator = NLayerDiscriminator(
104103
input_nc=disc_in_channels,
105104
n_layers=disc_num_layers,

0 commit comments

Comments
 (0)