20
20
import numpy as np
21
21
import torch
22
22
from torch .distributions .normal import Normal
23
+ from gym .spaces import Box , utils
23
24
24
25
from cogment_verse import Model
25
26
from cogment_verse .run .run_session import RunSession
26
27
from cogment_verse .run .sample_producer_worker import SampleProducerSession
27
28
from cogment_verse .specs import (
28
- PLAYER_ACTOR_CLASS ,
29
29
AgentConfig ,
30
+ cog_settings ,
30
31
EnvironmentConfig ,
31
32
EnvironmentSpecs ,
32
- PlayerAction ,
33
- cog_settings ,
34
- flatten ,
35
- flattened_dimensions ,
36
- unflatten ,
33
+ PLAYER_ACTOR_CLASS ,
37
34
)
38
35
39
36
torch .multiprocessing .set_sharing_strategy ("file_system" )
@@ -308,14 +305,15 @@ def get_actor_classes(self):
308
305
async def impl (self , actor_session ):
309
306
# Start a session
310
307
actor_session .start ()
308
+
311
309
config = actor_session .config
312
- assert config .environment_specs .num_players == 1
313
- assert len (config .environment_specs .action_space .properties ) == 1
314
- assert config .environment_specs .action_space .properties [0 ].WhichOneof ("type" ) == "box"
315
310
316
- # Get observation and action space
317
- observation_space = config .environment_specs .observation_space
318
- action_space = config .environment_specs .action_space
311
+ environment_specs = EnvironmentSpecs .deserialize (config .environment_specs )
312
+ observation_space = environment_specs .get_observation_space ()
313
+ action_space = environment_specs .get_action_space ()
314
+
315
+ assert isinstance (action_space .gym_space , Box )
316
+ assert config .environment_specs .num_players == 1
319
317
320
318
# Get model
321
319
model , _ , _ = await actor_session .model_registry .retrieve_version (
@@ -324,9 +322,9 @@ async def impl(self, actor_session):
324
322
325
323
async for event in actor_session .all_events ():
326
324
if event .observation and event .type == cogment .EventType .ACTIVE :
327
- obs_tensor = torch . tensor (
328
- flatten ( observation_space , event . observation . observation . value ), dtype = self . _dtype
329
- ).view (1 , - 1 )
325
+ observation = observation_space . deserialize ( event . observation . observation )
326
+
327
+ obs_tensor = torch . tensor ( observation . flat_value , dtype = self . _dtype ).view (1 , - 1 )
330
328
331
329
# Normalize the observation
332
330
if model .state_normalization is not None :
@@ -339,11 +337,11 @@ async def impl(self, actor_session):
339
337
# Get action from policy network
340
338
with torch .no_grad ():
341
339
dist , _ = model .policy_network (obs_tensor )
342
- action = dist .sample ().cpu ().numpy ()[0 ]
340
+ action_value = dist .sample ().cpu ().numpy ()[0 ]
343
341
344
342
# Send action to environment
345
- action_value = unflatten ( action_space , action )
346
- actor_session .do_action (PlayerAction ( value = action_value ))
343
+ action = action_space . create ( value = action_value )
344
+ actor_session .do_action (action_space . serialize ( action ))
347
345
348
346
349
347
class PPOTraining :
@@ -392,8 +390,8 @@ def __init__(self, environment_specs: EnvironmentSpecs, cfg: EnvironmentConfig)
392
390
self .model = PPOModel (
393
391
model_id = "" ,
394
392
environment_implementation = self ._environment_specs .implementation ,
395
- num_input = flattened_dimensions (self ._environment_specs .observation_space ),
396
- num_output = flattened_dimensions (self ._environment_specs .action_space ),
393
+ num_input = utils . flatdim (self ._environment_specs .get_observation_space (). gym_space ),
394
+ num_output = utils . flatdim (self ._environment_specs .get_action_space (). gym_space ),
397
395
learning_rate = self ._cfg .learning_rate ,
398
396
n_iter = self ._cfg .num_epochs ,
399
397
policy_network_hidden_nodes = self ._cfg .policy_network .num_hidden_nodes ,
@@ -404,15 +402,20 @@ def __init__(self, environment_specs: EnvironmentSpecs, cfg: EnvironmentConfig)
404
402
405
403
async def trial_sample_sequences_producer_impl (self , sample_producer_session : SampleProducerSession ):
406
404
"""Collect sample from the trial"""
405
+
406
+ # Share with A2C
407
+
407
408
observation = []
408
409
action = []
409
410
reward = []
410
411
done = []
411
412
412
413
player_actor_params = sample_producer_session .trial_info .parameters .actors [0 ]
414
+
413
415
player_actor_name = player_actor_params .name
414
- player_observation_space = player_actor_params .config .environment_specs .observation_space
415
- player_action_space = player_actor_params .config .environment_specs .action_space
416
+ player_environment_specs = EnvironmentSpecs .deserialize (player_actor_params .config .environment_specs )
417
+ player_observation_space = player_environment_specs .get_observation_space ()
418
+ player_action_space = player_environment_specs .get_action_space ()
416
419
417
420
async for sample in sample_producer_session .all_trial_samples ():
418
421
if sample .trial_state == cogment .TrialState .ENDED :
@@ -423,9 +426,10 @@ async def trial_sample_sequences_producer_impl(self, sample_producer_session: Sa
423
426
424
427
actor_sample = sample .actors_data [player_actor_name ]
425
428
observation .append (
426
- torch .tensor (flatten ( player_observation_space , actor_sample .observation .value ) , dtype = self ._dtype )
429
+ torch .tensor (player_observation_space . deserialize ( actor_sample .observation ) .value , dtype = self ._dtype )
427
430
)
428
- action .append (torch .tensor (flatten (player_action_space , actor_sample .action .value ), dtype = self ._dtype ))
431
+
432
+ action .append (torch .tensor (player_action_space .deserialize (actor_sample .action ).value , dtype = self ._dtype ))
429
433
reward .append (
430
434
torch .tensor (actor_sample .reward if actor_sample .reward is not None else 0 , dtype = self ._dtype )
431
435
)
@@ -438,8 +442,9 @@ async def impl(self, run_session: RunSession) -> dict:
438
442
"""Train and publish model the model"""
439
443
440
444
model_id = f"{ run_session .run_id } _model"
445
+
441
446
assert self ._environment_specs .num_players == 1
442
- assert len (self ._environment_specs .action_space . properties ) == 1
447
+ assert isinstance (self ._environment_specs .get_action_space (). gym_space , Box )
443
448
444
449
# Initalize model
445
450
self .model .model_id = model_id
@@ -462,7 +467,7 @@ def create_trial_params(trial_idx: int, iter_idx: int):
462
467
implementation = "actors.ppo.PPOActor" ,
463
468
config = AgentConfig (
464
469
run_id = run_session .run_id ,
465
- environment_specs = self ._environment_specs ,
470
+ environment_specs = self ._environment_specs . serialize () ,
466
471
model_id = model_id ,
467
472
model_version = version_info ["version_number" ],
468
473
),
0 commit comments