Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit cdc80c2

Browse files
committed
Console ProgressCallback, Fix batch scheduler issue
1 parent 0d68062 commit cdc80c2

File tree

9 files changed

+16
-54
lines changed

9 files changed

+16
-54
lines changed

OnnxStack.Console/Examples/StableDebug.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public async Task RunAsync()
6666
OutputHelpers.WriteConsole($"Generating {scheduler} Image...", ConsoleColor.Green);
6767

6868
// Run pipeline
69-
var result = await pipeline.RunAsync(promptOptions, schedulerOptions);
69+
var result = await pipeline.RunAsync(promptOptions, schedulerOptions, progressCallback: OutputHelpers.ProgressCallback);
7070

7171
// Create Image from Tensor result
7272
var image = result.ToImage();

OnnxStack.Console/Examples/StableDiffusionBatch.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,9 @@ public async Task RunAsync()
5656
// Preload Models (optional)
5757
await pipeline.LoadAsync();
5858

59-
// Progress Callback (optional)
60-
var progressCallback = (DiffusionProgress progress) => OutputHelpers.WriteConsole($"Image: {progress.BatchValue}/{progress.BatchMax} - Step: {progress.StepValue}/{progress.StepMax}", ConsoleColor.Cyan);
61-
6259
// Run Batch
6360
var timestamp = Stopwatch.GetTimestamp();
64-
await foreach (var result in pipeline.RunBatchAsync(batchOptions, promptOptions, progressCallback: progressCallback))
61+
await foreach (var result in pipeline.RunBatchAsync(batchOptions, promptOptions, progressCallback: OutputHelpers.BatchProgressCallback))
6562
{
6663
// Create Image from Tensor result
6764
var image = result.ImageResult.ToImage();

OnnxStack.Console/Examples/StableDiffusionExample.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public async Task RunAsync()
6767
OutputHelpers.WriteConsole($"Generating '{schedulerType}' Image...", ConsoleColor.Green);
6868

6969
// Run pipeline
70-
var result = await pipeline.RunAsync(promptOptions, schedulerOptions);
70+
var result = await pipeline.RunAsync(promptOptions, schedulerOptions, progressCallback: OutputHelpers.ProgressCallback);
7171

7272
// Create Image from Tensor result
7373
var image = result.ToImage();

OnnxStack.Console/Examples/StableDiffusionGenerator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public async Task RunAsync()
5555
OutputHelpers.WriteConsole($"Generating '{generationPrompt.Key}'", ConsoleColor.Green);
5656

5757
// Run pipeline
58-
var result = await pipeline.RunAsync(promptOptions);
58+
var result = await pipeline.RunAsync(promptOptions, progressCallback: OutputHelpers.ProgressCallback);
5959

6060
// Create Image from Tensor result
6161
var image = result.ToImage();

OnnxStack.Console/Examples/VideoToVideoExample.cs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,6 @@ public async Task RunAsync()
3434
var videoInfo = await _videoService.GetVideoInfoAsync(inputFile);
3535
var videoInput = await _videoService.CreateFramesAsync(inputFile, videoInfo.FPS);
3636

37-
// Progress Callback (optional)
38-
var progressCallback = (DiffusionProgress progress) => OutputHelpers.WriteConsole($"Frame: {progress.BatchValue}/{progress.BatchMax} - Step: {progress.StepValue}/{progress.StepMax}", ConsoleColor.Cyan);
39-
40-
4137
// Loop though the appsettings.json model sets
4238
foreach (var modelSet in _configuration.ModelSets)
4339
{
@@ -58,7 +54,7 @@ public async Task RunAsync()
5854
};
5955

6056
// Run pipeline
61-
var result = await pipeline.RunAsync(promptOptions, progressCallback: progressCallback);
57+
var result = await pipeline.RunAsync(promptOptions, progressCallback: OutputHelpers.FrameProgressCallback);
6258

6359
// Create Video from Tensor result
6460
var videoResult = await _videoService.CreateVideoAsync(result, videoInfo.FPS);

OnnxStack.Console/OutputHelpers.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
namespace OnnxStack.Console
1+
using OnnxStack.StableDiffusion.Common;
2+
3+
namespace OnnxStack.Console
24
{
35
internal static class OutputHelpers
46
{
@@ -22,5 +24,9 @@ public static void WriteConsole(string value, ConsoleColor color, bool line = tr
2224
System.Console.Write(value);
2325
System.Console.ForegroundColor = previous;
2426
}
27+
28+
public static Action<DiffusionProgress> ProgressCallback => (DiffusionProgress progress) => WriteConsole($"Step: {progress.StepValue}/{progress.StepMax}", ConsoleColor.Gray);
29+
public static Action<DiffusionProgress> BatchProgressCallback => (DiffusionProgress progress) => WriteConsole($"Batch: {progress.BatchValue}/{progress.BatchMax} - Step: {progress.StepValue}/{progress.StepMax}", ConsoleColor.Gray);
30+
public static Action<DiffusionProgress> FrameProgressCallback => (DiffusionProgress progress) => WriteConsole($"Frame: {progress.BatchValue}/{progress.BatchMax} - Step: {progress.StepValue}/{progress.StepMax}", ConsoleColor.Gray);
2531
}
2632
}

OnnxStack.StableDiffusion/Extensions.cs

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -88,43 +88,5 @@ internal static int[] GetScaledDimension(this SchedulerOptions options, int batc
8888
{
8989
return new[] { batch, channels, options.GetScaledHeight(), options.GetScaledWidth() };
9090
}
91-
92-
93-
/// <summary>
94-
/// Gets the pipeline schedulers.
95-
/// </summary>
96-
/// <param name="pipelineType">Type of the pipeline.</param>
97-
/// <returns></returns>
98-
public static SchedulerType[] GetSchedulerTypes(this DiffuserPipelineType pipelineType)
99-
{
100-
switch (pipelineType)
101-
{
102-
case DiffuserPipelineType.StableDiffusion:
103-
case DiffuserPipelineType.StableDiffusionXL:
104-
return new[]
105-
{
106-
SchedulerType.LMS,
107-
SchedulerType.Euler,
108-
SchedulerType.EulerAncestral,
109-
SchedulerType.DDPM,
110-
SchedulerType.DDIM,
111-
SchedulerType.KDPM2
112-
};
113-
case DiffuserPipelineType.LatentConsistency:
114-
case DiffuserPipelineType.LatentConsistencyXL:
115-
return new[]
116-
{
117-
SchedulerType.LCM
118-
};
119-
case DiffuserPipelineType.InstaFlow:
120-
return new[]
121-
{
122-
SchedulerType.InstaFlow
123-
};
124-
default:
125-
return default;
126-
}
127-
}
128-
12991
}
13092
}

