Skip to content

Commit d3f0931

Browse files
committed
refactor: improve code readability and maintainability in neural network implementation
1 parent b4935a1 commit d3f0931

File tree

4 files changed

+186
-201
lines changed

4 files changed

+186
-201
lines changed

C/Makefile

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1717

1818
CC = clang
19-
CC_FLAGS = -Wfatal-errors -Wall -Wextra -Wpedantic -Wconversion -Wshadow -lm
19+
CC_FLAGS = -Wfatal-errors -Wall -Wextra -Wpedantic -Wconversion -Wshadow
20+
21+
# Detect OS and add -lm on Linux
22+
UNAME_S := $(shell uname -s)
23+
ifeq ($(UNAME_S),Linux)
24+
LDFLAGS = -lm
25+
endif
2026

2127
# Final binary
2228
BIN = cmain
@@ -25,7 +31,7 @@ BIN = cmain
2531
BUILD_DIR = ./build
2632

2733
# List of all .c source files.
28-
CCS = main.c $(wildcard *.c)
34+
CCS = $(wildcard *.c)
2935

3036
# All .o files go to build dir.
3137
OBJ = $(CCS:%.c=$(BUILD_DIR)/%.o)
@@ -39,7 +45,7 @@ $(BIN) : $(BUILD_DIR)/$(BIN)
3945
# Actual target of the binary - depends on all .o files.
4046
$(BUILD_DIR)/$(BIN) : $(OBJ)
4147
mkdir -p $(@D)
42-
$(CC) $(CC_FLAGS) $^ -o $@
48+
$(CC) $(CC_FLAGS) $^ -o $@ $(LDFLAGS)
4349

4450
# Include all .d files
4551
-include $(DEP)

C/main.c

Lines changed: 70 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -18,110 +18,98 @@ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1818
*/
1919

2020
#include "neural.h"
21-
#include <stdlib.h>
2221
#include <stdio.h>
22+
#include <stdlib.h>
2323

2424
uint32_t P = 2147483647;
2525
uint32_t A = 16807;
2626
uint32_t current = 1;
2727

