Skip to content

Commit aec4c2c

Browse files
floor and onehot
1 parent 727e656 commit aec4c2c

File tree

11 files changed

+688
-6
lines changed

11 files changed

+688
-6
lines changed

src/CodeGen/math_handler.zig

+125-1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ pub fn write_math_op(writer: std.fs.File.Writer, node: *ReadyNode) !void {
6969
try write_elu(writer, node);
7070
} else if (std.mem.eql(u8, node.nodeProto.op_type, "Flatten")) {
7171
try write_flatten(writer, node);
72+
} else if (std.mem.eql(u8, node.nodeProto.op_type, "Floor")) {
73+
try write_floor(writer, node);
7274
} else if (std.mem.eql(u8, node.nodeProto.op_type, "Gather")) {
7375
try write_gather(writer, node);
7476
} else if (std.mem.eql(u8, node.nodeProto.op_type, "Gemm")) {
@@ -88,7 +90,7 @@ pub fn write_math_op(writer: std.fs.File.Writer, node: *ReadyNode) !void {
8890
} else if (std.mem.eql(u8, node.nodeProto.op_type, "Neg")) {
8991
try write_neg(writer, node);
9092
} else if (std.mem.eql(u8, node.nodeProto.op_type, "OneHot")) {
91-
try writer.writeAll("// Handle OneHot\n");
93+
try write_oneHot(writer, node);
9294
} else if (std.mem.eql(u8, node.nodeProto.op_type, "Pad")) {
9395
try write_pads(writer, node);
9496
} else if (std.mem.eql(u8, node.nodeProto.op_type, "ReduceMean")) {
@@ -337,6 +339,97 @@ inline fn write_BatchNormalization(writer: std.fs.File.Writer, node: *ReadyNode)
337339
});
338340
}
339341

342+
inline fn write_oneHot(writer: std.fs.File.Writer, node: *ReadyNode) !void {
343+
// https://onnx.ai/onnx/operators/onnx__OneHot.html
344+
// INPUTS:
345+
// - indices (heterogeneous) - T1: Tensor of indices.
346+
// - depth (heterogeneous) - T2: Scalar tensor for depth.
347+
// - values (heterogeneous) - T3: Tensor of shape [off_value, on_value].
348+
// OUTPUT:
349+
// - output (heterogeneous) - T3: Output tensor with one-hot encoding.
350+
// ATTRIBUTES:
351+
// - axis - INT (default is -1): Axis along which to add the one-hot dimension.
352+
353+
var axis: i64 = -1; // Default axis per ONNX
354+
for (node.nodeProto.attribute) |attr| {
355+
if (std.mem.eql(u8, attr.name, "axis")) {
356+
if (attr.type != AttributeType.INT) return error.InvalidAxisType;
357+
axis = attr.i;
358+
}
359+
}
360+
361+
//----create indices string
362+
var indices_string: []u8 = undefined;
363+
defer allocator.free(indices_string);
364+
if (node.inputs.items[0].?.tag == globals.TensorTag.INITIALIZER) {
365+
indices_string = try std.mem.concat(allocator, u8, &[_][]const u8{
366+
"@constCast(&param_lib.tensor_",
367+
try utils.getSanitizedName(node.inputs.items[0].?.name),
368+
")",
369+
});
370+
} else {
371+
indices_string = try std.mem.concat(allocator, u8, &[_][]const u8{
372+
"@constCast(&tensor_",
373+
try utils.getSanitizedName(node.inputs.items[0].?.name),
374+
")",
375+
});
376+
}
377+
378+
//----create depth string
379+
var depth_string: []u8 = undefined;
380+
defer allocator.free(depth_string);
381+
if (node.inputs.items[1].?.tag == globals.TensorTag.INITIALIZER) {
382+
depth_string = try std.mem.concat(allocator, u8, &[_][]const u8{
383+
"@constCast(&param_lib.tensor_",
384+
try utils.getSanitizedName(node.inputs.items[1].?.name),
385+
")",
386+
});
387+
} else {
388+
depth_string = try std.mem.concat(allocator, u8, &[_][]const u8{
389+
"@constCast(&tensor_",
390+
try utils.getSanitizedName(node.inputs.items[1].?.name),
391+
")",
392+
});
393+
}
394+
395+
//----create values string
396+
var values_string: []u8 = undefined;
397+
defer allocator.free(values_string);
398+
if (node.inputs.items[2].?.tag == globals.TensorTag.INITIALIZER) {
399+
values_string = try std.mem.concat(allocator, u8, &[_][]const u8{
400+
"@constCast(&param_lib.tensor_",
401+
try utils.getSanitizedName(node.inputs.items[2].?.name),
402+
")",
403+
});
404+
} else {
405+
values_string = try std.mem.concat(allocator, u8, &[_][]const u8{
406+
"@constCast(&tensor_",
407+
try utils.getSanitizedName(node.inputs.items[2].?.name),
408+
")",
409+
});
410+
}
411+
412+
_ = try writer.print(
413+
\\
414+
\\
415+
\\ tensMath.oneHot_lean(
416+
\\ {s}, // T
417+
\\ {s}, // indices
418+
\\ {s}.data[0], // depth (scalare)
419+
\\ {s}, // values
420+
\\ {}, // axis
421+
\\ &tensor_{s}, // output
422+
\\ )
423+
, .{
424+
try utils.getTypeString(globals.tensorHashMap.getPtr(node.inputs.items[2].?.name).?.tensorProto.?.data_type), // T
425+
indices_string, // indices
426+
depth_string, // depth
427+
values_string, // values
428+
axis, // axis
429+
try utils.getSanitizedName(node.outputs.items[0].name), // output
430+
});
431+
}
432+
340433
inline fn write_sub(writer: std.fs.File.Writer, node: *ReadyNode) !void {
341434
// https://onnx.ai/onnx/operators/onnx__Sub.html
342435
// INPUTS:
@@ -2136,6 +2229,37 @@ inline fn write_transpose(writer: std.fs.File.Writer, node: *ReadyNode) !void {
21362229
});
21372230
}
21382231

