Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit f10c73e

Browse files
committed
Add current Latent/Image to progress callback
1 parent 8c00ba9 commit f10c73e

20 files changed

+143
-69
lines changed

OnnxStack.Console/Examples/StableDiffusionBatch.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using OnnxStack.StableDiffusion.Config;
33
using OnnxStack.StableDiffusion.Enums;
44
using OnnxStack.StableDiffusion.Helpers;
5+
using OnnxStack.StableDiffusion.Models;
56
using SixLabors.ImageSharp;
67

78
namespace OnnxStack.Console.Runner
@@ -58,10 +59,10 @@ public async Task RunAsync()
5859
await _stableDiffusionService.LoadModelAsync(model);
5960

6061
var batchIndex = 0;
61-
var callback = (int batch, int batchCount, int step, int steps) =>
62+
var callback = (DiffusionProgress progress) =>
6263
{
63-
batchIndex = batch;
64-
OutputHelpers.WriteConsole($"Image: {batch}/{batchCount} - Step: {step}/{steps}", ConsoleColor.Cyan);
64+
batchIndex = progress.ProgressValue;
65+
OutputHelpers.WriteConsole($"Image: {progress.ProgressValue}/{progress.ProgressMax} - Step: {progress.SubProgressValue}/{progress.SubProgressMax}", ConsoleColor.Cyan);
6566
};
6667

6768
await foreach (var result in _stableDiffusionService.GenerateBatchAsync(model, promptOptions, schedulerOptions, batchOptions, callback))

OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ public interface IStableDiffusionService
4545
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
4646
/// <param name="cancellationToken">The cancellation token.</param>
4747
/// <returns>The diffusion result as <see cref="DenseTensor<float>"/></returns>
48-
Task<DenseTensor<float>> GenerateAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
48+
Task<DenseTensor<float>> GenerateAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
4949

5050
/// <summary>
5151
/// Generates the StableDiffusion image using the prompt and options provided.
@@ -55,7 +55,7 @@ public interface IStableDiffusionService
5555
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
5656
/// <param name="cancellationToken">The cancellation token.</param>
5757
/// <returns>The diffusion result as <see cref="SixLabors.ImageSharp.Image<Rgba32>"/></returns>
58-
Task<Image<Rgba32>> GenerateAsImageAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
58+
Task<Image<Rgba32>> GenerateAsImageAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
5959

6060
/// <summary>
6161
/// Generates the StableDiffusion image using the prompt and options provided.
@@ -65,7 +65,7 @@ public interface IStableDiffusionService
6565
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
6666
/// <param name="cancellationToken">The cancellation token.</param>
6767
/// <returns>The diffusion result as <see cref="byte[]"/></returns>
68-
Task<byte[]> GenerateAsBytesAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
68+
Task<byte[]> GenerateAsBytesAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
6969

7070
/// <summary>
7171
/// Generates the StableDiffusion image using the prompt and options provided.
@@ -75,7 +75,7 @@ public interface IStableDiffusionService
7575
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
7676
/// <param name="cancellationToken">The cancellation token.</param>
7777
/// <returns>The diffusion result as <see cref="System.IO.Stream"/></returns>
78-
Task<Stream> GenerateAsStreamAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
78+
Task<Stream> GenerateAsStreamAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
7979

8080
/// <summary>
8181
/// Generates a batch of StableDiffusion image using the prompt and options provided.
@@ -87,7 +87,7 @@ public interface IStableDiffusionService
8787
/// <param name="progressCallback">The progress callback.</param>
8888
/// <param name="cancellationToken">The cancellation token.</param>
8989
/// <returns></returns>
90-
IAsyncEnumerable<BatchResult> GenerateBatchAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
90+
IAsyncEnumerable<BatchResult> GenerateBatchAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
9191

9292
/// <summary>
9393
/// Generates a batch of StableDiffusion image using the prompt and options provided.
@@ -99,7 +99,7 @@ public interface IStableDiffusionService
9999
/// <param name="progressCallback">The progress callback.</param>
100100
/// <param name="cancellationToken">The cancellation token.</param>
101101
/// <returns></returns>
102-
IAsyncEnumerable<Image<Rgba32>> GenerateBatchAsImageAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
102+
IAsyncEnumerable<Image<Rgba32>> GenerateBatchAsImageAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
103103

