Skip to content

Commit f106dd7

Browse files
committed
feat(model): add Conv2d module to the Model class
- Added the Conv2d module to the Model class, enabling 2D convolution operations within the model. - Configured the necessary parameters and initialization methods for the Conv2d module. - This addition enhances the model's ability to handle image - related tasks by incorporating convolutional layers.
1 parent 54ef859 commit f106dd7

2 files changed

Lines changed: 57 additions & 2 deletions

File tree

include/NeuZephyr/Model.cuh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@ namespace nz {
2020
void update(opt::Optimizer* optimizer) const;
2121

2222
Tensor::value_type getLoss() const;
23+
2324
private:
2425
std::vector<Node*> hiddenNodes;
2526

2627
graph::ComputeGraph computeGraph;
28+
2729
protected:
2830
Node* Add(Node* lhs, Node* rhs);
2931

@@ -57,6 +59,15 @@ namespace nz {
5759

5860
Node* TargetExpand(Node* input, const Tensor::shape_type& shape);
5961

62+
Node* Img2Col(Node* input, Tensor::size_type kernelHeight, Tensor::size_type kernelWidth,
63+
Tensor::size_type stride, Tensor::size_type padding);
64+
65+
Node* Col2Img(Node* input, Tensor::size_type outputHeight, Tensor::size_type outputWidth);
66+
67+
Node* Conv2d(Node* input, Tensor::size_type outChannels, Tensor::size_type kernelHeight,
68+
Tensor::size_type kernelWidth,
69+
Tensor::size_type stride, Tensor::size_type padding, bool bias = true);
70+
6071
void MSELoss(Node* input, Node* target);
6172

6273
void BCELoss(Node* input, Node* target);

src/Model.cu

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ Node* nz::Model::Linear(Node* input, size_t outSize) {
8888
Node* shapedInput;
8989
if (input->output->shape()[2] != inputSize) {
9090
shapedInput = Reshape(input, {input->output->shape()[0], 1, inputSize, 1});
91-
} else {
91+
}
92+
else {
9293
shapedInput = input;
9394
}
9495
auto mulParam = new io::InputNode({1, 1, outSize, inputSize}, true);
@@ -189,7 +190,8 @@ Node* nz::Model::Softmax(Node* input) {
189190
Node* reshapedInput;
190191
if (input->output->shape()[2] != size) {
191192
reshapedInput = Reshape(input, {batch, 1, size, 1});
192-
} else {
193+
}
194+
else {
193195
reshapedInput = input;
194196
}
195197
auto* softmaxNode = new calc::SoftmaxNode(reshapedInput);
@@ -217,6 +219,48 @@ Node* nz::Model::TargetExpand(Node* input, const Tensor::shape_type& shape) {
217219
return expandNode;
218220
}
219221

222+
Node* nz::Model::Img2Col(Node* input, const Tensor::size_type kernelHeight, const Tensor::size_type kernelWidth,
223+
const Tensor::size_type stride, const Tensor::size_type padding) {
224+
if (!computeGraph.inGraph(input)) {
225+
computeGraph.addNode(input);
226+
}
227+
auto* img2ColNode = new calc::Img2ColNode(input, kernelHeight, kernelWidth, stride, padding);
228+
hiddenNodes.push_back(img2ColNode);
229+
computeGraph.addNode(img2ColNode);
230+
return img2ColNode;
231+
}
232+
233+
Node* nz::Model::Col2Img(Node* input, Tensor::size_type outputHeight, Tensor::size_type outputWidth) {
234+
if (!computeGraph.inGraph(input)) {
235+
computeGraph.addNode(input);
236+
}
237+
auto* col2ImgNode = new calc::Col2ImgNode(input, outputHeight, outputWidth);
238+
hiddenNodes.push_back(col2ImgNode);
239+
computeGraph.addNode(col2ImgNode);
240+
return col2ImgNode;
241+
}
242+
243+
Node* nz::Model::Conv2d(Node* input, Tensor::size_type outChannels, Tensor::size_type kernelHeight,
244+
Tensor::size_type kernelWidth, Tensor::size_type stride, Tensor::size_type padding, bool bias) {
245+
if (!computeGraph.inGraph(input)) {
246+
computeGraph.addNode(input);
247+
}
248+
auto* convKernel = new io::InputNode({
249+
input->output->shape().N(), 1,
250+
input->output->shape().C() * kernelHeight * kernelWidth, outChannels
251+
}, true);
252+
convKernel->output->randomize();
253+
hiddenNodes.push_back(convKernel);
254+
computeGraph.addNode(convKernel);
255+
auto inputCol = Img2Col(input, kernelHeight, kernelWidth, stride, padding);
256+
auto resultCol = Mul(inputCol, convKernel);
257+
if (bias) {
258+
resultCol = Bias(resultCol);
259+
}
260+
return Col2Img(resultCol, (input->output->shape().H() + 2 * padding - kernelHeight) / stride + 1,
261+
(input->output->shape().W() + 2 * padding - kernelWidth) / stride + 1);
262+
}
263+
220264
void nz::Model::MSELoss(Node* input, Node* target) {
221265
if (!computeGraph.inGraph(input)) {
222266
computeGraph.addNode(input);

0 commit comments

Comments
 (0)