-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathdar_tool.py
More file actions
93 lines (73 loc) · 3.01 KB
/
dar_tool.py
File metadata and controls
93 lines (73 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
from datasets import load_dataset
from config_util import Config
from wds_wrapper import GenericWDS
from tokenizer.tokenizer_image.vq_model import VQModel
from dataset.build import build_dataset
def load_visual_tokenizer(config=None, ckpt=None, device="cpu"):
# NOTE: config might be overrided by ckpt['config']
if config is not None:
tokenizer_config = Config(config)
else:
tokenizer_config = None
if ckpt is not None:
tokenizer_ckpt = torch.load(ckpt, map_location="cpu", weights_only=False)
tokenizer_config = tokenizer_ckpt['config'] if 'config' in tokenizer_ckpt else tokenizer_config
assert tokenizer_config is not None
# create and load model
vq_model = VQModel(tokenizer_config['model'])
vq_model.to(device)
vq_model.eval()
if 'ema' in tokenizer_ckpt or 'model' in tokenizer_ckpt:
print(vq_model.load_state_dict(tokenizer_ckpt['ema' if 'ema' in tokenizer_ckpt else 'model'], strict=False))
else:
print(vq_model.load_state_dict(tokenizer_ckpt, strict=False))
vq_model = vq_model.to(device)
vq_model.eval()
# vq_model.requires_grad_(False)
return vq_model
class CustomCollateFn:
def __init__(self, img_key, txt_key, transform):
self.img_key = img_key
self.txt_key = txt_key
self.transform = transform
def __call__(self, batch):
imgs = [self.transform(sample[self.img_key]) for sample in batch]
if self.txt_key is not None:
txts = [sample[self.txt_key] for sample in batch]
else:
txts = [torch.zeros(1) for sample in batch]
return torch.stack(imgs), txts
def make_generic_dataset_loader(data_path, **kwargs):
if data_path.startswith("wds://"):
real_data_path = data_path[6:]
dataset = GenericWDS(real_data_path, **kwargs)
loader = dataset.make_loader()
return dataset, loader
if data_path.startswith("datasets://"):
real_data_path = data_path[11:]
print(real_data_path)
img_key = 'jpg'
txt_key = None
transform = kwargs.pop("transform")
kwargs.pop("shuffle", None)
accelerator = kwargs.pop("accelerator", None)
if accelerator is not None:
with accelerator.main_process_first():
dataset = load_dataset(real_data_path, split="train", streaming=True)
else:
dataset = load_dataset(real_data_path, split="train", streaming=True)
collate_fn = CustomCollateFn(img_key, txt_key, transform)
return dataset, DataLoader(dataset, collate_fn=collate_fn, **kwargs)
# default
args = kwargs.pop("args", None)
if args is None:
dataset = datasets.ImageFolder(data_path, transform=kwargs.pop("transform"))
else:
dataset = build_dataset(args, transform=kwargs.pop("transform"))
loader = DataLoader(
dataset,
**kwargs)
return dataset, loader