Skip to content

Commit 463bd16

Browse files
authored
check for converted data as specified by config and generate it if necessary (#9)
1 parent 6e28aae commit 463bd16

File tree

3 files changed

+49
-9
lines changed

3 files changed

+49
-9
lines changed

envs/data.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def wrap_env(env: MinerRLDataEnv, env_config):
5454
return wrap(env, **env_config)
5555

5656

57-
def write_jsons(environment, data_dir, env_config, save_path, overwrite=False, **kwargs):
57+
def write_jsons(environment, data_dir, env_config, save_path, overwrite=False, fail_safe=True, **kwargs):
5858
data_pipeline = minerl.data.make(environment, data_dir, **kwargs)
5959
env = MinerRLDataEnv(data_pipeline)
6060
env = wrap_env(env, env_config)
@@ -67,8 +67,12 @@ def write_jsons(environment, data_dir, env_config, save_path, overwrite=False, *
6767
print(f'Overwriting! {abs_save_path}')
6868
shutil.rmtree(abs_save_path)
6969
else:
70-
raise ValueError(f'Directory {abs_save_path} not empty!'
71-
f'Cannot overwrite existing data automatically, please delete old data if unused.')
70+
if fail_safe:
71+
print(f'Json data already exists at {abs_save_path}')
72+
return
73+
else:
74+
raise ValueError(f'Directory {abs_save_path} not empty!'
75+
f'Cannot overwrite existing data automatically, please delete old data if unused.')
7276

7377
batch_builder = SampleBatchBuilder()
7478
writer = JsonWriter(save_path)

generate_kmeans.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
parser.add_argument('--env', default=None)
1212
parser.add_argument('--num-actions', type=int, default=32)
1313
parser.add_argument('--data-dir', default=os.getenv('MINERL_DATA_ROOT', 'data'))
14+
parser.add_argument('--overwrite', action='store_true')
1415

1516

1617
def main():
@@ -25,6 +26,15 @@ def main():
2526

2627
for env_name in env_list:
2728
print(f'Generating {args.num_actions}-means for {env_name}')
29+
30+
file_dir = os.path.join(args.data_dir, f'{args.num_actions}-means')
31+
file = os.path.join(file_dir, env_name + '.npy')
32+
if os.path.exists(file) and not args.overwrite:
33+
print(f'k-means file already exists at {file}')
34+
continue
35+
if not os.path.exists(file_dir):
36+
os.mkdir(file_dir)
37+
2838
data = minerl.data.make(env_name)
2939
actions = []
3040
for trajectory_name in tqdm(list(data.get_trajectory_names())):
@@ -37,10 +47,6 @@ def main():
3747
print('computing k-means...')
3848
kmeans = KMeans(n_clusters=args.num_actions, verbose=1, random_state=0).fit(actions)
3949
print(kmeans)
40-
file_dir = os.path.join(args.data_dir, f'{args.num_actions}-means')
41-
if not os.path.exists(file_dir):
42-
os.mkdir(file_dir)
43-
file = os.path.join(file_dir, env_name + '.npy')
4450
np.save(file, kmeans.cluster_centers_)
4551

4652

rllib_train.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""
22
Runs the training process through the RLLib train script
3-
Example command:
4-
$ python rllib_train.py -f config/minerl-impala-debug.yaml
53
See help with:
64
$ python rllib_train.py --help
75
"""
6+
import subprocess
87

8+
import yaml
99
from ray.rllib.train import create_parser, run
1010

1111
import envs
@@ -15,9 +15,39 @@
1515
models.register()
1616

1717

18+
def generate_kmeans(env):
19+
command = f'python generate_kmeans.py --env {env}'
20+
print('running:', command)
21+
subprocess.run(command.split())
22+
23+
24+
def convert_data(args):
25+
command = f'python convert_data.py -f {args.config_file}'
26+
print('running:', command)
27+
subprocess.run(command.split())
28+
29+
30+
def check_data(args):
31+
if args.config_file is not None:
32+
config = yaml.safe_load(open(args.config_file))
33+
settings = list(config.values())[0]
34+
if 'env' in settings:
35+
env = settings['env']
36+
if 'config' in settings:
37+
if 'env' in settings['config']:
38+
env = settings['config']['env']
39+
if 'env_config' in settings['config']:
40+
env_config = settings['config']['env_config']
41+
if env_config.get('discrete', False):
42+
generate_kmeans(env)
43+
if 'input' in settings['config']:
44+
convert_data(args)
45+
46+
1847
def main():
1948
parser = create_parser()
2049
args = parser.parse_args()
50+
check_data(args)
2151
run(args, parser)
2252

2353

0 commit comments

Comments
 (0)