Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit a8bdc65

Browse files
committed
Add IAsyncEnumerable return type to base diffusers
1 parent 77a2b6a commit a8bdc65

File tree

4 files changed

+63
-0
lines changed

4 files changed

+63
-0
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using System.Threading.Tasks;
6+
7+
namespace OnnxStack.StableDiffusion.Config
8+
{
9+
public class BatchOptions
10+
{
11+
12+
}
13+
}

OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using OnnxStack.StableDiffusion.Config;
44
using OnnxStack.StableDiffusion.Enums;
55
using System;
6+
using System.Collections.Generic;
67
using System.Threading;
78
using System.Threading.Tasks;
89

@@ -33,5 +34,18 @@ public interface IDiffuser
3334
/// <param name="cancellationToken">The cancellation token.</param>
3435
/// <returns></returns>
3536
Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
37+
38+
39+
/// <summary>
40+
/// Runs the stable diffusion batch loop
41+
/// </summary>
42+
/// <param name="modelOptions">The model options.</param>
43+
/// <param name="promptOptions">The prompt options.</param>
44+
/// <param name="schedulerOptions">The scheduler options.</param>
45+
/// <param name="batchOptions">The batch options.</param>
46+
/// <param name="progressCallback">The progress callback.</param>
47+
/// <param name="cancellationToken">The cancellation token.</param>
48+
/// <returns></returns>
49+
IAsyncEnumerable<DenseTensor<float>> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
3650
}
3751
}

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,23 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
144144
}
145145

146146

147+
/// <summary>
148+
/// Runs the stable diffusion batch loop
149+
/// </summary>
150+
/// <param name="modelOptions">The model options.</param>
151+
/// <param name="promptOptions">The prompt options.</param>
152+
/// <param name="schedulerOptions">The scheduler options.</param>
153+
/// <param name="batchOptions">The batch options.</param>
154+
/// <param name="progressCallback">The progress callback.</param>
155+
/// <param name="cancellationToken">The cancellation token.</param>
156+
/// <returns></returns>
157+
/// <exception cref="System.NotImplementedException"></exception>
158+
public IAsyncEnumerable<DenseTensor<float>> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int> progressCallback = null, CancellationToken cancellationToken = default)
159+
{
160+
throw new NotImplementedException();
161+
}
162+
163+
147164
/// <summary>
148165
/// Decodes the latents.
149166
/// </summary>
@@ -261,5 +278,7 @@ protected static IReadOnlyList<NamedOnnxValue> CreateInputParameters(params Name
261278
{
262279
return parameters.ToList();
263280
}
281+
282+
264283
}
265284
}

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,23 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
141141
}
142142

143143

144+
/// <summary>
145+
/// Runs the stable diffusion batch loop
146+
/// </summary>
147+
/// <param name="modelOptions">The model options.</param>
148+
/// <param name="promptOptions">The prompt options.</param>
149+
/// <param name="schedulerOptions">The scheduler options.</param>
150+
/// <param name="batchOptions">The batch options.</param>
151+
/// <param name="progressCallback">The progress callback.</param>
152+
/// <param name="cancellationToken">The cancellation token.</param>
153+
/// <returns></returns>
154+
/// <exception cref="System.NotImplementedException"></exception>
155+
public IAsyncEnumerable<DenseTensor<float>> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int> progressCallback = null, CancellationToken cancellationToken = default)
156+
{
157+
throw new NotImplementedException();
158+
}
159+
160+
144161
/// <summary>
145162
/// Decodes the latents.
146163
/// </summary>

0 commit comments

Comments
 (0)