Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 8c00ba9

Browse files
committed
LCM VideoToVideo diffuser
1 parent 47836c1 commit 8c00ba9

File tree

2 files changed

+174
-0
lines changed

2 files changed

+174
-0
lines changed
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core;
4+
using OnnxStack.Core.Config;
5+
using OnnxStack.Core.Model;
6+
using OnnxStack.Core.Services;
7+
using OnnxStack.StableDiffusion.Common;
8+
using OnnxStack.StableDiffusion.Config;
9+
using OnnxStack.StableDiffusion.Enums;
10+
using OnnxStack.StableDiffusion.Helpers;
11+
using System;
12+
using System.Collections.Generic;
13+
using System.Diagnostics;
14+
using System.Linq;
15+
using System.Threading;
16+
using System.Threading.Tasks;
17+
18+
namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistency
19+
{
20+
public sealed class VideoDiffuser : LatentConsistencyDiffuser
21+
{
22+
/// <summary>
23+
/// Initializes a new instance of the <see cref="VideoDiffuser"/> class.
24+
/// </summary>
25+
/// <param name="configuration">The configuration.</param>
26+
/// <param name="onnxModelService">The onnx model service.</param>
27+
public VideoDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<LatentConsistencyDiffuser> logger)
28+
: base(onnxModelService, promptService, logger) { }
29+
30+
31+
/// <summary>
32+
/// Gets the type of the diffuser.
33+
/// </summary>
34+
public override DiffuserType DiffuserType => DiffuserType.VideoToVideo;
35+
36+
37+
/// <summary>
38+
/// Runs the scheduler steps.
39+
/// </summary>
40+
/// <param name="modelOptions">The model options.</param>
41+
/// <param name="promptOptions">The prompt options.</param>
42+
/// <param name="schedulerOptions">The scheduler options.</param>
43+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
44+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
45+
/// <param name="progressCallback">The progress callback.</param>
46+
/// <param name="cancellationToken">The cancellation token.</param>
47+
/// <returns></returns>
48+
protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
49+
{
50+
DenseTensor<float> resultTensor = null;
51+
foreach (var videoFrame in promptOptions.InputVideo.Frames)
52+
{
53+
// Get Scheduler
54+
using (var scheduler = GetScheduler(schedulerOptions))
55+
{
56+
// Get timesteps
57+
var timesteps = GetTimesteps(schedulerOptions, scheduler);
58+
59+
// Create latent sample
60+
var latents = await PrepareFrameLatentsAsync(modelOptions, videoFrame, schedulerOptions, scheduler, timesteps);
61+
62+
// Get Guidance Scale Embedding
63+
var guidanceEmbeddings = GetGuidanceScaleEmbedding(schedulerOptions.GuidanceScale);
64+
65+
// Denoised result
66+
DenseTensor<float> denoised = null;
67+
68+
// Get Model metadata
69+
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet);
70+
71+
// Loop though the timesteps
72+
var step = 0;
73+
foreach (var timestep in timesteps)
74+
{
75+
step++;
76+
var stepTime = Stopwatch.GetTimestamp();
77+
cancellationToken.ThrowIfCancellationRequested();
78+
79+
// Create input tensor.
80+
var inputTensor = scheduler.ScaleInput(latents, timestep);
81+
var timestepTensor = CreateTimestepTensor(timestep);
82+
83+
var outputChannels = 1;
84+
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
85+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
86+
{
87+
inferenceParameters.AddInputTensor(inputTensor);
88+
inferenceParameters.AddInputTensor(timestepTensor);
89+
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
90+
inferenceParameters.AddInputTensor(guidanceEmbeddings);
91+
inferenceParameters.AddOutputBuffer(outputDimension);
92+
93+
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);
94+
using (var result = results.First())
95+
{
96+
var noisePred = result.ToDenseTensor();
97+
98+
// Scheduler Step
99+
var schedulerResult = scheduler.Step(noisePred, timestep, latents);
100+
101+
latents = schedulerResult.Result;
102+
denoised = schedulerResult.SampleData;
103+
}
104+
}
105+
106+
progressCallback?.Invoke(step, timesteps.Count);
107+
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
108+
}
109+
110+
// Decode Latents
111+
var frameResultTensor = await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, denoised);
112+
resultTensor = resultTensor is null
113+
? frameResultTensor
114+
: resultTensor.Concatenate(frameResultTensor);
115+
}
116+
}
117+
return resultTensor;
118+
}
119+
120+
121+
/// <summary>
122+
/// Gets the timesteps.
123+
/// </summary>
124+
/// <param name="prompt">The prompt.</param>
125+
/// <param name="options">The options.</param>
126+
/// <param name="scheduler">The scheduler.</param>
127+
/// <returns></returns>
128+
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
129+
{
130+
var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps);
131+
var start = Math.Max(options.InferenceSteps - inittimestep, 0);
132+
return scheduler.Timesteps.Skip(start).ToList();
133+
}
134+
135+
136+
/// <summary>
137+
/// Prepares the latents for inference.
138+
/// </summary>
139+
/// <param name="prompt">The prompt.</param>
140+
/// <param name="options">The options.</param>
141+
/// <param name="scheduler">The scheduler.</param>
142+
/// <returns></returns>
143+
private async Task<DenseTensor<float>> PrepareFrameLatentsAsync(StableDiffusionModelSet model, byte[] videoFrame, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
144+
{
145+
var imageTensor = ImageHelpers.TensorFromBytes(videoFrame, new[] { 1, 3, options.Height, options.Width });
146+
147+
//TODO: Model Config, Channels
148+
var outputDimension = options.GetScaledDimension();
149+
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeEncoder);
150+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
151+
{
152+
inferenceParameters.AddInputTensor(imageTensor);
153+
inferenceParameters.AddOutputBuffer(outputDimension);
154+
155+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inferenceParameters);
156+
using (var result = results.First())
157+
{
158+
var outputResult = result.ToDenseTensor();
159+
var scaledSample = outputResult.MultiplyBy(model.ScaleFactor);
160+
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
161+
}
162+
}
163+
}
164+
165+
protected override Task<DenseTensor<float>> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
166+
{
167+
throw new NotImplementedException();
168+
}
169+
}
170+
}

