Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 16e5d29

Browse files
committed
Return SchedulerOptions with each batch result as it has mutated
1 parent 0504e02 commit 16e5d29

File tree

10 files changed

+103
-40
lines changed

10 files changed

+103
-40
lines changed

OnnxStack.Console/Examples/StableDiffusionBatch.cs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,23 +40,27 @@ public async Task RunAsync()
4040
Seed = Random.Shared.Next(),
4141

4242
GuidanceScale = 8,
43-
InferenceSteps = 8,
43+
InferenceSteps = 20,
4444
Strength = 0.6f
4545
};
4646

4747
var batchOptions = new BatchOptions
4848
{
49-
BatchType = BatchOptionType.Seed,
50-
Count = 5
49+
BatchType = BatchOptionType.Guidance,
50+
ValueFrom = 4,
51+
ValueTo = 20,
52+
Increment = 0.5f
5153
};
5254

5355
foreach (var model in _stableDiffusionService.Models)
5456
{
5557
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
5658
await _stableDiffusionService.LoadModel(model);
5759

60+
var batchIndex = 0;
5861
var callback = (int batch, int batchCount, int step, int steps) =>
5962
{
63+
batchIndex = batch;
6064
OutputHelpers.WriteConsole($"Image: {batch}/{batchCount} - Step: {step}/{steps}", ConsoleColor.Cyan);
6165
};
6266

@@ -65,8 +69,8 @@ public async Task RunAsync()
6569
promptOptions.SchedulerType = schedulerType;
6670
await foreach (var result in _stableDiffusionService.GenerateBatchAsync(model, promptOptions, schedulerOptions, batchOptions, callback))
6771
{
68-
var outputFilename = Path.Combine(_outputDirectory, $"{schedulerOptions.Seed}.png");
69-
var image = result.ToImage();
72+
var outputFilename = Path.Combine(_outputDirectory, $"{batchIndex}_{result.SchedulerOptions.Seed}.png");
73+
var image = result.ImageResult.ToImage();
7074
await image.SaveAsPngAsync(outputFilename);
7175
OutputHelpers.WriteConsole($"Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
7276
}

OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using OnnxStack.Core.Config;
33
using OnnxStack.Core.Model;
44
using OnnxStack.StableDiffusion.Config;
5+
using OnnxStack.StableDiffusion.Models;
56
using SixLabors.ImageSharp;
67
using SixLabors.ImageSharp.PixelFormats;
78
using System;
@@ -94,7 +95,7 @@ public interface IStableDiffusionService
9495
/// <param name="progressCallback">The progress callback.</param>
9596
/// <param name="cancellationToken">The cancellation token.</param>
9697
/// <returns></returns>
97-
IAsyncEnumerable<DenseTensor<float>> GenerateBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
98+
IAsyncEnumerable<BatchResult> GenerateBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
9899

99100
/// <summary>
100101
/// Generates a batch of StableDiffusion image using the prompt and options provided.

OnnxStack.StableDiffusion/Config/BatchOptions.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,16 @@ public class BatchOptions
1010
{
1111
public BatchOptionType BatchType { get; set; }
1212
public int Count { get; set; }
13+
14+
public float ValueTo { get; set; }
15+
public float ValueFrom { get; set; }
16+
public float Increment { get; set; } = 1f;
1317
}
1418

1519
public enum BatchOptionType
1620
{
17-
Seed = 0
21+
Seed = 0,
22+
Step = 1,
23+
Guidance = 2
1824
}
1925
}

OnnxStack.StableDiffusion/Config/SchedulerOptions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
namespace OnnxStack.StableDiffusion.Config
66
{
7-
public class SchedulerOptions
7+
public record SchedulerOptions
88
{
99
/// <summary>
1010
/// Gets or sets the height.

OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using OnnxStack.StableDiffusion.Common;
33
using OnnxStack.StableDiffusion.Config;
44
using OnnxStack.StableDiffusion.Enums;
5+
using OnnxStack.StableDiffusion.Models;
56
using System;
67
using System.Collections.Generic;
78
using System.Threading;
@@ -46,6 +47,6 @@ public interface IDiffuser
4647
/// <param name="progressCallback">The progress callback.</param>
4748
/// <param name="cancellationToken">The cancellation token.</param>
4849
/// <returns></returns>
49-
IAsyncEnumerable<DenseTensor<float>> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
50+
IAsyncEnumerable<BatchResult> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
5051
}
5152
}

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using OnnxStack.StableDiffusion.Config;
99
using OnnxStack.StableDiffusion.Enums;
1010
using OnnxStack.StableDiffusion.Helpers;
11+
using OnnxStack.StableDiffusion.Models;
1112
using OnnxStack.StableDiffusion.Schedulers.LatentConsistency;
1213
using System;
1314
using System.Collections.Generic;
@@ -113,7 +114,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
113114
/// <param name="cancellationToken">The cancellation token.</param>
114115
/// <returns></returns>
115116
/// <exception cref="System.NotImplementedException"></exception>
116-
public async IAsyncEnumerable<DenseTensor<float>> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, [EnumeratorCancellation]CancellationToken cancellationToken = default)
117+
public async IAsyncEnumerable<BatchResult> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
117118
{
118119
var diffuseBatchTime = _logger?.LogBegin("Begin...");
119120
_logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {promptOptions.SchedulerType}");
@@ -122,23 +123,23 @@ public async IAsyncEnumerable<DenseTensor<float>> DiffuseBatchAsync(IModelOption
122123
var performGuidance = false;
123124
promptOptions.NegativePrompt = string.Empty;
124125

125-
var batchIndex = 1;
126-
var batchCount = batchOptions.Count;
127-
var schedulerCallback = (int p, int t) => progressCallback?.Invoke(batchIndex, batchCount, p, t);
128-
129126
// Process prompts
130127
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance);
131128

132-
if (batchOptions.BatchType == BatchOptionType.Seed)
129+
// Generate batch options
130+
var batchSchedulerOptions = BatchGenerator.GenerateBatch(batchOptions, schedulerOptions);
131+
132+
var batchIndex = 1;
133+
var batchCount = batchSchedulerOptions.Count;
134+
var schedulerCallback = (int p, int t) => progressCallback?.Invoke(batchIndex, batchCount, p, t);
135+
136+
foreach (var batchSchedulerOption in batchSchedulerOptions)
133137
{
134-
var randomSeeds = Enumerable.Range(0, Math.Max(1, batchOptions.Count)).Select(x => Random.Shared.Next());
135-
foreach (var randomSeed in randomSeeds)
136-
{
137-
schedulerOptions.Seed = randomSeed;
138-
yield return await RunSchedulerSteps(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, schedulerCallback, cancellationToken);
139-
batchIndex++;
140-
}
138+
yield return new BatchResult(batchSchedulerOption, await RunSchedulerSteps(modelOptions, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, schedulerCallback, cancellationToken));
139+
batchIndex++;
141140
}
141+
142+
_logger?.LogEnd($"End", diffuseBatchTime);
142143
}
143144

