Skip to content

Commit 6da7eff

Browse files
committed
[rllib] Properly flatten 2-d observations as input to FCnet (#5733)
1 parent 7131166 commit 6da7eff

File tree

6 files changed

+29
-6
lines changed

6 files changed

+29
-6
lines changed

rllib/models/catalog.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def _get_default_torch_model_v2(obs_space, action_space, num_outputs,
450450
else:
451451
obs_rank = len(obs_space.shape)
452452

453-
if obs_rank > 1:
453+
if obs_rank > 2:
454454
return PyTorchVisionNet(obs_space, action_space, num_outputs,
455455
model_config, name)
456456

@@ -506,7 +506,7 @@ def _get_model(input_dict, obs_space, action_space, num_outputs, options,
506506

507507
obs_rank = len(input_dict["obs"].shape) - 1
508508

509-
if obs_rank > 1:
509+
if obs_rank > 2:
510510
return VisionNetwork(input_dict, obs_space, action_space,
511511
num_outputs, options)
512512

@@ -521,7 +521,7 @@ def _get_v2_model(obs_space, options):
521521
if options.get("use_lstm"):
522522
return None # TODO: default LSTM v2 not implemented
523523

524-
if obs_rank > 1:
524+
if obs_rank > 2:
525525
return VisionNetV2
526526

527527
return FCNetV2

rllib/models/model.py

+12
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,18 @@ def _validate_output_shape(self):
190190
self._num_outputs, shape))
191191

192192

193+
@DeveloperAPI
194+
def flatten(obs, framework):
195+
"""Flatten the given tensor."""
196+
if framework == "tf":
197+
return tf.layers.flatten(obs)
198+
elif framework == "torch":
199+
import torch
200+
return torch.flatten(obs, start_dim=1)
201+
else:
202+
raise NotImplementedError("flatten", framework)
203+
204+
193205
@DeveloperAPI
194206
def restore_original_dimensions(obs, obs_space, tensorlib=tf):
195207
"""Unpacks Dict and Tuple space observations into their original form.

rllib/models/modelv2.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import print_function
44

55
from ray.rllib.policy.sample_batch import SampleBatch
6-
from ray.rllib.models.model import restore_original_dimensions
6+
from ray.rllib.models.model import restore_original_dimensions, flatten
77
from ray.rllib.utils.annotations import PublicAPI
88

99

@@ -146,7 +146,10 @@ def __call__(self, input_dict, state=None, seq_lens=None):
146146
restored = input_dict.copy()
147147
restored["obs"] = restore_original_dimensions(
148148
input_dict["obs"], self.obs_space, self.framework)
149-
restored["obs_flat"] = input_dict["obs"]
149+
if len(input_dict["obs"].shape) > 2:
150+
restored["obs_flat"] = flatten(input_dict["obs"], self.framework)
151+
else:
152+
restored["obs_flat"] = input_dict["obs"]
150153
with self.context():
151154
res = self.forward(restored, state or [], seq_lens)
152155
if ((not isinstance(res, list) and not isinstance(res, tuple))

rllib/models/tf/fcnet_v1.py

+3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ def _build_layers(self, inputs, num_outputs, options):
2525
hiddens = options.get("fcnet_hiddens")
2626
activation = get_activation_fn(options.get("fcnet_activation"))
2727

28+
if len(inputs.shape) > 2:
29+
inputs = tf.layers.flatten(inputs)
30+
2831
with tf.name_scope("fc_net"):
2932
i = 1
3033
last_layer = inputs

rllib/models/tf/fcnet_v2.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from __future__ import division
33
from __future__ import print_function
44

5+
import numpy as np
6+
57
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
68
from ray.rllib.models.tf.misc import normc_initializer, get_activation_fn
79
from ray.rllib.utils import try_import_tf
@@ -22,8 +24,9 @@ def __init__(self, obs_space, action_space, num_outputs, model_config,
2224
no_final_linear = model_config.get("no_final_linear")
2325
vf_share_layers = model_config.get("vf_share_layers")
2426

27+
# we are using obs_flat, so take the flattened shape as input
2528
inputs = tf.keras.layers.Input(
26-
shape=obs_space.shape, name="observations")
29+
shape=(np.product(obs_space.shape), ), name="observations")
2730
last_layer = inputs
2831
i = 1
2932

rllib/tests/test_supported_spaces.py

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
OBSERVATION_SPACES_TO_TEST = {
3232
"discrete": Discrete(5),
3333
"vector": Box(-1.0, 1.0, (5, ), dtype=np.float32),
34+
"vector2": Box(-1.0, 1.0, (5, 5), dtype=np.float32),
3435
"image": Box(-1.0, 1.0, (84, 84, 1), dtype=np.float32),
3536
"atari": Box(-1.0, 1.0, (210, 160, 3), dtype=np.float32),
3637
"tuple": Tuple([Discrete(10),
@@ -106,6 +107,7 @@ def check_support(alg, config, stats, check_bounds=False, name=None):
106107
def check_support_multiagent(alg, config):
107108
register_env("multi_mountaincar", lambda _: MultiMountainCar(2))
108109
register_env("multi_cartpole", lambda _: MultiCartpole(2))
110+
config["log_level"] = "ERROR"
109111
if "DDPG" in alg:
110112
a = get_agent_class(alg)(config=config, env="multi_mountaincar")
111113
else:

0 commit comments

Comments
 (0)