Skip to content

Commit b6de3ca

Browse files
committed
Modify model inputs + use height and width from config
1 parent 6456e79 commit b6de3ca

File tree

4 files changed

+26
-9
lines changed

4 files changed

+26
-9
lines changed

AIDevGallery/Samples/SharedCode/StableDiffusionCode/SafetyChecker.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ public bool IsNotSafe(Tensor<float> resultImage, StableDiffusionConfig config)
9393
var inputTensor = ClipImageFeatureExtractor(resultImage, config);
9494

9595
// images input
96-
var inputImagesTensor = ReorderTensor(inputTensor);
96+
var inputImagesTensor = inputTensor;
9797

9898
var input = new List<NamedOnnxValue>
9999
{

AIDevGallery/Samples/SharedCode/StableDiffusionCode/StableDiffusion.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,9 @@ private Task<InferenceSession> GetInferenceSession(string modelPath, ExecutionPr
8484
sessionOptions.RegisterOrtExtensions();
8585

8686
sessionOptions.AddFreeDimensionOverrideByName("batch", 2);
87-
sessionOptions.AddFreeDimensionOverrideByName("time_batch", 1);
8887
sessionOptions.AddFreeDimensionOverrideByName("channels", 4);
89-
sessionOptions.AddFreeDimensionOverrideByName("height", 64);
90-
sessionOptions.AddFreeDimensionOverrideByName("width", 64);
88+
sessionOptions.AddFreeDimensionOverrideByName("height", config.Height / 8);
89+
sessionOptions.AddFreeDimensionOverrideByName("width", config.Width / 8);
9190
sessionOptions.AddFreeDimensionOverrideByName("sequence", 77);
9291

9392

@@ -116,7 +115,7 @@ public static List<NamedOnnxValue> CreateUnetModelInput(Tensor<float> encoderHid
116115
{
117116
NamedOnnxValue.CreateFromTensor("encoder_hidden_states", encoderHiddenStates),
118117
NamedOnnxValue.CreateFromTensor("sample", sample),
119-
NamedOnnxValue.CreateFromTensor("timestep", new DenseTensor<long>(new long[] { timeStep }, [ 1 ]))
118+
NamedOnnxValue.CreateFromTensor("timestep", new DenseTensor<long>(new long[] { timeStep, timeStep }, [ 2 ]))
120119
};
121120

122121
return input;

AIDevGallery/Samples/SharedCode/StableDiffusionCode/TextProcessing.cs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@ internal class TextProcessing : IDisposable
1818
private InferenceSession? encoderInferenceSession;
1919
private bool disposedValue;
2020

21+
private readonly StableDiffusionConfig config = new()
22+
{
23+
// Number of denoising steps
24+
NumInferenceSteps = 15,
25+
26+
// Scale for classifier-free guidance
27+
GuidanceScale = 7.5
28+
};
29+
2130
private TextProcessing()
2231
{
2332
}
@@ -62,8 +71,8 @@ private Task<InferenceSession> GetInferenceSession(string modelPath, ExecutionPr
6271

6372
sessionOptions.AddFreeDimensionOverrideByName("batch", 1);
6473
sessionOptions.AddFreeDimensionOverrideByName("channels", 3);
65-
sessionOptions.AddFreeDimensionOverrideByName("height", 512);
66-
sessionOptions.AddFreeDimensionOverrideByName("width", 512);
74+
sessionOptions.AddFreeDimensionOverrideByName("height", config.Height);
75+
sessionOptions.AddFreeDimensionOverrideByName("width", config.Width);
6776

6877
if (policy != null)
6978
{

AIDevGallery/Samples/SharedCode/StableDiffusionCode/VaeDecoder.cs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ internal class VaeDecoder : IDisposable
1717
private InferenceSession? vaeDecoderInferenceSession;
1818
private bool disposedValue;
1919

20+
private readonly StableDiffusionConfig config = new()
21+
{
22+
// Number of denoising steps
23+
NumInferenceSteps = 15,
24+
25+
// Scale for classifier-free guidance
26+
GuidanceScale = 7.5
27+
};
28+
2029
private VaeDecoder()
2130
{
2231
}
@@ -59,8 +68,8 @@ private Task<InferenceSession> GetInferenceSession(string modelPath, ExecutionPr
5968

6069
sessionOptions.AddFreeDimensionOverrideByName("batch", 1);
6170
sessionOptions.AddFreeDimensionOverrideByName("channels", 4);
62-
sessionOptions.AddFreeDimensionOverrideByName("height", 64);
63-
sessionOptions.AddFreeDimensionOverrideByName("width", 64);
71+
sessionOptions.AddFreeDimensionOverrideByName("height", config.Height / 8);
72+
sessionOptions.AddFreeDimensionOverrideByName("width", config.Width / 8);
6473

6574
if (policy != null)
6675
{

0 commit comments

Comments
 (0)