forked from cO-Oe/SurakartaAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.h
More file actions
107 lines (83 loc) · 3.51 KB
/
Copy pathtrain.h
File metadata and controls
107 lines (83 loc) · 3.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#pragma once
#include "episode.h"
#include <string.h>
#include <vector>
#include <iomanip>
// transform function
class BoardDataSet : public torch::data::Dataset<BoardDataSet> {
private:
std::vector<board> states_;
std::vector<board> states_flip;
std::vector<PIECE> pieces_;
std::vector<int> labels_;
const unsigned stack_size = 3;
public:
explicit BoardDataSet(const std::vector<board> st, const std::vector<PIECE> pieces, const std::vector<int> label)
: states_(st),
pieces_(pieces),
labels_(label) {}
torch::data::Example<> get(size_t index) override {
board next = states_[index];
auto piece = pieces_[index];
std::vector<board> input_board;
for(int i = 0; i < stack_size; i++) {
input_board.push_back( states_[ index + i ] );
// std::cout << i+1 << ":\n" << input_board[i] << '\n';
}
float tensor_stack[ board::SIZE * (stack_size * 2 + 1) ];
memset(tensor_stack, 0, sizeof(tensor_stack));
// Convert board value to C-Style array
generate_states(tensor_stack, input_board, piece);
for(int i=0; i<stack_size; i++) {
std::cout << "Board: " << i+1 << '\n';
std::cout << input_board[i];
}
// Convert C-array to Tensor
torch::Tensor state_tensor = torch::from_blob(tensor_stack, {7, 6, 6}).to(device);
std::cout << "Input Tensor: \n" << state_tensor << '\n';
// Convert label to Tensor
int64_t label = labels_[index + (stack_size-1)];
std::cout << "label: " << label << '\n';
torch::Tensor label_tensor = torch::full({1}, label).to(device);
return {state_tensor, label_tensor};
};
torch::optional<size_t> size() const override {
assert(states_.size() > (stack_size-1) );
return states_.size() - (stack_size-1);
};
};
void train_Net(const episode &game) {
// Set arguments
const int num_epoch = 10;
const int64_t batch_size = 64;
const double learning_rate = 0.001;
// Package board and label to train dataset
auto data_set = BoardDataSet(game.train_boards_, game.train_pieces_, game.train_result).map(torch::data::transforms::Stack<>());
auto set_size = data_set.size().value();
auto data_loader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(data_set, torch::data::DataLoaderOptions().batch_size(batch_size));
// construct optimizer
torch::optim::Adam optimizer(Net->parameters(), torch::optim::AdamOptions( learning_rate ));
std::cerr << "Start to train Network: \n\n";
for(int epoch=1; epoch <= num_epoch; epoch++) {
double mse = 0.0;
int batch_num = 0;
for (torch::data::Example<>& batch : *data_loader) {
auto boards_ = batch.data.to(device);
auto labels_ = batch.target.squeeze().to(device); // reduce dim from (1, x) to (x)
auto output = Net->forward(boards_).to(device);
// std::cout << boards_ << '\n';
// std::cout << "output: " << output << '\n';
// std::cout << "labels: " << labels_ << '\n';
auto loss = torch::mse_loss(output, labels_).to(device);
mse += loss.item<double>() * boards_.size(0);
// update SGD
optimizer.zero_grad();
loss.backward();
optimizer.step();
batch_num++;
}
// mse /= batch_num;
mse /= set_size;
std::cout << "Epoch " << std::setw(2) << epoch << ": " << "Batch Nums = " << std::setw(2) << batch_num << " Mean square error= " << mse << '\n';
}
}