144145

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using OnnxStack.StableDiffusion.Config;
99
using OnnxStack.StableDiffusion.Enums;
1010
using OnnxStack.StableDiffusion.Helpers;
11+
using OnnxStack.StableDiffusion.Models;
1112
using OnnxStack.StableDiffusion.Schedulers.StableDiffusion;
1213
using System;
1314
using System.Collections.Generic;
@@ -111,31 +112,31 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
111112
/// <param name="cancellationToken">The cancellation token.</param>
112113
/// <returns></returns>
113114
/// <exception cref="System.NotImplementedException"></exception>
114-
public async IAsyncEnumerable<DenseTensor<float>> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, [EnumeratorCancellation]CancellationToken cancellationToken = default)
115+
public async IAsyncEnumerable<BatchResult> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, [EnumeratorCancellation]CancellationToken cancellationToken = default)
115116
{
116117
var diffuseBatchTime = _logger?.LogBegin("Begin...");
117118
_logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {promptOptions.SchedulerType}");
118119

119120
// Should we perform classifier free guidance
120121
var performGuidance = schedulerOptions.GuidanceScale > 1.0f;
121122

122-
var batchIndex = 1;
123-
var batchCount = batchOptions.Count;
124-
var schedulerCallback = (int p,int t) => progressCallback?.Invoke(batchIndex, batchCount, p, t);
125-
126123
// Process prompts
127124
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance);
128125

129-
if (batchOptions.BatchType == BatchOptionType.Seed)
126+
// Generate batch options
127+
var batchSchedulerOptions = BatchGenerator.GenerateBatch(batchOptions, schedulerOptions);
128+
129+
var batchIndex = 1;
130+
var batchCount = batchSchedulerOptions.Count;
131+
var schedulerCallback = (int p, int t) => progressCallback?.Invoke(batchIndex, batchCount, p, t);
132+
133+
foreach (var batchSchedulerOption in batchSchedulerOptions)
130134
{
131-
var randomSeeds = Enumerable.Range(0, Math.Max(1, batchOptions.Count)).Select(x => Random.Shared.Next());
132-
foreach (var randomSeed in randomSeeds)
133-
{
134-
schedulerOptions.Seed = randomSeed;
135-
yield return await RunSchedulerSteps(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, schedulerCallback, cancellationToken);
136-
batchIndex++;
137-
}
135+
yield return new BatchResult(batchSchedulerOption, await RunSchedulerSteps(modelOptions, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, schedulerCallback, cancellationToken));
136+
batchIndex++;
138137
}
138+
139+
_logger?.LogEnd($"End", diffuseBatchTime);
139140
}
140141

