Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 60 additions & 43 deletions open_spiel/algorithms/alpha_zero_torch/vpnet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "open_spiel/algorithms/alpha_zero_torch/vpnet.h"

#include <torch/torch.h>
#include <torch/types.h>

#include <fstream> // For ifstream/ofstream.
#include <string>
Expand Down Expand Up @@ -148,28 +149,36 @@ std::vector<VPNetModel::InferenceOutputs> VPNetModel::Inference(
const std::vector<InferenceInputs>& inputs) {
int inference_batch_size = inputs.size();

// Torch tensors by default use a dense, row-aligned memory layout.
// - Their default data type is a 32-bit float
// - Use the byte data type for boolean

torch::Tensor torch_inf_inputs =
torch::empty({inference_batch_size, flat_input_size_}, torch_device_);
torch::Tensor torch_inf_legal_mask = torch::full(
{inference_batch_size, num_actions_}, false,
torch::TensorOptions().dtype(torch::kByte).device(torch_device_));
// Format the data outside of torch. Random assignments can be very slow on
// torch::Tensor objects and this approach is _much_ faster.
std::vector<float> raw_observations(inference_batch_size * flat_input_size_);
std::vector<uint8_t> raw_legal_mask(inference_batch_size * num_actions_, 0);

for (int batch = 0; batch < inference_batch_size; ++batch) {
// Copy legal mask(s) to a Torch tensor.
for (Action action : inputs[batch].legal_actions) {
torch_inf_legal_mask[batch][action] = true;
}

// Copy the observation(s) to a Torch tensor.
for (int i = 0; i < inputs[batch].observations.size(); ++i) {
torch_inf_inputs[batch][i] = inputs[batch].observations[i];
raw_legal_mask[batch * num_actions_ + action] = 1;
}
std::copy(inputs[batch].observations.begin(),
inputs[batch].observations.end(),
raw_observations.begin() + (batch * flat_input_size_));
}

// Torch tensors by default use a dense, row-aligned memory layout.
// - Their default data type is a 32-bit float
// - Use the byte data type for boolean

torch::Tensor torch_inf_inputs =
torch::from_blob(raw_observations.data(),
{inference_batch_size, flat_input_size_})
.to(torch_device_)
.clone();
torch::Tensor torch_inf_legal_mask =
torch::from_blob(raw_legal_mask.data(),
{inference_batch_size, num_actions_},
torch::TensorOptions().dtype(torch::kByte))
.to(torch_device_)
.clone();

// Run the inference.
model_->eval();
std::vector<torch::Tensor> torch_outputs =
Expand Down Expand Up @@ -200,40 +209,48 @@ std::vector<VPNetModel::InferenceOutputs> VPNetModel::Inference(
VPNetModel::LossInfo VPNetModel::Learn(const std::vector<TrainInputs>& inputs) {
int training_batch_size = inputs.size();

// Torch tensors by default use a dense, row-aligned memory layout.
// - Their default data type is a 32-bit float
// - Use the byte data type for boolean

torch::Tensor torch_train_inputs =
torch::empty({training_batch_size, flat_input_size_}, torch_device_);
torch::Tensor torch_train_legal_mask = torch::full(
{training_batch_size, num_actions_}, false,
torch::TensorOptions().dtype(torch::kByte).device(torch_device_));
torch::Tensor torch_policy_targets =
torch::zeros({training_batch_size, num_actions_}, torch_device_);
torch::Tensor torch_value_targets =
torch::empty({training_batch_size, 1}, torch_device_);
std::vector<float> raw_train_inputs(training_batch_size * flat_input_size_);
std::vector<uint8_t> raw_legal_mask(training_batch_size * num_actions_, 0);
std::vector<float> raw_policy_targets(training_batch_size * num_actions_, 0);
std::vector<float> raw_value_targets(training_batch_size);

for (int batch = 0; batch < training_batch_size; ++batch) {
// Copy the legal mask(s) to a Torch tensor.
std::copy(inputs[batch].observations.begin(),
inputs[batch].observations.end(),
raw_train_inputs.begin() + (batch * flat_input_size_));
for (Action action : inputs[batch].legal_actions) {
torch_train_legal_mask[batch][action] = true;
raw_legal_mask[num_actions_ * batch + action] = 1;
}

// Copy the observation(s) to a Torch tensor.
for (int i = 0; i < inputs[batch].observations.size(); ++i) {
torch_train_inputs[batch][i] = inputs[batch].observations[i];
}

// Copy the policy target(s) to a Torch tensor.
for (const auto& [action, probability] : inputs[batch].policy) {
torch_policy_targets[batch][action] = probability;
for (const auto &[action, probability] : inputs[batch].policy) {
raw_policy_targets[num_actions_ * batch + action] = probability;
}

// Copy the value target(s) to a Torch tensor.
torch_value_targets[batch][0] = inputs[batch].value;
raw_value_targets[batch] = inputs[batch].value;
}

// Torch tensors by default use a dense, row-aligned memory layout.
// - Their default data type is a 32-bit float
// - Use the byte data type for boolean
torch::Tensor torch_train_inputs =
torch::from_blob(raw_train_inputs.data(),
{training_batch_size, flat_input_size_})
.to(torch_device_)
.clone();
torch::Tensor torch_train_legal_mask =
torch::from_blob(raw_legal_mask.data(),
{training_batch_size, num_actions_},
torch::TensorOptions().dtype(torch::kByte))
.to(torch_device_)
.clone();
torch::Tensor torch_policy_targets =
torch::from_blob(raw_policy_targets.data(),
{training_batch_size, num_actions_})
.to(torch_device_)
.clone();
torch::Tensor torch_value_targets =
torch::from_blob(raw_value_targets.data(), {training_batch_size, 1})
.to(torch_device_)
.clone();

// Run a training step and get the losses.
model_->train();
model_->zero_grad();
Expand Down
Loading