Skip to content

Commit aa8fa62

Browse files
committed
Added simple Zig implementation
1 parent 622e2a3 commit aa8fa62

File tree

7 files changed

+501
-0
lines changed

7 files changed

+501
-0
lines changed

.github/workflows/ci.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,20 @@ jobs:
112112
- name: Tests
113113
working-directory: Java
114114
run: bats test.bats
115+
116+
zig:
117+
runs-on: ubuntu-latest
118+
steps:
119+
- uses: actions/checkout@v3
120+
- name: Setup
121+
run: |
122+
sudo npm install -g bats
123+
124+
- name: Setup Zig
125+
uses: mlugg/setup-zig@v1
126+
with:
127+
version: 0.14.0
128+
129+
- name: Tests
130+
working-directory: Zig
131+
run: bats test.bats

Zig/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
zig-out
2+
.zig-cache

Zig/build.zig

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
const std = @import("std");
2+
3+
pub fn build(b: *std.Build) void {
4+
const target = b.standardTargetOptions(.{});
5+
const optimize = b.standardOptimizeOption(.{});
6+
7+
const exe_mod = b.createModule(.{
8+
.root_source_file = b.path("src/main.zig"),
9+
.target = target,
10+
.optimize = optimize,
11+
});
12+
13+
const exe = b.addExecutable(.{
14+
.name = "NeuralNetworkInAllLangs",
15+
.root_module = exe_mod,
16+
});
17+
18+
b.installArtifact(exe);
19+
20+
const run_cmd = b.addRunArtifact(exe);
21+
22+
run_cmd.step.dependOn(b.getInstallStep());
23+
24+
if (b.args) |args| {
25+
run_cmd.addArgs(args);
26+
}
27+
28+
const run_step = b.step("run", "Run the app");
29+
run_step.dependOn(&run_cmd.step);
30+
31+
const exe_unit_tests = b.addTest(.{
32+
.root_module = exe_mod,
33+
});
34+
35+
const run_exe_unit_tests = b.addRunArtifact(exe_unit_tests);
36+
37+
const test_step = b.step("test", "Run unit tests");
38+
test_step.dependOn(&run_exe_unit_tests.step);
39+
}

Zig/build.zig.zon

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
.{
2+
// This is the default name used by packages depending on this one. For
3+
// example, when a user runs `zig fetch --save <url>`, this field is used
4+
// as the key in the `dependencies` table. Although the user can choose a
5+
// different name, most users will stick with this provided value.
6+
//
7+
// It is redundant to include "zig" in this name because it is already
8+
// within the Zig package namespace.
9+
.name = .Zig,
10+
11+
// This is a [Semantic Version](https://semver.org/).
12+
// In a future version of Zig it will be used for package deduplication.
13+
.version = "0.0.0",
14+
15+
// Together with name, this represents a globally unique package
16+
// identifier. This field is generated by the Zig toolchain when the
17+
// package is first created, and then *never changes*. This allows
18+
// unambiguous detection of one package being an updated version of
19+
// another.
20+
//
21+
// When forking a Zig project, this id should be regenerated (delete the
22+
// field and run `zig build`) if the upstream project is still maintained.
23+
// Otherwise, the fork is *hostile*, attempting to take control over the
24+
// original project's identity. Thus it is recommended to leave the comment
25+
// on the following line intact, so that it shows up in code reviews that
26+
// modify the field.
27+
.fingerprint = 0xf9835661bb6dc018, // Changing this has security and trust implications.
28+
29+
// Tracks the earliest Zig version that the package considers to be a
30+
// supported use case.
31+
.minimum_zig_version = "0.14.0",
32+
33+
// This field is optional.
34+
// Each dependency must either provide a `url` and `hash`, or a `path`.
35+
// `zig build --fetch` can be used to fetch all dependencies of a package, recursively.
36+
// Once all dependencies are fetched, `zig build` no longer requires
37+
// internet connectivity.
38+
.dependencies = .{
39+
// See `zig fetch --save <url>` for a command-line interface for adding dependencies.
40+
//.example = .{
41+
// // When updating this field to a new URL, be sure to delete the corresponding
42+
// // `hash`, otherwise you are communicating that you expect to find the old hash at
43+
// // the new URL. If the contents of a URL change this will result in a hash mismatch
44+
// // which will prevent zig from using it.
45+
// .url = "https://example.com/foo.tar.gz",
46+
//
47+
// // This is computed from the file contents of the directory of files that is
48+
// // obtained after fetching `url` and applying the inclusion rules given by
49+
// // `paths`.
50+
// //
51+
// // This field is the source of truth; packages do not come from a `url`; they
52+
// // come from a `hash`. `url` is just one of many possible mirrors for how to
53+
// // obtain a package matching this `hash`.
54+
// //
55+
// // Uses the [multihash](https://multiformats.io/multihash/) format.
56+
// .hash = "...",
57+
//
58+
// // When this is provided, the package is found in a directory relative to the
59+
// // build root. In this case the package's hash is irrelevant and therefore not
60+
// // computed. This field and `url` are mutually exclusive.
61+
// .path = "foo",
62+
//
63+
// // When this is set to `true`, a package is declared to be lazily
64+
// // fetched. This makes the dependency only get fetched if it is
65+
// // actually used.
66+
// .lazy = false,
67+
//},
68+
},
69+
70+
// Specifies the set of files and directories that are included in this package.
71+
// Only files and directories listed here are included in the `hash` that
72+
// is computed for this package. Only files listed here will remain on disk
73+
// when using the zig package manager. As a rule of thumb, one should list
74+
// files required for compilation plus any license(s).
75+
// Paths are relative to the build root. Use the empty string (`""`) to refer to
76+
// the build root itself.
77+
// A directory listed here means that all files within, recursively, are included.
78+
.paths = .{
79+
"build.zig",
80+
"build.zig.zon",
81+
"src",
82+
// For example...
83+
//"LICENSE",
84+
//"README.md",
85+
},
86+
}