141142

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
using OnnxStack.StableDiffusion.Config;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
6+
namespace OnnxStack.StableDiffusion.Helpers
7+
{
8+
public static class BatchGenerator
9+
{
10+
/// <summary>
11+
/// Generates the batch of SchedulerOptions fo batch processing.
12+
/// </summary>
13+
/// <param name="batchOptions">The batch options.</param>
14+
/// <param name="schedulerOptions">The scheduler options.</param>
15+
/// <returns></returns>
16+
public static List<SchedulerOptions> GenerateBatch(BatchOptions batchOptions, SchedulerOptions schedulerOptions)
17+
{
18+
if (batchOptions.BatchType == BatchOptionType.Seed)
19+
{
20+
return Enumerable.Range(0, Math.Max(1, batchOptions.Count))
21+
.Select(x => Random.Shared.Next())
22+
.Select(x => schedulerOptions with { Seed = x })
23+
.ToList();
24+
}
25+
else if (batchOptions.BatchType == BatchOptionType.Step)
26+
{
27+
return Enumerable.Range(Math.Max(0, (int)batchOptions.ValueFrom), Math.Max(1, (int)batchOptions.ValueTo))
28+
.Select(x => schedulerOptions with { InferenceSteps = x })
29+
.ToList();
30+
}
31+
else if (batchOptions.BatchType == BatchOptionType.Guidance)
32+
{
33+
var totalIncrements = (batchOptions.ValueTo - batchOptions.ValueFrom) / batchOptions.Increment;
34+
return Enumerable.Range(0, Math.Max(1, (int)totalIncrements))
35+
.Select(x => schedulerOptions with { GuidanceScale = batchOptions.ValueFrom + (batchOptions.Increment * x) })
36+
.ToList();
37+
}
38+
return new List<SchedulerOptions>();
39+
}
40+
}
41+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
using Microsoft.ML.OnnxRuntime.Tensors;
2+
using OnnxStack.StableDiffusion.Config;
3+
4+
namespace OnnxStack.StableDiffusion.Models
5+
{
6+
public record BatchResult(SchedulerOptions SchedulerOptions, DenseTensor<float> ImageResult);
7+
}

OnnxStack.StableDiffusion/Services/StableDiffusionService.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using OnnxStack.StableDiffusion.Config;
66
using OnnxStack.StableDiffusion.Enums;
77
using OnnxStack.StableDiffusion.Helpers;
8+
using OnnxStack.StableDiffusion.Models;
89
using SixLabors.ImageSharp;
910
using SixLabors.ImageSharp.PixelFormats;
1011
using System;
@@ -150,7 +151,7 @@ public async Task<Stream> GenerateAsStreamAsync(IModelOptions model, PromptOptio
150151
/// <param name="progressCallback">The progress callback.</param>
151152
/// <param name="cancellationToken">The cancellation token.</param>
152153
/// <returns></returns>
153-
public IAsyncEnumerable<DenseTensor<float>> GenerateBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default)
154+
public IAsyncEnumerable<BatchResult> GenerateBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default)
154155
{
155156
return DiffuseBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken);
156157
}
@@ -169,7 +170,7 @@ public IAsyncEnumerable<DenseTensor<float>> GenerateBatchAsync(IModelOptions mod
169170
public async IAsyncEnumerable<Image<Rgba32>> GenerateBatchAsImageAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
170171
{
171172
await foreach (var result in GenerateBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken))
172-
yield return result.ToImage();
173+
yield return result.ImageResult.ToImage();
173174
}
174175

175176

@@ -186,7 +187,7 @@ public async IAsyncEnumerable<Image<Rgba32>> GenerateBatchAsImageAsync(IModelOpt
186187
public async IAsyncEnumerable<byte[]> GenerateBatchAsBytesAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
187188
{
188189
await foreach (var result in GenerateBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken))
189-
yield return result.ToImageBytes();
190+
yield return result.ImageResult.ToImageBytes();
190191
}
191192

192193

@@ -203,7 +204,7 @@ public async IAsyncEnumerable<byte[]> GenerateBatchAsBytesAsync(IModelOptions mo
203204
public async IAsyncEnumerable<Stream> GenerateBatchAsStreamAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
204205
{
205206
await foreach (var result in GenerateBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken))
206-
yield return result.ToImageStream();
207+
yield return result.ImageResult.ToImageStream();
207208
}
208209

209210

@@ -220,7 +221,7 @@ private async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOptions,
220221
}
221222

222223

223-
private IAsyncEnumerable<DenseTensor<float>> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progress = null, CancellationToken cancellationToken = default)
224+
private IAsyncEnumerable<BatchResult> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progress = null, CancellationToken cancellationToken = default)
224225
{
225226
if (!_pipelines.TryGetValue(modelOptions.PipelineType, out var pipeline))
226227
throw new Exception("Pipeline not found or is unsupported");

0 commit comments

Comments
 (0)