-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathaugmentation.py
More file actions
91 lines (71 loc) · 2.69 KB
/
augmentation.py
File metadata and controls
91 lines (71 loc) · 2.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import numpy as np
from torchvision import transforms
import sys
import os
sys.path.insert(0, os.path.dirname(__file__))
from transforms import GaussianBlur, make_normalize_transform
class DataAugmentationDINO(object):
"""Data augmentation class for DINO-based detection."""
def __init__(self, local_crops_number):
self.local_crops_number = local_crops_number
# Geometric augmentation
self.geometric_augmentation_global1 = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.8),
])
# Color jittering
scale = 0.3
color_jittering1 = transforms.Compose([
transforms.RandomApply(
[transforms.ColorJitter(
brightness=0.6 * scale,
contrast=0.8 * scale,
saturation=0.4 * scale,
hue=0.1 * scale
)],
p=0.8,
),
])
global_transfo2_extra = transforms.Compose([
GaussianBlur(p=0),
])
# Normalization
self.normalize = transforms.Compose([
transforms.ToTensor(),
make_normalize_transform(),
])
self.source_trans = transforms.Compose([
transforms.ToTensor(),
make_normalize_transform(),
])
self.crop = transforms.Compose([
transforms.RandomCrop(224),
])
self.centercrop = transforms.Compose([
transforms.CenterCrop(224),
])
self.global_transfo_all = transforms.Compose([
self.geometric_augmentation_global1,
color_jittering1,
global_transfo2_extra,
self.normalize
])
def __call__(self, image):
output = {}
output["source"] = []
output["global_crops"] = []
if np.array(image).shape[0] < 224 or np.array(image).shape[1] < 224:
crops_all = [
self.centercrop(image) for _ in range(self.local_crops_number)
]
for crops_image in crops_all:
output["source"].append(self.source_trans(crops_image))
for crops_image in crops_all:
output["global_crops"].append(self.global_transfo_all(crops_image))
else:
crops_all1 = [
self.crop(image) for _ in range(self.local_crops_number)
]
for crops_image in crops_all1:
output["source"].append(self.source_trans(crops_image))
output["global_crops"].append(self.global_transfo_all(crops_image))
return output