Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 51f0d76

Browse files
committed
Split processes to accommodate Inpainting
1 parent 6bceed8 commit 51f0d76

File tree

12 files changed

+376
-72
lines changed

12 files changed

+376
-72
lines changed

OnnxStack.StableDiffusion/Common/ISchedulerService.cs renamed to OnnxStack.StableDiffusion/Common/IDiffuserService.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
namespace OnnxStack.StableDiffusion.Common
88
{
9-
public interface ISchedulerService
9+
public interface IDiffuserService
1010
{
1111

1212
/// <summary>
13-
/// Runs the specified Scheduler with the prompt inputs provided.
13+
/// Runs the specified Diffuser with the prompt inputs provided.
1414
/// </summary>
1515
/// <param name="prompt">The prompt.</param>
1616
/// <param name="options">The options.</param>

OnnxStack.StableDiffusion/Config/PromptOptions.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
using OnnxStack.StableDiffusion.Enums;
22
using OnnxStack.StableDiffusion.Models;
33
using System.ComponentModel.DataAnnotations;
4-
using System.Text.Json.Serialization;
54

65
namespace OnnxStack.StableDiffusion.Config
76
{
87
public class PromptOptions
98
{
9+
public ProcessType ProcessType { get; set; }
10+
1011
[Required]
1112
[StringLength(512, MinimumLength = 4)]
1213
public string Prompt { get; set; }
@@ -22,4 +23,11 @@ public class PromptOptions
2223
public bool HasInputImage => InputImage?.HasImage ?? false;
2324
public bool HasInputImageMask => InputImageMask?.HasImage ?? false;
2425
}
26+
27+
public enum ProcessType
28+
{
29+
TextToImage = 0,
30+
ImageToImage = 1,
31+
ImageInpaint = 2
32+
}
2533
}

OnnxStack.StableDiffusion/Services/SchedulerService.cs renamed to OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 38 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,32 @@
11
using Microsoft.ML.OnnxRuntime;
22
using Microsoft.ML.OnnxRuntime.Tensors;
3-
using OnnxStack.Core;
43
using OnnxStack.Core.Config;
54
using OnnxStack.Core.Services;
65
using OnnxStack.StableDiffusion.Common;
76
using OnnxStack.StableDiffusion.Config;
87
using OnnxStack.StableDiffusion.Enums;
98
using OnnxStack.StableDiffusion.Helpers;
109
using OnnxStack.StableDiffusion.Schedulers;
11-
using SixLabors.ImageSharp;
1210
using System;
1311
using System.Collections.Generic;
1412
using System.Linq;
1513
using System.Threading;
1614
using System.Threading.Tasks;
1715

18-
19-
namespace OnnxStack.StableDiffusion.Services
16+
namespace OnnxStack.StableDiffusion.Diffusers
2017
{
21-
public sealed class SchedulerService : ISchedulerService
18+
public abstract class DiffuserBase : IDiffuser
2219
{
23-
private readonly IPromptService _promptService;
24-
private readonly OnnxStackConfig _configuration;
25-
private readonly IOnnxModelService _onnxModelService;
20+
protected readonly IPromptService _promptService;
21+
protected readonly OnnxStackConfig _configuration;
22+
protected readonly IOnnxModelService _onnxModelService;
2623

2724
/// <summary>
28-
/// Initializes a new instance of the <see cref="SchedulerService"/> class.
25+
/// Initializes a new instance of the <see cref="DiffuserBase"/> class.
2926
/// </summary>
3027
/// <param name="configuration">The configuration.</param>
3128
/// <param name="onnxModelService">The onnx model service.</param>
32-
public SchedulerService(IOnnxModelService onnxModelService, IPromptService promptService)
29+
public DiffuserBase(IOnnxModelService onnxModelService, IPromptService promptService)
3330
{
3431
_promptService = promptService;
3532
_onnxModelService = onnxModelService;
@@ -38,12 +35,34 @@ public SchedulerService(IOnnxModelService onnxModelService, IPromptService promp
3835

3936

4037
/// <summary>
41-
/// Runs the Stable Diffusion inference.
38+
/// Gets the timesteps.
39+
/// </summary>
40+
/// <param name="prompt">The prompt.</param>
41+
/// <param name="options">The options.</param>
42+
/// <param name="scheduler">The scheduler.</param>
43+
/// <returns></returns>
44+
protected abstract IReadOnlyList<int> GetTimesteps(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler);
45+
46+
/// <summary>
47+
/// Prepares the latents.
48+
/// </summary>
49+
/// <param name="prompt">The prompt.</param>
50+
/// <param name="options">The options.</param>
51+
/// <param name="scheduler">The scheduler.</param>
52+
/// <param name="timesteps">The timesteps.</param>
53+
/// <returns></returns>
54+
protected abstract DenseTensor<float> PrepareLatents(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps);
55+
56+
57+
/// <summary>
58+
/// Rund the stable diffusion loop
4259
/// </summary>
43-
/// <param name="promptOptions">The options.</param>
44-
/// <param name="schedulerOptions">The scheduler configuration.</param>
60+
/// <param name="promptOptions">The prompt options.</param>
61+
/// <param name="schedulerOptions">The scheduler options.</param>
62+
/// <param name="progress">The progress.</param>
63+
/// <param name="cancellationToken">The cancellation token.</param>
4564
/// <returns></returns>
46-
public async Task<DenseTensor<float>> RunAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<int, int> progress = null, CancellationToken cancellationToken = default)
65+
public virtual async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<int, int> progress = null, CancellationToken cancellationToken = default)
4766
{
4867
// Create random seed if none was set
4968
schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next();
@@ -103,53 +122,13 @@ public async Task<DenseTensor<float>> RunAsync(PromptOptions promptOptions, Sche
103122
}
104123
}
105124

106-
private IReadOnlyList<int> GetTimesteps(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler)
107-
{
108-
if (!prompt.HasInputImage)
109-
return scheduler.Timesteps;
110-
111-
// Image2Image we narrow step the range by the Strength
112-
var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps);
113-
var start = Math.Max(options.InferenceSteps - inittimestep, 0);
114-
return scheduler.Timesteps.Skip(start).ToList();
115-
}
116-
117-
/// <summary>
118-
/// Prepares the latents for inference.
119-
/// </summary>
120-
/// <param name="prompt">The prompt.</param>
121-
/// <param name="options">The options.</param>
122-
/// <param name="scheduler">The scheduler.</param>
123-
/// <returns></returns>
124-
private DenseTensor<float> PrepareLatents(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
125-
{
126-
// If we dont have an initial image create random sample
127-
if (!prompt.HasInputImage)
128-
return scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma);
129-
130-
// Image input, decode, add noise, return as latent 0
131-
var imageTensor = prompt.InputImage.ToDenseTensor(options.Width, options.Height);
132-
var inputNames = _onnxModelService.GetInputNames(OnnxModelType.VaeEncoder);
133-
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], imageTensor));
134-
using (var inferResult = _onnxModelService.RunInference(OnnxModelType.VaeEncoder, inputParameters))
135-
{
136-
var sample = inferResult.FirstElementAs<DenseTensor<float>>();
137-
var noisySample = sample
138-
.AddTensors(scheduler.CreateRandomSample(sample.Dimensions, options.InitialNoiseLevel))
139-
.MultipleTensorByFloat(_configuration.ScaleFactor);
140-
var noise = scheduler.CreateRandomSample(sample.Dimensions);
141-
return scheduler.AddNoise(noisySample, noise, timesteps);
142-
}
143-
}
144-
145-
146125
/// <summary>
147126
/// Decodes the latents.
148127
/// </summary>
149128
/// <param name="options">The options.</param>
150129
/// <param name="latents">The latents.</param>
151130
/// <returns></returns>
152-
private async Task<DenseTensor<float>> DecodeLatents(SchedulerOptions options, DenseTensor<float> latents)
131+
protected async Task<DenseTensor<float>> DecodeLatents(SchedulerOptions options, DenseTensor<float> latents)
153132
{
154133
// Scale and decode the image latents with vae.
155134
// latents = 1 / 0.18215 * latents
@@ -181,7 +160,7 @@ private async Task<DenseTensor<float>> DecodeLatents(SchedulerOptions options, D
181160
/// <returns>
182161
/// <c>true</c> if the specified result image is safe; otherwise, <c>false</c>.
183162
/// </returns>
184-
private async Task<bool> IsImageSafe(SchedulerOptions options, DenseTensor<float> resultImage)
163+
protected async Task<bool> IsImageSafe(SchedulerOptions options, DenseTensor<float> resultImage)
185164
{
186165
//clip input
187166
var inputTensor = ClipImageFeatureExtractor(options, resultImage);
@@ -207,7 +186,7 @@ private async Task<bool> IsImageSafe(SchedulerOptions options, DenseTensor<float
207186
/// </summary>
208187
/// <param name="imageTensor">The image tensor.</param>
209188
/// <returns></returns>
210-
private static DenseTensor<float> ClipImageFeatureExtractor(SchedulerOptions options, DenseTensor<float> imageTensor)
189+
protected static DenseTensor<float> ClipImageFeatureExtractor(SchedulerOptions options, DenseTensor<float> imageTensor)
211190
{
212191
//convert tensor result to image
213192
using (var image = imageTensor.ToImage())
@@ -243,7 +222,7 @@ private static DenseTensor<float> ClipImageFeatureExtractor(SchedulerOptions opt
243222
/// <param name="options">The options.</param>
244223
/// <param name="schedulerConfig">The scheduler configuration.</param>
245224
/// <returns></returns>
246-
private static IScheduler GetScheduler(PromptOptions prompt, SchedulerOptions options)
225+
protected static IScheduler GetScheduler(PromptOptions prompt, SchedulerOptions options)
247226
{
248227
return prompt.SchedulerType switch
249228
{
@@ -259,7 +238,7 @@ private static IScheduler GetScheduler(PromptOptions prompt, SchedulerOptions op
259238
/// </summary>
260239
/// <param name="parameters">The parameters.</param>
261240
/// <returns></returns>
262-
private static IReadOnlyCollection<NamedOnnxValue> CreateInputParameters(params NamedOnnxValue[] parameters)
241+
protected static IReadOnlyCollection<NamedOnnxValue> CreateInputParameters(params NamedOnnxValue[] parameters)
263242
{
264243
return parameters.ToList().AsReadOnly();
265244
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using Microsoft.ML.OnnxRuntime.Tensors;
2+
using OnnxStack.StableDiffusion.Config;
3+
using System;
4+
using System.Threading;
5+
using System.Threading.Tasks;
6+
7+
namespace OnnxStack.StableDiffusion.Diffusers
8+
{
9+
public interface IDiffuser
10+
{
11+
Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<int, int> progress = null, CancellationToken cancellationToken = default);
12+
}
13+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
using Microsoft.ML.OnnxRuntime;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core.Config;
4+
using OnnxStack.Core.Services;
5+
using OnnxStack.StableDiffusion.Common;
6+
using OnnxStack.StableDiffusion.Config;
7+
using OnnxStack.StableDiffusion.Diffusers;
8+
using OnnxStack.StableDiffusion.Helpers;
9+
using SixLabors.ImageSharp;
10+
using System;
11+
using System.Collections.Generic;
12+
using System.Linq;
13+
14+
15+
namespace OnnxStack.StableDiffusion.Services
16+
{
17+
public sealed class ImageDiffuser : DiffuserBase
18+
{
19+
/// <summary>
20+
/// Initializes a new instance of the <see cref="ImageDiffuser"/> class.
21+
/// </summary>
22+
/// <param name="configuration">The configuration.</param>
23+
/// <param name="onnxModelService">The onnx model service.</param>
24+
public ImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptService)
25+
:base(onnxModelService, promptService)
26+
{
27+
}
28+
29+
30+
/// <summary>
31+
/// Gets the timesteps.
32+
/// </summary>
33+
/// <param name="prompt">The prompt.</param>
34+
/// <param name="options">The options.</param>
35+
/// <param name="scheduler">The scheduler.</param>
36+
/// <returns></returns>
37+
protected override IReadOnlyList<int> GetTimesteps(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler)
38+
{
39+
// Image2Image we narrow step the range by the Strength
40+
var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps);
41+
var start = Math.Max(options.InferenceSteps - inittimestep, 0);
42+
return scheduler.Timesteps.Skip(start).ToList();
43+
}
44+
45+
46+
/// <summary>
47+
/// Prepares the latents for inference.
48+
/// </summary>
49+
/// <param name="prompt">The prompt.</param>
50+
/// <param name="options">The options.</param>
51+
/// <param name="scheduler">The scheduler.</param>
52+
/// <returns></returns>
53+
protected override DenseTensor<float> PrepareLatents(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
54+
{
55+
// Image input, decode, add noise, return as latent 0
56+
var imageTensor = prompt.InputImage.ToDenseTensor(options.Width, options.Height);
57+
var inputNames = _onnxModelService.GetInputNames(OnnxModelType.VaeEncoder);
58+
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], imageTensor));
59+
using (var inferResult = _onnxModelService.RunInference(OnnxModelType.VaeEncoder, inputParameters))
60+
{
61+
var sample = inferResult.FirstElementAs<DenseTensor<float>>();
62+
var noisySample = sample
63+
.AddTensors(scheduler.CreateRandomSample(sample.Dimensions, options.InitialNoiseLevel))
64+
.MultipleTensorByFloat(_configuration.ScaleFactor);
65+
var noise = scheduler.CreateRandomSample(sample.Dimensions);
66+
return scheduler.AddNoise(noisySample, noise, timesteps);
67+
}
68+
}
69+
70+
}
71+
}

0 commit comments

Comments
 (0)