1
1
using Microsoft . ML . OnnxRuntime ;
2
2
using Microsoft . ML . OnnxRuntime . Tensors ;
3
- using OnnxStack . Core ;
4
3
using OnnxStack . Core . Config ;
5
4
using OnnxStack . Core . Services ;
6
5
using OnnxStack . StableDiffusion . Common ;
7
6
using OnnxStack . StableDiffusion . Config ;
8
7
using OnnxStack . StableDiffusion . Enums ;
9
8
using OnnxStack . StableDiffusion . Helpers ;
10
9
using OnnxStack . StableDiffusion . Schedulers ;
11
- using SixLabors . ImageSharp ;
12
10
using System ;
13
11
using System . Collections . Generic ;
14
12
using System . Linq ;
15
13
using System . Threading ;
16
14
using System . Threading . Tasks ;
17
15
18
-
19
- namespace OnnxStack . StableDiffusion . Services
16
+ namespace OnnxStack . StableDiffusion . Diffusers
20
17
{
21
- public sealed class SchedulerService : ISchedulerService
18
+ public abstract class DiffuserBase : IDiffuser
22
19
{
23
- private readonly IPromptService _promptService ;
24
- private readonly OnnxStackConfig _configuration ;
25
- private readonly IOnnxModelService _onnxModelService ;
20
+ protected readonly IPromptService _promptService ;
21
+ protected readonly OnnxStackConfig _configuration ;
22
+ protected readonly IOnnxModelService _onnxModelService ;
26
23
27
24
/// <summary>
28
- /// Initializes a new instance of the <see cref="SchedulerService "/> class.
25
+ /// Initializes a new instance of the <see cref="DiffuserBase "/> class.
29
26
/// </summary>
30
27
/// <param name="configuration">The configuration.</param>
31
28
/// <param name="onnxModelService">The onnx model service.</param>
32
- public SchedulerService ( IOnnxModelService onnxModelService , IPromptService promptService )
29
+ public DiffuserBase ( IOnnxModelService onnxModelService , IPromptService promptService )
33
30
{
34
31
_promptService = promptService ;
35
32
_onnxModelService = onnxModelService ;
@@ -38,12 +35,34 @@ public SchedulerService(IOnnxModelService onnxModelService, IPromptService promp
38
35
39
36
40
37
/// <summary>
41
- /// Runs the Stable Diffusion inference.
38
+ /// Gets the timesteps.
39
+ /// </summary>
40
+ /// <param name="prompt">The prompt.</param>
41
+ /// <param name="options">The options.</param>
42
+ /// <param name="scheduler">The scheduler.</param>
43
+ /// <returns></returns>
44
+ protected abstract IReadOnlyList < int > GetTimesteps ( PromptOptions prompt , SchedulerOptions options , IScheduler scheduler ) ;
45
+
46
+ /// <summary>
47
+ /// Prepares the latents.
48
+ /// </summary>
49
+ /// <param name="prompt">The prompt.</param>
50
+ /// <param name="options">The options.</param>
51
+ /// <param name="scheduler">The scheduler.</param>
52
+ /// <param name="timesteps">The timesteps.</param>
53
+ /// <returns></returns>
54
+ protected abstract DenseTensor < float > PrepareLatents ( PromptOptions prompt , SchedulerOptions options , IScheduler scheduler , IReadOnlyList < int > timesteps ) ;
55
+
56
+
57
+ /// <summary>
58
+ /// Rund the stable diffusion loop
42
59
/// </summary>
43
- /// <param name="promptOptions">The options.</param>
44
- /// <param name="schedulerOptions">The scheduler configuration.</param>
60
+ /// <param name="promptOptions">The prompt options.</param>
61
+ /// <param name="schedulerOptions">The scheduler options.</param>
62
+ /// <param name="progress">The progress.</param>
63
+ /// <param name="cancellationToken">The cancellation token.</param>
45
64
/// <returns></returns>
46
- public async Task < DenseTensor < float > > RunAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , Action < int , int > progress = null , CancellationToken cancellationToken = default )
65
+ public virtual async Task < DenseTensor < float > > DiffuseAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , Action < int , int > progress = null , CancellationToken cancellationToken = default )
47
66
{
48
67
// Create random seed if none was set
49
68
schedulerOptions . Seed = schedulerOptions . Seed > 0 ? schedulerOptions . Seed : Random . Shared . Next ( ) ;
@@ -103,53 +122,13 @@ public async Task<DenseTensor<float>> RunAsync(PromptOptions promptOptions, Sche
103
122
}
104
123
}
105
124
106
- private IReadOnlyList < int > GetTimesteps ( PromptOptions prompt , SchedulerOptions options , IScheduler scheduler )
107
- {
108
- if ( ! prompt . HasInputImage )
109
- return scheduler . Timesteps ;
110
-
111
- // Image2Image we narrow step the range by the Strength
112
- var inittimestep = Math . Min ( ( int ) ( options . InferenceSteps * options . Strength ) , options . InferenceSteps ) ;
113
- var start = Math . Max ( options . InferenceSteps - inittimestep , 0 ) ;
114
- return scheduler . Timesteps . Skip ( start ) . ToList ( ) ;
115
- }
116
-
117
- /// <summary>
118
- /// Prepares the latents for inference.
119
- /// </summary>
120
- /// <param name="prompt">The prompt.</param>
121
- /// <param name="options">The options.</param>
122
- /// <param name="scheduler">The scheduler.</param>
123
- /// <returns></returns>
124
- private DenseTensor < float > PrepareLatents ( PromptOptions prompt , SchedulerOptions options , IScheduler scheduler , IReadOnlyList < int > timesteps )
125
- {
126
- // If we dont have an initial image create random sample
127
- if ( ! prompt . HasInputImage )
128
- return scheduler . CreateRandomSample ( options . GetScaledDimension ( ) , scheduler . InitNoiseSigma ) ;
129
-
130
- // Image input, decode, add noise, return as latent 0
131
- var imageTensor = prompt . InputImage . ToDenseTensor ( options . Width , options . Height ) ;
132
- var inputNames = _onnxModelService . GetInputNames ( OnnxModelType . VaeEncoder ) ;
133
- var inputParameters = CreateInputParameters ( NamedOnnxValue . CreateFromTensor ( inputNames [ 0 ] , imageTensor ) ) ;
134
- using ( var inferResult = _onnxModelService . RunInference ( OnnxModelType . VaeEncoder , inputParameters ) )
135
- {
136
- var sample = inferResult . FirstElementAs < DenseTensor < float > > ( ) ;
137
- var noisySample = sample
138
- . AddTensors ( scheduler . CreateRandomSample ( sample . Dimensions , options . InitialNoiseLevel ) )
139
- . MultipleTensorByFloat ( _configuration . ScaleFactor ) ;
140
- var noise = scheduler . CreateRandomSample ( sample . Dimensions ) ;
141
- return scheduler . AddNoise ( noisySample , noise , timesteps ) ;
142
- }
143
- }
144
-
145
-
146
125
/// <summary>
147
126
/// Decodes the latents.
148
127
/// </summary>
149
128
/// <param name="options">The options.</param>
150
129
/// <param name="latents">The latents.</param>
151
130
/// <returns></returns>
152
- private async Task < DenseTensor < float > > DecodeLatents ( SchedulerOptions options , DenseTensor < float > latents )
131
+ protected async Task < DenseTensor < float > > DecodeLatents ( SchedulerOptions options , DenseTensor < float > latents )
153
132
{
154
133
// Scale and decode the image latents with vae.
155
134
// latents = 1 / 0.18215 * latents
@@ -181,7 +160,7 @@ private async Task<DenseTensor<float>> DecodeLatents(SchedulerOptions options, D
181
160
/// <returns>
182
161
/// <c>true</c> if the specified result image is safe; otherwise, <c>false</c>.
183
162
/// </returns>
184
- private async Task < bool > IsImageSafe ( SchedulerOptions options , DenseTensor < float > resultImage )
163
+ protected async Task < bool > IsImageSafe ( SchedulerOptions options , DenseTensor < float > resultImage )
185
164
{
186
165
//clip input
187
166
var inputTensor = ClipImageFeatureExtractor ( options , resultImage ) ;
@@ -207,7 +186,7 @@ private async Task<bool> IsImageSafe(SchedulerOptions options, DenseTensor<float
207
186
/// </summary>
208
187
/// <param name="imageTensor">The image tensor.</param>
209
188
/// <returns></returns>
210
- private static DenseTensor < float > ClipImageFeatureExtractor ( SchedulerOptions options , DenseTensor < float > imageTensor )
189
+ protected static DenseTensor < float > ClipImageFeatureExtractor ( SchedulerOptions options , DenseTensor < float > imageTensor )
211
190
{
212
191
//convert tensor result to image
213
192
using ( var image = imageTensor . ToImage ( ) )
@@ -243,7 +222,7 @@ private static DenseTensor<float> ClipImageFeatureExtractor(SchedulerOptions opt
243
222
/// <param name="options">The options.</param>
244
223
/// <param name="schedulerConfig">The scheduler configuration.</param>
245
224
/// <returns></returns>
246
- private static IScheduler GetScheduler ( PromptOptions prompt , SchedulerOptions options )
225
+ protected static IScheduler GetScheduler ( PromptOptions prompt , SchedulerOptions options )
247
226
{
248
227
return prompt . SchedulerType switch
249
228
{
@@ -259,7 +238,7 @@ private static IScheduler GetScheduler(PromptOptions prompt, SchedulerOptions op
259
238
/// </summary>
260
239
/// <param name="parameters">The parameters.</param>
261
240
/// <returns></returns>
262
- private static IReadOnlyCollection < NamedOnnxValue > CreateInputParameters ( params NamedOnnxValue [ ] parameters )
241
+ protected static IReadOnlyCollection < NamedOnnxValue > CreateInputParameters ( params NamedOnnxValue [ ] parameters )
263
242
{
264
243
return parameters . ToList ( ) . AsReadOnly ( ) ;
265
244
}
0 commit comments