Skip to content

Commit d324650

Browse files
committed
feat(model - framework): add a Model class
- Added a Model class, which is a higher - level data structure than ComputeGraph. - The Model class helps users build their own models. Users can inherit from the Model class in their model classes. - Users only need to add external data to the member list of their model classes and use simple syntax to build models in the constructor. - The computation graph is transparent to users and is implicitly managed by the Model class.
1 parent 7b21123 commit d324650

3 files changed

Lines changed: 480 additions & 0 deletions

File tree

include/NeuZephyr/Model.cuh

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#ifndef MODEL_CUH
2+
#define MODEL_CUH
3+
#include "ComputeGraph.cuh"
4+
5+
using namespace nz::nodes;
6+
7+
namespace nz {
8+
class DL_API Model {
9+
public:
10+
friend DL_API std::ostream& operator<<(std::ostream& os, Model& model);
11+
12+
Model();
13+
14+
~Model();
15+
16+
Tensor& forward();
17+
18+
void backward();
19+
20+
void update(opt::Optimizer* optimizer) const;
21+
22+
Tensor::value_type getLoss() const;
23+
private:
24+
std::vector<Node*> hiddenNodes;
25+
26+
graph::ComputeGraph computeGraph;
27+
protected:
28+
Node* Add(Node* lhs, Node* rhs);
29+
30+
Node* Sub(Node* lhs, Node* rhs);
31+
32+
Node* Mul(Node* lhs, Node* rhs);
33+
34+
Node* Bias(Node* input);
35+
36+
Node* Reshape(Node* input, const Tensor::shape_type& shape);
37+
38+
Node* Linear(Node* input, size_t outSize);
39+
40+
Node* ReLU(Node* input);
41+
42+
Node* Sigmoid(Node* input);
43+
44+
Node* Tanh(Node* input);
45+
46+
Node* LeakyReLU(Node* input, float alpha = 0.01f);
47+
48+
Node* Swish(Node* input);
49+
50+
Node* ELU(Node* input, float alpha = 1.0f);
51+
52+
Node* HardSigmoid(Node* input, float alpha = 0.2f, float beta = 0.5f);
53+
54+
Node* HardSwish(Node* input, float alpha = 0.2f, float beta = 0.5f);
55+
56+
Node* Softmax(Node* input);
57+
58+
Node* TargetExpand(Node* input, const Tensor::shape_type& shape);
59+
60+
void MSELoss(Node* input, Node* target);
61+
62+
void BCELoss(Node* input, Node* target);
63+
64+
void defaultOutput(Node* input);
65+
};
66+
}
67+
68+
69+
#endif //MODEL_CUH

