Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 7eccd73

Browse files
committed
Dependency Injection for Pipelines and Diffusers, Logging added
1 parent cb9fa44 commit 7eccd73

18 files changed

+320
-68
lines changed

OnnxStack.Core/LogExtensions.cs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
using Microsoft.Extensions.Logging;
2+
using System.Diagnostics;
3+
using System.Runtime.CompilerServices;
4+
5+
namespace OnnxStack.Core
6+
{
7+
public static class LogExtensions
8+
{
9+
10+
public static void Log(this ILogger logger, string message, [CallerMemberName] string caller = default)
11+
{
12+
LogInternal(logger, LogLevel.Information, message, caller);
13+
}
14+
15+
16+
public static void Log(this ILogger logger, LogLevel logLevel, string message, [CallerMemberName] string caller = default)
17+
{
18+
LogInternal(logger, logLevel, message, caller);
19+
}
20+
21+
22+
public static long LogBegin(this ILogger logger, string message, [CallerMemberName] string caller = default)
23+
{
24+
return LogBeginInternal(logger, LogLevel.Information, message, caller);
25+
}
26+
27+
28+
public static long LogBegin(this ILogger logger, LogLevel logLevel, string message, [CallerMemberName] string caller = default)
29+
{
30+
31+
return LogBeginInternal(logger, logLevel, message, caller);
32+
}
33+
34+
35+
public static void LogEnd(this ILogger logger, string message, long? timestamp, [CallerMemberName] string caller = default)
36+
{
37+
LogEndInternal(logger, LogLevel.Information, message, timestamp, caller);
38+
}
39+
40+
41+
public static void LogEnd(this ILogger logger, LogLevel logLevel, string message, long? timestamp, [CallerMemberName] string caller = default)
42+
{
43+
LogEndInternal(logger, logLevel, message, timestamp, caller);
44+
}
45+
46+
47+
private static long LogBeginInternal(ILogger logger, LogLevel logLevel, string message, string caller)
48+
{
49+
LogInternal(logger, logLevel, message, caller);
50+
return Stopwatch.GetTimestamp();
51+
}
52+
53+
54+
private static void LogEndInternal(ILogger logger, LogLevel logLevel, string message, long? timestamp, string caller)
55+
{
56+
var elapsed = Stopwatch.GetElapsedTime(timestamp ?? 0);
57+
var timeString = elapsed.TotalSeconds >= 1
58+
? $"{message}, Elapsed: {elapsed.TotalSeconds:F4}sec"
59+
: $"{message}, Elapsed: {elapsed.TotalMilliseconds:F0}ms";
60+
LogInternal(logger, logLevel, timeString, caller);
61+
}
62+
63+
private static void LogInternal(ILogger logger, LogLevel logLevel, string message, string caller)
64+
{
65+
logger.Log(logLevel, string.IsNullOrEmpty(caller) ? message : $"[{caller}] - {message}", args: default);
66+
}
67+
}
68+
}

OnnxStack.Core/OnnxStack.Core.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
<ItemGroup>
3838
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="7.0.0" />
3939
<PackageReference Include="Microsoft.Extensions.Hosting.Abstractions" Version="7.0.0" />
40+
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="7.0.0" />
4041
<PackageReference Include="Microsoft.ML" Version="2.0.1" />
4142
<PackageReference Include="Microsoft.ML.OnnxRuntime.Extensions" Version="0.9.0" />
4243
<PackageReference Include="Microsoft.ML.OnnxRuntime.Managed" Version="1.16.1" />

OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
22
using OnnxStack.StableDiffusion.Common;
33
using OnnxStack.StableDiffusion.Config;
4+
using OnnxStack.StableDiffusion.Enums;
45
using System;
56
using System.Threading;
67
using System.Threading.Tasks;
@@ -9,6 +10,28 @@ namespace OnnxStack.StableDiffusion.Diffusers
910
{
1011
public interface IDiffuser
1112
{
13+
14+
/// <summary>
15+
/// Gets the type of the diffuser.
16+
/// </summary>
17+
DiffuserType DiffuserType { get; }
18+
19+
20+
/// <summary>
21+
/// Gets the type of the pipeline.
22+
/// </summary>
23+
DiffuserPipelineType PipelineType { get; }
24+
25+
26+
/// <summary>
27+
/// Runs the stable diffusion loop
28+
/// </summary>
29+
/// <param name="modelOptions">The model options.</param>
30+
/// <param name="promptOptions">The prompt options.</param>
31+
/// <param name="schedulerOptions">The scheduler options.</param>
32+
/// <param name="progressCallback">The progress callback.</param>
33+
/// <param name="cancellationToken">The cancellation token.</param>
34+
/// <returns></returns>
1235
Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
1336
}
1437
}

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
using Microsoft.ML.OnnxRuntime;
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime;
23
using Microsoft.ML.OnnxRuntime.Tensors;
34
using OnnxStack.Core.Config;
45
using OnnxStack.Core.Services;
56
using OnnxStack.StableDiffusion.Common;
67
using OnnxStack.StableDiffusion.Config;
8+
using OnnxStack.StableDiffusion.Diffusers.StableDiffusion;
9+
using OnnxStack.StableDiffusion.Enums;
710
using OnnxStack.StableDiffusion.Helpers;
811
using SixLabors.ImageSharp;
912
using System;
@@ -19,12 +22,18 @@ public sealed class ImageDiffuser : LatentConsistencyDiffuser
1922
/// </summary>
2023
/// <param name="configuration">The configuration.</param>
2124
/// <param name="onnxModelService">The onnx model service.</param>
22-
public ImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptService)
23-
: base(onnxModelService, promptService)
25+
public ImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<LatentConsistencyDiffuser> logger)
26+
: base(onnxModelService, promptService, logger)
2427
{
2528
}
2629

