-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrecon_keyframe_vcflow.py
More file actions
520 lines (386 loc) · 18.5 KB
/
recon_keyframe_vcflow.py
File metadata and controls
520 lines (386 loc) · 18.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
import os
import sys
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
import cv2
sys.path.append('generative_models/')
from generative_models.sgm.models.diffusion import DiffusionEngine
from omegaconf import OmegaConf
from model_variants.VCFlowModel import (Neurons,Fusion,fMRIBackbone,BrainNetwork,RedistributionHead ,PriorNetwork, BrainDiffusionPrior,
CLIPProj, TextDecoder, TextDrivenDecoder, MotionProj, MultiLabelClassifier)
from model_variants.VCFlow_dataset import CC2017_Dataset
from tqdm import tqdm
torch.backends.cuda.matmul.allow_tf32 = True
from transformers import GPT2Tokenizer
import utils
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_Tokenizer = _Tokenizer()
from diffusers import AutoencoderKL
import torch.nn.functional as F
from animatediff.utils.util import save_videos_grid
from einops import rearrange
def parse_arg():
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
"--model_name", type=str, default="testing",
help="will load ckpt for model found in ../train_logs/model_name",
)
parser.add_argument(
"--data_path", type=str, default=os.getcwd(),
help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
"--root_dir", type=str, default='./cc2017_dataset',
)
parser.add_argument(
"--weights_dir", type=str, default='./pretrained_weights',
)
parser.add_argument(
"--exp", type=str, default='./saved_weights',
)
parser.add_argument(
"--subj", type=int, default=1, choices=[1, 2, 3],
help="Validate on which subject?",
)
parser.add_argument(
"--blurry_recon", action=argparse.BooleanOptionalAction, default=False,
)
parser.add_argument("--pretrained-model-path", type=str, default="runwayml/stable-diffusion-v1-5")
parser.add_argument(
"--n_blocks", type=int, default=4,
)
parser.add_argument(
"--n_frames", type=int, default=6,
)
parser.add_argument(
"--batch_size", type=int, default=20,
)
parser.add_argument(
"--hidden_dim", type=int, default=4096,
)
parser.add_argument(
"--seed", type=int, default=42,
)
args = parser.parse_args()
return args
def Decoding(model,clip_features):
model.eval()
embedding_cat = model.clip_project(clip_features).reshape(1,1,-1)
entry_length = 30
temperature = 1
tokens = None
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
for i in range(entry_length):
# print(location_token.shape)
outputs = model.decoder(inputs_embeds=embedding_cat)
logits = outputs.logits
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
logits_max = logits.max()
logits = torch.nn.functional.softmax(logits)
next_token = torch.argmax(logits, -1).unsqueeze(0)
next_token_embed = model.decoder.transformer.wte(next_token)
if tokens is None:
tokens = next_token
else:
tokens = torch.cat((tokens, next_token), dim=1)
if next_token.item()==49407:
break
embedding_cat = torch.cat((embedding_cat, next_token_embed), dim=1)
try:
output_list = list(tokens.squeeze().cpu().numpy())
output = _Tokenizer.decode(output_list)
# output = tokenizer.decode(output_list, skip_special_tokens=True)
except:
output = 'None'
return output
def prepare_dataset(args):
voxel_test = torch.load(f'{args.root_dir}/origin_data/fmri_vc_new/subject{args.subj}_test_fmri_vc.pt', map_location='cpu')
voxel_test = torch.mean(voxel_test, dim=1)
print("Loaded all fmri test frames to cpu!", voxel_test.shape)
test_images = torch.load(f'{args.root_dir}/GT_test_3fps.pt', map_location='cpu')
test_text = torch.load(f'{args.root_dir}/GT_test_caption_emb.pt', map_location='cpu')
print("Loaded all crucial test frames to cpu!", test_images.shape)
test_dataset = CC2017_Dataset([voxel_test], test_images, test_text)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=False)
return test_dl, voxel_test
def prepare_masks(return_shape=False):
roi_names = [
'V1','V2','V3','V3A','V3B','V3CD','V4','LO1','LO2','LO3','PIT','V4t',
'V6','V6A','V7','V8','PH','FFC','IP0','MT','MST','FST','VVC','VMV1',
'VMV2','VMV3','PHA1','PHA2','PHA3','TE2p','IPS1'
]
roi_id = {n:i for i,n in enumerate(roi_names,1)}
low_names = ['V1'] + ['V2','V3','V4']
mid_names = ['V3A','V3B','V6','V6A','V7','IPS1'] + [
'LO1', 'LO2', 'LO3', 'FST', 'MT', 'MST', 'V3CD', 'V4t',
'PH', 'IP0'
]
high_names = ['FFC','PIT','V8','VMV1','VMV2','VMV3','VVC'] + [
'PHA1', 'PHA2', 'PHA3',
'TE2p'
]
low_idx = [roi_id[n] for n in low_names]
mid_idx = [roi_id[n] for n in mid_names]
high_idx = [roi_id[n] for n in high_names]
label_img = np.load("NSD/vc_masks.npy") # shape (H,W) or (31,H,W)
if label_img.ndim == 3:
label_img = np.argmax(label_img,0)
label_img = label_img+1
mask_low = np.isin(label_img, low_idx)
mask_mid = np.isin(label_img, mid_idx)
mask_high = np.isin(label_img, high_idx)
mask_low = torch.tensor(mask_low, dtype=torch.bool)
mask_mid = torch.tensor(mask_mid, dtype=torch.bool)
mask_high = torch.tensor(mask_high, dtype=torch.bool)
if return_shape == True:
return sum(mask_low.flatten()==1).item(), sum(mask_mid.flatten()==1).item(), sum(mask_high.flatten()==1).item()
return mask_low, mask_mid, mask_high
def add_hook(clip_img_embedder,hook_layers = (10, 26, 44)):
visual = clip_img_embedder.model.visual
if hasattr(visual, "blocks"): # timm / OpenCLIP s•x•b
block_list = visual.blocks
elif hasattr(visual, "transformer"): # OpenAI CLIP style
if hasattr(visual.transformer, "resblocks"): # ViT-bigG-14
block_list = visual.transformer.resblocks
else:
block_list = visual.transformer
else:
raise RuntimeError("Did not find visual transformer layer list")
clip_img_embedder.mid_feats = {}
clip_img_embedder._hooks = []
names = ["low","mid","high"]
def _save_hook(self, name):
def fn(module, _, out):
tokens = out[1] if isinstance(out, tuple) else out
tokens = tokens[1:].permute(1, 0, 2).detach()
self.mid_feats[name] = tokens
return fn
for idx in range(len(hook_layers)):
h = block_list[hook_layers[idx]].register_forward_hook(_save_hook(clip_img_embedder, names[idx]))
clip_img_embedder._hooks.append(h)
def prepare_brain_model(args):
clip_seq_dim = 256
clip_emb_dim = 1664
seq_len = 1
clip_txt_emb_dim = 1280
hidden_dim = 4096
model = Neurons()
model.clipproj = CLIPProj()
low_shape, mid_shape, high_shape = prepare_masks(return_shape=True)
model.backbone = fMRIBackbone(
dim = 1024,
vision_dim = clip_emb_dim,
clip_txt_emb_dim = clip_txt_emb_dim,
emb_dropout = 0.1
)
model.distribution_head = RedistributionHead(domain_classes=2)
model.fusion_low = Fusion(voxel_len=low_shape)
model.fusion_high = Fusion(voxel_len=high_shape)
model.fusion_motion = Fusion(voxel_len=mid_shape)
# setup diffusion prior network
out_dim = clip_emb_dim
depth = args.n_frames
dim_head = 52
heads = clip_emb_dim // 52 # heads * dim_head = clip_emb_dim
timesteps = 100
prior_network = PriorNetwork(
dim=out_dim,
depth=depth,
dim_head=dim_head,
heads=heads,
causal=False,
num_tokens = clip_seq_dim,
learned_query_mode="pos_emb",
)
model.diffusion_prior = BrainDiffusionPrior(
net=prior_network,
image_embed_dim=out_dim,
condition_on_text_encodings=False,
timesteps=timesteps,
cond_drop_prob=0.2,
image_embed_scale=None,
)
model.text_seg_dec = TextDrivenDecoder(clip_emb_dim, clip_txt_emb_dim)
model.text_dec = TextDecoder(clip_txt_emb_dim)
model.motion_proj = MotionProj(n_frames=args.n_frames, clip_size=clip_emb_dim)
model.to(device)
utils.count_params(model.diffusion_prior)
utils.count_params(model)
print("---resuming from last.pth ckpt---")
checkpoint = torch.load(os.path.join("EXP", f"exp_{args.exp}", f"subj_{args.subj}", "checkpoints", f"brain_model_prior.pth"), map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
print(f"\033[92m Pretrained brain_model loaded from {os.path.join('EXP', f'exp_{args.exp}/subj_{args.subj}', 'checkpoints', f'brain_model_prior.pth')} \033[0m")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, cache_dir=args.weights_dir,
subfolder="vae").to(device)
print(f"\033[92m vae loaded \033[0m")
vae.eval()
vae.requires_grad_(False)
vae.to(device)
utils.count_params(vae)
autoenc = AutoencoderKL(
down_block_types=['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'],
up_block_types=['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'],
block_out_channels=[128, 256, 512, 512],
layers_per_block=2,
sample_size=256,
)
ckpt = torch.load(f'{args.weights_dir}/sd_image_var_autoenc.pth')
autoenc.load_state_dict(ckpt)
print(f"\033[92m autoenc loaded \033[0m")
autoenc.eval()
autoenc.requires_grad_(False)
autoenc.to(device)
utils.count_params(autoenc)
# prep unCLIP
config = OmegaConf.load("./generative_models/configs/unclip6.yaml")
config = OmegaConf.to_container(config, resolve=True)
unclip_params = config["model"]["params"]
network_config = unclip_params["network_config"]
denoiser_config = unclip_params["denoiser_config"]
first_stage_config = unclip_params["first_stage_config"]
conditioner_config = unclip_params["conditioner_config"]
sampler_config = unclip_params["sampler_config"]
scale_factor = unclip_params["scale_factor"]
disable_first_stage_autocast = unclip_params["disable_first_stage_autocast"]
offset_noise_level = unclip_params["loss_fn_config"]["params"]["offset_noise_level"]
first_stage_config['target'] = 'sgm.models.autoencoder.AutoencoderKL'
sampler_config['params']['num_steps'] = 38
diffusion_engine = DiffusionEngine(network_config=network_config,
denoiser_config=denoiser_config,
first_stage_config=first_stage_config,
conditioner_config=conditioner_config,
sampler_config=sampler_config,
scale_factor=scale_factor,
disable_first_stage_autocast=disable_first_stage_autocast)
# set to inference
diffusion_engine.eval().requires_grad_(False)
diffusion_engine.to(device)
ckpt_path = f'{args.weights_dir}/unclip6_epoch0_step110000.ckpt'
ckpt = torch.load(ckpt_path, map_location='cpu')
diffusion_engine.load_state_dict(ckpt['state_dict'])
del ckpt
return model, diffusion_engine, vae, autoenc
def inference(args, model, diffusion_engine, vae, test_dl):
batch = {"jpg": torch.randn(1, 3, 1, 1).to(device), # jpg doesnt get used, it's just a placeholder
"original_size_as_tuple": torch.ones(1, 2).to(device) * 768,
"crop_coords_top_left": torch.zeros(1, 2).to(device)}
out = diffusion_engine.conditioner(batch)
vector_suffix = out["vector"].to(device)
print("vector_suffix", vector_suffix.shape)
# get all reconstructions
model.to(device)
model.eval().requires_grad_(False)
# all_images = None
all_recons = None
all_gts = None
all_generated_texts = None
all_blurryrecons = None
num_samples_per_image = 1
assert num_samples_per_image == 1
index = 0
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
for batch in tqdm(test_dl, desc='batches'):
voxel = batch['voxel'][:,0]
video = batch['pixel_values'].to(device)
voxel = voxel.unsqueeze(1).to(device)
voxel = voxel.float()
voxel_low_mask, voxel_mid_mask, voxel_high_mask = prepare_masks()
test_voxel_low = voxel[:,:, voxel_low_mask.bool()]
test_voxel_mid = voxel[:,:, voxel_mid_mask.bool()]
test_voxel_high = voxel[:,:, voxel_high_mask.bool()]
clip_vision_embeds_raw = model.backbone(voxel)
clip_vision_embeds,_ = model.distribution_head(clip_vision_embeds_raw)
prior_out = model.diffusion_prior.p_sample_loop(clip_vision_embeds.shape,
text_cond=dict(text_embed=clip_vision_embeds),
cond_scale=1., timesteps=100)
prior_out_low = model.fusion_low(prior_out, test_voxel_low)
prior_out_high = model.fusion_high(prior_out, test_voxel_high)
prior_out_motion = model.fusion_motion(prior_out_low, test_voxel_mid)
motion_embeds = model.motion_proj(prior_out_motion)
clip_text_embeds = model.clipproj(prior_out_high)
clip_text_embeds_norm = nn.functional.normalize(clip_text_embeds.flatten(1), dim=-1)
vae_embeds = model.text_seg_dec(rearrange(motion_embeds, "b f n c -> (b f) n c"),
model.clipproj(motion_embeds.mean(1)),
time=args.batch_size * args.n_frames,
is_seg=False)
vae_embeds = F.interpolate(vae_embeds, (28, 28), mode="nearest")
blurry_recon_images = (vae.decode(vae_embeds / 0.18215).sample / 2 + 0.5).clamp(0, 1)
blurry_recon_images = rearrange(blurry_recon_images, "(b f) c h w -> b f c h w", f= args.n_frames)
# Feed diffusion prior outputs through unCLIP
for i in range(len(voxel)):
print(index)
im = torch.Tensor(blurry_recon_images[i])
video_save = rearrange(blurry_recon_images, "b f c h w -> b c f h w")[i].cpu()
save_videos_grid(
torch.cat((video_save.unsqueeze(0), video.permute(0, 2, 1, 3, 4)[i].unsqueeze(0).cpu())),
f"EXP/exp_{args.exp}/subj_{args.subj}/frames_generated_video/video_{index}.gif")
# print(f"\033[92m {pred_text_norm[i].shape} \033[0m")
generated_text = Decoding(model.text_dec, clip_text_embeds_norm[i])
generated_text = generated_text.replace('<|startoftext|>', '').replace('<|endoftext|>', '')
# print(generated_text)
print(f"{generated_text}")
gt = video[i]
samples = utils.unclip_recon(prior_out_high[[i]],
diffusion_engine,
vector_suffix,
num_samples=num_samples_per_image,
device=device)
image = samples[0].permute(1, 2, 0).cpu().numpy()
image = (image * 255).astype('uint8')
image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# print(f"\033[92m {gt.shape} \033[0m")
gt_image = gt[0].permute(1, 2, 0).cpu().numpy()
gt_image = (gt_image * 255).astype('uint8')
gt_image = cv2.cvtColor(gt_image, cv2.COLOR_RGB2BGR)
image = np.concatenate((image, gt_image), axis=0)
# save JPG
cv2.imwrite(f"EXP/exp_{args.exp}/subj_{args.subj}/frames_generated_img/frame_{index}.jpg", image)
# print(f"\033[92m samples {samples.shape} \033[0m")
if all_recons is None:
all_recons = samples.cpu()
all_gts = gt.cpu()
all_generated_texts = [generated_text]
all_blurryrecons = im[None].cpu()
else:
all_recons = torch.vstack((all_recons, samples.cpu()))
all_gts = torch.vstack((all_gts, gt.cpu()))
all_generated_texts = np.hstack((all_generated_texts, generated_text))
all_blurryrecons = torch.vstack((all_blurryrecons, im[None].cpu()))
index += 1
return all_recons, all_gts, all_generated_texts, all_blurryrecons
if __name__ == "__main__":
args = parse_arg()
# seed all random functions
utils.seed_everything(args.seed)
### Multi-GPU config ###
local_rank = os.getenv('RANK')
if local_rank is None:
local_rank = 0
else:
local_rank = int(local_rank)
print("LOCAL RANK ", local_rank)
# device = accelerator.device
device = 'cuda:0'
print("device:", device)
model_name = f'video_subj0{args.subj}'
test_dl, voxel_test = prepare_dataset(args)
model, diffusion_engine, vae, autoenc = prepare_brain_model(args)
os.makedirs(f"EXP/exp_{args.exp}/subj_{args.subj}/frames_generated", exist_ok=True)
os.makedirs(f"EXP/exp_{args.exp}/subj_{args.subj}/frames_generated_img", exist_ok=True)
all_recons, all_gts, all_generated_texts, all_blurryrecons = inference(args, model, diffusion_engine, vae, test_dl)
# resize outputs before saving
imsize = 256
all_recons = transforms.Resize((imsize, imsize))(all_recons).float()
# saving
print(all_recons.shape)
torch.save(all_recons, f"EXP/exp_{args.exp}/subj_{args.subj}/frames_generated/{model_name}_all_recons.pt")
torch.save(all_gts, f"EXP/exp_{args.exp}/subj_{args.subj}/frames_generated/{model_name}_all_gts.pt")
torch.save(all_generated_texts, f'EXP/exp_{args.exp}/subj_{args.subj}/frames_generated/pred_test_caption_self.pt')
torch.save(all_blurryrecons, f'EXP/exp_{args.exp}/subj_{args.subj}/frames_generated/recon_videos.pt')
print(f"saved {model_name} outputs!")
if not utils.is_interactive():
sys.exit(0)