-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdbn.h
46 lines (33 loc) · 1.06 KB
/
dbn.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
//
// Created by whiterose on 12/2/22.
//
#ifndef DBN_DBN_H
#define DBN_DBN_H
#include <Eigen/Dense>
#include <vector>
#include <cmath>
#include "dataset.h"
#include <iostream>
// The convention adopted follows the one used by Hinton in the paper "A Fast Learning Algorithm for Deep Belief Nets" (2006)
// We will consider the vectors as row vectors
// The name of the layers are: lab <--> top <--> pen --> hid --> vis
class dbn {
public: // Weight matrices
MatrixXd hidvis;
MatrixXd vishid; // not part of generative model
MatrixXd penhid;
MatrixXd hidpen; // not part of generative model
MatrixXd toppen;
MatrixXd toplab;
const int BOLTZEPOCHS = 50;
const int BATCHSIZE = 10;
static double sigmoid(const double);
static RowVectorXd softmax(RowVectorXd);
void trainboltz(MatrixXd &, std::vector<RowVectorXd>);
void trainmemboltz(MatrixXd &, MatrixXd &, std::vector<RowVectorXd>, std::vector<RowVectorXd>);
dbn();
void fit(dataset *);
int predict(RowVectorXd vis);
void test(dataset *);
};
#endif //DBN_DBN_H