Skip to content

Commit 58cdaa9

Browse files
authored
Merge pull request microsoft#393 from keshavv27/rel/v0.4.2
[StableDiffusion Sample][NvTensorRTRTX EP] Add free dimension overrides for SD1.4 components
2 parents 74ec8c0 + befe00c commit 58cdaa9

File tree

4 files changed

+41
-2
lines changed

4 files changed

+41
-2
lines changed

AIDevGallery/Samples/SharedCode/StableDiffusionCode/SafetyChecker.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ private Task<InferenceSession> GetInferenceSession(string modelPath, ExecutionPr
5858
SessionOptions sessionOptions = new();
5959
sessionOptions.RegisterOrtExtensions();
6060

61+
sessionOptions.AddFreeDimensionOverrideByName("batch", 1);
62+
sessionOptions.AddFreeDimensionOverrideByName("channels", 3);
63+
sessionOptions.AddFreeDimensionOverrideByName("height", 224);
64+
sessionOptions.AddFreeDimensionOverrideByName("width", 224);
65+
6166
if (policy != null)
6267
{
6368
sessionOptions.SetEpSelectionPolicy(policy.Value);
@@ -88,7 +93,7 @@ public bool IsNotSafe(Tensor<float> resultImage, StableDiffusionConfig config)
8893
var inputTensor = ClipImageFeatureExtractor(resultImage, config);
8994

9095
// images input
91-
var inputImagesTensor = ReorderTensor(inputTensor);
96+
var inputImagesTensor = inputTensor;
9297

9398
var input = new List<NamedOnnxValue>
9499
{

AIDevGallery/Samples/SharedCode/StableDiffusionCode/StableDiffusion.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ private Task<InferenceSession> GetInferenceSession(string modelPath, ExecutionPr
8383
SessionOptions sessionOptions = new();
8484
sessionOptions.RegisterOrtExtensions();
8585

86+
sessionOptions.AddFreeDimensionOverrideByName("batch", 2);
87+
sessionOptions.AddFreeDimensionOverrideByName("channels", 4);
88+
sessionOptions.AddFreeDimensionOverrideByName("height", config.Height / 8);
89+
sessionOptions.AddFreeDimensionOverrideByName("width", config.Width / 8);
90+
sessionOptions.AddFreeDimensionOverrideByName("sequence", 77);
91+
8692
if (policy != null)
8793
{
8894
sessionOptions.SetEpSelectionPolicy(policy.Value);
@@ -108,7 +114,7 @@ public static List<NamedOnnxValue> CreateUnetModelInput(Tensor<float> encoderHid
108114
{
109115
NamedOnnxValue.CreateFromTensor("encoder_hidden_states", encoderHiddenStates),
110116
NamedOnnxValue.CreateFromTensor("sample", sample),
111-
NamedOnnxValue.CreateFromTensor("timestep", new DenseTensor<long>(new long[] { timeStep }, [ 1 ]))
117+
NamedOnnxValue.CreateFromTensor("timestep", new DenseTensor<long>(new long[] { timeStep, timeStep }, [ 2 ]))
112118
};
113119

114120
return input;

AIDevGallery/Samples/SharedCode/StableDiffusionCode/TextProcessing.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@ namespace AIDevGallery.Samples.SharedCode.StableDiffusionCode;
1414

1515
internal class TextProcessing : IDisposable
1616
{
17+
private readonly StableDiffusionConfig config = new()
18+
{
19+
// Number of denoising steps
20+
NumInferenceSteps = 15,
21+
22+
// Scale for classifier-free guidance
23+
GuidanceScale = 7.5
24+
};
25+
1726
private InferenceSession? tokenizerInferenceSession;
1827
private InferenceSession? encoderInferenceSession;
1928
private bool disposedValue;
@@ -60,6 +69,11 @@ private Task<InferenceSession> GetInferenceSession(string modelPath, ExecutionPr
6069
SessionOptions sessionOptions = new();
6170
sessionOptions.RegisterOrtExtensions();
6271

72+
sessionOptions.AddFreeDimensionOverrideByName("batch", 1);
73+
sessionOptions.AddFreeDimensionOverrideByName("channels", 3);
74+
sessionOptions.AddFreeDimensionOverrideByName("height", config.Height);
75+
sessionOptions.AddFreeDimensionOverrideByName("width", config.Width);
76+
6377
if (policy != null)
6478
{
6579
sessionOptions.SetEpSelectionPolicy(policy.Value);

AIDevGallery/Samples/SharedCode/StableDiffusionCode/VaeDecoder.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@ namespace AIDevGallery.Samples.SharedCode.StableDiffusionCode;
1414

1515
internal class VaeDecoder : IDisposable
1616
{
17+
private readonly StableDiffusionConfig config = new()
18+
{
19+
// Number of denoising steps
20+
NumInferenceSteps = 15,
21+
22+
// Scale for classifier-free guidance
23+
GuidanceScale = 7.5
24+
};
25+
1726
private InferenceSession? vaeDecoderInferenceSession;
1827
private bool disposedValue;
1928

@@ -57,6 +66,11 @@ private Task<InferenceSession> GetInferenceSession(string modelPath, ExecutionPr
5766
SessionOptions sessionOptions = new();
5867
sessionOptions.RegisterOrtExtensions();
5968

69+
sessionOptions.AddFreeDimensionOverrideByName("batch", 1);
70+
sessionOptions.AddFreeDimensionOverrideByName("channels", 4);
71+
sessionOptions.AddFreeDimensionOverrideByName("height", config.Height / 8);
72+
sessionOptions.AddFreeDimensionOverrideByName("width", config.Width / 8);
73+
6074
if (policy != null)
6175
{
6276
sessionOptions.SetEpSelectionPolicy(policy.Value);

0 commit comments

Comments
 (0)