1
1
using Microsoft . ML . OnnxRuntime . Tensors ;
2
2
using OnnxStack . Core ;
3
- using OnnxStack . Core . Config ;
4
3
using OnnxStack . Core . Services ;
5
4
using OnnxStack . StableDiffusion . Common ;
6
5
using OnnxStack . StableDiffusion . Config ;
@@ -26,6 +25,7 @@ namespace OnnxStack.StableDiffusion.Services
26
25
/// <seealso cref="OnnxStack.StableDiffusion.Common.IStableDiffusionService" />
27
26
public sealed class StableDiffusionService : IStableDiffusionService
28
27
{
28
+ private readonly IVideoService _videoService ;
29
29
private readonly IOnnxModelService _modelService ;
30
30
private readonly StableDiffusionConfig _configuration ;
31
31
private readonly ConcurrentDictionary < DiffuserPipelineType , IPipeline > _pipelines ;
@@ -34,10 +34,11 @@ public sealed class StableDiffusionService : IStableDiffusionService
34
34
/// Initializes a new instance of the <see cref="StableDiffusionService"/> class.
35
35
/// </summary>
36
36
/// <param name="schedulerService">The scheduler service.</param>
37
- public StableDiffusionService ( StableDiffusionConfig configuration , IOnnxModelService onnxModelService , IEnumerable < IPipeline > pipelines )
37
+ public StableDiffusionService ( StableDiffusionConfig configuration , IOnnxModelService onnxModelService , IVideoService videoService , IEnumerable < IPipeline > pipelines )
38
38
{
39
39
_configuration = configuration ;
40
40
_modelService = onnxModelService ;
41
+ _videoService = videoService ;
41
42
_pipelines = pipelines . ToConcurrentDictionary ( k => k . PipelineType , k => k ) ;
42
43
}
43
44
@@ -115,9 +116,11 @@ public async Task<Image<Rgba32>> GenerateAsImageAsync(StableDiffusionModelSet mo
115
116
/// <returns>The diffusion result as <see cref="byte[]"/></returns>
116
117
public async Task < byte [ ] > GenerateAsBytesAsync ( StableDiffusionModelSet model , PromptOptions prompt , SchedulerOptions options , Action < int , int > progressCallback = null , CancellationToken cancellationToken = default )
117
118
{
118
- return await GenerateAsync ( model , prompt , options , progressCallback , cancellationToken )
119
- . ContinueWith ( t => t . Result . ToImageBytes ( ) , cancellationToken )
120
- . ConfigureAwait ( false ) ;
119
+ var generateResult = await GenerateAsync ( model , prompt , options , progressCallback , cancellationToken ) . ConfigureAwait ( false ) ;
120
+ if ( ! prompt . HasInputVideo )
121
+ return generateResult . ToImageBytes ( ) ;
122
+
123
+ return await GetVideoResultAsBytesAsync ( options , generateResult , cancellationToken ) . ConfigureAwait ( false ) ;
121
124
}
122
125
123
126
@@ -131,9 +134,11 @@ public async Task<byte[]> GenerateAsBytesAsync(StableDiffusionModelSet model, Pr
131
134
/// <returns>The diffusion result as <see cref="System.IO.Stream"/></returns>
132
135
public async Task < Stream > GenerateAsStreamAsync ( StableDiffusionModelSet model , PromptOptions prompt , SchedulerOptions options , Action < int , int > progressCallback = null , CancellationToken cancellationToken = default )
133
136
{
134
- return await GenerateAsync ( model , prompt , options , progressCallback , cancellationToken )
135
- . ContinueWith ( t => t . Result . ToImageStream ( ) , cancellationToken )
136
- . ConfigureAwait ( false ) ;
137
+ var generateResult = await GenerateAsync ( model , prompt , options , progressCallback , cancellationToken ) . ConfigureAwait ( false ) ;
138
+ if ( ! prompt . HasInputVideo )
139
+ return generateResult . ToImageStream ( ) ;
140
+
141
+ return await GetVideoResultAsStreamAsync ( options , generateResult , cancellationToken ) . ConfigureAwait ( false ) ;
137
142
}
138
143
139
144
@@ -183,7 +188,12 @@ public async IAsyncEnumerable<Image<Rgba32>> GenerateBatchAsImageAsync(StableDif
183
188
public async IAsyncEnumerable < byte [ ] > GenerateBatchAsBytesAsync ( StableDiffusionModelSet modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , BatchOptions batchOptions , Action < int , int , int , int > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
184
189
{
185
190
await foreach ( var result in GenerateBatchAsync ( modelOptions , promptOptions , schedulerOptions , batchOptions , progressCallback , cancellationToken ) )
186
- yield return result . ImageResult . ToImageBytes ( ) ;
191
+ {
192
+ if ( ! promptOptions . HasInputVideo )
193
+ yield return result . ImageResult . ToImageBytes ( ) ;
194
+
195
+ yield return await GetVideoResultAsBytesAsync ( schedulerOptions , result . ImageResult , cancellationToken ) . ConfigureAwait ( false ) ;
196
+ }
187
197
}
188
198
189
199
@@ -200,7 +210,12 @@ public async IAsyncEnumerable<byte[]> GenerateBatchAsBytesAsync(StableDiffusionM
200
210
public async IAsyncEnumerable < Stream > GenerateBatchAsStreamAsync ( StableDiffusionModelSet modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , BatchOptions batchOptions , Action < int , int , int , int > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
201
211
{
202
212
await foreach ( var result in GenerateBatchAsync ( modelOptions , promptOptions , schedulerOptions , batchOptions , progressCallback , cancellationToken ) )
203
- yield return result . ImageResult . ToImageStream ( ) ;
213
+ {
214
+ if ( ! promptOptions . HasInputVideo )
215
+ yield return result . ImageResult . ToImageStream ( ) ;
216
+
217
+ yield return await GetVideoResultAsStreamAsync ( schedulerOptions , result . ImageResult , cancellationToken ) . ConfigureAwait ( false ) ;
218
+ }
204
219
}
205
220
206
221
@@ -237,6 +252,21 @@ private IAsyncEnumerable<BatchResult> DiffuseBatchAsync(StableDiffusionModelSet
237
252
return diffuser . DiffuseBatchAsync ( modelOptions , promptOptions , schedulerOptions , batchOptions , progress , cancellationToken ) ;
238
253
}
239
254
255
+ private async Task < byte [ ] > GetVideoResultAsBytesAsync ( SchedulerOptions options , DenseTensor < float > tensorResult , CancellationToken cancellationToken = default )
256
+ {
257
+ var frameTensors = tensorResult
258
+ . Split ( tensorResult . Dimensions [ 0 ] )
259
+ . Select ( x => x . ToImageBytes ( ) ) ;
260
+
261
+ var videoResult = await _videoService . CreateVideoAsync ( frameTensors , options . VideoFPS , cancellationToken ) ;
262
+ return videoResult . Data ;
263
+ }
264
+
265
+ private async Task < MemoryStream > GetVideoResultAsStreamAsync ( SchedulerOptions options , DenseTensor < float > tensorResult , CancellationToken cancellationToken = default )
266
+ {
267
+ return new MemoryStream ( await GetVideoResultAsBytesAsync ( options , tensorResult , cancellationToken ) ) ;
268
+ }
269
+
240
270
241
271
}
242
272
}
0 commit comments