@@ -13,6 +13,36 @@ using namespace winrt::Windows::Graphics::DirectX;
13
13
using namespace winrt ::Windows::Graphics::Imaging;
14
14
using namespace winrt ::Windows::Graphics::DirectX::Direct3D11;
15
15
16
+ template <TensorKind T> struct TensorKindToType { static_assert(true , " No TensorKind mapped for given type!" ); };
17
+ template <> struct TensorKindToType <TensorKind::UInt8 > { typedef uint8_t Type; };
18
+ template <> struct TensorKindToType <TensorKind::Int8> { typedef uint8_t Type; };
19
+ template <> struct TensorKindToType <TensorKind::UInt16 > { typedef uint16_t Type; };
20
+ template <> struct TensorKindToType <TensorKind::Int16> { typedef int16_t Type; };
21
+ template <> struct TensorKindToType <TensorKind::UInt32 > { typedef uint32_t Type; };
22
+ template <> struct TensorKindToType <TensorKind::Int32> { typedef int32_t Type; };
23
+ template <> struct TensorKindToType <TensorKind::UInt64 > { typedef uint64_t Type; };
24
+ template <> struct TensorKindToType <TensorKind::Int64> { typedef int64_t Type; };
25
+ template <> struct TensorKindToType <TensorKind::Boolean > { typedef boolean Type; };
26
+ template <> struct TensorKindToType <TensorKind::Double> { typedef double Type; };
27
+ template <> struct TensorKindToType <TensorKind::Float> { typedef float Type; };
28
+ template <> struct TensorKindToType <TensorKind::Float16> { typedef float Type; };
29
+ template <> struct TensorKindToType <TensorKind::String> { typedef winrt::hstring Type; };
30
+
31
+ template <TensorKind T> struct TensorKindToValue { static_assert(true , " No TensorKind mapped for given type!" ); };
32
+ template <> struct TensorKindToValue <TensorKind::UInt8 > { typedef TensorUInt8Bit Type; };
33
+ template <> struct TensorKindToValue <TensorKind::Int8> { typedef TensorInt8Bit Type; };
34
+ template <> struct TensorKindToValue <TensorKind::UInt16 > { typedef TensorUInt16Bit Type; };
35
+ template <> struct TensorKindToValue <TensorKind::Int16> { typedef TensorInt16Bit Type; };
36
+ template <> struct TensorKindToValue <TensorKind::UInt32 > { typedef TensorUInt32Bit Type; };
37
+ template <> struct TensorKindToValue <TensorKind::Int32> { typedef TensorInt32Bit Type; };
38
+ template <> struct TensorKindToValue <TensorKind::UInt64 > { typedef TensorUInt64Bit Type; };
39
+ template <> struct TensorKindToValue <TensorKind::Int64> { typedef TensorInt64Bit Type; };
40
+ template <> struct TensorKindToValue <TensorKind::Boolean > { typedef TensorBoolean Type; };
41
+ template <> struct TensorKindToValue <TensorKind::Double> { typedef TensorDouble Type; };
42
+ template <> struct TensorKindToValue <TensorKind::Float> { typedef TensorFloat Type; };
43
+ template <> struct TensorKindToValue <TensorKind::Float16> { typedef TensorFloat16Bit Type; };
44
+ template <> struct TensorKindToValue <TensorKind::String> { typedef TensorString Type; };
45
+
16
46
namespace BindingUtilities
17
47
{
18
48
static unsigned int seed = 0 ;
@@ -175,6 +205,41 @@ namespace BindingUtilities
175
205
return elementStrings;
176
206
}
177
207
208
+ template <TensorKind T>
209
+ static ITensor CreateTensor (
210
+ const CommandLineArgs& args,
211
+ std::vector<std::string>& tensorStringInput,
212
+ TensorFeatureDescriptor& tensorDescriptor)
213
+ {
214
+ using TensorValue = typename TensorKindToValue<T>::Type;
215
+ using DataType = typename TensorKindToType<T>::Type;
216
+
217
+ if (!args.CsvPath ().empty ())
218
+ {
219
+ ModelBinding<DataType> binding (tensorDescriptor);
220
+ WriteDataToBinding<DataType>(tensorStringInput, binding);
221
+ return TensorValue::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
222
+ }
223
+ else if (args.IsGarbageInput ())
224
+ {
225
+ auto tensorValue = TensorValue::Create (tensorDescriptor.Shape ());
226
+
227
+ com_ptr<ITensorNative> spTensorValueNative;
228
+ tensorValue.as (spTensorValueNative);
229
+
230
+ BYTE* actualData;
231
+ uint32_t actualSizeInBytes;
232
+ spTensorValueNative->GetBuffer (&actualData, &actualSizeInBytes);
233
+
234
+ return tensorValue;
235
+ }
236
+ else
237
+ {
238
+ // Creating Tensors for Input Images haven't been added yet.
239
+ throw hresult_not_implemented (L" Creating Tensors for Input Images haven't been implemented yet!" );
240
+ }
241
+ }
242
+
178
243
// Binds tensor floats, ints, doubles from CSV data.
179
244
ITensor CreateBindableTensor (const ILearningModelFeatureDescriptor& description, const CommandLineArgs& args)
180
245
{
@@ -188,6 +253,10 @@ namespace BindingUtilities
188
253
}
189
254
190
255
std::vector<std::string> elementStrings;
256
+ if (!args.CsvPath ().empty ())
257
+ {
258
+ elementStrings = ParseCSVElementStrings (args.CsvPath ());
259
+ }
191
260
switch (tensorDescriptor.TensorKind ())
192
261
{
193
262
case TensorKind::Undefined:
@@ -197,167 +266,57 @@ namespace BindingUtilities
197
266
}
198
267
case TensorKind::Float:
199
268
{
200
- ModelBinding<float > binding (description);
201
- if (args.IsGarbageInput ())
202
- {
203
- memset (binding.GetData (), 0 , sizeof (float ) * binding.GetDataBufferSize ());
204
- }
205
- else
206
- {
207
- elementStrings = ParseCSVElementStrings (args.CsvPath ());
208
- WriteDataToBinding<float >(elementStrings, binding);
209
- }
210
- return TensorFloat::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
269
+ return CreateTensor<TensorKind::Float>(args, elementStrings, tensorDescriptor);
211
270
}
212
271
break ;
213
272
case TensorKind::Float16:
214
273
{
215
- ModelBinding<float > binding (description);
216
- if (args.IsGarbageInput ())
217
- {
218
- memset (binding.GetData (), 0 , sizeof (float ) * binding.GetDataBufferSize ());
219
- }
220
- else
221
- {
222
- elementStrings = ParseCSVElementStrings (args.CsvPath ());
223
- WriteDataToBinding<float >(elementStrings, binding);
224
- }
225
- return TensorFloat16Bit::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
274
+ return CreateTensor<TensorKind::Float16>(args, elementStrings, tensorDescriptor);
226
275
}
227
276
break ;
228
277
case TensorKind::Double:
229
278
{
230
- ModelBinding<double > binding (description);
231
- if (args.IsGarbageInput ())
232
- {
233
- memset (binding.GetData (), 0 , sizeof (double ) * binding.GetDataBufferSize ());
234
- }
235
- else
236
- {
237
- elementStrings = ParseCSVElementStrings (args.CsvPath ());
238
- WriteDataToBinding<double >(elementStrings, binding);
239
- }
240
- return TensorDouble::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
279
+ return CreateTensor<TensorKind::Double>(args, elementStrings, tensorDescriptor);
241
280
}
242
281
break ;
243
282
case TensorKind::Int8:
244
283
{
245
- ModelBinding<uint8_t > binding (description);
246
- if (args.IsGarbageInput ())
247
- {
248
- memset (binding.GetData (), 0 , sizeof (uint8_t ) * binding.GetDataBufferSize ());
249
- }
250
- else
251
- {
252
- elementStrings = ParseCSVElementStrings (args.CsvPath ());
253
- WriteDataToBinding<uint8_t >(elementStrings, binding);
254
- }
255
- return TensorInt8Bit::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
284
+ return CreateTensor<TensorKind::Int8>(args, elementStrings, tensorDescriptor);
256
285
}
257
286
break ;
258
287
case TensorKind::UInt8 :
259
288
{
260
- ModelBinding<uint8_t > binding (description);
261
- if (args.IsGarbageInput ())
262
- {
263
- memset (binding.GetData (), 0 , sizeof (uint8_t ) * binding.GetDataBufferSize ());
264
- }
265
- else
266
- {
267
- elementStrings = ParseCSVElementStrings (args.CsvPath ());
268
- WriteDataToBinding<uint8_t >(elementStrings, binding);
269
- }
270
- return TensorUInt8Bit::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
289
+ return CreateTensor<TensorKind::UInt8 >(args, elementStrings, tensorDescriptor);
271
290
}
272
291
break ;
273
292
case TensorKind::Int16:
274
293
{
275
- ModelBinding<int16_t > binding (description);
276
- if (args.IsGarbageInput ())
277
- {
278
- memset (binding.GetData (), 0 , sizeof (int16_t ) * binding.GetDataBufferSize ());
279
- }
280
- else
281
- {
282
- elementStrings = ParseCSVElementStrings (args.CsvPath ());
283
- WriteDataToBinding<int16_t >(elementStrings, binding);
284
- }
285
- return TensorInt16Bit::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
294
+ return CreateTensor<TensorKind::Int16>(args, elementStrings, tensorDescriptor);
286
295
}
287
296
break ;
288
297
case TensorKind::UInt16 :
289
298
{
290
- ModelBinding<uint16_t > binding (description);
291
- if (args.IsGarbageInput ())
292
- {
293
- memset (binding.GetData (), 0 , sizeof (uint16_t ) * binding.GetDataBufferSize ());
294
- }
295
- else
296
- {
297
- elementStrings = ParseCSVElementStrings (args.CsvPath ());
298
- WriteDataToBinding<uint16_t >(elementStrings, binding);
299
- }
300
- return TensorUInt16Bit::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
299
+ return CreateTensor<TensorKind::UInt16 >(args, elementStrings, tensorDescriptor);
301
300
}
302
301
break ;
303
302
case TensorKind::Int32:
304
303
{
305
- ModelBinding<int32_t > binding (description);
306
- if (args.IsGarbageInput ())
307
- {
308
- memset (binding.GetData (), 0 , sizeof (int32_t ) * binding.GetDataBufferSize ());
309
- }
310
- else
311
- {
312
- elementStrings = ParseCSVElementStrings (args.CsvPath ());
313
- WriteDataToBinding<int32_t >(elementStrings, binding);
314
- }
315
- return TensorInt32Bit::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
304
+ return CreateTensor<TensorKind::Int32>(args, elementStrings, tensorDescriptor);
316
305
}
317
306
break ;
318
307
case TensorKind::UInt32 :
319
308
{
320
- ModelBinding<uint32_t > binding (description);
321
- if (args.IsGarbageInput ())
322
- {
323
- memset (binding.GetData (), 0 , sizeof (uint32_t ) * binding.GetDataBufferSize ());
324
- }
325
- else
326
- {
327
- elementStrings = ParseCSVElementStrings (args.CsvPath ());
328
- WriteDataToBinding<uint32_t >(elementStrings, binding);
329
- }
330
- return TensorUInt32Bit::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
309
+ return CreateTensor<TensorKind::UInt32 >(args, elementStrings, tensorDescriptor);
331
310
}
332
311
break ;
333
312
case TensorKind::Int64:
334
313
{
335
- ModelBinding<int64_t > binding (description);
336
- if (args.IsGarbageInput ())
337
- {
338
- memset (binding.GetData (), 0 , sizeof (int64_t ) * binding.GetDataBufferSize ());
339
- }
340
- else
341
- {
342
- elementStrings = ParseCSVElementStrings (args.CsvPath ());
343
- WriteDataToBinding<int64_t >(elementStrings, binding);
344
- }
345
- return TensorInt64Bit::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
314
+ return CreateTensor<TensorKind::Int64>(args, elementStrings, tensorDescriptor);
346
315
}
347
316
break ;
348
317
case TensorKind::UInt64 :
349
318
{
350
- ModelBinding<uint64_t > binding (description);
351
- if (args.IsGarbageInput ())
352
- {
353
- memset (binding.GetData (), 0 , sizeof (uint64_t ) * binding.GetDataBufferSize ());
354
- }
355
- else
356
- {
357
- elementStrings = ParseCSVElementStrings (args.CsvPath ());
358
- WriteDataToBinding<uint64_t >(elementStrings, binding);
359
- }
360
- return TensorUInt64Bit::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
319
+ return CreateTensor<TensorKind::UInt64 >(args, elementStrings, tensorDescriptor);
361
320
}
362
321
break ;
363
322
}
0 commit comments