-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_dataset.py
More file actions
55 lines (45 loc) · 1.85 KB
/
generate_dataset.py
File metadata and controls
55 lines (45 loc) · 1.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import pandas as pd
import numpy as np
class DatasetGenerator(object):
def __init__(self, goal=False, observationSpaceType="array"):
self.goal = goal
self.data = self._reset_data()
self._num_samples = 0
self._observationSpaceType = observationSpaceType
def _reset_data(self):
data = {'observations': [],
'actions': [],
'terminals': [],
'rewards': [],
'infos': [],
}
if self.goal:
data['goal'] = []
return data
def __len__(self):
return self._num_samples
def append_data(self, s, a, rew, done, info, goal=None):
self._num_samples += 1
self.data['observations'].append(s)
self.data['actions'].append(a)
self.data['terminals'].append(done)
self.data['rewards'].append(rew)
self.data['infos'].append(info)
if self.goal:
self.data['goal'].append(goal)
def reformat_data_next_obs(self):
self.data['next_observations'] = self.data['observations'][1:]
if self._observationSpaceType == "array":
self.data['next_observations'].append(np.zeros(self.data['observations'][0].shape)) # last element has no next observation
elif self._observationSpaceType == "dict":
last_row = self.data['observations'][0]
for key in last_row:
last_row[key][:] = 0
self.data['next_observations'].append(last_row)
def write_data(self, filename='recorded_data_templatename', filetype="pickle"):
print('writing dataset to file ...')
df = pd.DataFrame.from_dict(self.data)
df.to_csv(filename+'.csv')
print('dataset saved as {}'.format(filename+'.csv'))
df.to_pickle(filename+'.pkl')
print('dataset saved as {}'.format(filename+'.pkl'))