Skip to content

Commit e1d723a

Browse files
committed
Add free dimension overrides for SD 1.4 components
1 parent 74ec8c0 commit e1d723a

File tree

4 files changed

+43
-1
lines changed

4 files changed

+43
-1
lines changed

AIDevGallery/Samples/SharedCode/StableDiffusionCode/SafetyChecker.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@ private Task<InferenceSession> GetInferenceSession(string modelPath, ExecutionPr
5858
SessionOptions sessionOptions = new();
5959
sessionOptions.RegisterOrtExtensions();
6060

61+
if(device == "NvTensorRTRTXExecutionProvider")
62+
{
63+
sessionOptions.AddFreeDimensionOverrideByName("batch", 1);
64+
sessionOptions.AddFreeDimensionOverrideByName("channels", 3);
65+
sessionOptions.AddFreeDimensionOverrideByName("height", 224);
66+
sessionOptions.AddFreeDimensionOverrideByName("width", 224);
67+
}
68+
6169
if (policy != null)
6270
{
6371
sessionOptions.SetEpSelectionPolicy(policy.Value);
@@ -88,7 +96,14 @@ public bool IsNotSafe(Tensor<float> resultImage, StableDiffusionConfig config)
8896
var inputTensor = ClipImageFeatureExtractor(resultImage, config);
8997

9098
// images input
91-
var inputImagesTensor = ReorderTensor(inputTensor);
99+
if(device == "NvTensorRTRTXExecutionProvider")
100+
{
101+
var inputImagesTensor = inputTensor;
102+
}
103+
else
104+
{
105+
var inputImagesTensor = ReorderTensor(inputTensor);
106+
}
92107

93108
var input = new List<NamedOnnxValue>
94109
{

AIDevGallery/Samples/SharedCode/StableDiffusionCode/StableDiffusion.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,17 @@ private Task<InferenceSession> GetInferenceSession(string modelPath, ExecutionPr
8383
SessionOptions sessionOptions = new();
8484
sessionOptions.RegisterOrtExtensions();
8585

86+
if(device == "NvTensorRTRTXExecutionProvider")
87+
{
88+
sessionOptions.AddFreeDimensionOverrideByName("batch", 2);
89+
sessionOptions.AddFreeDimensionOverrideByName("time_batch", 1);
90+
sessionOptions.AddFreeDimensionOverrideByName("channels", 4);
91+
sessionOptions.AddFreeDimensionOverrideByName("height", 64);
92+
sessionOptions.AddFreeDimensionOverrideByName("width", 64);
93+
sessionOptions.AddFreeDimensionOverrideByName("sequence", 77);
94+
}
95+
96+
8697
if (policy != null)
8798
{
8899
sessionOptions.SetEpSelectionPolicy(policy.Value);

AIDevGallery/Samples/SharedCode/StableDiffusionCode/TextProcessing.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@ private Task<InferenceSession> GetInferenceSession(string modelPath, ExecutionPr
6060
SessionOptions sessionOptions = new();
6161
sessionOptions.RegisterOrtExtensions();
6262

63+
if(device == "NvTensorRTRTXExecutionProvider")
64+
{
65+
sessionOptions.AddFreeDimensionOverrideByName("batch", 1);
66+
sessionOptions.AddFreeDimensionOverrideByName("channels", 4);
67+
sessionOptions.AddFreeDimensionOverrideByName("height", 512);
68+
sessionOptions.AddFreeDimensionOverrideByName("width", 512);
69+
}
70+
6371
if (policy != null)
6472
{
6573
sessionOptions.SetEpSelectionPolicy(policy.Value);

AIDevGallery/Samples/SharedCode/StableDiffusionCode/VaeDecoder.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ private Task<InferenceSession> GetInferenceSession(string modelPath, ExecutionPr
5757
SessionOptions sessionOptions = new();
5858
sessionOptions.RegisterOrtExtensions();
5959

60+
if(device == "NvTensorRTRTXExecutionProvider")
61+
{
62+
sessionOptions.AddFreeDimensionOverrideByName("batch", 1);
63+
sessionOptions.AddFreeDimensionOverrideByName("channels", 4);
64+
sessionOptions.AddFreeDimensionOverrideByName("height", 64);
65+
sessionOptions.AddFreeDimensionOverrideByName("width", 64);
66+
}
67+
6068
if (policy != null)
6169
{
6270
sessionOptions.SetEpSelectionPolicy(policy.Value);

0 commit comments

Comments
 (0)