2730

31+
/// <summary>
32+
/// Gets the type of the diffuser.
33+
/// </summary>
34+
public override DiffuserType DiffuserType => DiffuserType.ImageToImage;
35+
36+
2837
/// <summary>
2938
/// Gets the timesteps.
3039
/// </summary>

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
using Microsoft.ML.OnnxRuntime;
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime;
23
using Microsoft.ML.OnnxRuntime.Tensors;
4+
using OnnxStack.Core;
35
using OnnxStack.Core.Config;
46
using OnnxStack.Core.Services;
57
using OnnxStack.StableDiffusion.Common;
@@ -9,6 +11,7 @@
911
using OnnxStack.StableDiffusion.Schedulers.LatentConsistency;
1012
using System;
1113
using System.Collections.Generic;
14+
using System.Diagnostics;
1215
using System.Linq;
1316
using System.Threading;
1417
using System.Threading.Tasks;
@@ -19,19 +22,33 @@ public abstract class LatentConsistencyDiffuser : IDiffuser
1922
{
2023
protected readonly IPromptService _promptService;
2124
protected readonly IOnnxModelService _onnxModelService;
25+
protected readonly ILogger<LatentConsistencyDiffuser> _logger;
2226

2327
/// <summary>
2428
/// Initializes a new instance of the <see cref="LatentConsistencyDiffuser"/> class.
2529
/// </summary>
2630
/// <param name="configuration">The configuration.</param>
2731
/// <param name="onnxModelService">The onnx model service.</param>
28-
public LatentConsistencyDiffuser(IOnnxModelService onnxModelService, IPromptService promptService)
32+
public LatentConsistencyDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<LatentConsistencyDiffuser> logger)
2933
{
34+
_logger = logger;
3035
_promptService = promptService;
3136
_onnxModelService = onnxModelService;
3237
}
3338

3439

40+
/// <summary>
41+
/// Gets the type of the pipeline.
42+
/// </summary>
43+
public DiffuserPipelineType PipelineType => DiffuserPipelineType.LatentConsistency;
44+
45+
46+
/// <summary>
47+
/// Gets the type of the diffuser.
48+
/// </summary>
49+
public abstract DiffuserType DiffuserType { get; }
50+
51+
3552
/// <summary>
3653
/// Gets the timesteps.
3754
/// </summary>
@@ -65,6 +82,9 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
6582
// Create random seed if none was set
6683
schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next();
6784

85+
var diffuseTime = _logger?.LogBegin("Begin...");
86+
_logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {promptOptions.SchedulerType}");
87+
6888
// LCM does not support negative prompting
6989
promptOptions.NegativePrompt = string.Empty;
7090

@@ -91,6 +111,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
91111
foreach (var timestep in timesteps)
92112
{
93113
step++;
114+
var stepTime = Stopwatch.GetTimestamp();
94115
cancellationToken.ThrowIfCancellationRequested();
95116

96117
// Create input tensor.
@@ -112,10 +133,13 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
112133
}
113134

114135
progressCallback?.Invoke(step, timesteps.Count);
136+
_logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
115137
}
116138

117139
// Decode Latents
118-
return await DecodeLatents(modelOptions, promptOptions, schedulerOptions, denoised);
140+
var result = await DecodeLatents(modelOptions, promptOptions, schedulerOptions, denoised);
141+
_logger?.LogEnd($"End", diffuseTime);
142+
return result;
119143
}
120144
}
121145

