@@ -92,6 +92,8 @@ pub fn write_math_op(writer: std.fs.File.Writer, node: *ReadyNode) !void {
9292 try write_elu (writer , node );
9393 } else if (std .mem .eql (u8 , node .nodeProto .op_type , "Flatten" )) {
9494 try write_flatten (writer , node );
95+ } else if (std .mem .eql (u8 , node .nodeProto .op_type , "Floor" )) {
96+ try write_floor (writer , node );
9597 } else if (std .mem .eql (u8 , node .nodeProto .op_type , "Gather" )) {
9698 try write_gather (writer , node );
9799 } else if (std .mem .eql (u8 , node .nodeProto .op_type , "Gemm" )) {
@@ -111,7 +113,7 @@ pub fn write_math_op(writer: std.fs.File.Writer, node: *ReadyNode) !void {
111113 } else if (std .mem .eql (u8 , node .nodeProto .op_type , "Neg" )) {
112114 try write_neg (writer , node );
113115 } else if (std .mem .eql (u8 , node .nodeProto .op_type , "OneHot" )) {
114- try writer . writeAll ( "// Handle OneHot \n " );
116+ try write_oneHot ( writer , node );
115117 } else if (std .mem .eql (u8 , node .nodeProto .op_type , "Pad" )) {
116118 try write_pads (writer , node );
117119 } else if (std .mem .eql (u8 , node .nodeProto .op_type , "ReduceMean" )) {
@@ -366,6 +368,97 @@ inline fn write_BatchNormalization(writer: std.fs.File.Writer, node: *ReadyNode)
366368 });
367369}
368370
371+ inline fn write_oneHot (writer : std.fs.File.Writer , node : * ReadyNode ) ! void {
372+ // https://onnx.ai/onnx/operators/onnx__OneHot.html
373+ // INPUTS:
374+ // - indices (heterogeneous) - T1: Tensor of indices.
375+ // - depth (heterogeneous) - T2: Scalar tensor for depth.
376+ // - values (heterogeneous) - T3: Tensor of shape [off_value, on_value].
377+ // OUTPUT:
378+ // - output (heterogeneous) - T3: Output tensor with one-hot encoding.
379+ // ATTRIBUTES:
380+ // - axis - INT (default is -1): Axis along which to add the one-hot dimension.
381+
382+ var axis : i64 = -1 ; // Default axis per ONNX
383+ for (node .nodeProto .attribute ) | attr | {
384+ if (std .mem .eql (u8 , attr .name , "axis" )) {
385+ if (attr .type != AttributeType .INT ) return error .InvalidAxisType ;
386+ axis = attr .i ;
387+ }
388+ }
389+
390+ //----create indices string
391+ var indices_string : []u8 = undefined ;
392+ defer allocator .free (indices_string );
393+ if (node .inputs .items [0 ].? .tag == globals .TensorTag .INITIALIZER ) {
394+ indices_string = try std .mem .concat (allocator , u8 , &[_ ][]const u8 {
395+ "@constCast(¶m_lib.tensor_" ,
396+ try utils .getSanitizedName (node .inputs .items [0 ].? .name ),
397+ ")" ,
398+ });
399+ } else {
400+ indices_string = try std .mem .concat (allocator , u8 , &[_ ][]const u8 {
401+ "@constCast(&tensor_" ,
402+ try utils .getSanitizedName (node .inputs .items [0 ].? .name ),
403+ ")" ,
404+ });
405+ }
406+
407+ //----create depth string
408+ var depth_string : []u8 = undefined ;
409+ defer allocator .free (depth_string );
410+ if (node .inputs .items [1 ].? .tag == globals .TensorTag .INITIALIZER ) {
411+ depth_string = try std .mem .concat (allocator , u8 , &[_ ][]const u8 {
412+ "@constCast(¶m_lib.tensor_" ,
413+ try utils .getSanitizedName (node .inputs .items [1 ].? .name ),
414+ ")" ,
415+ });
416+ } else {
417+ depth_string = try std .mem .concat (allocator , u8 , &[_ ][]const u8 {
418+ "@constCast(&tensor_" ,
419+ try utils .getSanitizedName (node .inputs .items [1 ].? .name ),
420+ ")" ,
421+ });
422+ }
423+
424+ //----create values string
425+ var values_string : []u8 = undefined ;
426+ defer allocator .free (values_string );
427+ if (node .inputs .items [2 ].? .tag == globals .TensorTag .INITIALIZER ) {
428+ values_string = try std .mem .concat (allocator , u8 , &[_ ][]const u8 {
429+ "@constCast(¶m_lib.tensor_" ,
430+ try utils .getSanitizedName (node .inputs .items [2 ].? .name ),
431+ ")" ,
432+ });
433+ } else {
434+ values_string = try std .mem .concat (allocator , u8 , &[_ ][]const u8 {
435+ "@constCast(&tensor_" ,
436+ try utils .getSanitizedName (node .inputs .items [2 ].? .name ),
437+ ")" ,
438+ });
439+ }
440+
441+ _ = try writer .print (
442+ \\
443+ \\
444+ \\ tensMath.oneHot_lean(
445+ \\ {s}, // T
446+ \\ {s}, // indices
447+ \\ {s}.data[0], // depth (scalare)
448+ \\ {s}, // values
449+ \\ {}, // axis
450+ \\ &tensor_{s}, // output
451+ \\ )
452+ , .{
453+ try utils .getTypeString (globals .tensorHashMap .getPtr (node .inputs .items [2 ].? .name ).? .tensorProto .? .data_type ), // T
454+ indices_string , // indices
455+ depth_string , // depth
456+ values_string , // values
457+ axis , // axis
458+ try utils .getSanitizedName (node .outputs .items [0 ].name ), // output
459+ });
460+ }
461+
369462inline fn write_sub (writer : std.fs.File.Writer , node : * ReadyNode ) ! void {
370463 // https://onnx.ai/onnx/operators/onnx__Sub.html
371464 // INPUTS:
@@ -2177,6 +2270,37 @@ inline fn write_transpose(writer: std.fs.File.Writer, node: *ReadyNode) !void {
21772270 });
21782271}
21792272
2273+ inline fn write_floor (writer : std.fs.File.Writer , node : * ReadyNode ) ! void {
2274+ // https://onnx.ai/onnx/operators/onnx__Floor.html
2275+ // INPUTS:
2276+ // - X (heterogeneous) - T: Input tensor
2277+ // OUTPUTS:
2278+ // - Y (heterogeneous) - T: Output tensor with floor of input elements (If x is integral, +0, -0, NaN, or infinite, x itself is returned)
2279+
2280+ // Create input tensor string
2281+ var input_tensor_string : []u8 = undefined ;
2282+ defer allocator .free (input_tensor_string );
2283+
2284+ if (node .inputs .items [0 ].? .tag == globals .TensorTag .INITIALIZER ) {
2285+ input_tensor_string = try std .mem .concat (allocator , u8 , &[_ ][]const u8 {
2286+ "@constCast(¶m_lib.tensor_" ,
2287+ try utils .getSanitizedName (node .inputs .items [0 ].? .name ),
2288+ ")" ,
2289+ });
2290+ } else {
2291+ input_tensor_string = try std .mem .concat (allocator , u8 , &[_ ][]const u8 { "&tensor_" , try utils .getSanitizedName (node .inputs .items [0 ].? .name ) });
2292+ }
2293+
2294+ _ = try writer .print (
2295+ \\
2296+ \\
2297+ \\ tensMath.floor_lean(T, {s}, &tensor_{s})
2298+ , .{
2299+ input_tensor_string ,
2300+ try utils .getSanitizedName (node .outputs .items [0 ].name ),
2301+ });
2302+ }
2303+
21802304inline fn write_tanh (writer : std.fs.File.Writer , node : * ReadyNode ) ! void {
21812305 // https://onnx.ai/onnx/operators/onnx__Tanh.html
21822306 // INPUTS:
0 commit comments