104104
/// <summary>
105105
/// Generates a batch of StableDiffusion image using the prompt and options provided.
@@ -111,7 +111,7 @@ public interface IStableDiffusionService
111111
/// <param name="progressCallback">The progress callback.</param>
112112
/// <param name="cancellationToken">The cancellation token.</param>
113113
/// <returns></returns>
114-
IAsyncEnumerable<byte[]> GenerateBatchAsBytesAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
114+
IAsyncEnumerable<byte[]> GenerateBatchAsBytesAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
115115

116116
/// <summary>
117117
/// Generates a batch of StableDiffusion image using the prompt and options provided.
@@ -123,6 +123,6 @@ public interface IStableDiffusionService
123123
/// <param name="progressCallback">The progress callback.</param>
124124
/// <param name="cancellationToken">The cancellation token.</param>
125125
/// <returns></returns>
126-
IAsyncEnumerable<Stream> GenerateBatchAsStreamAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
126+
IAsyncEnumerable<Stream> GenerateBatchAsStreamAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
127127
}
128128
}

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ public DiffuserBase(IOnnxModelService onnxModelService, IPromptService promptSer
8888
/// <param name="progressCallback">The progress callback.</param>
8989
/// <param name="cancellationToken">The cancellation token.</param>
9090
/// <returns></returns>
91-
protected abstract Task<DenseTensor<float>> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
91+
protected abstract Task<DenseTensor<float>> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
9292

9393

9494
/// <summary>
@@ -99,7 +99,7 @@ public DiffuserBase(IOnnxModelService onnxModelService, IPromptService promptSer
9999
/// <param name="progress">The progress.</param>
100100
/// <param name="cancellationToken">The cancellation token.</param>
101101
/// <returns></returns>
102-
public virtual async Task<DenseTensor<float>> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
102+
public virtual async Task<DenseTensor<float>> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
103103
{
104104
// Create random seed if none was set
105105
schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next();
@@ -133,7 +133,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(StableDiffusionModelS
133133
/// <param name="cancellationToken">The cancellation token.</param>
134134
/// <returns></returns>
135135
/// <exception cref="System.NotImplementedException"></exception>
136-
public virtual async IAsyncEnumerable<BatchResult> DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
136+
public virtual async IAsyncEnumerable<BatchResult> DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
137137
{
138138
// Create random seed if none was set
139139
schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next();
@@ -152,7 +152,11 @@ public virtual async IAsyncEnumerable<BatchResult> DiffuseBatchAsync(StableDiffu
152152
var batchSchedulerOptions = BatchGenerator.GenerateBatch(modelOptions, batchOptions, schedulerOptions);
153153

154154
var batchIndex = 1;
155-
var schedulerCallback = (int step, int steps) => progressCallback?.Invoke(batchIndex, batchSchedulerOptions.Count, step, steps);
155+
var schedulerCallback = (DiffusionProgress progress) => progressCallback?.Invoke(new DiffusionProgress(batchIndex, batchSchedulerOptions.Count, progress.ProgressTensor)
156+
{
157+
SubProgressMax = progress.ProgressMax,
158+
SubProgressValue = progress.ProgressValue,
159+
});
156160
foreach (var batchSchedulerOption in batchSchedulerOptions)
157161
{
158162
var diffuseTime = _logger?.LogBegin("Diffuse starting...");
@@ -251,5 +255,37 @@ protected static IReadOnlyList<NamedOnnxValue> CreateInputParameters(params Name
251255
{
252256
return parameters.ToList();
253257
}
258+
259+
260+
/// <summary>
261+
/// Reports the progress.
262+
/// </summary>
263+
/// <param name="progressCallback">The progress callback.</param>
264+
/// <param name="progress">The progress.</param>
265+
/// <param name="progressMax">The progress maximum.</param>
266+
/// <param name="output">The output.</param>
267+
protected void ReportProgress(Action<DiffusionProgress> progressCallback, int progress, int progressMax, DenseTensor<float> output)
268+
{
269+
progressCallback?.Invoke(new DiffusionProgress(progress, progressMax, output));
270+
}
271+
272+
273+
/// <summary>
274+
/// Reports the progress.
275+
/// </summary>
276+
/// <param name="progressCallback">The progress callback.</param>
277+
/// <param name="progress">The progress.</param>
278+
/// <param name="progressMax">The progress maximum.</param>
279+
/// <param name="subProgress">The sub progress.</param>
280+
/// <param name="subProgressMax">The sub progress maximum.</param>
281+
/// <param name="output">The output.</param>
282+
protected void ReportProgress(Action<DiffusionProgress> progressCallback, int progress, int progressMax, int subProgress, int subProgressMax, DenseTensor<float> output)
283+
{
284+
progressCallback?.Invoke(new DiffusionProgress(progress, progressMax, output)
285+
{
286+
SubProgressMax = subProgressMax,
287+
SubProgressValue = subProgress,
288+
});
289+
}
254290
}
255291
}

OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public interface IDiffuser
3333
/// <param name="progressCallback">The progress callback.</param>
3434
/// <param name="cancellationToken">The cancellation token.</param>
3535
/// <returns></returns>
36-
Task<DenseTensor<float>> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
36+
Task<DenseTensor<float>> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
3737

3838

3939
/// <summary>
@@ -46,6 +46,6 @@ public interface IDiffuser
4646
/// <param name="progressCallback">The progress callback.</param>
4747
/// <param name="cancellationToken">The cancellation token.</param>
4848
/// <returns></returns>
49-
IAsyncEnumerable<BatchResult> DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
49+
IAsyncEnumerable<BatchResult> DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
5050
}
5151
}

OnnxStack.StableDiffusion/Diffusers/InstaFlow/InstaFlowDiffuser.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using OnnxStack.StableDiffusion.Config;
99
using OnnxStack.StableDiffusion.Enums;
1010
using OnnxStack.StableDiffusion.Helpers;
11+
using OnnxStack.StableDiffusion.Models;
1112
using OnnxStack.StableDiffusion.Schedulers.InstaFlow;
1213
using System;
1314
using System.Diagnostics;
@@ -45,7 +46,7 @@ public InstaFlowDiffuser(IOnnxModelService onnxModelService, IPromptService prom
4546
/// <param name="progressCallback">The progress callback.</param>
4647
/// <param name="cancellationToken">The cancellation token.</param>
4748
/// <returns></returns>
48-
protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
49+
protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
4950
{
5051
// Get Scheduler
5152
using (var scheduler = GetScheduler(schedulerOptions))
@@ -102,7 +103,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
102103
}
103104
}
104105

105-
progressCallback?.Invoke(step, timesteps.Count);
106+
ReportProgress(progressCallback, step, timesteps.Count, latents);
106107
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
107108
}
108109

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/InpaintLegacyDiffuser.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using OnnxStack.StableDiffusion.Config;
99
using OnnxStack.StableDiffusion.Enums;
1010
using OnnxStack.StableDiffusion.Helpers;
11+
using OnnxStack.StableDiffusion.Models;
1112
using SixLabors.ImageSharp;
1213
using SixLabors.ImageSharp.Processing;
1314
using System;
@@ -65,7 +66,7 @@ protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, ISc
6566
/// <param name="progressCallback">The progress callback.</param>
6667
/// <param name="cancellationToken">The cancellation token.</param>
6768
/// <returns></returns>
68-
protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
69+
protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
6970
{
7071
using (var scheduler = GetScheduler(schedulerOptions))
7172
{
@@ -138,7 +139,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
138139
}
139140
}
140141

141-
progressCallback?.Invoke(step, timesteps.Count);
142+
ReportProgress(progressCallback, step, timesteps.Count, latents);
142143
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
143144
}
144145

0 commit comments

Comments
 (0)