-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathadam.h
More file actions
82 lines (67 loc) · 2.26 KB
/
adam.h
File metadata and controls
82 lines (67 loc) · 2.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
// adam.h
// Adam optimizer
// https://arxiv.org/abs/1412.6980
#ifndef TINYTENSOR_NN_OPTIMIZER_ADAM_H_
#define TINYTENSOR_NN_OPTIMIZER_ADAM_H_
#include <tt/export.h>
#include <tt/optim/optimizer.h>
#include <tt/tensor.h>
#include <functional>
#include <string>
#include <vector>
namespace tinytensor::optim {
struct TINYTENSOR_EXPORT AdamBetas {
double beta1;
double beta2;
};
// Options for Adagrad
// @note See https://pytorch.org/docs/stable/generated/torch.optim.Adam.html
struct TINYTENSOR_EXPORT AdamOptions {
RegularizationMode regularization_mode = RegularizationMode::l2;
double weight_decay = 0;
AdamBetas betas = {.beta1 = 0.9, .beta2 = 0.999};
double eps = 1e-8;
bool use_amsgrad = false;
bool maximize = false;
};
class TINYTENSOR_EXPORT Adam : public Optimizer {
using TensorRefList = std::vector<std::reference_wrapper<Tensor>>;
public:
/**
* Create an Adam optimizer
* @note See https://pytorch.org/docs/stable/generated/torch.optim.Adam.html
* @param params Parameters the optimizer should optimize
* @param learning_rate The learning rate
* @param options Additional options for Adam
*/
Adam(const TensorRefList ¶ms, double learning_rate, const AdamOptions &options = {});
/**
* Save the internal state of the optimizer
* @param path The path to save the optimizer state
*/
void save(const std::string &path) const override;
/**
* Load the internal state of the optimizer
* @param path The path to the saved optimizer state
*/
void load(const std::string &path) override;
/**
* Add parameters to the optimizer
* @param params The parameters to add
*/
void add_parameters(const std::vector<std::reference_wrapper<Tensor>> ¶ms) override;
/**
* Perform a single optimization step of the optimizer algorithm
* @note See https://pytorch.org/docs/stable/generated/torch.optim.Adam.html
*/
void step() override;
protected:
double learning_rate_;
AdamOptions options_;
std::vector<Tensor> first_moments_;
std::vector<Tensor> second_moments_;
std::vector<Tensor> second_moments_max_;
std::vector<int> steps_;
};
} // namespace tinytensor::optim
#endif // TINYTENSOR_NN_OPTIMIZER_ADAM_H_