-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathprepare_pascal_ctx_sem_seg.py
More file actions
81 lines (64 loc) · 2.92 KB
/
prepare_pascal_ctx_sem_seg.py
File metadata and controls
81 lines (64 loc) · 2.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
Reference: https://github.com/NVlabs/ODISE/blob/main/datasets/prepare_pascal_ctx_sem_seg.py
"""
import os
from pathlib import Path
import shutil
import numpy as np
import tqdm
from PIL import Image
import multiprocessing as mp
import functools
from detail import Detail
# fmt: off
_mapping = np.sort(
np.array([
0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25, 284,
158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45, 46, 308, 59,
440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458, 34, 207, 80, 355,
85, 347, 220, 349, 360, 98, 187, 104, 105, 366, 189, 368, 113, 115
]))
# fmt: on
_key = np.array(range(len(_mapping))).astype("uint8")
def generate_labels(img_info, detail_api, out_dir):
def _class_to_index(mask, _mapping, _key):
# assert the values
values = np.unique(mask)
for i in range(len(values)):
assert values[i] in _mapping
index = np.digitize(mask.ravel(), _mapping, right=True)
return _key[index].reshape(mask.shape)
sem_seg = _class_to_index(detail_api.getMask(img_info), _mapping=_mapping, _key=_key)
sem_seg = sem_seg - 1 # 0 (ignore) becomes 255. others are shifted by 1
filename = img_info["file_name"]
Image.fromarray(sem_seg).save(out_dir / filename.replace("jpg", "png"))
def copy_images(img_info, img_dir, out_dir):
filename = img_info["file_name"]
shutil.copy2(img_dir / filename, out_dir / filename)
if __name__ == "__main__":
dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets")) / "pascal_ctx_d2"
voc_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets")) / "VOCdevkit/VOC2010"
for split in ["training", "validation"]:
img_dir = voc_dir / "JPEGImages"
if split == "training":
detail_api = Detail(voc_dir / "trainval_merged.json", img_dir, "train")
else:
detail_api = Detail(voc_dir / "trainval_merged.json", img_dir, "val")
img_infos = detail_api.getImgs()
output_img_dir = dataset_dir / "images" / split
output_ann_dir = dataset_dir / "annotations_ctx59" / split
output_img_dir.mkdir(parents=True, exist_ok=True)
output_ann_dir.mkdir(parents=True, exist_ok=True)
pool = mp.Pool(processes=max(mp.cpu_count() // 2, 4))
pool.map(
functools.partial(copy_images, img_dir=img_dir, out_dir=output_img_dir),
tqdm.tqdm(img_infos, desc=f"Writing {split} images to {output_img_dir} ..."),
chunksize=100,
)
pool.map(
functools.partial(generate_labels, detail_api=detail_api, out_dir=output_ann_dir),
tqdm.tqdm(img_infos, desc=f"Writing {split} images to {output_ann_dir} ..."),
chunksize=100,
)