OnnxStack.StableDiffusion/Helpers/BatchGenerator.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using OnnxStack.StableDiffusion.Config;
22
using OnnxStack.StableDiffusion.Enums;
3+
using OnnxStack.StableDiffusion.Pipelines;
34
using System;
45
using System.Collections.Generic;
56
using System.Linq;
@@ -14,7 +15,7 @@ public static class BatchGenerator
1415
/// <param name="batchOptions">The batch options.</param>
1516
/// <param name="schedulerOptions">The scheduler options.</param>
1617
/// <returns></returns>
17-
public static List<SchedulerOptions> GenerateBatch(DiffuserPipelineType pipelineType, BatchOptions batchOptions, SchedulerOptions schedulerOptions)
18+
public static List<SchedulerOptions> GenerateBatch(IPipeline pipeline, BatchOptions batchOptions, SchedulerOptions schedulerOptions)
1819
{
1920
if (batchOptions.BatchType == BatchOptionType.Seed)
2021
{
@@ -46,7 +47,7 @@ public static List<SchedulerOptions> GenerateBatch(DiffuserPipelineType pipeline
4647
}
4748
else if (batchOptions.BatchType == BatchOptionType.Scheduler)
4849
{
49-
return pipelineType.GetSchedulerTypes()
50+
return pipeline.SupportedSchedulers
5051
.Select(x => schedulerOptions with { SchedulerType = x })
5152
.ToList();
5253
}

OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ public override async IAsyncEnumerable<BatchResult> RunBatchAsync(BatchOptions b
209209
var promptEmbeddings = await CreatePromptEmbedsAsync(promptOptions, performGuidance);
210210

211211
// Generate batch options
212-
var batchSchedulerOptions = BatchGenerator.GenerateBatch(DiffuserPipelineType.StableDiffusion, batchOptions, schedulerOptions);
212+
var batchSchedulerOptions = BatchGenerator.GenerateBatch(this, batchOptions, schedulerOptions);
213213

214214
// Create Diffuser
215215
var diffuser = CreateDiffuser(promptOptions.DiffuserType, controlNet);

0 commit comments

Comments
 (0)