Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 0504e02

Browse files
committed
Add asynchronous batching
1 parent a8bdc65 commit 0504e02

File tree

9 files changed

+361
-114
lines changed

9 files changed

+361
-114
lines changed
Lines changed: 22 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
using OnnxStack.Core;
2-
using OnnxStack.StableDiffusion.Common;
1+
using OnnxStack.StableDiffusion.Common;
32
using OnnxStack.StableDiffusion.Config;
4-
using OnnxStack.StableDiffusion.Enums;
53
using OnnxStack.StableDiffusion.Helpers;
64
using SixLabors.ImageSharp;
75

@@ -31,68 +29,53 @@ public async Task RunAsync()
3129

3230
while (true)
3331
{
34-
OutputHelpers.WriteConsole("Please type a prompt and press ENTER", ConsoleColor.Yellow);
35-
var prompt = OutputHelpers.ReadConsole(ConsoleColor.Cyan);
36-
37-
OutputHelpers.WriteConsole("Please type a negative prompt and press ENTER (optional)", ConsoleColor.Yellow);
38-
var negativePrompt = OutputHelpers.ReadConsole(ConsoleColor.Cyan);
39-
40-
OutputHelpers.WriteConsole("Please enter a batch count and press ENTER", ConsoleColor.Yellow);
41-
var batch = OutputHelpers.ReadConsole(ConsoleColor.Cyan);
42-
int.TryParse(batch, out var batchCount);
43-
batchCount = Math.Max(1, batchCount);
4432

4533
var promptOptions = new PromptOptions
4634
{
47-
Prompt = prompt,
48-
NegativePrompt = negativePrompt,
49-
BatchCount = batchCount
35+
Prompt = "Photo of a cat"
5036
};
5137

5238
var schedulerOptions = new SchedulerOptions
5339
{
5440
Seed = Random.Shared.Next(),
5541

5642
GuidanceScale = 8,
57-
InferenceSteps = 22,
43+
InferenceSteps = 8,
5844
Strength = 0.6f
5945
};
6046

47+
var batchOptions = new BatchOptions
48+
{
49+
BatchType = BatchOptionType.Seed,
50+
Count = 5
51+
};
52+
6153
foreach (var model in _stableDiffusionService.Models)
6254
{
6355
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
6456
await _stableDiffusionService.LoadModel(model);
6557

66-
foreach (var schedulerType in Helpers.GetPipelineSchedulers(model.PipelineType))
58+
var callback = (int batch, int batchCount, int step, int steps) =>
59+
{
60+
OutputHelpers.WriteConsole($"Image: {batch}/{batchCount} - Step: {step}/{steps}", ConsoleColor.Cyan);
61+
};
62+
63+
foreach (var schedulerType in Helpers.GetPipelineSchedulers(model.PipelineType).Take(1))
6764
{
6865
promptOptions.SchedulerType = schedulerType;
69-
OutputHelpers.WriteConsole($"Generating {schedulerType} Image...", ConsoleColor.Green);
70-
await GenerateImage(model, promptOptions, schedulerOptions);
66+
await foreach (var result in _stableDiffusionService.GenerateBatchAsync(model, promptOptions, schedulerOptions, batchOptions, callback))
67+
{
68+
var outputFilename = Path.Combine(_outputDirectory, $"{schedulerOptions.Seed}.png");
69+
var image = result.ToImage();
70+
await image.SaveAsPngAsync(outputFilename);
71+
OutputHelpers.WriteConsole($"Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
72+
}
7173
}
7274

7375
OutputHelpers.WriteConsole($"Unloading Model `{model.Name}`...", ConsoleColor.Green);
7476
await _stableDiffusionService.UnloadModel(model);
7577
}
7678
}
7779
}
78-
79-
private async Task<bool> GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options)
80-
{
81-
82-
var result = await _stableDiffusionService.GenerateAsync(model, prompt, options);
83-
if (result == null)
84-
return false;
85-
86-
var imageTensors = result.Split(prompt.BatchCount);
87-
for (int i = 0; i < imageTensors.Length; i++)
88-
{
89-
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{prompt.SchedulerType}_{i}.png");
90-
var image = imageTensors[i].ToImage();
91-
await image.SaveAsPngAsync(outputFilename);
92-
OutputHelpers.WriteConsole($"{prompt.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
93-
}
94-
95-
return true;
96-
}
9780
}
9881
}

OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,5 +83,53 @@ public interface IStableDiffusionService
8383
/// <param name="cancellationToken">The cancellation token.</param>
8484
/// <returns>The diffusion result as <see cref="System.IO.Stream"/></returns>
8585
Task<Stream> GenerateAsStreamAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
86+
87+
/// <summary>
88+
/// Generates a batch of StableDiffusion image using the prompt and options provided.
89+
/// </summary>
90+
/// <param name="modelOptions">The model options.</param>
91+
/// <param name="promptOptions">The prompt options.</param>
92+
/// <param name="schedulerOptions">The scheduler options.</param>
93+
/// <param name="batchOptions">The batch options.</param>
94+
/// <param name="progressCallback">The progress callback.</param>
95+
/// <param name="cancellationToken">The cancellation token.</param>
96+
/// <returns></returns>
97+
IAsyncEnumerable<DenseTensor<float>> GenerateBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
98+
99+
/// <summary>
100+
/// Generates a batch of StableDiffusion image using the prompt and options provided.
101+
/// </summary>
102+
/// <param name="modelOptions">The model options.</param>
103+
/// <param name="promptOptions">The prompt options.</param>
104+
/// <param name="schedulerOptions">The scheduler options.</param>
105+
/// <param name="batchOptions">The batch options.</param>
106+
/// <param name="progressCallback">The progress callback.</param>
107+
/// <param name="cancellationToken">The cancellation token.</param>
108+
/// <returns></returns>
109+
IAsyncEnumerable<Image<Rgba32>> GenerateBatchAsImageAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
110+
111+
/// <summary>
112+
/// Generates a batch of StableDiffusion image using the prompt and options provided.
113+
/// </summary>
114+
/// <param name="modelOptions">The model options.</param>
115+
/// <param name="promptOptions">The prompt options.</param>
116+
/// <param name="schedulerOptions">The scheduler options.</param>
117+
/// <param name="batchOptions">The batch options.</param>
118+
/// <param name="progressCallback">The progress callback.</param>
119+
/// <param name="cancellationToken">The cancellation token.</param>
120+
/// <returns></returns>
121+
IAsyncEnumerable<byte[]> GenerateBatchAsBytesAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
122+
123+
/// <summary>
124+
/// Generates a batch of StableDiffusion image using the prompt and options provided.
125+
/// </summary>
126+
/// <param name="modelOptions">The model options.</param>
127+
/// <param name="promptOptions">The prompt options.</param>
128+
/// <param name="schedulerOptions">The scheduler options.</param>
129+
/// <param name="batchOptions">The batch options.</param>
130+
/// <param name="progressCallback">The progress callback.</param>
131+
/// <param name="cancellationToken">The cancellation token.</param>
132+
/// <returns></returns>
133+
IAsyncEnumerable<Stream> GenerateBatchAsStreamAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
86134
}
87135
}

OnnxStack.StableDiffusion/Config/BatchOptions.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ namespace OnnxStack.StableDiffusion.Config
88
{
99
public class BatchOptions
1010
{
11+
public BatchOptionType BatchType { get; set; }
12+
public int Count { get; set; }
13+
}
1114

15+
public enum BatchOptionType
16+
{
17+
Seed = 0
1218
}
1319
}

OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,6 @@ public interface IDiffuser
4646
/// <param name="progressCallback">The progress callback.</param>
4747
/// <param name="cancellationToken">The cancellation token.</param>
4848
/// <returns></returns>
49-
IAsyncEnumerable<DenseTensor<float>> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
49+
IAsyncEnumerable<DenseTensor<float>> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
5050
}
5151
}

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs

Lines changed: 69 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
using System.Collections.Generic;
1414
using System.Diagnostics;
1515
using System.Linq;
16+
using System.Runtime.CompilerServices;
1617
using System.Threading;
1718
using System.Threading.Tasks;
1819

@@ -86,14 +87,77 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
8687
_logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {promptOptions.SchedulerType}");
8788

8889
// LCM does not support negative prompting
90+
var performGuidance = false;
8991
promptOptions.NegativePrompt = string.Empty;
9092

93+
// Process prompts
94+
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance);
95+
96+
// Run Scheduler steps
97+
var schedulerResult = await RunSchedulerSteps(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
98+
99+
_logger?.LogEnd($"End", diffuseTime);
100+
101+
return schedulerResult;
102+
}
103+
104+
105+
/// <summary>
106+
/// Runs the stable diffusion batch loop
107+
/// </summary>
108+
/// <param name="modelOptions">The model options.</param>
109+
/// <param name="promptOptions">The prompt options.</param>
110+
/// <param name="schedulerOptions">The scheduler options.</param>
111+
/// <param name="batchOptions">The batch options.</param>
112+
/// <param name="progressCallback">The progress callback.</param>
113+
/// <param name="cancellationToken">The cancellation token.</param>
114+
/// <returns></returns>
115+
/// <exception cref="System.NotImplementedException"></exception>
116+
public async IAsyncEnumerable<DenseTensor<float>> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, [EnumeratorCancellation]CancellationToken cancellationToken = default)
117+
{
118+
var diffuseBatchTime = _logger?.LogBegin("Begin...");
119+
_logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {promptOptions.SchedulerType}");
120+
121+
// LCM does not support negative prompting
122+
var performGuidance = false;
123+
promptOptions.NegativePrompt = string.Empty;
124+
125+
var batchIndex = 1;
126+
var batchCount = batchOptions.Count;
127+
var schedulerCallback = (int p, int t) => progressCallback?.Invoke(batchIndex, batchCount, p, t);
128+
129+
// Process prompts
130+
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance);
131+
132+
if (batchOptions.BatchType == BatchOptionType.Seed)
133+
{
134+
var randomSeeds = Enumerable.Range(0, Math.Max(1, batchOptions.Count)).Select(x => Random.Shared.Next());
135+
foreach (var randomSeed in randomSeeds)
136+
{
137+
schedulerOptions.Seed = randomSeed;
138+
yield return await RunSchedulerSteps(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, schedulerCallback, cancellationToken);
139+
batchIndex++;
140+
}
141+
}
142+
}
143+
144+
145+
/// <summary>
146+
/// Runs the scheduler steps.
147+
/// </summary>
148+
/// <param name="modelOptions">The model options.</param>
149+
/// <param name="promptOptions">The prompt options.</param>
150+
/// <param name="schedulerOptions">The scheduler options.</param>
151+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
152+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
153+
/// <param name="progressCallback">The progress callback.</param>
154+
/// <param name="cancellationToken">The cancellation token.</param>
155+
/// <returns></returns>
156+
protected virtual async Task<DenseTensor<float>> RunSchedulerSteps(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, DenseTensor<float> promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
157+
{
91158
// Get Scheduler
92159
using (var scheduler = GetScheduler(promptOptions, schedulerOptions))
93160
{
94-
// Process prompts
95-
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, false);
96-
97161
// Get timesteps
98162
var timesteps = GetTimesteps(promptOptions, schedulerOptions, scheduler);
99163

@@ -137,30 +201,11 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
137201
}
138202

139203
// Decode Latents
140-
var result = await DecodeLatents(modelOptions, promptOptions, schedulerOptions, denoised);
141-
_logger?.LogEnd($"End", diffuseTime);
142-
return result;
204+
return await DecodeLatents(modelOptions, promptOptions, schedulerOptions, denoised);
143205
}
144206
}
145207