Zig/src/main.zig

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
const std = @import("std");
2+
const Trainer = @import("neural.zig").Trainer;
3+
4+
fn rand() f64 {
5+
const P: u32 = 2147483647;
6+
const A: u32 = 16807;
7+
const S = struct {
8+
var current: u32 = 1;
9+
};
10+
11+
S.current = S.current *% A % P;
12+
13+
return @as(f64, @floatFromInt(S.current)) / @as(f64, @floatFromInt(P));
14+
}
15+
16+
fn _xor(i: u32, j: u32) u32 {
17+
return i ^ j;
18+
}
19+
20+
fn _xnor(i: u32, j: u32) u32 {
21+
return 1 - _xor(i, j);
22+
}
23+
24+
fn _or(i: u32, j: u32) u32 {
25+
return i | j;
26+
}
27+
28+
fn _and(i: u32, j: u32) u32 {
29+
return i & j;
30+
}
31+
32+
fn _nor(i: u32, j: u32) u32 {
33+
return 1 - _or(i, j);
34+
}
35+
36+
fn _nand(i: u32, j: u32) u32 {
37+
return 1 - _and(i, j);
38+
}
39+
40+
fn DataItem(comptime I: usize, comptime O: usize) type {
41+
return struct {
42+
input: [I]f64,
43+
output: [O]f64,
44+
};
45+
}
46+
47+
pub fn main() !void {
48+
var dba = std.heap.DebugAllocator(.{}){};
49+
defer std.debug.assert(dba.deinit() == .ok);
50+
const allocator = dba.allocator();
51+
52+
var all_data = try allocator.alloc(DataItem(2, 6), 4);
53+
defer allocator.free(all_data);
54+
55+
for ([_]u32{ 0, 1 }) |i| {
56+
for ([_]u32{ 0, 1 }) |j| {
57+
all_data[i * 2 + j] = .{
58+
.input = .{ @floatFromInt(i), @floatFromInt(j) },
59+
.output = .{
60+
@floatFromInt(_xor(i, j)),
61+
@floatFromInt(_xnor(i, j)),
62+
@floatFromInt(_or(i, j)),
63+
@floatFromInt(_and(i, j)),
64+
@floatFromInt(_nor(i, j)),
65+
@floatFromInt(_nand(i, j)),
66+
},
67+
};
68+
}
69+
}
70+
71+
var trainer = try Trainer.init(allocator, 2, 2, 6, rand);
72+
defer trainer.deinit(allocator);
73+
74+
const steps = 4000;
75+
const lr: f64 = 1.0;
76+
77+
for (0..steps) |i| {
78+
var example = all_data[i % 4];
79+
try trainer.train(&example.input, &example.output, lr);
80+
}
81+
82+
std.debug.print("Result after {d} iterations\n", .{steps});
83+
std.debug.print(" XOR XNOR OR AND NOR NAND\n", .{});
84+
for (0..all_data.len) |i| {
85+
var example = all_data[i];
86+
const pred = try trainer.network.predict(allocator, &example.input);
87+
defer allocator.free(pred);
88+
std.debug.print(
89+
"{d:.0},{d:.0} = {d:.3} {d:.3} {d:.3} {d:.3} {d:.3} {d:.3}\n",
90+
.{
91+
example.input[0],
92+
example.input[1],
93+
pred[0],
94+
pred[1],
95+
pred[2],
96+
pred[3],
97+
pred[4],
98+
pred[5],
99+
},
100+
);
101+
}
102+
103+
trainer.network.print();
104+
}

0 commit comments

Comments
 (0)