Skip to content

Commit 7009ec6

Browse files
committed
Use same SD config object for all SD components
1 parent 58cdaa9 commit 7009ec6

File tree

3 files changed

+9
-25
lines changed

3 files changed

+9
-25
lines changed

AIDevGallery/Samples/SharedCode/StableDiffusionCode/StableDiffusion.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ public async Task InitializeAsync(ExecutionProviderDevicePolicy? policy, string?
4747
{
4848
string tokenizerPath = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "Assets", config.TokenizerModelPath);
4949

50-
textProcessor = await TextProcessing.CreateAsync(tokenizerPath, config.TextEncoderModelPath, policy, device, compileOption);
50+
textProcessor = await TextProcessing.CreateAsync(config, tokenizerPath, config.TextEncoderModelPath, policy, device, compileOption);
5151
unetInferenceSession = await GetInferenceSession(config.UnetModelPath, policy, device, compileOption);
52-
vaeDecoder = await VaeDecoder.CreateAsync(config.VaeDecoderModelPath, policy, device, compileOption);
52+
vaeDecoder = await VaeDecoder.CreateAsync(config, config.VaeDecoderModelPath, policy, device, compileOption);
5353
safetyChecker = await SafetyChecker.CreateAsync(config.SafetyModelPath, policy, device, compileOption);
5454
}
5555

AIDevGallery/Samples/SharedCode/StableDiffusionCode/TextProcessing.cs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,6 @@ namespace AIDevGallery.Samples.SharedCode.StableDiffusionCode;
1414

1515
internal class TextProcessing : IDisposable
1616
{
17-
private readonly StableDiffusionConfig config = new()
18-
{
19-
// Number of denoising steps
20-
NumInferenceSteps = 15,
21-
22-
// Scale for classifier-free guidance
23-
GuidanceScale = 7.5
24-
};
25-
2617
private InferenceSession? tokenizerInferenceSession;
2718
private InferenceSession? encoderInferenceSession;
2819
private bool disposedValue;
@@ -32,19 +23,20 @@ private TextProcessing()
3223
}
3324

3425
public static async Task<TextProcessing> CreateAsync(
26+
StableDiffusionConfig config,
3527
string tokenizerPath,
3628
string encoderPath,
3729
ExecutionProviderDevicePolicy? policy,
3830
string? device,
3931
bool compileOption)
4032
{
4133
var instance = new TextProcessing();
42-
instance.tokenizerInferenceSession = await instance.GetInferenceSession(tokenizerPath, policy, device, compileOption);
43-
instance.encoderInferenceSession = await instance.GetInferenceSession(encoderPath, policy, device, compileOption);
34+
instance.tokenizerInferenceSession = await instance.GetInferenceSession(config, tokenizerPath, policy, device, compileOption);
35+
instance.encoderInferenceSession = await instance.GetInferenceSession(config, encoderPath, policy, device, compileOption);
4436
return instance;
4537
}
4638

47-
private Task<InferenceSession> GetInferenceSession(string modelPath, ExecutionProviderDevicePolicy? policy, string? device, bool compileOption)
39+
private Task<InferenceSession> GetInferenceSession(StableDiffusionConfig config, string modelPath, ExecutionProviderDevicePolicy? policy, string? device, bool compileOption)
4840
{
4941
return Task.Run(async () =>
5042
{

AIDevGallery/Samples/SharedCode/StableDiffusionCode/VaeDecoder.cs

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,6 @@ namespace AIDevGallery.Samples.SharedCode.StableDiffusionCode;
1414

1515
internal class VaeDecoder : IDisposable
1616
{
17-
private readonly StableDiffusionConfig config = new()
18-
{
19-
// Number of denoising steps
20-
NumInferenceSteps = 15,
21-
22-
// Scale for classifier-free guidance
23-
GuidanceScale = 7.5
24-
};
25-
2617
private InferenceSession? vaeDecoderInferenceSession;
2718
private bool disposedValue;
2819

@@ -31,17 +22,18 @@ private VaeDecoder()
3122
}
3223

3324
public static async Task<VaeDecoder> CreateAsync(
25+
StableDiffusionConfig config,
3426
string modelPath,
3527
ExecutionProviderDevicePolicy? policy,
3628
string? device,
3729
bool compileOption)
3830
{
3931
var instance = new VaeDecoder();
40-
instance.vaeDecoderInferenceSession = await instance.GetInferenceSession(modelPath, policy, device, compileOption);
32+
instance.vaeDecoderInferenceSession = await instance.GetInferenceSession(config, modelPath, policy, device, compileOption);
4133
return instance;
4234
}
4335

44-
private Task<InferenceSession> GetInferenceSession(string modelPath, ExecutionProviderDevicePolicy? policy, string? device, bool compileOption)
36+
private Task<InferenceSession> GetInferenceSession(StableDiffusionConfig config, string modelPath, ExecutionProviderDevicePolicy? policy, string? device, bool compileOption)
4537
{
4638
return Task.Run(async () =>
4739
{

0 commit comments

Comments
 (0)