Skip to content

Commit d41e908

Browse files
authored
Merge pull request #308 from jonbinney/jdb/masked-policy-loss
Option to mask predicted policies during training
2 parents 4befbf7 + a191d63 commit d41e908

File tree

3 files changed

+33
-4
lines changed

3 files changed

+33
-4
lines changed

deep_quoridor/src/agents/alphazero/alphazero.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,11 @@ class AlphaZeroParams(SubargsBase):
157157
# Alphazero used 256. It's set lower here to make training faster, but we should try a higher value.
158158
nn_resnet_num_channels: int = 32
159159

160+
# Whether to mask the policies predicted by the NN during training (before computing the loss). When this is
161+
# False, the loss function penalizes the network producing a non-zero probability for any action which is
162+
# illegal.
163+
nn_mask_training_predictions: bool = False
164+
160165
# Maximum size of for entries in worker cache
161166
max_cache_size: int = 200000
162167

@@ -661,12 +666,15 @@ def store_training_data(self, game, mcts_policy, player, game_idx):
661666
"""Store training data for later use in training."""
662667
game, is_rotated = self.evaluator.rotate_if_needed_to_point_downwards(game)
663668
input_array = self.evaluator.game_to_input_array(game)
669+
action_mask = game.get_action_mask()
664670
if is_rotated:
665671
mcts_policy = self.evaluator.rotate_policy_from_original(mcts_policy)
672+
666673
self.replay_buffers_in_progress[game_idx].append(
667674
{
668675
"input_array": input_array,
669676
"mcts_policy": mcts_policy,
677+
"action_mask": action_mask,
670678
"value": None, # Will be filled in at end of episode
671679
"player": player,
672680
}

deep_quoridor/src/agents/alphazero/nn_evaluator.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ class NNConfig:
2121
type: str = "mlp" # mlp or resnet
2222
resnet: Optional[ResnetConfig] = None
2323

24+
mask_training_predictions: bool = False
25+
2426
# TO DO: AlphaZeroParams should have an instance of this class instead of using different keys,
2527
# but this requires significant changes (e.g. hierarchical subargs)
2628
@staticmethod
2729
def from_alphazero_params(params: "AlphaZeroParams") -> "NNConfig": # type: ignore
28-
config = NNConfig(type=params.nn_type)
30+
config = NNConfig(type=params.nn_type, mask_training_predictions=params.nn_mask_training_predictions)
2931
if params.nn_type == "resnet":
3032
resnet_config = ResnetConfig()
3133
resnet_config.num_blocks = params.nn_resnet_num_blocks
@@ -50,6 +52,7 @@ class NNEvaluator:
5052
def __init__(self, action_encoder: ActionEncoder, device, config: NNConfig, max_cache_size: int):
5153
self.action_encoder = action_encoder
5254
self.device = device
55+
self.config = config
5356
self.network = create_network(action_encoder, device, config)
5457
self.max_cache_size = max_cache_size
5558

@@ -229,22 +232,30 @@ def compute_losses(self, batch_data):
229232

230233
target_values = []
231234

235+
action_masks = []
236+
232237
for data in batch_data:
233238
inputs.append(torch.from_numpy(data["input_array"]))
234239
target_policies.append(torch.FloatTensor(data["mcts_policy"]))
235240
target_values.append(torch.FloatTensor([data["value"]]))
241+
action_masks.append(torch.FloatTensor(data["action_mask"]))
236242

237243
inputs = torch.stack(inputs).to(self.device)
238244
target_policies = torch.stack(target_policies).to(self.device)
239245
target_values = torch.stack(target_values).to(self.device)
246+
action_masks = torch.stack(action_masks).to(self.device)
240247

241248
assert not (inputs.isnan().any() or target_policies.isnan().any() or target_values.isnan().any()), (
242249
"NaN in training data"
243250
)
244251

245252
# Forward pass
246253
pred_logits, pred_values = self.network(inputs)
247-
# TODO: Should we apply masking before calculating cross-entropy here?
254+
255+
if self.config.mask_training_predictions:
256+
# Apply masking - this means that even if the network gives a high probability to an invalid
257+
# action in the policy, we don't penalize it.
258+
pred_logits = pred_logits * action_masks + INVALID_ACTION_VALUE * (1 - action_masks)
248259

249260
# Compute losses
250261
policy_loss = F.cross_entropy(pred_logits, target_policies, reduction="mean")

deep_quoridor/test/agents/alphazero_test.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,12 @@ def test_evaluator_training_with_deterministic_policy():
8686
replay_buffer = []
8787
for _ in range(100):
8888
replay_buffer.append(
89-
{"input_array": evaluator.game_to_input_array(game), "mcts_policy": target_policy, "value": 1.0}
89+
{
90+
"input_array": evaluator.game_to_input_array(game),
91+
"action_mask": game.get_action_mask(),
92+
"mcts_policy": target_policy,
93+
"value": 1.0,
94+
}
9095
)
9196

9297
evaluator.train_prepare(learning_rate, batch_size, optimizer_iterations, weight_decay)
@@ -118,7 +123,12 @@ def test_evaluator_training_with_probabilistic_policy():
118123
replay_buffer = []
119124
for _ in range(100):
120125
replay_buffer.append(
121-
{"input_array": evaluator.game_to_input_array(game), "mcts_policy": target_policy, "value": 1.0}
126+
{
127+
"input_array": evaluator.game_to_input_array(game),
128+
"action_mask": game.get_action_mask(),
129+
"mcts_policy": target_policy,
130+
"value": 1.0,
131+
}
122132
)
123133

124134
evaluator.train_prepare(learning_rate, batch_size, optimizer_iterations, weight_decay)

0 commit comments

Comments
 (0)