-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathsgd.h
More file actions
72 lines (58 loc) · 1.99 KB
/
sgd.h
File metadata and controls
72 lines (58 loc) · 1.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
// sgd.h
// Stochastic Gradient Descent optimizer
#ifndef TINYTENSOR_NN_OPTIMIZER_SGD_H_
#define TINYTENSOR_NN_OPTIMIZER_SGD_H_
#include <tt/export.h>
#include <tt/optim/optimizer.h>
#include <tt/tensor.h>
#include <functional>
#include <string>
#include <vector>
namespace tinytensor::optim {
// Options for SGD
// @note See https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
struct TINYTENSOR_EXPORT SGDOptions {
RegularizationMode regularization_mode = RegularizationMode::l2;
double weight_decay = 0;
double momentum = 0;
bool use_nesterov = false;
bool maximize = false;
};
class TINYTENSOR_EXPORT SGD : public Optimizer {
using TensorRefList = std::vector<std::reference_wrapper<Tensor>>;
public:
/**
* Create an SGD optimizer
* @note See https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
* @param params Parameters the optimizer should optimize
* @param learning_rate The learning rate
* @param options Additional options for SGD
*/
SGD(const TensorRefList ¶ms, double learning_rate, const SGDOptions &options = {});
/**
* Save the internal state of the optimizer
* @param path The path to save the optimizer state
*/
void save(const std::string &path) const override;
/**
* Load the internal state of the optimizer
* @param path The path to the saved optimizer state
*/
void load(const std::string &path) override;
/**
* Add parameters to the optimizer
* @param params The parameters to add
*/
void add_parameters(const std::vector<std::reference_wrapper<Tensor>> ¶ms) override;
/**
* Perform a single optimization step of the optimizer algorithm
* @note See https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
*/
void step() override;
public:
double learning_rate_;
SGDOptions options_;
std::vector<Tensor> velocities_;
};
} // namespace tinytensor::optim
#endif // TINYTENSOR_NN_OPTIMIZER_SGD_H_