@@ -11,25 +11,25 @@ extern fn predict(
1111 result : * [* ]f32 , // Pointer to receive the output slice pointer
1212) i32 ;
1313
14- fn prepareInputData (allocator : std.mem.Allocator ) ! []model_opts.data_type {
14+ fn prepareInputData (allocator : std.mem.Allocator ) ! []model_opts.input_data_type {
1515 const shape = model_opts .input_shape ;
1616 var total_size : usize = 1 ;
1717 for (shape ) | dim | {
1818 total_size *= dim ;
1919 }
2020
21- const data = try allocator .alloc (model_opts .data_type , total_size );
21+ const data = try allocator .alloc (model_opts .input_data_type , total_size );
2222 errdefer allocator .free (data );
2323
2424 for (data , 0.. ) | * val , i | {
25- val .* = @as (model_opts .data_type , @floatFromInt (i ));
25+ val .* = @as (model_opts .input_data_type , @floatFromInt (i ));
2626 }
2727
2828 return data ;
2929}
3030
3131fn getPredictOutputSize () usize {
32- return 1 * 84 * 1344 ;
32+ return 1 * 4 ;
3333}
3434
3535pub fn main () ! void {
@@ -41,17 +41,20 @@ pub fn main() !void {
4141 const input_data = try prepareInputData (allocator );
4242 const input_shape = model_opts .input_shape ;
4343
44- var output_ptr : [* ]model_opts.data_type = undefined ;
44+ var output_ptr : [* ]model_opts.output_data_type = undefined ;
4545
4646 main_log .info ("Calling predict (via model_opts.lib)...\\ n" , .{});
4747
48- model_opts .lib .predict (
48+ const res = model_opts .lib .predict (
4949 input_data .ptr ,
5050 @constCast (@ptrCast (& input_shape )),
5151 @intCast (input_shape .len ),
5252 & output_ptr ,
5353 );
5454
55+ if (res == 0 ) {
56+ main_log .info ("\n !!!! ERRORR!!! \n\n something went wrong" , .{});
57+ }
5558 main_log .info ("Predict call finished.\n " , .{});
5659
5760 const output_size = getPredictOutputSize ();
@@ -64,7 +67,7 @@ pub fn main() !void {
6467 return ;
6568 }
6669
67- const output_slice = @as ([* ]model_opts .data_type , @ptrCast (output_ptr ))[0.. output_size ];
70+ const output_slice = @as ([* ]model_opts .output_data_type , @ptrCast (output_ptr ))[0.. output_size ];
6871
6972 //print the output
7073 main_log .info ("Output (first 10 elements):\n " , .{});
0 commit comments