Skip to content

Commit a1483cf

Browse files
committed
Directly using and serializing gym.spaces and their value (#116)
* Directly using and serializing gym.spaces and their value * Introducing debug inspector of received observation on the web side * Take into account review
1 parent bf8ebbc commit a1483cf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+1922
-1873
lines changed

actors/ppo.py

+31-26
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,17 @@
2020
import numpy as np
2121
import torch
2222
from torch.distributions.normal import Normal
23+
from gym.spaces import Box, utils
2324

2425
from cogment_verse import Model
2526
from cogment_verse.run.run_session import RunSession
2627
from cogment_verse.run.sample_producer_worker import SampleProducerSession
2728
from cogment_verse.specs import (
28-
PLAYER_ACTOR_CLASS,
2929
AgentConfig,
30+
cog_settings,
3031
EnvironmentConfig,
3132
EnvironmentSpecs,
32-
PlayerAction,
33-
cog_settings,
34-
flatten,
35-
flattened_dimensions,
36-
unflatten,
33+
PLAYER_ACTOR_CLASS,
3734
)
3835

3936
torch.multiprocessing.set_sharing_strategy("file_system")
@@ -308,14 +305,15 @@ def get_actor_classes(self):
308305
async def impl(self, actor_session):
309306
# Start a session
310307
actor_session.start()
308+
311309
config = actor_session.config
312-
assert config.environment_specs.num_players == 1
313-
assert len(config.environment_specs.action_space.properties) == 1
314-
assert config.environment_specs.action_space.properties[0].WhichOneof("type") == "box"
315310

316-
# Get observation and action space
317-
observation_space = config.environment_specs.observation_space
318-
action_space = config.environment_specs.action_space
311+
environment_specs = EnvironmentSpecs.deserialize(config.environment_specs)
312+
observation_space = environment_specs.get_observation_space()
313+
action_space = environment_specs.get_action_space()
314+
315+
assert isinstance(action_space.gym_space, Box)
316+
assert config.environment_specs.num_players == 1
319317

320318
# Get model
321319
model, _, _ = await actor_session.model_registry.retrieve_version(
@@ -324,9 +322,9 @@ async def impl(self, actor_session):
324322

325323
async for event in actor_session.all_events():
326324
if event.observation and event.type == cogment.EventType.ACTIVE:
327-
obs_tensor = torch.tensor(
328-
flatten(observation_space, event.observation.observation.value), dtype=self._dtype
329-
).view(1, -1)
325+
observation = observation_space.deserialize(event.observation.observation)
326+
327+
obs_tensor = torch.tensor(observation.flat_value, dtype=self._dtype).view(1, -1)
330328

331329
# Normalize the observation
332330
if model.state_normalization is not None:
@@ -339,11 +337,11 @@ async def impl(self, actor_session):
339337
# Get action from policy network
340338
with torch.no_grad():
341339
dist, _ = model.policy_network(obs_tensor)
342-
action = dist.sample().cpu().numpy()[0]
340+
action_value = dist.sample().cpu().numpy()[0]
343341

344342
# Send action to environment
345-
action_value = unflatten(action_space, action)
346-
actor_session.do_action(PlayerAction(value=action_value))
343+
action = action_space.create(value=action_value)
344+
actor_session.do_action(action_space.serialize(action))
347345

348346

349347
class PPOTraining:
@@ -392,8 +390,8 @@ def __init__(self, environment_specs: EnvironmentSpecs, cfg: EnvironmentConfig)
392390
self.model = PPOModel(
393391
model_id="",
394392
environment_implementation=self._environment_specs.implementation,
395-
num_input=flattened_dimensions(self._environment_specs.observation_space),
396-
num_output=flattened_dimensions(self._environment_specs.action_space),
393+
num_input=utils.flatdim(self._environment_specs.get_observation_space().gym_space),
394+
num_output=utils.flatdim(self._environment_specs.get_action_space().gym_space),
397395
learning_rate=self._cfg.learning_rate,
398396
n_iter=self._cfg.num_epochs,
399397
policy_network_hidden_nodes=self._cfg.policy_network.num_hidden_nodes,
@@ -404,15 +402,20 @@ def __init__(self, environment_specs: EnvironmentSpecs, cfg: EnvironmentConfig)
404402

405403
async def trial_sample_sequences_producer_impl(self, sample_producer_session: SampleProducerSession):
406404
"""Collect sample from the trial"""
405+
406+
# Share with A2C
407+
407408
observation = []
408409
action = []
409410
reward = []
410411
done = []
411412

412413
player_actor_params = sample_producer_session.trial_info.parameters.actors[0]
414+
413415
player_actor_name = player_actor_params.name
414-
player_observation_space = player_actor_params.config.environment_specs.observation_space
415-
player_action_space = player_actor_params.config.environment_specs.action_space
416+
player_environment_specs = EnvironmentSpecs.deserialize(player_actor_params.config.environment_specs)
417+
player_observation_space = player_environment_specs.get_observation_space()
418+
player_action_space = player_environment_specs.get_action_space()
416419

417420
async for sample in sample_producer_session.all_trial_samples():
418421
if sample.trial_state == cogment.TrialState.ENDED:
@@ -423,9 +426,10 @@ async def trial_sample_sequences_producer_impl(self, sample_producer_session: Sa
423426

424427
actor_sample = sample.actors_data[player_actor_name]
425428
observation.append(
426-
torch.tensor(flatten(player_observation_space, actor_sample.observation.value), dtype=self._dtype)
429+
torch.tensor(player_observation_space.deserialize(actor_sample.observation).value, dtype=self._dtype)
427430
)
428-
action.append(torch.tensor(flatten(player_action_space, actor_sample.action.value), dtype=self._dtype))
431+
432+
action.append(torch.tensor(player_action_space.deserialize(actor_sample.action).value, dtype=self._dtype))
429433
reward.append(
430434
torch.tensor(actor_sample.reward if actor_sample.reward is not None else 0, dtype=self._dtype)
431435
)
@@ -438,8 +442,9 @@ async def impl(self, run_session: RunSession) -> dict:
438442
"""Train and publish model the model"""
439443

440444
model_id = f"{run_session.run_id}_model"
445+
441446
assert self._environment_specs.num_players == 1
442-
assert len(self._environment_specs.action_space.properties) == 1
447+
assert isinstance(self._environment_specs.get_action_space().gym_space, Box)
443448

444449
# Initalize model
445450
self.model.model_id = model_id
@@ -462,7 +467,7 @@ def create_trial_params(trial_idx: int, iter_idx: int):
462467
implementation="actors.ppo.PPOActor",
463468
config=AgentConfig(
464469
run_id=run_session.run_id,
465-
environment_specs=self._environment_specs,
470+
environment_specs=self._environment_specs.serialize(),
466471
model_id=model_id,
467472
model_version=version_info["version_number"],
468473
),

actors/random_actor.py

+11-16
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,8 @@
1313
# limitations under the License.
1414

1515
import cogment
16-
import numpy as np
1716

18-
from cogment_verse.specs import (
19-
PLAYER_ACTOR_CLASS,
20-
PlayerAction,
21-
sample_space,
22-
)
17+
from cogment_verse.specs import PLAYER_ACTOR_CLASS, EnvironmentSpecs
2318

2419

2520
class RandomActor:
@@ -33,19 +28,19 @@ async def impl(self, actor_session):
3328
actor_session.start()
3429

3530
config = actor_session.config
31+
environment_specs = EnvironmentSpecs.deserialize(config.environment_specs)
32+
observation_space = environment_specs.get_observation_space()
33+
action_space = environment_specs.get_action_space()
3634

37-
action_space = config.environment_specs.action_space
38-
39-
rng = np.random.default_rng(config.seed if config.seed is not None else 0)
35+
action_space.gym_space.seed(config.seed if config.seed is not None else 0)
4036

4137
async for event in actor_session.all_events():
4238
if event.observation and event.type == cogment.EventType.ACTIVE:
43-
if (
44-
event.observation.observation.HasField("current_player")
45-
and event.observation.observation.current_player != actor_session.name
46-
):
39+
observation = observation_space.deserialize(event.observation.observation)
40+
if observation.current_player is not None and observation.current_player != actor_session.name:
4741
# Not the turn of the agent
48-
actor_session.do_action(PlayerAction())
42+
actor_session.do_action(action_space.serialize(action_space.create()))
4943
continue
50-
[action_value] = sample_space(action_space, rng=rng, mask=event.observation.observation.action_mask)
51-
actor_session.do_action(PlayerAction(value=action_value))
44+
45+
action = action_space.sample(mask=observation.action_mask)
46+
actor_session.do_action(action_space.serialize(action))

actors/simple_a2c.py

+26-43
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# pylint: disable=C0303
16-
# pylint: disable=W0611
17-
# pylint: disable=W0612
18-
1915
import logging
2016

2117
import cogment
2218
import torch
2319

20+
from gym.spaces import utils, Discrete
21+
2422
from cogment_verse import Model
2523
from cogment_verse.specs import (
26-
PLAYER_ACTOR_CLASS,
2724
AgentConfig,
28-
EnvironmentConfig,
29-
PlayerAction,
30-
SpaceValue,
3125
cog_settings,
32-
flatten,
33-
flattened_dimensions,
34-
unflatten,
26+
EnvironmentConfig,
27+
EnvironmentSpecs,
3528
)
29+
from cogment_verse.constants import PLAYER_ACTOR_CLASS
3630

3731
torch.multiprocessing.set_sharing_strategy("file_system")
3832

@@ -132,12 +126,9 @@ async def impl(self, actor_session):
132126

133127
config = actor_session.config
134128

135-
assert config.environment_specs.num_players == 1
136-
assert len(config.environment_specs.action_space.properties) == 1
137-
# assert config.environment_specs.action_space.properties[0].WhichOneof("type") == "discrete"
138-
139-
observation_space = config.environment_specs.observation_space
140-
action_space = config.environment_specs.action_space
129+
environment_specs = EnvironmentSpecs.deserialize(config.environment_specs)
130+
observation_space = environment_specs.get_observation_space()
131+
action_space = environment_specs.get_action_space(seed=config.seed)
141132

142133
model, _, _ = await actor_session.model_registry.retrieve_version(
143134
SimpleA2CModel, config.model_id, config.model_version
@@ -147,22 +138,17 @@ async def impl(self, actor_session):
147138

148139
async for event in actor_session.all_events():
149140
if event.observation and event.type == cogment.EventType.ACTIVE:
150-
obs_tensor = torch.tensor(
151-
flatten(observation_space, event.observation.observation.value), dtype=self._dtype
152-
)
153-
if config.environment_specs.action_space.properties[0].WhichOneof("type") == "discrete":
154-
probs = torch.softmax(model.actor_network(obs_tensor), dim=-1)
155-
discrete_action_tensor = torch.distributions.Categorical(probs).sample()
156-
action_value = SpaceValue(
157-
properties=[SpaceValue.PropertyValue(discrete=discrete_action_tensor.item())]
158-
)
141+
observation = observation_space.deserialize(event.observation.observation)
159142

143+
if isinstance(action_space.gym_space, Discrete):
144+
observation_tensor = torch.tensor(observation.flat_value, dtype=self._dtype)
145+
probs = torch.softmax(model.actor_network(observation_tensor), dim=-1)
146+
discrete_action_tensor = torch.distributions.Categorical(probs).sample()
147+
action = action_space.create(value=discrete_action_tensor.numpy())
160148
else:
161-
action = torch.rand((1,) + (action_space.properties[0].box.shape[0],))
162-
action = action.cpu().numpy()[0]
163-
action_value = unflatten(action_space, action)
149+
action = action_space.sample()
164150

165-
actor_session.do_action(PlayerAction(value=action_value))
151+
actor_session.do_action(action_space.serialize(action))
166152

167153

168154
class SimpleA2CTraining:
@@ -195,7 +181,9 @@ async def trial_sample_sequences_producer_impl(self, sample_producer_session):
195181
player_actor_params = sample_producer_session.trial_info.parameters.actors[0]
196182

197183
player_actor_name = player_actor_params.name
198-
player_observation_space = player_actor_params.config.environment_specs.observation_space
184+
player_environment_specs = EnvironmentSpecs.deserialize(player_actor_params.config.environment_specs)
185+
player_observation_space = player_environment_specs.get_observation_space()
186+
player_action_space = player_environment_specs.get_action_space()
199187

200188
async for sample in sample_producer_session.all_trial_samples():
201189
if sample.trial_state == cogment.TrialState.ENDED:
@@ -206,14 +194,10 @@ async def trial_sample_sequences_producer_impl(self, sample_producer_session):
206194

207195
actor_sample = sample.actors_data[player_actor_name]
208196
observation.append(
209-
torch.tensor(flatten(player_observation_space, actor_sample.observation.value), dtype=self._dtype)
210-
)
211-
action_value = actor_sample.action.value
212-
action.append(
213-
torch.tensor(
214-
action_value.properties[0].discrete if len(action_value.properties) > 0 else 0, dtype=self._dtype
215-
)
197+
torch.tensor(player_observation_space.deserialize(actor_sample.observation).value, dtype=self._dtype)
216198
)
199+
200+
action.append(torch.tensor(player_action_space.deserialize(actor_sample.action).value, dtype=self._dtype))
217201
reward.append(
218202
torch.tensor(actor_sample.reward if actor_sample.reward is not None else 0, dtype=self._dtype)
219203
)
@@ -227,14 +211,13 @@ async def impl(self, run_session):
227211
model_id = f"{run_session.run_id}_model"
228212

229213
assert self._environment_specs.num_players == 1
230-
assert len(self._environment_specs.action_space.properties) == 1
231-
# assert self._environment_specs.action_space.properties[0].WhichOneof("type") == "discrete"
214+
assert isinstance(self._environment_specs.get_action_space().gym_space, Discrete)
232215

233216
model = SimpleA2CModel(
234217
model_id,
235218
environment_implementation=self._environment_specs.implementation,
236-
num_input=flattened_dimensions(self._environment_specs.observation_space),
237-
num_output=flattened_dimensions(self._environment_specs.action_space),
219+
num_input=utils.flatdim(self._environment_specs.get_observation_space().gym_space),
220+
num_output=utils.flatdim(self._environment_specs.get_action_space().gym_space),
238221
actor_network_num_hidden_nodes=self._cfg.actor_network.num_hidden_nodes,
239222
critic_network_num_hidden_nodes=self._cfg.critic_network.num_hidden_nodes,
240223
dtype=self._dtype,
@@ -285,7 +268,7 @@ async def impl(self, run_session):
285268
run_id=run_session.run_id,
286269
model_id=model_id,
287270
model_version=version_info["version_number"],
288-
environment_specs=self._environment_specs,
271+
environment_specs=self._environment_specs.serialize(),
289272
),
290273
)
291274
],

0 commit comments

Comments
 (0)