-
Notifications
You must be signed in to change notification settings - Fork 36
Expand file tree
/
Copy pathCustomVisionONNXModel.cs
More file actions
91 lines (81 loc) · 3.19 KB
/
CustomVisionONNXModel.cs
File metadata and controls
91 lines (81 loc) · 3.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Windows.AI.MachineLearning;
using Windows.Media;
using Windows.Storage;
/// <summary>
/// See Custom Vision ONNX UWP sample https://github.com/Azure-Samples/Custom-Vision-ONNX-UWP/blob/master/VisionApp/ONNXModel.cs
/// </summary>
namespace IoTVisualAlerts.CustomVision
{
public sealed class CustomVisionModelInput
{
public VideoFrame data { get; set; }
}
public sealed class CustomVisionModelOutput
{
// The label returned by the model
public TensorString classLabel = TensorString.Create(new long[] { 1, 1 });
// The loss returned by the model
public IList<IDictionary<string, float>> loss = new List<IDictionary<string, float>>();
public List<Tuple<string, float>> GetPredictionResult()
{
List<Tuple<string, float>> result = new List<Tuple<string, float>>();
foreach (IDictionary<string, float> dict in loss)
{
foreach (var item in dict)
{
result.Add(new Tuple<string, float>(item.Key, item.Value));
}
}
return result;
}
}
public sealed class CustomVisionONNXModel
{
private LearningModel _learningModel = null;
private LearningModelSession _session;
public int InputImageWidth { get; private set; }
public int InputImageHeight { get; private set; }
// Create a model from an ONNX 1.2 file
public static async Task<CustomVisionONNXModel> CreateONNXModel(StorageFile file)
{
LearningModel learningModel = null;
try
{
learningModel = await LearningModel.LoadFromStorageFileAsync(file);
}
catch (Exception ex)
{
throw ex;
}
var inputFeatures = learningModel.InputFeatures;
ImageFeatureDescriptor inputImageDescription = inputFeatures?.FirstOrDefault(feature => feature.Kind == LearningModelFeatureKind.Image) as ImageFeatureDescriptor;
uint inputImageWidth = 0, inputImageHeight = 0;
if (inputImageDescription != null)
{
inputImageHeight = inputImageDescription.Height;
inputImageWidth = inputImageDescription.Width;
}
return new CustomVisionONNXModel()
{
_learningModel = learningModel,
_session = new LearningModelSession(learningModel),
InputImageWidth = (int)inputImageWidth,
InputImageHeight = (int)inputImageHeight
};
}
public async Task<CustomVisionModelOutput> EvaluateAsync(CustomVisionModelInput input)
{
var output = new CustomVisionModelOutput();
var binding = new LearningModelBinding(_session);
binding.Bind("data", input.data);
binding.Bind("classLabel", output.classLabel);
binding.Bind("loss", output.loss);
LearningModelEvaluationResult result = await _session.EvaluateAsync(binding, "0");
return output;
}
}
}