Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 47836c1

Browse files
committed
StableDiffusionService initial video support
1 parent e2c31c2 commit 47836c1

File tree

10 files changed

+70
-26
lines changed

10 files changed

+70
-26
lines changed

OnnxStack.Core/Services/IVideoService.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ public interface IVideoService
7676
/// <param name="videoFPS">The video FPS.</param>
7777
/// <param name="cancellationToken">The cancellation token.</param>
7878
/// <returns></returns>
79-
Task<VideoResult> CreateVideoAsync(IEnumerable<byte[]> videoFrames, float videoFPS, CancellationToken cancellationToken = default);
79+
Task<VideoOutput> CreateVideoAsync(IEnumerable<byte[]> videoFrames, float videoFPS, CancellationToken cancellationToken = default);
8080

8181

8282
/// <summary>
@@ -85,7 +85,7 @@ public interface IVideoService
8585
/// <param name="videoFrames">The video frames.</param>
8686
/// <param name="cancellationToken">The cancellation token.</param>
8787
/// <returns></returns>
88-
Task<VideoResult> CreateVideoAsync(VideoFrames videoFrames, CancellationToken cancellationToken = default);
88+
Task<VideoOutput> CreateVideoAsync(VideoFrames videoFrames, CancellationToken cancellationToken = default);
8989

9090

9191
/// <summary>

