1818import torch
1919from torch import nn
2020
21+
2122class DMCNet (nn .Module ):
22- def __init__ (
23- self ,
24- state_shape ,
25- action_shape ,
26- mlp_layers = None
27- ):
23+ def __init__ (self , state_shape , action_shape , mlp_layers = None ):
2824 super ().__init__ ()
2925 if mlp_layers is None :
3026 mlp_layers = [512 , 512 , 512 , 512 , 512 ]
3127 input_dim = np .prod (state_shape ) + np .prod (action_shape )
3228 layer_dims = [input_dim ] + mlp_layers
3329 fc = []
34- for i in range (len (layer_dims )- 1 ):
35- fc .append (nn .Linear (layer_dims [i ], layer_dims [i + 1 ]))
30+ for i in range (len (layer_dims ) - 1 ):
31+ fc .append (nn .Linear (layer_dims [i ], layer_dims [i + 1 ]))
3632 fc .append (nn .ReLU ())
3733 fc .append (nn .Linear (layer_dims [- 1 ], 1 ))
3834 self .fc_layers = nn .Sequential (* fc )
@@ -44,19 +40,13 @@ def forward(self, obs, actions):
4440 values = self .fc_layers (x ).flatten ()
4541 return values
4642
43+
4744class DMCAgent :
48- def __init__ (
49- self ,
50- state_shape ,
51- action_shape ,
52- mlp_layers = None ,
53- exp_epsilon = 0.01 ,
54- device = "0" ,
55- ):
45+ def __init__ (self , state_shape , action_shape , mlp_layers = None , exp_epsilon = 0.01 , device = "0" ):
5646 if mlp_layers is None :
5747 mlp_layers = [512 , 512 , 512 , 512 , 512 ]
5848 self .use_raw = False
59- self .device = 'cuda:' + device if device != "cpu" else "cpu"
49+ self .device = 'cuda:' + device if device != "cpu" else "cpu"
6050 self .net = DMCNet (state_shape , action_shape , mlp_layers ).to (self .device )
6151 self .exp_epsilon = exp_epsilon
6252 self .action_shape = action_shape
@@ -78,8 +68,7 @@ def eval_step(self, state):
7868 action_idx = np .argmax (values )
7969 action = action_keys [action_idx ]
8070
81- info = {}
82- info ['values' ] = {state ['raw_legal_actions' ][i ]: float (values [i ]) for i in range (len (action_keys ))}
71+ info = {'values' : {state ['raw_legal_actions' ][i ]: float (values [i ]) for i in range (len (action_keys ))}}
8372
8473 return action , info
8574
@@ -125,15 +114,9 @@ def state_dict(self):
125114 def set_device (self , device ):
126115 self .device = device
127116
117+
128118class DMCModel :
129- def __init__ (
130- self ,
131- state_shape ,
132- action_shape ,
133- mlp_layers = None ,
134- exp_epsilon = 0.01 ,
135- device = 0
136- ):
119+ def __init__ (self , state_shape , action_shape , mlp_layers = None , exp_epsilon = 0.01 , device = 0 ):
137120 if mlp_layers is None :
138121 mlp_layers = [512 , 512 , 512 , 512 , 512 ]
139122 self .agents = []
0 commit comments