Skip to content

Commit ecc2b37

Browse files
author
Ryan Lai
authored
Added check to handle free dimensions when creating tensor with arbitrary data (#184)
1 parent bdd22fe commit ecc2b37

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

Tools/WinMLRunner/src/BindingUtilities.h

+23-3
Original file line numberDiff line numberDiff line change
@@ -222,15 +222,35 @@ namespace BindingUtilities
222222
}
223223
else if(args.IsGarbageInput())
224224
{
225-
auto tensorValue = TensorValue::Create(tensorDescriptor.Shape());
225+
std::vector<int64_t> vecShape = {};
226+
auto tensorDescriptorShape = tensorDescriptor.Shape();
227+
for (UINT dim = 0; dim < tensorDescriptorShape.Size(); dim++)
228+
{
229+
INT64 dimSize = tensorDescriptorShape.GetAt(dim);
230+
if (dimSize > 0) //If the dimension is greater than 0, then it is known.
231+
{
232+
vecShape.push_back(dimSize);
233+
}
234+
else //otherwise, make sure that the dimension is -1, representing free dimension. If not, then it's an invalid model.
235+
{
236+
if (dimSize == -1)
237+
{
238+
vecShape.push_back(1);
239+
}
240+
else
241+
{
242+
throw hresult_invalid_argument(L"Failed to create a tensor with an unknown dimension of: " + dimSize);
243+
}
244+
}
245+
}
246+
auto tensorValue = TensorValue::Create(vecShape);
226247

227248
com_ptr<ITensorNative> spTensorValueNative;
228249
tensorValue.as(spTensorValueNative);
229250

230251
BYTE* actualData;
231252
uint32_t actualSizeInBytes;
232-
spTensorValueNative->GetBuffer(&actualData, &actualSizeInBytes);
233-
253+
spTensorValueNative->GetBuffer(&actualData, &actualSizeInBytes); //Need to GetBuffer to have CPU memory backing tensorValue
234254
return tensorValue;
235255
}
236256
else

0 commit comments

Comments
 (0)