Skip to content

Commit 144ba08

Browse files
floor and onehot
1 parent 44a914d commit 144ba08

File tree

11 files changed

+688
-6
lines changed

11 files changed

+688
-6
lines changed

src/CodeGen/math_handler.zig

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ pub fn write_math_op(writer: std.fs.File.Writer, node: *ReadyNode) !void {
9292
try write_elu(writer, node);
9393
} else if (std.mem.eql(u8, node.nodeProto.op_type, "Flatten")) {
9494
try write_flatten(writer, node);
95+
} else if (std.mem.eql(u8, node.nodeProto.op_type, "Floor")) {
96+
try write_floor(writer, node);
9597
} else if (std.mem.eql(u8, node.nodeProto.op_type, "Gather")) {
9698
try write_gather(writer, node);
9799
} else if (std.mem.eql(u8, node.nodeProto.op_type, "Gemm")) {
@@ -111,7 +113,7 @@ pub fn write_math_op(writer: std.fs.File.Writer, node: *ReadyNode) !void {
111113
} else if (std.mem.eql(u8, node.nodeProto.op_type, "Neg")) {
112114
try write_neg(writer, node);
113115
} else if (std.mem.eql(u8, node.nodeProto.op_type, "OneHot")) {
114-
try writer.writeAll("// Handle OneHot\n");
116+
try write_oneHot(writer, node);
115117
} else if (std.mem.eql(u8, node.nodeProto.op_type, "Pad")) {
116118
try write_pads(writer, node);
117119
} else if (std.mem.eql(u8, node.nodeProto.op_type, "ReduceMean")) {
@@ -366,6 +368,97 @@ inline fn write_BatchNormalization(writer: std.fs.File.Writer, node: *ReadyNode)
366368
});
367369
}
368370

371+
inline fn write_oneHot(writer: std.fs.File.Writer, node: *ReadyNode) !void {
372+
// https://onnx.ai/onnx/operators/onnx__OneHot.html
373+
// INPUTS:
374+
// - indices (heterogeneous) - T1: Tensor of indices.
375+
// - depth (heterogeneous) - T2: Scalar tensor for depth.
376+
// - values (heterogeneous) - T3: Tensor of shape [off_value, on_value].
377+
// OUTPUT:
378+
// - output (heterogeneous) - T3: Output tensor with one-hot encoding.
379+
// ATTRIBUTES:
380+
// - axis - INT (default is -1): Axis along which to add the one-hot dimension.
381+
382+
var axis: i64 = -1; // Default axis per ONNX
383+
for (node.nodeProto.attribute) |attr| {
384+
if (std.mem.eql(u8, attr.name, "axis")) {
385+
if (attr.type != AttributeType.INT) return error.InvalidAxisType;
386+
axis = attr.i;
387+
}
388+
}
389+
390+
//----create indices string
391+
var indices_string: []u8 = undefined;
392+
defer allocator.free(indices_string);
393+
if (node.inputs.items[0].?.tag == globals.TensorTag.INITIALIZER) {
394+
indices_string = try std.mem.concat(allocator, u8, &[_][]const u8{
395+
"@constCast(&param_lib.tensor_",
396+
try utils.getSanitizedName(node.inputs.items[0].?.name),
397+
")",
398+
});
399+
} else {
400+
indices_string = try std.mem.concat(allocator, u8, &[_][]const u8{
401+
"@constCast(&tensor_",
402+
try utils.getSanitizedName(node.inputs.items[0].?.name),
403+
")",
404+
});
405+
}
406+
407+
//----create depth string
408+
var depth_string: []u8 = undefined;
409+
defer allocator.free(depth_string);
410+
if (node.inputs.items[1].?.tag == globals.TensorTag.INITIALIZER) {
411+
depth_string = try std.mem.concat(allocator, u8, &[_][]const u8{
412+
"@constCast(&param_lib.tensor_",
413+
try utils.getSanitizedName(node.inputs.items[1].?.name),
414+
")",
415+
});
416+
} else {
417+
depth_string = try std.mem.concat(allocator, u8, &[_][]const u8{
418+
"@constCast(&tensor_",
419+
try utils.getSanitizedName(node.inputs.items[1].?.name),
420+
")",
421+
});
422+
}
423+
424+
//----create values string
425+
var values_string: []u8 = undefined;
426+
defer allocator.free(values_string);
427+
if (node.inputs.items[2].?.tag == globals.TensorTag.INITIALIZER) {
428+
values_string = try std.mem.concat(allocator, u8, &[_][]const u8{
429+
"@constCast(&param_lib.tensor_",
430+
try utils.getSanitizedName(node.inputs.items[2].?.name),
431+
")",
432+
});
433+
} else {
434+
values_string = try std.mem.concat(allocator, u8, &[_][]const u8{
435+
"@constCast(&tensor_",
436+
try utils.getSanitizedName(node.inputs.items[2].?.name),
437+
")",
438+
});
439+
}
440+
441+
_ = try writer.print(
442+
\\
443+
\\
444+
\\ tensMath.oneHot_lean(
445+
\\ {s}, // T
446+
\\ {s}, // indices
447+
\\ {s}.data[0], // depth (scalare)
448+
\\ {s}, // values
449+
\\ {}, // axis
450+
\\ &tensor_{s}, // output
451+
\\ )
452+
, .{
453+
try utils.getTypeString(globals.tensorHashMap.getPtr(node.inputs.items[2].?.name).?.tensorProto.?.data_type), // T
454+
indices_string, // indices
455+
depth_string, // depth
456+
values_string, // values
457+
axis, // axis
458+
try utils.getSanitizedName(node.outputs.items[0].name), // output
459+
});
460+
}
461+
369462
inline fn write_sub(writer: std.fs.File.Writer, node: *ReadyNode) !void {
370463
// https://onnx.ai/onnx/operators/onnx__Sub.html
371464
// INPUTS:
@@ -2177,6 +2270,37 @@ inline fn write_transpose(writer: std.fs.File.Writer, node: *ReadyNode) !void {
21772270
});
21782271
}
21792272

