@@ -88,7 +88,8 @@ Node* nz::Model::Linear(Node* input, size_t outSize) {
8888 Node* shapedInput;
8989 if (input->output ->shape ()[2 ] != inputSize) {
9090 shapedInput = Reshape (input, {input->output ->shape ()[0 ], 1 , inputSize, 1 });
91- } else {
91+ }
92+ else {
9293 shapedInput = input;
9394 }
9495 auto mulParam = new io::InputNode ({1 , 1 , outSize, inputSize}, true );
@@ -189,7 +190,8 @@ Node* nz::Model::Softmax(Node* input) {
189190 Node* reshapedInput;
190191 if (input->output ->shape ()[2 ] != size) {
191192 reshapedInput = Reshape (input, {batch, 1 , size, 1 });
192- } else {
193+ }
194+ else {
193195 reshapedInput = input;
194196 }
195197 auto * softmaxNode = new calc::SoftmaxNode (reshapedInput);
@@ -217,6 +219,48 @@ Node* nz::Model::TargetExpand(Node* input, const Tensor::shape_type& shape) {
217219 return expandNode;
218220}
219221
222+ Node* nz::Model::Img2Col (Node* input, const Tensor::size_type kernelHeight, const Tensor::size_type kernelWidth,
223+ const Tensor::size_type stride, const Tensor::size_type padding) {
224+ if (!computeGraph.inGraph (input)) {
225+ computeGraph.addNode (input);
226+ }
227+ auto * img2ColNode = new calc::Img2ColNode (input, kernelHeight, kernelWidth, stride, padding);
228+ hiddenNodes.push_back (img2ColNode);
229+ computeGraph.addNode (img2ColNode);
230+ return img2ColNode;
231+ }
232+
233+ Node* nz::Model::Col2Img (Node* input, Tensor::size_type outputHeight, Tensor::size_type outputWidth) {
234+ if (!computeGraph.inGraph (input)) {
235+ computeGraph.addNode (input);
236+ }
237+ auto * col2ImgNode = new calc::Col2ImgNode (input, outputHeight, outputWidth);
238+ hiddenNodes.push_back (col2ImgNode);
239+ computeGraph.addNode (col2ImgNode);
240+ return col2ImgNode;
241+ }
242+
243+ Node* nz::Model::Conv2d (Node* input, Tensor::size_type outChannels, Tensor::size_type kernelHeight,
244+ Tensor::size_type kernelWidth, Tensor::size_type stride, Tensor::size_type padding, bool bias) {
245+ if (!computeGraph.inGraph (input)) {
246+ computeGraph.addNode (input);
247+ }
248+ auto * convKernel = new io::InputNode ({
249+ input->output ->shape ().N (), 1 ,
250+ input->output ->shape ().C () * kernelHeight * kernelWidth, outChannels
251+ }, true );
252+ convKernel->output ->randomize ();
253+ hiddenNodes.push_back (convKernel);
254+ computeGraph.addNode (convKernel);
255+ auto inputCol = Img2Col (input, kernelHeight, kernelWidth, stride, padding);
256+ auto resultCol = Mul (inputCol, convKernel);
257+ if (bias) {
258+ resultCol = Bias (resultCol);
259+ }
260+ return Col2Img (resultCol, (input->output ->shape ().H () + 2 * padding - kernelHeight) / stride + 1 ,
261+ (input->output ->shape ().W () + 2 * padding - kernelWidth) / stride + 1 );
262+ }
263+
220264void nz::Model::MSELoss (Node* input, Node* target) {
221265 if (!computeGraph.inGraph (input)) {
222266 computeGraph.addNode (input);
0 commit comments