@@ -69,6 +69,8 @@ pub fn write_math_op(writer: std.fs.File.Writer, node: *ReadyNode) !void {
69
69
try write_elu (writer , node );
70
70
} else if (std .mem .eql (u8 , node .nodeProto .op_type , "Flatten" )) {
71
71
try write_flatten (writer , node );
72
+ } else if (std .mem .eql (u8 , node .nodeProto .op_type , "Floor" )) {
73
+ try write_floor (writer , node );
72
74
} else if (std .mem .eql (u8 , node .nodeProto .op_type , "Gather" )) {
73
75
try write_gather (writer , node );
74
76
} else if (std .mem .eql (u8 , node .nodeProto .op_type , "Gemm" )) {
@@ -88,7 +90,7 @@ pub fn write_math_op(writer: std.fs.File.Writer, node: *ReadyNode) !void {
88
90
} else if (std .mem .eql (u8 , node .nodeProto .op_type , "Neg" )) {
89
91
try write_neg (writer , node );
90
92
} else if (std .mem .eql (u8 , node .nodeProto .op_type , "OneHot" )) {
91
- try writer . writeAll ( "// Handle OneHot \n " );
93
+ try write_oneHot ( writer , node );
92
94
} else if (std .mem .eql (u8 , node .nodeProto .op_type , "Pad" )) {
93
95
try write_pads (writer , node );
94
96
} else if (std .mem .eql (u8 , node .nodeProto .op_type , "ReduceMean" )) {
@@ -337,6 +339,97 @@ inline fn write_BatchNormalization(writer: std.fs.File.Writer, node: *ReadyNode)
337
339
});
338
340
}
339
341
342
+ inline fn write_oneHot (writer : std.fs.File.Writer , node : * ReadyNode ) ! void {
343
+ // https://onnx.ai/onnx/operators/onnx__OneHot.html
344
+ // INPUTS:
345
+ // - indices (heterogeneous) - T1: Tensor of indices.
346
+ // - depth (heterogeneous) - T2: Scalar tensor for depth.
347
+ // - values (heterogeneous) - T3: Tensor of shape [off_value, on_value].
348
+ // OUTPUT:
349
+ // - output (heterogeneous) - T3: Output tensor with one-hot encoding.
350
+ // ATTRIBUTES:
351
+ // - axis - INT (default is -1): Axis along which to add the one-hot dimension.
352
+
353
+ var axis : i64 = -1 ; // Default axis per ONNX
354
+ for (node .nodeProto .attribute ) | attr | {
355
+ if (std .mem .eql (u8 , attr .name , "axis" )) {
356
+ if (attr .type != AttributeType .INT ) return error .InvalidAxisType ;
357
+ axis = attr .i ;
358
+ }
359
+ }
360
+
361
+ //----create indices string
362
+ var indices_string : []u8 = undefined ;
363
+ defer allocator .free (indices_string );
364
+ if (node .inputs .items [0 ].? .tag == globals .TensorTag .INITIALIZER ) {
365
+ indices_string = try std .mem .concat (allocator , u8 , &[_ ][]const u8 {
366
+ "@constCast(¶m_lib.tensor_" ,
367
+ try utils .getSanitizedName (node .inputs .items [0 ].? .name ),
368
+ ")" ,
369
+ });
370
+ } else {
371
+ indices_string = try std .mem .concat (allocator , u8 , &[_ ][]const u8 {
372
+ "@constCast(&tensor_" ,
373
+ try utils .getSanitizedName (node .inputs .items [0 ].? .name ),
374
+ ")" ,
375
+ });
376
+ }
377
+
378
+ //----create depth string
379
+ var depth_string : []u8 = undefined ;
380
+ defer allocator .free (depth_string );
381
+ if (node .inputs .items [1 ].? .tag == globals .TensorTag .INITIALIZER ) {
382
+ depth_string = try std .mem .concat (allocator , u8 , &[_ ][]const u8 {
383
+ "@constCast(¶m_lib.tensor_" ,
384
+ try utils .getSanitizedName (node .inputs .items [1 ].? .name ),
385
+ ")" ,
386
+ });
387
+ } else {
388
+ depth_string = try std .mem .concat (allocator , u8 , &[_ ][]const u8 {
389
+ "@constCast(&tensor_" ,
390
+ try utils .getSanitizedName (node .inputs .items [1 ].? .name ),
391
+ ")" ,
392
+ });
393
+ }
394
+
395
+ //----create values string
396
+ var values_string : []u8 = undefined ;
397
+ defer allocator .free (values_string );
398
+ if (node .inputs .items [2 ].? .tag == globals .TensorTag .INITIALIZER ) {
399
+ values_string = try std .mem .concat (allocator , u8 , &[_ ][]const u8 {
400
+ "@constCast(¶m_lib.tensor_" ,
401
+ try utils .getSanitizedName (node .inputs .items [2 ].? .name ),
402
+ ")" ,
403
+ });
404
+ } else {
405
+ values_string = try std .mem .concat (allocator , u8 , &[_ ][]const u8 {
406
+ "@constCast(&tensor_" ,
407
+ try utils .getSanitizedName (node .inputs .items [2 ].? .name ),
408
+ ")" ,
409
+ });
410
+ }
411
+
412
+ _ = try writer .print (
413
+ \\
414
+ \\
415
+ \\ tensMath.oneHot_lean(
416
+ \\ {s}, // T
417
+ \\ {s}, // indices
418
+ \\ {s}.data[0], // depth (scalare)
419
+ \\ {s}, // values
420
+ \\ {}, // axis
421
+ \\ &tensor_{s}, // output
422
+ \\ )
423
+ , .{
424
+ try utils .getTypeString (globals .tensorHashMap .getPtr (node .inputs .items [2 ].? .name ).? .tensorProto .? .data_type ), // T
425
+ indices_string , // indices
426
+ depth_string , // depth
427
+ values_string , // values
428
+ axis , // axis
429
+ try utils .getSanitizedName (node .outputs .items [0 ].name ), // output
430
+ });
431
+ }
432
+
340
433
inline fn write_sub (writer : std.fs.File.Writer , node : * ReadyNode ) ! void {
341
434
// https://onnx.ai/onnx/operators/onnx__Sub.html
342
435
// INPUTS:
@@ -2136,6 +2229,37 @@ inline fn write_transpose(writer: std.fs.File.Writer, node: *ReadyNode) !void {
2136
2229
});
2137
2230
}
2138
2231
2232
+ inline fn write_floor (writer : std.fs.File.Writer , node : * ReadyNode ) ! void {
2233
+ // https://onnx.ai/onnx/operators/onnx__Floor.html
2234
+ // INPUTS:
2235
+ // - X (heterogeneous) - T: Input tensor
2236
+ // OUTPUTS:
2237
+ // - Y (heterogeneous) - T: Output tensor with floor of input elements (If x is integral, +0, -0, NaN, or infinite, x itself is returned)
2238
+
2239
+ // Create input tensor string
2240
+ var input_tensor_string : []u8 = undefined ;
2241
+ defer allocator .free (input_tensor_string );
2242
+
2243
+ if (node .inputs .items [0 ].? .tag == globals .TensorTag .INITIALIZER ) {
2244
+ input_tensor_string = try std .mem .concat (allocator , u8 , &[_ ][]const u8 {
2245
+ "@constCast(¶m_lib.tensor_" ,
2246
+ try utils .getSanitizedName (node .inputs .items [0 ].? .name ),
2247
+ ")" ,
2248
+ });
2249
+ } else {
2250
+ input_tensor_string = try std .mem .concat (allocator , u8 , &[_ ][]const u8 { "&tensor_" , try utils .getSanitizedName (node .inputs .items [0 ].? .name ) });
2251
+ }
2252
+
2253
+ _ = try writer .print (
2254
+ \\
2255
+ \\
2256
+ \\ tensMath.floor_lean(T, {s}, &tensor_{s})
2257
+ , .{
2258
+ input_tensor_string ,
2259
+ try utils .getSanitizedName (node .outputs .items [0 ].name ),
2260
+ });
2261
+ }
2262
+
2139
2263
inline fn write_tanh (writer : std.fs.File.Writer , node : * ReadyNode ) ! void {
2140
2264
// https://onnx.ai/onnx/operators/onnx__Tanh.html
2141
2265
// INPUTS:
0 commit comments