-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate.py
More file actions
132 lines (101 loc) · 3.92 KB
/
generate.py
File metadata and controls
132 lines (101 loc) · 3.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
# import matplotlib.pyplot as plt
from functools import partial
from importlib import import_module
from absl import flags, app
from spriteworld.renderers.pil_renderer import PILRenderer
from spriteworld.renderers.color_maps import hsv_to_rgb
from spriteworld.sprite_generators import create_iterator, chain_generators, \
create_sampler, product_iterator
FLAGS = flags.FLAGS
flags.DEFINE_string('config', default='spriteworld.configs.simple',
help='Configuration to generate')
flags.DEFINE_integer('image_size', default=64,
help='Size of rendered image')
flags.DEFINE_integer('aa', default=5,
help='Anti aliasing of rendered image')
flags.DEFINE_boolean('hsv_color', default=True,
help='Whether the task config uses HSV as color factors.')
flags.DEFINE_string('save_folder', default='datasets',
help='Folder where to save the data')
def main(argv):
del argv
config = import_module(FLAGS.config).load()
name = config['dataset']
image_size = FLAGS.image_size, FLAGS.image_size
distribs = config['distribs']
params = config['params']
value_maps = config['value_maps']
gentype = config['generation_type']
renderer = PILRenderer(image_size=image_size,
anti_aliasing=FLAGS.aa,
bg_color=(0,0,0), # black background
color_to_rgb=hsv_to_rgb if FLAGS.hsv_color else None)
if isinstance(distribs, list):
n_sprites = len(distribs)
else:
n_sprites = 1
if gentype == 'sample':
create_generator = partial(create_sampler,
num_samples=config['n_samples'])
else:
create_generator = create_iterator
if n_sprites > 1:
prefix = 'obj{}-'
obj_generators = [create_generator(d) for d in distribs]
if gentype == 'sample':
generator = chain_generators(*obj_generators)
else:
generator = product_iterator(*obj_generators)
else:
prefix = ''
generator = create_generator(distribs)
images, factor_values = [], []
for sprites in generator():
img = renderer.render(sprites)
# plt.imshow(img)
# plt.show()
fv = []
for s in sprites:
for (k,v) in s.factors(dtype=float).items():
if k in params:
fv.append(v)
images.append(img)
factor_values.append(fv)
images = np.array(images)
factor_values = np.array(factor_values)
factors = []
unique_values = {}
factor_sizes = []
for j in range(n_sprites):
for fn in params:
i = len(factors)
factor_values[:, i] = value_maps[fn](factor_values[:, i])
uv = np.unique(factor_values[:, i])
factor_sizes.append(len(uv))
fn = prefix.format(j) + fn
unique_values[fn] = uv
factors.append(fn)
factor_values, idx = np.unique(factor_values, return_index=True, axis=0)
images = images[idx]
print(unique_values)
print(len(factor_values))
save_folder = FLAGS.save_folder
os.makedirs(save_folder, exist_ok=True)
np.savez_compressed(os.path.join(save_folder,'{}.npz'.format(name)),
# metadata
name=name,
img_size=(3, *image_size),
n_factors=len(factors),
factors=factors,
factor_sizes=tuple(factor_sizes),
unique_values=unique_values,
# data
images=images,
factor_values=factor_values)
if __name__ == '__main__':
app.run(main)