src/Model.cu

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
#include "NeuZephyr/Model.cuh"
2+
3+
nz::Model::Model() = default;
4+
5+
nz::Model::~Model() {
6+
for (const auto* node : hiddenNodes) {
7+
delete node;
8+
}
9+
}
10+
11+
Tensor& nz::Model::forward() {
12+
computeGraph.forward();
13+
return *computeGraph.getOutputNode()->output;
14+
}
15+
16+
void nz::Model::backward() {
17+
computeGraph.backward();
18+
}
19+
20+
void nz::Model::update(opt::Optimizer* optimizer) const {
21+
computeGraph.update(optimizer);
22+
}
23+
24+
Tensor::value_type nz::Model::getLoss() const {
25+
return computeGraph.getLoss();
26+
}
27+
28+
Node* nz::Model::Add(Node* lhs, Node* rhs) {
29+
if (!computeGraph.inGraph(lhs)) {
30+
computeGraph.addNode(lhs);
31+
}
32+
if (!computeGraph.inGraph(rhs)) {
33+
computeGraph.addNode(rhs);
34+
}
35+
auto* addNode = new calc::AddNode(lhs, rhs);
36+
hiddenNodes.push_back(addNode);
37+
computeGraph.addNode(addNode);
38+
return addNode;
39+
}
40+
41+
Node* nz::Model::Sub(Node* lhs, Node* rhs) {
42+
if (!computeGraph.inGraph(lhs)) {
43+
computeGraph.addNode(lhs);
44+
}
45+
if (!computeGraph.inGraph(rhs)) {
46+
computeGraph.addNode(rhs);
47+
}
48+
auto* subNode = new calc::SubNode(lhs, rhs);
49+
hiddenNodes.push_back(subNode);
50+
computeGraph.addNode(subNode);
51+
return subNode;
52+
}
53+
54+
Node* nz::Model::Mul(Node* lhs, Node* rhs) {
55+
if (!computeGraph.inGraph(lhs)) {
56+
computeGraph.addNode(lhs);
57+
}
58+
if (!computeGraph.inGraph(rhs)) {
59+
computeGraph.addNode(rhs);
60+
}
61+
auto* mulNode = new calc::MatMulNode(lhs, rhs);
62+
hiddenNodes.push_back(mulNode);
63+
computeGraph.addNode(mulNode);
64+
return mulNode;
65+
}
66+
67+
Node* nz::Model::Bias(Node* input) {
68+
auto* param = new io::InputNode(
69+
{1, input->output->shape()[1], input->output->shape()[2], input->output->shape()[3]}, true);
70+
param->output->randomize();
71+
hiddenNodes.push_back(param);
72+
computeGraph.addNode(param);
73+
return Add(input, param);
74+
}
75+
76+
Node* nz::Model::Reshape(Node* input, const Tensor::shape_type& shape) {
77+
if (!computeGraph.inGraph(input)) {
78+
computeGraph.addNode(input);
79+
}
80+
auto* reshapeNode = new calc::ReshapeNode(input, shape);
81+
hiddenNodes.push_back(reshapeNode);
82+
computeGraph.addNode(reshapeNode);
83+
return reshapeNode;
84+
}
85+
86+
Node* nz::Model::Linear(Node* input, size_t outSize) {
87+
auto inputSize = input->output->shape()[1] * input->output->shape()[2] * input->output->shape()[3];
88+
Node* shapedInput;
89+
if (input->output->shape()[2] != inputSize) {
90+
shapedInput = Reshape(input, {input->output->shape()[0], 1, inputSize, 1});
91+
} else {
92+
shapedInput = input;
93+
}
94+
auto mulParam = new io::InputNode({1, 1, outSize, inputSize}, true);
95+
mulParam->output->randomize();
96+
hiddenNodes.push_back(mulParam);
97+
computeGraph.addNode(mulParam);
98+
auto mulResult = Mul(mulParam, shapedInput);
99+
auto biasResult = Bias(mulResult);
100+
return biasResult;
101+
}
102+
103+
Node* nz::Model::ReLU(Node* input) {
104+
if (!computeGraph.inGraph(input)) {
105+
computeGraph.addNode(input);
106+
}
107+
auto* reluNode = new calc::ReLUNode(input);
108+
hiddenNodes.push_back(reluNode);
109+
computeGraph.addNode(reluNode);
110+
return reluNode;
111+
}
112+
113+
Node* nz::Model::Sigmoid(Node* input) {
114+
if (!computeGraph.inGraph(input)) {
115+
computeGraph.addNode(input);
116+
}
117+
auto* sigmoidNode = new calc::SigmoidNode(input);
118+
hiddenNodes.push_back(sigmoidNode);
119+
computeGraph.addNode(sigmoidNode);
120+
return sigmoidNode;
121+
}
122+
123+
Node* nz::Model::Tanh(Node* input) {
124+
if (!computeGraph.inGraph(input)) {
125+
computeGraph.addNode(input);
126+
}
127+
auto* tanhNode = new calc::TanhNode(input);
128+
hiddenNodes.push_back(tanhNode);
129+
computeGraph.addNode(tanhNode);
130+
return tanhNode;
131+
}
132+
133+
Node* nz::Model::LeakyReLU(Node* input, const float alpha) {
134+
if (!computeGraph.inGraph(input)) {
135+
computeGraph.addNode(input);
136+
}
137+
auto* leakyReLUNode = new calc::LeakyReLUNode(input, alpha);
138+
hiddenNodes.push_back(leakyReLUNode);
139+
computeGraph.addNode(leakyReLUNode);
140+
return leakyReLUNode;
141+
}
142+
143+
Node* nz::Model::Swish(Node* input) {
144+
if (!computeGraph.inGraph(input)) {
145+
computeGraph.addNode(input);
146+
}
147+
auto* swishNode = new calc::SwishNode(input);
148+
hiddenNodes.push_back(swishNode);
149+
computeGraph.addNode(swishNode);
150+
return swishNode;
151+
}
152+
153+
Node* nz::Model::ELU(Node* input, const float alpha) {
154+
if (!computeGraph.inGraph(input)) {
155+
computeGraph.addNode(input);
156+
}
157+
auto* eluNode = new calc::ELUNode(input, alpha);
158+
hiddenNodes.push_back(eluNode);
159+
computeGraph.addNode(eluNode);
160+
return eluNode;
161+
}
162+
163+
Node* nz::Model::HardSigmoid(Node* input, const float alpha, const float beta) {
164+
if (!computeGraph.inGraph(input)) {
165+
computeGraph.addNode(input);
166+
}
167+
auto* hardSigmoidNode = new calc::HardSigmoidNode(input, alpha, beta);
168+
hiddenNodes.push_back(hardSigmoidNode);
169+
computeGraph.addNode(hardSigmoidNode);
170+
return hardSigmoidNode;
171+
}
172+
173+
Node* nz::Model::HardSwish(Node* input, float alpha, float beta) {
174+
if (!computeGraph.inGraph(input)) {
175+
computeGraph.addNode(input);
176+
}
177+
auto* hardSwishNode = new calc::HardSwishNode(input, alpha, beta);
178+
hiddenNodes.push_back(hardSwishNode);
179+
computeGraph.addNode(hardSwishNode);
180+
return hardSwishNode;
181+
}
182+
183+
Node* nz::Model::Softmax(Node* input) {
184+
if (!computeGraph.inGraph(input)) {
185+
computeGraph.addNode(input);
186+
}
187+
auto size = input->output->shape()[1] * input->output->shape()[2] * input->output->shape()[3];
188+
auto batch = input->output->shape()[0];
189+
Node* reshapedInput;
190+
if (input->output->shape()[2] != size) {
191+
reshapedInput = Reshape(input, {batch, 1, size, 1});
192+
} else {
193+
reshapedInput = input;
194+
}
195+
auto* softmaxNode = new calc::SoftmaxNode(reshapedInput);
196+
hiddenNodes.push_back(softmaxNode);
197+
computeGraph.addNode(softmaxNode);
198+
return softmaxNode;
199+
}
200+
201+
Node* nz::Model::TargetExpand(Node* input, const Tensor::shape_type& shape) {
202+
if (!computeGraph.inGraph(input)) {
203+
computeGraph.addNode(input);
204+
}
205+
if (input->output->shape() == shape) {
206+
return input;
207+
}
208+
if (input->output->shape()[0] != 1 ||
209+
input->output->shape()[1] != shape[1] ||
210+
input->output->shape()[2] != shape[2] ||
211+
input->output->shape()[3] != shape[3]) {
212+
throw std::runtime_error("The input data cannot be expanded.");
213+
}
214+
auto* expandNode = new calc::ExpandNode(input, shape.N());
215+
hiddenNodes.push_back(expandNode);
216+
computeGraph.addNode(expandNode);
217+
return expandNode;
218+
}
219+
220+
void nz::Model::MSELoss(Node* input, Node* target) {
221+
if (!computeGraph.inGraph(input)) {
222+
computeGraph.addNode(input);
223+
}
224+
auto* expandedTarget = TargetExpand(target, input->output->shape());
225+
auto* mseNode = new loss::MeanSquaredErrorNode(input, expandedTarget);
226+
hiddenNodes.push_back(mseNode);
227+
computeGraph.addOutput(mseNode);
228+
}
229+
230+
void nz::Model::BCELoss(Node* input, Node* target) {
231+
if (!computeGraph.inGraph(input)) {
232+
computeGraph.addNode(input);
233+
}
234+
auto* expandedTarget = TargetExpand(target, input->output->shape());
235+
auto* bceNode = new loss::BinaryCrossEntropyNode(input, expandedTarget);
236+
hiddenNodes.push_back(bceNode);
237+
computeGraph.addOutput(bceNode);
238+
}
239+
240+
void nz::Model::defaultOutput(Node* input) {
241+
auto* output = new io::OutputNode(input);
242+
hiddenNodes.push_back(output);
243+
computeGraph.addOutput(output);
244+
if (!computeGraph.inGraph(input)) {
245+
computeGraph.addNode(input);
246+
}
247+
}
248+
249+
std::ostream& nz::operator<<(std::ostream& os, Model& model) {
250+
return os << model.computeGraph;
251+
}

0 commit comments

Comments
 (0)