2232+
inline fn write_floor(writer: std.fs.File.Writer, node: *ReadyNode) !void {
2233+
// https://onnx.ai/onnx/operators/onnx__Floor.html
2234+
// INPUTS:
2235+
// - X (heterogeneous) - T: Input tensor
2236+
// OUTPUTS:
2237+
// - Y (heterogeneous) - T: Output tensor with floor of input elements (If x is integral, +0, -0, NaN, or infinite, x itself is returned)
2238+
2239+
// Create input tensor string
2240+
var input_tensor_string: []u8 = undefined;
2241+
defer allocator.free(input_tensor_string);
2242+
2243+
if (node.inputs.items[0].?.tag == globals.TensorTag.INITIALIZER) {
2244+
input_tensor_string = try std.mem.concat(allocator, u8, &[_][]const u8{
2245+
"@constCast(&param_lib.tensor_",
2246+
try utils.getSanitizedName(node.inputs.items[0].?.name),
2247+
")",
2248+
});
2249+
} else {
2250+
input_tensor_string = try std.mem.concat(allocator, u8, &[_][]const u8{ "&tensor_", try utils.getSanitizedName(node.inputs.items[0].?.name) });
2251+
}
2252+
2253+
_ = try writer.print(
2254+
\\
2255+
\\
2256+
\\ tensMath.floor_lean(T, {s}, &tensor_{s})
2257+
, .{
2258+
input_tensor_string,
2259+
try utils.getSanitizedName(node.outputs.items[0].name),
2260+
});
2261+
}
2262+
21392263
inline fn write_tanh(writer: std.fs.File.Writer, node: *ReadyNode) !void {
21402264
// https://onnx.ai/onnx/operators/onnx__Tanh.html
21412265
// INPUTS:

src/CodeGen/shape_handler.zig

+110-2
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ pub fn compute_output_shape(readyNode: *ReadyNode) !void {
7878
} else if (std.mem.eql(u8, readyNode.nodeProto.op_type, "Flatten")) {
7979
//https://onnx.ai/onnx/operators/onnx__Flatten.html
8080
try compute_flatten_output_shape(readyNode);
81+
} else if (std.mem.eql(u8, readyNode.nodeProto.op_type, "Floor")) {
82+
//https://onnx.ai/onnx/operators/onnx__Floor.html
83+
try compute_floor_output_shape(readyNode);
8184
} else if (std.mem.eql(u8, readyNode.nodeProto.op_type, "Gather")) {
8285
try compute_gather_output_shape(readyNode);
8386
} else if (std.mem.eql(u8, readyNode.nodeProto.op_type, "Gemm")) {
@@ -98,8 +101,8 @@ pub fn compute_output_shape(readyNode: *ReadyNode) !void {
98101
//https://onnx.ai/onnx/operators/onnx__Neg.html
99102
try compute_neg_output_shape(readyNode);
100103
} else if (std.mem.eql(u8, readyNode.nodeProto.op_type, "OneHot")) {
101-
// TODO
102-
return error.OperationWIP;
104+
//https://onnx.ai/onnx/operators/onnx__OneHot.html
105+
try compute_oneHot_output_shape(readyNode);
103106
} else if (std.mem.eql(u8, readyNode.nodeProto.op_type, "Pad")) {
104107
//https://onnx.ai/onnx/operators/onnx__Pad.html
105108
try compute_pads_output_shape(readyNode);
@@ -436,6 +439,93 @@ inline fn compute_gemm_output_shape(readyNode: *ReadyNode) !void {
436439
readyNode.outputs.items[0].shape = shape;
437440
}
438441

442+
inline fn compute_oneHot_output_shape(readyNode: *ReadyNode) !void {
443+
std.debug.print("\n====== compute_oneHot_output_shape node: {s} ======\n", .{readyNode.nodeProto.name orelse "(unnamed)"});
444+
445+
var shape: []const i64 = undefined;
446+
447+
if (utils.getTensorShape(readyNode.outputs.items[0].name)) |tensorShape| {
448+
shape = tensorShape;
449+
} else {
450+
// Verifica che ci siano esattamente 3 input: indices, depth, values
451+
if (readyNode.inputs.items.len != 3) {
452+
std.debug.print("\n ERROR: OneHot expects exactly 3 inputs, got {d}\n", .{readyNode.inputs.items.len});
453+
return error.InvalidNumberOfInputs;
454+
}
455+
456+
const indices = readyNode.inputs.items[0].?;
457+
const depth_tensor = readyNode.inputs.items[1].?;
458+
const values = readyNode.inputs.items[2].?;
459+
460+
std.debug.print("\n indices_shape: []i64 = {any}", .{indices.shape});
461+
std.debug.print("\n depth_shape: []i64 = {any}", .{depth_tensor.shape});
462+
std.debug.print("\n values_shape: []i64 = {any}", .{values.shape});
463+
464+
// Verifica che depth sia uno scalare (forma [] o [1])
465+
const depth_shape_i64 = depth_tensor.shape;
466+
const effective_depth_shape_i64 = if (depth_shape_i64.len == 0) &[_]i64{1} else depth_shape_i64;
467+
if (effective_depth_shape_i64.len > 1 or effective_depth_shape_i64[0] != 1) {
468+
std.debug.print("\n ERROR: depth must be a scalar, got shape {any}\n", .{effective_depth_shape_i64});
469+
return error.InvalidDepthShape;
470+
}
471+
472+
// Verifica che values abbia forma [2]
473+
const values_shape_i64 = values.shape;
474+
const effective_values_shape_i64 = if (values_shape_i64.len == 0) &[_]i64{1} else values_shape_i64;
475+
if (effective_values_shape_i64.len != 1 or effective_values_shape_i64[0] != 2) {
476+
std.debug.print("\n ERROR: values must have shape [2], got shape {any}\n", .{effective_values_shape_i64});
477+
return error.InvalidValuesShape;
478+
}
479+
480+
// Estrai il valore di depth
481+
var depth: i64 = undefined;
482+
if (depth_tensor.tensorProto != null and depth_tensor.tensorProto.?.int64_data != null) {
483+
depth = depth_tensor.tensorProto.?.int64_data.?[0];
484+
} else if (depth_tensor.tensorProto != null and depth_tensor.tensorProto.?.raw_data != null) {
485+
const raw = depth_tensor.tensorProto.?.raw_data.?;
486+
if (raw.len < @sizeOf(i64)) {
487+
std.debug.print("\n ERROR: depth raw_data is too small to contain an i64\n", .{});
488+
return error.InvalidDepthData;
489+
}
490+
depth = std.mem.readInt(i64, raw[0..@sizeOf(i64)], .little);
491+
} else {
492+
std.debug.print("\n ERROR: depth tensorProto is missing valid data\n", .{});
493+
return error.DepthDataMissing;
494+
}
495+
496+
// Verifica che depth sia positivo
497+
if (depth <= 0) {
498+
std.debug.print("\n ERROR: depth must be positive, got {d}\n", .{depth});
499+
return error.InvalidDepthValue;
500+
}
501+
502+
// Estrai l'attributo axis (default: -1)
503+
var axis: i64 = -1;
504+
for (readyNode.nodeProto.attribute) |attr| {
505+
if (std.mem.eql(u8, attr.name, "axis")) {
506+
if (attr.type != AttributeType.INT) {
507+
std.debug.print("\n ERROR: axis attribute must be INT, got type {any}\n", .{attr.type});
508+
return error.InvalidAttributeType;
509+
}
510+
axis = attr.i;
511+
break;
512+
}
513+
}
514+
515+
const indices_shape_i64 = indices.shape;
516+
const indices_shape_usize = try utils.i64SliceToUsizeSlice(indices_shape_i64);
517+
defer allocator.free(indices_shape_usize);
518+
519+
const output_shape_usize = try tensorMath.get_oneHot_output_shape(indices_shape_usize, depth, axis);
520+
defer allocator.free(output_shape_usize);
521+
522+
shape = try utils.usizeSliceToI64Slice(output_shape_usize);
523+
}
524+
525+
readyNode.outputs.items[0].shape = shape;
526+
std.debug.print("\n output_shape: []i64 = {any}", .{readyNode.outputs.items[0].shape});
527+
}
528+
439529
inline fn compute_mul_output_shape(readyNode: *ReadyNode) !void {
440530
std.debug.print("\n====== compute_mul_output_shape node: {s} ======\n", .{readyNode.nodeProto.name.?});
441531

@@ -1026,6 +1116,24 @@ inline fn compute_tanh_output_shape(readyNode: *ReadyNode) !void {
10261116
readyNode.outputs.items[0].shape = shape;
10271117
}
10281118

1119+
inline fn compute_floor_output_shape(readyNode: *ReadyNode) !void {
1120+
const input = readyNode.inputs.items[0] orelse {
1121+
return error.InputTensorIsNull;
1122+
};
1123+
1124+
var shape: []const i64 = undefined;
1125+
1126+
if (utils.getTensorShape(readyNode.outputs.items[0].name)) |tensorShape| {
1127+
shape = tensorShape;
1128+
} else {
1129+
const input_shape = input.shape;
1130+
std.debug.print("\n input_shape: []i64 = {any}", .{input_shape});
1131+
1132+
shape = try utils.usizeSliceToI64Slice(try tensorMath.get_floor_output_shape(try utils.i64SliceToUsizeSlice(input_shape)));
1133+
}
1134+
readyNode.outputs.items[0].shape = shape;
1135+
}
1136+
10291137
inline fn compute_elu_output_shape(readyNode: *ReadyNode) !void {
10301138
const input = readyNode.inputs.items[0] orelse {
10311139
return error.InputTensorIsNull;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
const std = @import("std");
2+
const zant = @import("../../../../zant.zig");
3+
const pkg_allocator = zant.utils.allocator.allocator;
4+
5+
const Tensor = zant.core.tensor.Tensor; // Import Tensor type
6+
7+
pub fn get_floor_output_shape(input_shape: []const usize) ![]usize {
8+
// Allocate and copy the input shape
9+
const output_shape = try pkg_allocator.alloc(usize, input_shape.len);
10+
errdefer pkg_allocator.free(output_shape);
11+
12+
std.mem.copyForwards(usize, output_shape, input_shape);
13+
14+
return output_shape;
15+
}
16+
17+
pub fn floor(comptime T: anytype, input: *Tensor(T)) !Tensor(T) {
18+
comptime if (!(std.meta.eql(T, f64) or std.meta.eql(T, f32) or std.meta.eql(T, f16))) {
19+
@compileError("Unsupported type in floor_lean");
20+
};
21+
22+
const output_shape = try get_floor_output_shape(input.shape);
23+
var output = try Tensor(T).fromShape(&pkg_allocator, output_shape);
24+
defer pkg_allocator.free(output_shape);
25+
errdefer output.deinit();
26+
27+
try floor_lean(T, input, &output);
28+
return output;
29+
}
30+
31+
pub fn floor_lean(comptime T: anytype, input: *Tensor(T), output: *Tensor(T)) !void {
32+
// Compute floor(x) for each element of the tensor
33+
for (input.data, output.data) |in_val, *out_val| {
34+
if (std.math.isNan(in_val) or std.math.isInf(in_val) or in_val == 0 or in_val == @trunc(in_val)) {
35+
out_val.* = in_val;
36+
} else {
37+
out_val.* = std.math.floor(in_val);
38+
}
39+
}
40+
}

0 commit comments

Comments
 (0)