@@ -21,11 +21,13 @@ class NNConfig:
2121 type : str = "mlp" # mlp or resnet
2222 resnet : Optional [ResnetConfig ] = None
2323
24+ mask_training_predictions : bool = False
25+
2426 # TO DO: AlphaZeroParams should have an instance of this class instead of using different keys,
2527 # but this requires significant changes (e.g. hierarchical subargs)
2628 @staticmethod
2729 def from_alphazero_params (params : "AlphaZeroParams" ) -> "NNConfig" : # type: ignore
28- config = NNConfig (type = params .nn_type )
30+ config = NNConfig (type = params .nn_type , mask_training_predictions = params . nn_mask_training_predictions )
2931 if params .nn_type == "resnet" :
3032 resnet_config = ResnetConfig ()
3133 resnet_config .num_blocks = params .nn_resnet_num_blocks
@@ -50,6 +52,7 @@ class NNEvaluator:
5052 def __init__ (self , action_encoder : ActionEncoder , device , config : NNConfig , max_cache_size : int ):
5153 self .action_encoder = action_encoder
5254 self .device = device
55+ self .config = config
5356 self .network = create_network (action_encoder , device , config )
5457 self .max_cache_size = max_cache_size
5558
@@ -229,22 +232,30 @@ def compute_losses(self, batch_data):
229232
230233 target_values = []
231234
235+ action_masks = []
236+
232237 for data in batch_data :
233238 inputs .append (torch .from_numpy (data ["input_array" ]))
234239 target_policies .append (torch .FloatTensor (data ["mcts_policy" ]))
235240 target_values .append (torch .FloatTensor ([data ["value" ]]))
241+ action_masks .append (torch .FloatTensor (data ["action_mask" ]))
236242
237243 inputs = torch .stack (inputs ).to (self .device )
238244 target_policies = torch .stack (target_policies ).to (self .device )
239245 target_values = torch .stack (target_values ).to (self .device )
246+ action_masks = torch .stack (action_masks ).to (self .device )
240247
241248 assert not (inputs .isnan ().any () or target_policies .isnan ().any () or target_values .isnan ().any ()), (
242249 "NaN in training data"
243250 )
244251
245252 # Forward pass
246253 pred_logits , pred_values = self .network (inputs )
247- # TODO: Should we apply masking before calculating cross-entropy here?
254+
255+ if self .config .mask_training_predictions :
256+ # Apply masking - this means that even if the network gives a high probability to an invalid
257+ # action in the policy, we don't penalize it.
258+ pred_logits = pred_logits * action_masks + INVALID_ACTION_VALUE * (1 - action_masks )
248259
249260 # Compute losses
250261 policy_loss = F .cross_entropy (pred_logits , target_policies , reduction = "mean" )
0 commit comments