@@ -128,6 +152,8 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
128152
/// <returns></returns>
129153
protected virtual async Task<DenseTensor<float>> DecodeLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, DenseTensor<float> latents)
130154
{
155+
var timestamp = _logger?.LogBegin("Begin...");
156+
131157
// Scale and decode the image latents with vae.
132158
latents = latents.MultiplyBy(1.0f / model.ScaleFactor);
133159

@@ -144,13 +170,15 @@ protected virtual async Task<DenseTensor<float>> DecodeLatents(IModelOptions mod
144170
using (var inferResult = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inputParameters))
145171
{
146172
var resultTensor = inferResult.FirstElementAs<DenseTensor<float>>();
147-
if (prompt.BatchCount == 1)
148-
return resultTensor.ToDenseTensor();
149-
150173
imageTensors.Add(resultTensor.ToDenseTensor());
151174
}
152175
}
153-
return imageTensors.Join();
176+
177+
var result = prompt.BatchCount > 1
178+
? imageTensors.Join()
179+
: imageTensors.FirstOrDefault();
180+
_logger?.LogEnd("End", timestamp);
181+
return result;
154182
}
155183

156184

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/TextDiffuser.cs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
using Microsoft.ML.OnnxRuntime.Tensors;
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
23
using OnnxStack.Core.Services;
34
using OnnxStack.StableDiffusion.Common;
45
using OnnxStack.StableDiffusion.Config;
6+
using OnnxStack.StableDiffusion.Enums;
57
using System.Collections.Generic;
68

79
namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistency
@@ -13,11 +15,16 @@ public sealed class TextDiffuser : LatentConsistencyDiffuser
1315
/// </summary>
1416
/// <param name="configuration">The configuration.</param>
1517
/// <param name="onnxModelService">The onnx model service.</param>
16-
public TextDiffuser(IOnnxModelService onnxModelService, IPromptService promptService)
17-
: base(onnxModelService, promptService)
18+
public TextDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<LatentConsistencyDiffuser> logger)
19+
: base(onnxModelService, promptService, logger)
1820
{
1921
}
2022

23+
/// <summary>
24+
/// Gets the type of the diffuser.
25+
/// </summary>
26+
public override DiffuserType DiffuserType => DiffuserType.TextToImage;
27+
2128

2229
/// <summary>
2330
/// Gets the timesteps.

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/AnimateDiffuser.cs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
using Microsoft.ML.OnnxRuntime.Tensors;
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
23
using OnnxStack.Core.Services;
34
using OnnxStack.StableDiffusion.Common;
45
using OnnxStack.StableDiffusion.Config;
6+
using OnnxStack.StableDiffusion.Enums;
57
using System.Collections.Generic;
68

79
namespace OnnxStack.StableDiffusion.Diffusers.StableDiffusion
@@ -14,8 +16,14 @@ public sealed class AnimateDiffuser : StableDiffusionDiffuser
1416
/// </summary>
1517
/// <param name="onnxModelService">The onnx model service.</param>
1618
/// <param name="promptService"></param>
17-
public AnimateDiffuser(IOnnxModelService onnxModelService, IPromptService promptService)
18-
: base(onnxModelService, promptService) { }
19+
public AnimateDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<StableDiffusionDiffuser> logger)
20+
: base(onnxModelService, promptService, logger) { }
21+
22+
23+
/// <summary>
24+
/// Gets the type of the diffuser.
25+
/// </summary>
26+
public override DiffuserType DiffuserType => DiffuserType.ImageToAnimation;
1927

2028

2129
/// <summary>

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
using Microsoft.ML.OnnxRuntime;
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime;
23
using Microsoft.ML.OnnxRuntime.Tensors;
34
using OnnxStack.Core.Config;
45
using OnnxStack.Core.Services;
56
using OnnxStack.StableDiffusion.Common;
67
using OnnxStack.StableDiffusion.Config;
8+
using OnnxStack.StableDiffusion.Enums;
79
using OnnxStack.StableDiffusion.Helpers;
810
using SixLabors.ImageSharp;
911
using System;
@@ -20,12 +22,18 @@ public sealed class ImageDiffuser : StableDiffusionDiffuser
2022
/// </summary>
2123
/// <param name="configuration">The configuration.</param>
2224
/// <param name="onnxModelService">The onnx model service.</param>
23-
public ImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptService)
24-
: base(onnxModelService, promptService)
25+
public ImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<StableDiffusionDiffuser> logger)
26+
: base(onnxModelService, promptService, logger)
2527
{
2628
}
2729

2830

31+
/// <summary>
32+
/// Gets the type of the diffuser.
33+
/// </summary>
34+
public override DiffuserType DiffuserType => DiffuserType.ImageToImage;
35+
36+
2937
/// <summary>
3038
/// Gets the timesteps.
3139
/// </summary>

0 commit comments

Comments
 (0)