146208

147-
/// <summary>
148-
/// Runs the stable diffusion batch loop
149-
/// </summary>
150-
/// <param name="modelOptions">The model options.</param>
151-
/// <param name="promptOptions">The prompt options.</param>
152-
/// <param name="schedulerOptions">The scheduler options.</param>
153-
/// <param name="batchOptions">The batch options.</param>
154-
/// <param name="progressCallback">The progress callback.</param>
155-
/// <param name="cancellationToken">The cancellation token.</param>
156-
/// <returns></returns>
157-
/// <exception cref="System.NotImplementedException"></exception>
158-
public IAsyncEnumerable<DenseTensor<float>> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int> progressCallback = null, CancellationToken cancellationToken = default)
159-
{
160-
throw new NotImplementedException();
161-
}
162-
163-
164209
/// <summary>
165210
/// Decodes the latents.
166211
/// </summary>
@@ -279,6 +324,6 @@ protected static IReadOnlyList<NamedOnnxValue> CreateInputParameters(params Name
279324
return parameters.ToList();
280325
}
281326

282-
327+
283328
}
284329
}

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,13 @@ public InpaintDiffuser(IOnnxModelService onnxModelService, IPromptService prompt
3939

4040

4141
/// <summary>
42-
/// Runs the Stable Diffusion inference.
42+
/// Runs the stable diffusion loop
4343
/// </summary>
44-
/// <param name="promptOptions">The options.</param>
45-
/// <param name="schedulerOptions">The scheduler configuration.</param>
44+
/// <param name="modelOptions"></param>
45+
/// <param name="promptOptions">The prompt options.</param>
46+
/// <param name="schedulerOptions">The scheduler options.</param>
47+
/// <param name="progressCallback"></param>
48+
/// <param name="cancellationToken">The cancellation token.</param>
4649
/// <returns></returns>
4750
public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
4851
{
@@ -52,16 +55,37 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
5255
var diffuseTime = _logger?.LogBegin("Begin...");
5356
_logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {promptOptions.SchedulerType}");
5457

58+
// Should we perform classifier free guidance
59+
var performGuidance = schedulerOptions.GuidanceScale > 1.0f;
60+
61+
// Process prompts
62+
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance);
63+
64+
// Run Scheduler steps
65+
var schedulerResult = await RunSchedulerSteps(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
66+
67+
_logger?.LogEnd($"End", diffuseTime);
68+
69+
return schedulerResult;
70+
}
71+
5572

73+
/// <summary>
74+
/// Runs the scheduler steps.
75+
/// </summary>
76+
/// <param name="modelOptions">The model options.</param>
77+
/// <param name="promptOptions">The prompt options.</param>
78+
/// <param name="schedulerOptions">The scheduler options.</param>
79+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
80+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
81+
/// <param name="progressCallback">The progress callback.</param>
82+
/// <param name="cancellationToken">The cancellation token.</param>
83+
/// <returns></returns>
84+
protected override async Task<DenseTensor<float>> RunSchedulerSteps(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, DenseTensor<float> promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
85+
{
5686
// Get Scheduler
5787
using (var scheduler = GetScheduler(promptOptions, schedulerOptions))
5888
{
59-
// Should we perform classifier free guidance
60-
var performGuidance = schedulerOptions.GuidanceScale > 1.0f;
61-
62-
// Process prompts
63-
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance);
64-
6589
// Get timesteps
6690
var timesteps = GetTimesteps(promptOptions, schedulerOptions, scheduler);
6791

@@ -110,9 +134,7 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
110134
}
111135

112136
// Decode Latents
113-
var result = await DecodeLatents(modelOptions, promptOptions, schedulerOptions, latents);
114-
_logger?.LogEnd($"End", diffuseTime);
115-
return result;
137+
return await DecodeLatents(modelOptions, promptOptions, schedulerOptions, latents);
116138
}
117139
}
118140

0 commit comments

Comments
 (0)