28-
double Rand() {
29-
current = current * A % P;
30-
double result = (double)current / P;
31-
return result;
28+
double Rand(void) {
29+
current = current * A % P;
30+
double result = (double)current / P;
31+
return result;
3232
}
3333

34-
void print_network(const Network* network);
34+
void print_network(const Network *network);
3535

36-
static uint32_t xor(uint32_t i, uint32_t j) { return i ^ j; }
37-
static uint32_t xnor(uint32_t i, uint32_t j) { return 1 - xor(i, j); }
38-
static uint32_t or(uint32_t i, uint32_t j) { return i | j; }
39-
static uint32_t and(uint32_t i, uint32_t j) { return i & j; }
40-
static uint32_t nor(uint32_t i, uint32_t j) { return 1 - or(i, j); }
36+
static uint32_t xor (uint32_t i, uint32_t j) { return i ^ j; } static uint32_t
37+
xnor(uint32_t i, uint32_t j) {
38+
return 1 - xor(i, j);
39+
}
40+
static uint32_t or (uint32_t i, uint32_t j) { return i | j; }
41+
static uint32_t and (uint32_t i, uint32_t j) { return i & j; }
42+
static uint32_t nor(uint32_t i, uint32_t j) { return 1 - or (i, j); }
4143
static uint32_t nand(uint32_t i, uint32_t j) { return 1 - and(i, j); }
4244

4345
const int ITERS = 4000;
4446

45-
int main() {
46-
Network network = {0};
47-
network_init(&network, 2, 2, 6, Rand);
48-
Trainer trainer = {0};
49-
trainer_init(&trainer, &network);
50-
double inputs[4][2] = {
51-
{0, 0},
52-
{0, 1},
53-
{1, 0},
54-
{1, 1}
55-
};
56-
double outputs[4][6] = {
57-
{ xor(0, 0), xnor(0, 0), or(0, 0), and(0, 0), nor(0, 0), nand(0, 0) },
58-
{ xor(0, 1), xnor(0, 1), or(0, 1), and(0, 1), nor(0, 1), nand(0, 1) },
59-
{ xor(1, 0), xnor(1, 0), or(1, 0), and(1, 0), nor(1, 0), nand(1, 0) },
60-
{ xor(1, 1), xnor(1, 1), or(1, 1), and(1, 1), nor(1, 1), nand(1, 1) }
61-
};
62-
63-
for (size_t i = 0; i < ITERS; i++) {
64-
double* input = inputs[i % 4];
65-
double* output = outputs[i % 4];
66-
67-
trainer_train(&trainer, &network, input, output, 1.0);
68-
}
69-
70-
printf(
71-
"Result after %d iterations\n XOR XNOR OR AND NOR NAND\n",
72-
ITERS);
73-
for (size_t i = 0; i < 4; i++)
74-
{
75-
double* input = inputs[i % 4];
76-
network_predict(&network, input);
77-
printf(
78-
"%.0f,%.0f = %.3f %.3f %.3f %.3f %.3f %.3f\n",
79-
input[0],
80-
input[1],
81-
network.output[0],
82-
network.output[1],
83-
network.output[2],
84-
network.output[3],
85-
network.output[4],
86-
network.output[5]);
87-
}
88-
89-
print_network(&network);
90-
trainer_free(&trainer);
91-
network_free(&network);
92-
return 0;
47+
int main(void) {
48+
Network network = {0};
49+
network_init(&network, 2, 2, 6, Rand);
50+
Trainer trainer = {0};
51+
trainer_init(&trainer, &network);
52+
double inputs[4][2] = {{0, 0}, {0, 1}, {1, 0}, {1, 1}};
53+
double outputs[4][6] = {
54+
{xor(0, 0), xnor(0, 0), or (0, 0), and(0, 0), nor(0, 0), nand(0, 0)},
55+
{xor(0, 1), xnor(0, 1), or (0, 1), and(0, 1), nor(0, 1), nand(0, 1)},
56+
{xor(1, 0), xnor(1, 0), or (1, 0), and(1, 0), nor(1, 0), nand(1, 0)},
57+
{xor(1, 1), xnor(1, 1), or (1, 1), and(1, 1), nor(1, 1), nand(1, 1)}};
58+
59+
for (size_t i = 0; i < ITERS; i++) {
60+
double *input = inputs[i % 4];
61+
double *output = outputs[i % 4];
62+
63+
trainer_train(&trainer, &network, input, output, 1.0);
64+
}
65+
66+
printf(
67+
"Result after %d iterations\n XOR XNOR OR AND NOR NAND\n",
68+
ITERS);
69+
for (size_t i = 0; i < 4; i++) {
70+
double *input = inputs[i % 4];
71+
network_predict(&network, input);
72+
printf("%.0f,%.0f = %.3f %.3f %.3f %.3f %.3f %.3f\n", input[0], input[1],
73+
network.output[0], network.output[1], network.output[2],
74+
network.output[3], network.output[4], network.output[5]);
75+
}
76+
77+
print_network(&network);
78+
trainer_free(&trainer);
79+
network_free(&network);
80+
return 0;
9381
}
9482

95-
void print_network(const Network* network) {
96-
printf("weights hidden:\n");
97-
for (size_t i = 0; i < network->n_inputs; i++) {
98-
for (size_t j = 0; j < network->n_hidden; j++) {
99-
printf(" %9.6f", network->weights_hidden[network->n_inputs * i + j]);
100-
}
101-
102-
printf("\n");
103-
}
104-
105-
printf("biases hidden:\n");
106-
for (size_t i = 0; i < network->n_hidden; i++) {
107-
printf(" %9.6f", network->biases_hidden[i]);
83+
void print_network(const Network *network) {
84+
printf("weights hidden:\n");
85+
for (size_t i = 0; i < network->n_inputs; i++) {
86+
for (size_t j = 0; j < network->n_hidden; j++) {
87+
printf(" %9.6f", network->weights_hidden[network->n_inputs * i + j]);
10888
}
10989

11090
printf("\n");
91+
}
11192

112-
printf("weights output:\n");
113-
for (size_t i = 0; i < network->n_hidden; i++) {
114-
for (size_t j = 0; j < network->n_outputs; j++) {
115-
printf(" %9.6f", network->weights_output[i * network->n_outputs + j]);
116-
}
93+
printf("biases hidden:\n");
94+
for (size_t i = 0; i < network->n_hidden; i++) {
95+
printf(" %9.6f", network->biases_hidden[i]);
96+
}
11797

118-
printf("\n");
119-
}
98+
printf("\n");
12099

121-
printf("biases output:\n");
122-
for (size_t i = 0; i < network->n_outputs; i++) {
123-
printf(" %9.6f", network->biases_output[i]);
100+
printf("weights output:\n");
101+
for (size_t i = 0; i < network->n_hidden; i++) {
102+
for (size_t j = 0; j < network->n_outputs; j++) {
103+
printf(" %9.6f", network->weights_output[i * network->n_outputs + j]);
124104
}
125105

126106
printf("\n");
107+
}
108+
109+
printf("biases output:\n");
110+
for (size_t i = 0; i < network->n_outputs; i++) {
111+
printf(" %9.6f", network->biases_output[i]);
112+
}
113+
114+
printf("\n");
127115
}

0 commit comments

Comments
 (0)