Skip to content

floor and onehot #310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: feature
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 11 additions & 59 deletions build.zig
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ pub fn build(b: *std.Build) void {
const zant_mod = b.createModule(.{ .root_source_file = b.path("src/zant.zig") });
zant_mod.addOptions("build_options", build_options);

//************************************************UNIT TESTS************************************************
const codeGen_mod = b.createModule(.{ .root_source_file = b.path("src/CodeGen/codegen.zig") });
codeGen_mod.addImport("zant", zant_mod);

//************************************************UNIT TESTS************************************************

// Define unified tests for the project.
const unit_tests = b.addTest(.{
.name = "test_lib",
Expand All @@ -41,6 +42,14 @@ pub fn build(b: *std.Build) void {
.optimize = optimize,
});

// Define test options
const test_options = b.addOptions();
test_options.addOption(bool, "heavy", b.option(bool, "heavy", "Run heavy tests") orelse false);
unit_tests.root_module.addOptions("test_options", test_options);

const test_name = b.option([]const u8, "test_name", "specify a test name to run") orelse "";
test_options.addOption([]const u8, "test_name", test_name);

unit_tests.root_module.addImport("zant", zant_mod);
unit_tests.root_module.addImport("codegen", codeGen_mod);

Expand Down Expand Up @@ -192,7 +201,6 @@ pub fn build(b: *std.Build) void {
test_step_generated_lib.dependOn(&run_test_generated_lib.step);

// ************************************************ ONEOP CODEGEN ************************************************

// Setup oneOp codegen

const oneop_codegen_exe = b.addExecutable(.{
Expand All @@ -212,8 +220,7 @@ pub fn build(b: *std.Build) void {
step_test_oneOp_codegen.dependOn(&run_oneop_codegen_exe.step);

// ************************************************

//Setup test_all_oneOp
// Setup test_all_oneOp

const test_all_oneOp = b.addTest(.{
.name = "test_all_oneOp",
Expand All @@ -238,34 +245,12 @@ pub fn build(b: *std.Build) void {
const step_test_oneOp = b.step("test-codegen", "Run generated library tests");
step_test_oneOp.dependOn(&run_test_all_oneOp.step);

// ************************************************
// Benchmark

const benchmark = b.addExecutable(.{
.name = "benchmark",
.root_source_file = b.path("benchmarks/main.zig"),
.target = target,
.optimize = optimize,
});

const bench_options = b.addOptions();
bench_options.addOption(bool, "full", b.option(bool, "full", "Choose whenever run full benchmark or not") orelse false);

benchmark.root_module.addImport("zant", zant_mod);
benchmark.root_module.addOptions("bench_options", bench_options);
benchmark.linkLibC();

const run_benchmark = b.addRunArtifact(benchmark);
const benchmark_step = b.step("benchmark", "Run benchmarks");
benchmark_step.dependOn(&run_benchmark.step);

// ************************************************ ONNX PARSER TESTS ************************************************
// Add test for generated library

const test_onnx_parser = b.addTest(.{
.name = "test_generated_lib",
.root_source_file = b.path("tests/Onnx/onnx_loader.zig"),

.target = target,
.optimize = optimize,
});
Expand All @@ -287,7 +272,6 @@ pub fn build(b: *std.Build) void {

const main_executable = b.addExecutable(.{
.name = "main_profiling_target",
.root_source_file = b.path("src/main.zig"),
.target = target,
.optimize = optimize,
});
Expand All @@ -304,36 +288,4 @@ pub fn build(b: *std.Build) void {

const build_main_step = b.step("build-main", "Build the main executable for profiling");
build_main_step.dependOn(&install_main_exe_step.step);

// ************************************************ NATIVE GUI ************************************************

{
const dvui_dep = b.dependency("dvui", .{ .target = target, .optimize = optimize, .backend = .sdl, .sdl3 = true });

const gui_exe = b.addExecutable(.{
.name = "gui",
.root_source_file = b.path("gui/sdl/sdl-standalone.zig"),
.target = target,
.optimize = optimize,
});

// Can either link the backend ourselves:
// const dvui_mod = dvui_dep.module("dvui");
// const sdl = dvui_dep.module("sdl");
// @import("dvui").linkBackend(dvui_mod, sdl);
// exe.root_module.addImport("dvui", dvui_mod);

// Or use a prelinked one:
gui_exe.root_module.addImport("dvui", dvui_dep.module("dvui_sdl"));

const compile_step = b.step("compile-gui", "Compile gui");
compile_step.dependOn(&b.addInstallArtifact(gui_exe, .{}).step);
b.getInstallStep().dependOn(compile_step);

const run_cmd = b.addRunArtifact(gui_exe);
run_cmd.step.dependOn(compile_step);

const run_step = b.step("gui", "Run gui");
run_step.dependOn(&run_cmd.step);
}
}
126 changes: 125 additions & 1 deletion src/CodeGen/math_handler.zig
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ pub fn write_math_op(writer: std.fs.File.Writer, node: *ReadyNode) !void {
try write_elu(writer, node);
} else if (std.mem.eql(u8, node.nodeProto.op_type, "Flatten")) {
try write_flatten(writer, node);
} else if (std.mem.eql(u8, node.nodeProto.op_type, "Floor")) {
try write_floor(writer, node);
} else if (std.mem.eql(u8, node.nodeProto.op_type, "Gather")) {
try write_gather(writer, node);
} else if (std.mem.eql(u8, node.nodeProto.op_type, "Gemm")) {
Expand All @@ -111,7 +113,7 @@ pub fn write_math_op(writer: std.fs.File.Writer, node: *ReadyNode) !void {
} else if (std.mem.eql(u8, node.nodeProto.op_type, "Neg")) {
try write_neg(writer, node);
} else if (std.mem.eql(u8, node.nodeProto.op_type, "OneHot")) {
try writer.writeAll("// Handle OneHot\n");
try write_oneHot(writer, node);
} else if (std.mem.eql(u8, node.nodeProto.op_type, "Pad")) {
try write_pads(writer, node);
} else if (std.mem.eql(u8, node.nodeProto.op_type, "ReduceMean")) {
Expand Down Expand Up @@ -366,6 +368,97 @@ inline fn write_BatchNormalization(writer: std.fs.File.Writer, node: *ReadyNode)
});
}

inline fn write_oneHot(writer: std.fs.File.Writer, node: *ReadyNode) !void {
// https://onnx.ai/onnx/operators/onnx__OneHot.html
// INPUTS:
// - indices (heterogeneous) - T1: Tensor of indices.
// - depth (heterogeneous) - T2: Scalar tensor for depth.
// - values (heterogeneous) - T3: Tensor of shape [off_value, on_value].
// OUTPUT:
// - output (heterogeneous) - T3: Output tensor with one-hot encoding.
// ATTRIBUTES:
// - axis - INT (default is -1): Axis along which to add the one-hot dimension.

var axis: i64 = -1; // Default axis per ONNX
for (node.nodeProto.attribute) |attr| {
if (std.mem.eql(u8, attr.name, "axis")) {
if (attr.type != AttributeType.INT) return error.InvalidAxisType;
axis = attr.i;
}
}

//----create indices string
var indices_string: []u8 = undefined;
defer allocator.free(indices_string);
if (node.inputs.items[0].?.tag == globals.TensorTag.INITIALIZER) {
indices_string = try std.mem.concat(allocator, u8, &[_][]const u8{
"@constCast(&param_lib.tensor_",
try utils.getSanitizedName(node.inputs.items[0].?.name),
")",
});
} else {
indices_string = try std.mem.concat(allocator, u8, &[_][]const u8{
"@constCast(&tensor_",
try utils.getSanitizedName(node.inputs.items[0].?.name),
")",
});
}

//----create depth string
var depth_string: []u8 = undefined;
defer allocator.free(depth_string);
if (node.inputs.items[1].?.tag == globals.TensorTag.INITIALIZER) {
depth_string = try std.mem.concat(allocator, u8, &[_][]const u8{
"@constCast(&param_lib.tensor_",
try utils.getSanitizedName(node.inputs.items[1].?.name),
")",
});
} else {
depth_string = try std.mem.concat(allocator, u8, &[_][]const u8{
"@constCast(&tensor_",
try utils.getSanitizedName(node.inputs.items[1].?.name),
")",
});
}

//----create values string
var values_string: []u8 = undefined;
defer allocator.free(values_string);
if (node.inputs.items[2].?.tag == globals.TensorTag.INITIALIZER) {
values_string = try std.mem.concat(allocator, u8, &[_][]const u8{
"@constCast(&param_lib.tensor_",
try utils.getSanitizedName(node.inputs.items[2].?.name),
")",
});
} else {
values_string = try std.mem.concat(allocator, u8, &[_][]const u8{
"@constCast(&tensor_",
try utils.getSanitizedName(node.inputs.items[2].?.name),
")",
});
}

_ = try writer.print(
\\
\\
\\ tensMath.oneHot_lean(
\\ {s}, // T
\\ {s}, // indices
\\ {s}.data[0], // depth (scalare)
\\ {s}, // values
\\ {}, // axis
\\ &tensor_{s}, // output
\\ )
, .{
try utils.getTypeString(globals.tensorHashMap.getPtr(node.inputs.items[2].?.name).?.tensorProto.?.data_type), // T
indices_string, // indices
depth_string, // depth
values_string, // values
axis, // axis
try utils.getSanitizedName(node.outputs.items[0].name), // output
});
}

inline fn write_sub(writer: std.fs.File.Writer, node: *ReadyNode) !void {
// https://onnx.ai/onnx/operators/onnx__Sub.html
// INPUTS:
Expand Down Expand Up @@ -2177,6 +2270,37 @@ inline fn write_transpose(writer: std.fs.File.Writer, node: *ReadyNode) !void {
});
}

inline fn write_floor(writer: std.fs.File.Writer, node: *ReadyNode) !void {
// https://onnx.ai/onnx/operators/onnx__Floor.html
// INPUTS:
// - X (heterogeneous) - T: Input tensor
// OUTPUTS:
// - Y (heterogeneous) - T: Output tensor with floor of input elements (If x is integral, +0, -0, NaN, or infinite, x itself is returned)

// Create input tensor string
var input_tensor_string: []u8 = undefined;
defer allocator.free(input_tensor_string);

if (node.inputs.items[0].?.tag == globals.TensorTag.INITIALIZER) {
input_tensor_string = try std.mem.concat(allocator, u8, &[_][]const u8{
"@constCast(&param_lib.tensor_",
try utils.getSanitizedName(node.inputs.items[0].?.name),
")",
});
} else {
input_tensor_string = try std.mem.concat(allocator, u8, &[_][]const u8{ "&tensor_", try utils.getSanitizedName(node.inputs.items[0].?.name) });
}

_ = try writer.print(
\\
\\
\\ tensMath.floor_lean(T, {s}, &tensor_{s})
, .{
input_tensor_string,
try utils.getSanitizedName(node.outputs.items[0].name),
});
}

inline fn write_tanh(writer: std.fs.File.Writer, node: *ReadyNode) !void {
// https://onnx.ai/onnx/operators/onnx__Tanh.html
// INPUTS:
Expand Down
Loading
Loading