OnnxStack.StableDiffusion/Registration.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
using Microsoft.Extensions.DependencyInjection;
22
using OnnxStack.Core.Config;
3+
using OnnxStack.Core.Services;
34
using OnnxStack.StableDiffusion.Common;
45
using OnnxStack.StableDiffusion.Config;
56
using OnnxStack.StableDiffusion.Diffusers;
7+
using OnnxStack.StableDiffusion.Diffusers.LatentConsistency;
68
using OnnxStack.StableDiffusion.Pipelines;
79
using OnnxStack.StableDiffusion.Services;
810
using SixLabors.ImageSharp;
@@ -44,6 +46,7 @@ private static void RegisterServices(this IServiceCollection serviceCollection)
4446
ConfigureLibraries();
4547

4648
// Services
49+
serviceCollection.AddSingleton<IVideoService, VideoService>();
4750
serviceCollection.AddSingleton<IPromptService, PromptService>();
4851
serviceCollection.AddSingleton<IStableDiffusionService, StableDiffusionService>();
4952

@@ -69,6 +72,7 @@ private static void RegisterServices(this IServiceCollection serviceCollection)
6972
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.LatentConsistency.TextDiffuser>();
7073
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.LatentConsistency.ImageDiffuser>();
7174
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.LatentConsistency.InpaintLegacyDiffuser>();
75+
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.LatentConsistency.VideoDiffuser>();
7276

7377
//LatentConsistencyXL
7478
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.LatentConsistencyXL.TextDiffuser>();

0 commit comments

Comments
 (0)