2273+
inline fn write_floor(writer: std.fs.File.Writer, node: *ReadyNode) !void {
2274+
// https://onnx.ai/onnx/operators/onnx__Floor.html
2275+
// INPUTS:
2276+
// - X (heterogeneous) - T: Input tensor
2277+
// OUTPUTS:
2278+
// - Y (heterogeneous) - T: Output tensor with floor of input elements (If x is integral, +0, -0, NaN, or infinite, x itself is returned)
2279+
2280+
// Create input tensor string
2281+
var input_tensor_string: []u8 = undefined;
2282+
defer allocator.free(input_tensor_string);
2283+
2284+
if (node.inputs.items[0].?.tag == globals.TensorTag.INITIALIZER) {
2285+
input_tensor_string = try std.mem.concat(allocator, u8, &[_][]const u8{
2286+
"@constCast(&param_lib.tensor_",
2287+
try utils.getSanitizedName(node.inputs.items[0].?.name),
2288+
")",
2289+
});
2290+
} else {
2291+
input_tensor_string = try std.mem.concat(allocator, u8, &[_][]const u8{ "&tensor_", try utils.getSanitizedName(node.inputs.items[0].?.name) });
2292+
}
2293+
2294+
_ = try writer.print(
2295+
\\
2296+
\\
2297+
\\ tensMath.floor_lean(T, {s}, &tensor_{s})
2298+
, .{
2299+
input_tensor_string,
2300+
try utils.getSanitizedName(node.outputs.items[0].name),
2301+
});
2302+
}
2303+
21802304
inline fn write_tanh(writer: std.fs.File.Writer, node: *ReadyNode) !void {
21812305
// https://onnx.ai/onnx/operators/onnx__Tanh.html
21822306
// INPUTS:

src/CodeGen/shape_handler.zig

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ pub fn compute_output_shape(readyNode: *ReadyNode) !void {
8080
} else if (std.mem.eql(u8, readyNode.nodeProto.op_type, "Flatten")) {
8181
//https://onnx.ai/onnx/operators/onnx__Flatten.html
8282
try compute_flatten_output_shape(readyNode);
83+
} else if (std.mem.eql(u8, readyNode.nodeProto.op_type, "Floor")) {
84+
//https://onnx.ai/onnx/operators/onnx__Floor.html
85+
try compute_floor_output_shape(readyNode);
8386
} else if (std.mem.eql(u8, readyNode.nodeProto.op_type, "Gather")) {
8487
try compute_gather_output_shape(readyNode);
8588
} else if (std.mem.eql(u8, readyNode.nodeProto.op_type, "Gemm")) {
@@ -100,8 +103,8 @@ pub fn compute_output_shape(readyNode: *ReadyNode) !void {
100103
//https://onnx.ai/onnx/operators/onnx__Neg.html
101104
try compute_neg_output_shape(readyNode);
102105
} else if (std.mem.eql(u8, readyNode.nodeProto.op_type, "OneHot")) {
103-
// TODO
104-
return error.OperationWIP;
106+
//https://onnx.ai/onnx/operators/onnx__OneHot.html
107+
try compute_oneHot_output_shape(readyNode);
105108
} else if (std.mem.eql(u8, readyNode.nodeProto.op_type, "Pad")) {
106109
//https://onnx.ai/onnx/operators/onnx__Pad.html
107110
try compute_pads_output_shape(readyNode);
@@ -438,6 +441,93 @@ inline fn compute_gemm_output_shape(readyNode: *ReadyNode) !void {
438441
readyNode.outputs.items[0].shape = shape;
439442
}
440443

444+
inline fn compute_oneHot_output_shape(readyNode: *ReadyNode) !void {
445+
std.debug.print("\n====== compute_oneHot_output_shape node: {s} ======\n", .{readyNode.nodeProto.name orelse "(unnamed)"});
446+
447+
var shape: []const i64 = undefined;
448+
449+
if (utils.getTensorShape(readyNode.outputs.items[0].name)) |tensorShape| {
450+
shape = tensorShape;
451+
} else {
452+
// Verifica che ci siano esattamente 3 input: indices, depth, values
453+
if (readyNode.inputs.items.len != 3) {
454+
std.debug.print("\n ERROR: OneHot expects exactly 3 inputs, got {d}\n", .{readyNode.inputs.items.len});
455+
return error.InvalidNumberOfInputs;
456+
}
457+
458+
const indices = readyNode.inputs.items[0].?;
459+
const depth_tensor = readyNode.inputs.items[1].?;
460+
const values = readyNode.inputs.items[2].?;
461+
462+
std.debug.print("\n indices_shape: []i64 = {any}", .{indices.shape});
463+
std.debug.print("\n depth_shape: []i64 = {any}", .{depth_tensor.shape});
464+
std.debug.print("\n values_shape: []i64 = {any}", .{values.shape});
465+
466+
// Verifica che depth sia uno scalare (forma [] o [1])
467+
const depth_shape_i64 = depth_tensor.shape;
468+
const effective_depth_shape_i64 = if (depth_shape_i64.len == 0) &[_]i64{1} else depth_shape_i64;
469+
if (effective_depth_shape_i64.len > 1 or effective_depth_shape_i64[0] != 1) {
470+
std.debug.print("\n ERROR: depth must be a scalar, got shape {any}\n", .{effective_depth_shape_i64});
471+
return error.InvalidDepthShape;
472+
}
473+
474+
// Verifica che values abbia forma [2]
475+
const values_shape_i64 = values.shape;
476+
const effective_values_shape_i64 = if (values_shape_i64.len == 0) &[_]i64{1} else values_shape_i64;
477+
if (effective_values_shape_i64.len != 1 or effective_values_shape_i64[0] != 2) {
478+
std.debug.print("\n ERROR: values must have shape [2], got shape {any}\n", .{effective_values_shape_i64});
479+
return error.InvalidValuesShape;
480+
}
481+
482+
// Estrai il valore di depth
483+
var depth: i64 = undefined;
484+
if (depth_tensor.tensorProto != null and depth_tensor.tensorProto.?.int64_data != null) {
485+
depth = depth_tensor.tensorProto.?.int64_data.?[0];
486+
} else if (depth_tensor.tensorProto != null and depth_tensor.tensorProto.?.raw_data != null) {
487+
const raw = depth_tensor.tensorProto.?.raw_data.?;
488+
if (raw.len < @sizeOf(i64)) {
489+
std.debug.print("\n ERROR: depth raw_data is too small to contain an i64\n", .{});
490+
return error.InvalidDepthData;
491+
}
492+
depth = std.mem.readInt(i64, raw[0..@sizeOf(i64)], .little);
493+
} else {
494+
std.debug.print("\n ERROR: depth tensorProto is missing valid data\n", .{});
495+
return error.DepthDataMissing;
496+
}
497+
498+
// Verifica che depth sia positivo
499+
if (depth <= 0) {
500+
std.debug.print("\n ERROR: depth must be positive, got {d}\n", .{depth});
501+
return error.InvalidDepthValue;
502+
}
503+
504+
// Estrai l'attributo axis (default: -1)
505+
var axis: i64 = -1;
506+
for (readyNode.nodeProto.attribute) |attr| {
507+
if (std.mem.eql(u8, attr.name, "axis")) {
508+
if (attr.type != AttributeType.INT) {
509+
std.debug.print("\n ERROR: axis attribute must be INT, got type {any}\n", .{attr.type});
510+
return error.InvalidAttributeType;
511+
}
512+
axis = attr.i;
513+
break;
514+
}
515+
}
516+
517+
const indices_shape_i64 = indices.shape;
518+
const indices_shape_usize = try utils.i64SliceToUsizeSlice(indices_shape_i64);
519+
defer allocator.free(indices_shape_usize);
520+
521+
const output_shape_usize = try tensorMath.get_oneHot_output_shape(indices_shape_usize, depth, axis);
522+
defer allocator.free(output_shape_usize);
523+
524+
shape = try utils.usizeSliceToI64Slice(output_shape_usize);
525+
}
526+
527+
readyNode.outputs.items[0].shape = shape;
528+
std.debug.print("\n output_shape: []i64 = {any}", .{readyNode.outputs.items[0].shape});
529+
}
530+
441531
inline fn compute_mul_output_shape(readyNode: *ReadyNode) !void {
442532
Codegen_log.info("\n====== compute_mul_output_shape node: {s} ======\n", .{readyNode.nodeProto.name.?});
443533

@@ -1028,6 +1118,24 @@ inline fn compute_tanh_output_shape(readyNode: *ReadyNode) !void {
10281118
readyNode.outputs.items[0].shape = shape;
10291119
}
10301120

1121+
inline fn compute_floor_output_shape(readyNode: *ReadyNode) !void {
1122+
const input = readyNode.inputs.items[0] orelse {
1123+
return error.InputTensorIsNull;
1124+
};
1125+
1126+
var shape: []const i64 = undefined;
1127+
1128+
if (utils.getTensorShape(readyNode.outputs.items[0].name)) |tensorShape| {
1129+
shape = tensorShape;
1130+
} else {
1131+
const input_shape = input.shape;
1132+
std.debug.print("\n input_shape: []i64 = {any}", .{input_shape});
1133+
1134+
shape = try utils.usizeSliceToI64Slice(try tensorMath.get_floor_output_shape(try utils.i64SliceToUsizeSlice(input_shape)));
1135+
}
1136+
readyNode.outputs.items[0].shape = shape;
1137+
}
1138+
10311139
inline fn compute_elu_output_shape(readyNode: *ReadyNode) !void {
10321140
const input = readyNode.inputs.items[0] orelse {
10331141
return error.InputTensorIsNull;
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
const std = @import("std");
2+
const zant = @import("../../../../zant.zig");
3+
const pkg_allocator = zant.utils.allocator.allocator;
4+
5+
const Tensor = zant.core.tensor.Tensor; // Import Tensor type
6+
7+
pub fn get_floor_output_shape(input_shape: []const usize) ![]usize {
8+
// Allocate and copy the input shape
9+
const output_shape = try pkg_allocator.alloc(usize, input_shape.len);
10+
errdefer pkg_allocator.free(output_shape);
11+
12+
std.mem.copyForwards(usize, output_shape, input_shape);
13+
14+
return output_shape;
15+
}
16+
17+
pub fn floor(comptime T: anytype, input: *Tensor(T)) !Tensor(T) {
18+
comptime if (!(std.meta.eql(T, f64) or std.meta.eql(T, f32) or std.meta.eql(T, f16))) {
19+
@compileError("Unsupported type in floor_lean");
20+
};
21+
22+
const output_shape = try get_floor_output_shape(input.shape);
23+
var output = try Tensor(T).fromShape(&pkg_allocator, output_shape);
24+
defer pkg_allocator.free(output_shape);
25+
errdefer output.deinit();
26+
27+
try floor_lean(T, input, &output);
28+
return output;
29+
}
30+
31+
pub fn floor_lean(comptime T: anytype, input: *Tensor(T), output: *Tensor(T)) !void {
32+
// Compute floor(x) for each element of the tensor
33+
for (input.data, output.data) |in_val, *out_val| {
34+
if (std.math.isNan(in_val) or std.math.isInf(in_val) or in_val == 0 or in_val == @trunc(in_val)) {
35+
out_val.* = in_val;
36+
} else {
37+
out_val.* = std.math.floor(in_val);
38+
}
39+
}
40+
}

0 commit comments

Comments
 (0)