OnnxStack.Core/Services/VideoService.cs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ public async Task<VideoInfo> GetVideoInfoAsync(byte[] videoBytes, CancellationTo
8888
/// <param name="videoFrames">The video frames.</param>
8989
/// <param name="cancellationToken">The cancellation token.</param>
9090
/// <returns></returns>
91-
public async Task<VideoResult> CreateVideoAsync(VideoFrames videoFrames, CancellationToken cancellationToken = default)
91+
public async Task<VideoOutput> CreateVideoAsync(VideoFrames videoFrames, CancellationToken cancellationToken = default)
9292
{
93-
return await CreateVideoInternalAsync(videoFrames.Frames, videoFrames.FPS, cancellationToken);
93+
return await CreateVideoInternalAsync(videoFrames.Frames, videoFrames.Info.FPS, cancellationToken);
9494
}
9595

9696

@@ -101,7 +101,7 @@ public async Task<VideoResult> CreateVideoAsync(VideoFrames videoFrames, Cancell
101101
/// <param name="videoFPS">The video FPS.</param>
102102
/// <param name="cancellationToken">The cancellation token.</param>
103103
/// <returns></returns>
104-
public async Task<VideoResult> CreateVideoAsync(IEnumerable<byte[]> videoFrames, float videoFPS, CancellationToken cancellationToken = default)
104+
public async Task<VideoOutput> CreateVideoAsync(IEnumerable<byte[]> videoFrames, float videoFPS, CancellationToken cancellationToken = default)
105105
{
106106
return await CreateVideoInternalAsync(videoFrames, videoFPS, cancellationToken);
107107
}
@@ -139,7 +139,9 @@ public async Task<VideoFrames> CreateFramesAsync(VideoInput videoInput, float vi
139139
/// <returns></returns>
140140
public async Task<VideoFrames> CreateFramesAsync(byte[] videoBytes, float videoFPS, CancellationToken cancellationToken = default)
141141
{
142-
return new VideoFrames(videoFPS, await CreateFramesInternalAsync(videoBytes, videoFPS, cancellationToken).ToListAsync(cancellationToken));
142+
var videoInfo = await GetVideoInfoAsync(videoBytes, cancellationToken);
143+
var videoFrames = await CreateFramesInternalAsync(videoBytes, videoFPS, cancellationToken).ToListAsync(cancellationToken);
144+
return new VideoFrames(videoInfo, videoFrames);
143145
}
144146

145147

@@ -155,7 +157,10 @@ public async Task<VideoFrames> CreateFramesAsync(Stream videoStream, float video
155157
using (var memoryStream = new MemoryStream())
156158
{
157159
await memoryStream.CopyToAsync(videoStream, cancellationToken).ConfigureAwait(false);
158-
return new VideoFrames(videoFPS, await CreateFramesInternalAsync(memoryStream.ToArray(), videoFPS, cancellationToken).ToListAsync(cancellationToken));
160+
var videoBytes = memoryStream.ToArray();
161+
var videoInfo = await GetVideoInfoAsync(videoBytes, cancellationToken);
162+
var videoFrames = await CreateFramesInternalAsync(videoBytes, videoFPS, cancellationToken).ToListAsync(cancellationToken);
163+
return new VideoFrames(videoInfo, videoFrames);
159164
}
160165
}
161166

@@ -197,7 +202,7 @@ private async Task<VideoInfo> GetVideoInfoInternalAsync(MemoryStream videoStream
197202
/// <param name="fps">The FPS.</param>
198203
/// <param name="cancellationToken">The cancellation token.</param>
199204
/// <returns></returns>
200-
private async Task<VideoResult> CreateVideoInternalAsync(IEnumerable<byte[]> imageData, float fps = 15, CancellationToken cancellationToken = default)
205+
private async Task<VideoOutput> CreateVideoInternalAsync(IEnumerable<byte[]> imageData, float fps = 15, CancellationToken cancellationToken = default)
201206
{
202207
string tempVideoPath = GetTempFilename();
203208
try
@@ -224,7 +229,7 @@ private async Task<VideoResult> CreateVideoInternalAsync(IEnumerable<byte[]> ima
224229

225230
// Analyze the result
226231
var videoInfo = await GetVideoInfoAsync(videoResult);
227-
return new VideoResult(videoResult, videoInfo);
232+
return new VideoOutput(videoResult, videoInfo);
228233
}
229234
}
230235
finally

OnnxStack.Core/Video/VideoFrames.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22

33
namespace OnnxStack.Core.Video
44
{
5-
public record VideoFrames(float FPS, List<byte[]> Frames);
5+
public record VideoFrames(VideoInfo Info, IReadOnlyList<byte[]> Frames);
66
}

OnnxStack.Core/Video/VideoInfo.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22

33
namespace OnnxStack.Core.Video
44
{
5-
public record VideoInfo(int Width, int Height, TimeSpan Duration, int Fps);
5+
public record VideoInfo(int Width, int Height, TimeSpan Duration, int FPS);
66
}

OnnxStack.Core/Video/VideoOutput.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
namespace OnnxStack.Core.Video
2+
{
3+
public record VideoOutput(byte[] Data, VideoInfo Info);
4+
}

OnnxStack.Core/Video/VideoResult.cs

Lines changed: 0 additions & 4 deletions
This file was deleted.

OnnxStack.StableDiffusion/Config/PromptOptions.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using OnnxStack.Core.Image;
2+
using OnnxStack.Core.Video;
23
using OnnxStack.StableDiffusion.Enums;
34
using System.ComponentModel.DataAnnotations;
45

@@ -19,6 +20,9 @@ public class PromptOptions
1920

2021
public InputImage InputImageMask { get; set; }
2122

23+
public VideoFrames InputVideo { get; set; }
24+
25+
public bool HasInputVideo => InputVideo?.Frames?.Count > 0;
2226
public bool HasInputImage => InputImage?.HasImage ?? false;
2327
public bool HasInputImageMask => InputImageMask?.HasImage ?? false;
2428
}

OnnxStack.StableDiffusion/Config/SchedulerOptions.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ public record SchedulerOptions
8484
public float AestheticScore { get; set; } = 6f;
8585
public float AestheticNegativeScore { get; set; } = 2.5f;
8686

87+
public float VideoFPS { get; set; }
88+
8789
public bool IsKarrasScheduler
8890
{
8991
get

OnnxStack.StableDiffusion/Enums/DiffuserType.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ public enum DiffuserType
1717
ImageInpaintLegacy = 3,
1818

1919
[Description("Image To Animation")]
20-
ImageToAnimation = 4
20+
ImageToAnimation = 4,
21+
22+
[Description("Video To Video")]
23+
VideoToVideo = 5
2124
}
2225
}

OnnxStack.StableDiffusion/Services/StableDiffusionService.cs

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
22
using OnnxStack.Core;
3-
using OnnxStack.Core.Config;
43
using OnnxStack.Core.Services;
54
using OnnxStack.StableDiffusion.Common;
65
using OnnxStack.StableDiffusion.Config;
@@ -26,6 +25,7 @@ namespace OnnxStack.StableDiffusion.Services
2625
/// <seealso cref="OnnxStack.StableDiffusion.Common.IStableDiffusionService" />
2726
public sealed class StableDiffusionService : IStableDiffusionService
2827
{
28+
private readonly IVideoService _videoService;
2929
private readonly IOnnxModelService _modelService;
3030
private readonly StableDiffusionConfig _configuration;
3131
private readonly ConcurrentDictionary<DiffuserPipelineType, IPipeline> _pipelines;
@@ -34,10 +34,11 @@ public sealed class StableDiffusionService : IStableDiffusionService
3434
/// Initializes a new instance of the <see cref="StableDiffusionService"/> class.
3535
/// </summary>
3636
/// <param name="schedulerService">The scheduler service.</param>
37-
public StableDiffusionService(StableDiffusionConfig configuration, IOnnxModelService onnxModelService, IEnumerable<IPipeline> pipelines)
37+
public StableDiffusionService(StableDiffusionConfig configuration, IOnnxModelService onnxModelService, IVideoService videoService, IEnumerable<IPipeline> pipelines)
3838
{
3939
_configuration = configuration;
4040
_modelService = onnxModelService;
41+
_videoService = videoService;
4142
_pipelines = pipelines.ToConcurrentDictionary(k => k.PipelineType, k => k);
4243
}
4344

@@ -115,9 +116,11 @@ public async Task<Image<Rgba32>> GenerateAsImageAsync(StableDiffusionModelSet mo
115116
/// <returns>The diffusion result as <see cref="byte[]"/></returns>
116117
public async Task<byte[]> GenerateAsBytesAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
117118
{
118-
return await GenerateAsync(model, prompt, options, progressCallback, cancellationToken)
119-
.ContinueWith(t => t.Result.ToImageBytes(), cancellationToken)
120-
.ConfigureAwait(false);
119+
var generateResult = await GenerateAsync(model, prompt, options, progressCallback, cancellationToken).ConfigureAwait(false);
120+
if (!prompt.HasInputVideo)
121+
return generateResult.ToImageBytes();
122+
123+
return await GetVideoResultAsBytesAsync(options, generateResult, cancellationToken).ConfigureAwait(false);
121124
}
122125

123126

@@ -131,9 +134,11 @@ public async Task<byte[]> GenerateAsBytesAsync(StableDiffusionModelSet model, Pr
131134
/// <returns>The diffusion result as <see cref="System.IO.Stream"/></returns>
132135
public async Task<Stream> GenerateAsStreamAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
133136
{
134-
return await GenerateAsync(model, prompt, options, progressCallback, cancellationToken)
135-
.ContinueWith(t => t.Result.ToImageStream(), cancellationToken)
136-
.ConfigureAwait(false);
137+
var generateResult = await GenerateAsync(model, prompt, options, progressCallback, cancellationToken).ConfigureAwait(false);
138+
if (!prompt.HasInputVideo)
139+
return generateResult.ToImageStream();
140+
141+
return await GetVideoResultAsStreamAsync(options, generateResult, cancellationToken).ConfigureAwait(false);
137142
}
138143

139144

@@ -183,7 +188,12 @@ public async IAsyncEnumerable<Image<Rgba32>> GenerateBatchAsImageAsync(StableDif
183188
public async IAsyncEnumerable<byte[]> GenerateBatchAsBytesAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
184189
{
185190
await foreach (var result in GenerateBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken))
186-
yield return result.ImageResult.ToImageBytes();
191+
{
192+
if (!promptOptions.HasInputVideo)
193+
yield return result.ImageResult.ToImageBytes();
194+
195+
yield return await GetVideoResultAsBytesAsync(schedulerOptions, result.ImageResult, cancellationToken).ConfigureAwait(false);
196+
}
187197
}
188198

189199

@@ -200,7 +210,12 @@ public async IAsyncEnumerable<byte[]> GenerateBatchAsBytesAsync(StableDiffusionM
200210
public async IAsyncEnumerable<Stream> GenerateBatchAsStreamAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
201211
{
202212
await foreach (var result in GenerateBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken))
203-
yield return result.ImageResult.ToImageStream();
213+
{
214+
if (!promptOptions.HasInputVideo)
215+
yield return result.ImageResult.ToImageStream();
216+
217+
yield return await GetVideoResultAsStreamAsync(schedulerOptions, result.ImageResult, cancellationToken).ConfigureAwait(false);
218+
}
204219
}
205220

206221

@@ -237,6 +252,21 @@ private IAsyncEnumerable<BatchResult> DiffuseBatchAsync(StableDiffusionModelSet
237252
return diffuser.DiffuseBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progress, cancellationToken);
238253
}
239254

255+
private async Task<byte[]> GetVideoResultAsBytesAsync(SchedulerOptions options, DenseTensor<float> tensorResult, CancellationToken cancellationToken = default)
256+
{
257+
var frameTensors = tensorResult
258+
.Split(tensorResult.Dimensions[0])
259+
.Select(x => x.ToImageBytes());
260+
261+
var videoResult = await _videoService.CreateVideoAsync(frameTensors, options.VideoFPS, cancellationToken);
262+
return videoResult.Data;
263+
}
264+
265+
private async Task<MemoryStream> GetVideoResultAsStreamAsync(SchedulerOptions options, DenseTensor<float> tensorResult, CancellationToken cancellationToken = default)
266+
{
267+
return new MemoryStream(await GetVideoResultAsBytesAsync(options, tensorResult, cancellationToken));
268+
}
269+
240270

241271
}
242272
}

0 commit comments

Comments
 (0)