Skip to content

Commit 6456e79

Browse files
committed
Fix channel dimension in textprocessing and make dimension overrides EP-independent
1 parent 18b6d3e commit 6456e79

File tree

4 files changed

+19
-31
lines changed

4 files changed

+19
-31
lines changed

AIDevGallery/Samples/SharedCode/StableDiffusionCode/SafetyChecker.cs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,10 @@ private Task<InferenceSession> GetInferenceSession(string modelPath, ExecutionPr
5858
SessionOptions sessionOptions = new();
5959
sessionOptions.RegisterOrtExtensions();
6060

61-
if(device == "NvTensorRTRTXExecutionProvider")
62-
{
63-
sessionOptions.AddFreeDimensionOverrideByName("batch", 1);
64-
sessionOptions.AddFreeDimensionOverrideByName("channels", 3);
65-
sessionOptions.AddFreeDimensionOverrideByName("height", 224);
66-
sessionOptions.AddFreeDimensionOverrideByName("width", 224);
67-
}
61+
sessionOptions.AddFreeDimensionOverrideByName("batch", 1);
62+
sessionOptions.AddFreeDimensionOverrideByName("channels", 3);
63+
sessionOptions.AddFreeDimensionOverrideByName("height", 224);
64+
sessionOptions.AddFreeDimensionOverrideByName("width", 224);
6865

6966
if (policy != null)
7067
{

AIDevGallery/Samples/SharedCode/StableDiffusionCode/StableDiffusion.cs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,12 @@ private Task<InferenceSession> GetInferenceSession(string modelPath, ExecutionPr
8383
SessionOptions sessionOptions = new();
8484
sessionOptions.RegisterOrtExtensions();
8585

86-
if(device == "NvTensorRTRTXExecutionProvider")
87-
{
88-
sessionOptions.AddFreeDimensionOverrideByName("batch", 2);
89-
sessionOptions.AddFreeDimensionOverrideByName("time_batch", 1);
90-
sessionOptions.AddFreeDimensionOverrideByName("channels", 4);
91-
sessionOptions.AddFreeDimensionOverrideByName("height", 64);
92-
sessionOptions.AddFreeDimensionOverrideByName("width", 64);
93-
sessionOptions.AddFreeDimensionOverrideByName("sequence", 77);
94-
}
86+
sessionOptions.AddFreeDimensionOverrideByName("batch", 2);
87+
sessionOptions.AddFreeDimensionOverrideByName("time_batch", 1);
88+
sessionOptions.AddFreeDimensionOverrideByName("channels", 4);
89+
sessionOptions.AddFreeDimensionOverrideByName("height", 64);
90+
sessionOptions.AddFreeDimensionOverrideByName("width", 64);
91+
sessionOptions.AddFreeDimensionOverrideByName("sequence", 77);
9592

9693

9794
if (policy != null)

AIDevGallery/Samples/SharedCode/StableDiffusionCode/TextProcessing.cs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,11 @@ private Task<InferenceSession> GetInferenceSession(string modelPath, ExecutionPr
5959

6060
SessionOptions sessionOptions = new();
6161
sessionOptions.RegisterOrtExtensions();
62-
63-
if(device == "NvTensorRTRTXExecutionProvider")
64-
{
65-
sessionOptions.AddFreeDimensionOverrideByName("batch", 1);
66-
sessionOptions.AddFreeDimensionOverrideByName("channels", 4);
67-
sessionOptions.AddFreeDimensionOverrideByName("height", 512);
68-
sessionOptions.AddFreeDimensionOverrideByName("width", 512);
69-
}
62+
63+
sessionOptions.AddFreeDimensionOverrideByName("batch", 1);
64+
sessionOptions.AddFreeDimensionOverrideByName("channels", 3);
65+
sessionOptions.AddFreeDimensionOverrideByName("height", 512);
66+
sessionOptions.AddFreeDimensionOverrideByName("width", 512);
7067

7168
if (policy != null)
7269
{

AIDevGallery/Samples/SharedCode/StableDiffusionCode/VaeDecoder.cs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,10 @@ private Task<InferenceSession> GetInferenceSession(string modelPath, ExecutionPr
5757
SessionOptions sessionOptions = new();
5858
sessionOptions.RegisterOrtExtensions();
5959

60-
if(device == "NvTensorRTRTXExecutionProvider")
61-
{
62-
sessionOptions.AddFreeDimensionOverrideByName("batch", 1);
63-
sessionOptions.AddFreeDimensionOverrideByName("channels", 4);
64-
sessionOptions.AddFreeDimensionOverrideByName("height", 64);
65-
sessionOptions.AddFreeDimensionOverrideByName("width", 64);
66-
}
60+
sessionOptions.AddFreeDimensionOverrideByName("batch", 1);
61+
sessionOptions.AddFreeDimensionOverrideByName("channels", 4);
62+
sessionOptions.AddFreeDimensionOverrideByName("height", 64);
63+
sessionOptions.AddFreeDimensionOverrideByName("width", 64);
6764

6865
if (policy != null)
6966
{

0 commit comments

Comments
 (0)