Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 34e3d86

Browse files
committed
Support more Tokenizer, TextEncoder combinations
1 parent f1d13e5 commit 34e3d86

File tree

11 files changed

+180
-133
lines changed

11 files changed

+180
-133
lines changed

OnnxStack.Console/Examples/StableDebug.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ public async Task RunAsync()
4343
{
4444
SchedulerType = SchedulerType.LMS,
4545
Seed = 624461087,
46-
//Seed = Random.Shared.Next(),
4746
GuidanceScale = 8,
4847
InferenceSteps = 22,
4948
Strength = 0.6f
@@ -54,6 +53,9 @@ public async Task RunAsync()
5453
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
5554
await _stableDiffusionService.LoadModelAsync(model);
5655

56+
schedulerOptions.Width = model.SampleSize;
57+
schedulerOptions.Height = model.SampleSize;
58+
5759
foreach (var schedulerType in model.PipelineType.GetSchedulerTypes())
5860
{
5961
schedulerOptions.SchedulerType = schedulerType;

OnnxStack.Console/appsettings.json

Lines changed: 15 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
"PadTokenId": 49407,
1515
"BlankTokenId": 49407,
1616
"TokenizerLimit": 77,
17-
"EmbeddingsLength": 768,
17+
"TokenizerLength": 768,
18+
"TokenizerType": "One",
1819
"ScaleFactor": 0.18215,
1920
"SampleSize": 512,
2021
"PipelineType": "StableDiffusion",
@@ -57,7 +58,8 @@
5758
"PadTokenId": 49407,
5859
"BlankTokenId": 49407,
5960
"TokenizerLimit": 77,
60-
"EmbeddingsLength": 768,
61+
"TokenizerLength": 768,
62+
"TokenizerType": "One",
6163
"ScaleFactor": 0.18215,
6264
"SampleSize": 512,
6365
"PipelineType": "LatentConsistency",
@@ -93,56 +95,14 @@
9395
}
9496
]
9597
},
96-
{
97-
"Name": "Photon",
98-
"IsEnabled": true,
99-
"PadTokenId": 49407,
100-
"BlankTokenId": 49407,
101-
"TokenizerLimit": 77,
102-
"EmbeddingsLength": 768,
103-
"ScaleFactor": 0.18215,
104-
"SampleSize": 512,
105-
"PipelineType": "StableDiffusion",
106-
"Diffusers": [
107-
"TextToImage",
108-
"ImageToImage",
109-
"ImageInpaintLegacy"
110-
],
111-
"DeviceId": 0,
112-
"InterOpNumThreads": 0,
113-
"IntraOpNumThreads": 0,
114-
"ExecutionMode": "ORT_SEQUENTIAL",
115-
"ExecutionProvider": "DirectML",
116-
"ModelConfigurations": [
117-
{
118-
"Type": "Tokenizer",
119-
"OnnxModelPath": "D:\\Repositories\\photon\\tokenizer\\model.onnx"
120-
},
121-
{
122-
"Type": "Unet",
123-
"OnnxModelPath": "D:\\Repositories\\photon\\unet\\model.onnx"
124-
},
125-
{
126-
"Type": "TextEncoder",
127-
"OnnxModelPath": "D:\\Repositories\\photon\\text_encoder\\model.onnx"
128-
},
129-
{
130-
"Type": "VaeEncoder",
131-
"OnnxModelPath": "D:\\Repositories\\photon\\vae_encoder\\model.onnx"
132-
},
133-
{
134-
"Type": "VaeDecoder",
135-
"OnnxModelPath": "D:\\Repositories\\photon\\vae_decoder\\model.onnx"
136-
}
137-
]
138-
},
13998
{
14099
"Name": "InstaFlow",
141100
"IsEnabled": true,
142101
"PadTokenId": 49407,
143102
"BlankTokenId": 49407,
144103
"TokenizerLimit": 77,
145-
"EmbeddingsLength": 768,
104+
"TokenizerLength": 768,
105+
"TokenizerType": "One",
146106
"ScaleFactor": 0.18215,
147107
"SampleSize": 512,
148108
"PipelineType": "InstaFlow",
@@ -178,14 +138,14 @@
178138
]
179139
},
180140
{
181-
"Name": "DreamShaper XL",
141+
"Name": "Stable Diffusion XL",
182142
"IsEnabled": true,
183143
"PadTokenId": 1,
184144
"BlankTokenId": 49407,
185145
"TokenizerLimit": 77,
186-
"EmbeddingsLength": 768,
187-
"DualEmbeddingsLength": 1280,
188-
"IsDualTokenizer": true,
146+
"TokenizerLength": 768,
147+
"Tokenizer2Length": 1280,
148+
"TokenizerType": "Both",
189149
"ScaleFactor": 0.13025,
190150
"SampleSize": 1024,
191151
"PipelineType": "StableDiffusionXL",
@@ -200,23 +160,23 @@
200160
"ModelConfigurations": [
201161
{
202162
"Type": "Tokenizer",
203-
"OnnxModelPath": "D:\\Repositories\\dreamshaper-xl-1-0-Olive-Onnx\\tokenizer\\model.onnx"
163+
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-xl-base-1.0-Olive-Onnx\\tokenizer\\model.onnx"
204164
},
205165
{
206166
"Type": "Tokenizer2",
207-
"OnnxModelPath": "D:\\Repositories\\dreamshaper-xl-1-0-Olive-Onnx\\tokenizer_2\\model.onnx"
167+
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-xl-base-1.0-Olive-Onnx\\tokenizer_2\\model.onnx"
208168
},
209169
{
210170
"Type": "Unet",
211-
"OnnxModelPath": "D:\\Repositories\\dreamshaper-xl-1-0-Olive-Onnx\\unet\\model.onnx"
171+
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-xl-base-1.0-Olive-Onnx\\unet\\model.onnx"
212172
},
213173
{
214174
"Type": "TextEncoder",
215-
"OnnxModelPath": "D:\\Repositories\\dreamshaper-xl-1-0-Olive-Onnx\\text_encoder\\model.onnx"
175+
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-xl-base-1.0-Olive-Onnx\\text_encoder\\model.onnx"
216176
},
217177
{
218178
"Type": "TextEncoder2",
219-
"OnnxModelPath": "D:\\Repositories\\dreamshaper-xl-1-0-Olive-Onnx\\text_encoder_2\\model.onnx"
179+
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-xl-base-1.0-Olive-Onnx\\text_encoder_2\\model.onnx"
220180
},
221181
{
222182
"Type": "VaeEncoder",

OnnxStack.StableDiffusion/Common/IModelOptions.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ public interface IModelOptions : IOnnxModel
1313
int SampleSize { get; set; }
1414
float ScaleFactor { get; set; }
1515
int TokenizerLimit { get; set; }
16-
int EmbeddingsLength { get; set; }
17-
int DualEmbeddingsLength { get; set; }
18-
bool IsDualTokenizer { get; set; }
16+
int TokenizerLength { get; set; }
17+
int Tokenizer2Length { get; set; }
18+
TokenizerType TokenizerType { get; set; }
1919
DiffuserPipelineType PipelineType { get; set; }
2020
List<DiffuserType> Diffusers { get; set; }
2121
ImmutableArray<int> BlankTokenValueArray { get; set; }

OnnxStack.StableDiffusion/Config/ModelOptions.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ public class ModelOptions : IModelOptions, IOnnxModelSetConfig
1616
public int PadTokenId { get; set; }
1717
public int BlankTokenId { get; set; }
1818
public int TokenizerLimit { get; set; }
19-
public int EmbeddingsLength { get; set; }
20-
public int DualEmbeddingsLength { get; set; }
21-
public bool IsDualTokenizer { get; set; }
19+
public int TokenizerLength { get; set; }
20+
public int Tokenizer2Length { get; set; }
21+
public TokenizerType TokenizerType { get; set; }
2222
public int SampleSize { get; set; } = 512;
2323
public float ScaleFactor { get; set; }
2424
public DiffuserPipelineType PipelineType { get; set; }
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
namespace OnnxStack.StableDiffusion.Enums
2+
{
3+
public enum TokenizerType
4+
{
5+
None = 0,
6+
One = 1,
7+
Two = 2,
8+
Both = 3
9+
}
10+
}

OnnxStack.StableDiffusion/Services/PromptService.cs

Lines changed: 89 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
using Microsoft.ML.OnnxRuntime;
2-
using Microsoft.ML.OnnxRuntime.Tensors;
1+
using Microsoft.ML.OnnxRuntime.Tensors;
32
using OnnxStack.Core;
43
using OnnxStack.Core.Config;
54
using OnnxStack.Core.Model;
65
using OnnxStack.Core.Services;
76
using OnnxStack.StableDiffusion.Common;
87
using OnnxStack.StableDiffusion.Config;
8+
using OnnxStack.StableDiffusion.Enums;
99
using OnnxStack.StableDiffusion.Helpers;
1010
using System;
1111
using System.Collections.Generic;
@@ -40,6 +40,25 @@ public record EmbedsResult(DenseTensor<float> PromptEmbeds, DenseTensor<float> P
4040
/// <param name="negativePrompt">The negative prompt.</param>
4141
/// <returns>Tensor containing all text embeds generated from the prompt and negative prompt</returns>
4242
public async Task<PromptEmbeddingsResult> CreatePromptAsync(IModelOptions model, PromptOptions promptOptions, bool isGuidanceEnabled)
43+
{
44+
return model.TokenizerType switch
45+
{
46+
TokenizerType.One => await CreateEmbedsOneAsync(model, promptOptions, isGuidanceEnabled),
47+
TokenizerType.Two => await CreateEmbedsTwoAsync(model, promptOptions, isGuidanceEnabled),
48+
TokenizerType.Both => await CreateEmbedsBothAsync(model, promptOptions, isGuidanceEnabled),
49+
_ => throw new ArgumentException("TokenizerType is not set")
50+
};
51+
}
52+
53+
54+
/// <summary>
55+
/// Creates the embeds using Tokenizer and TextEncoder
56+
/// </summary>
57+
/// <param name="model">The model.</param>
58+
/// <param name="promptOptions">The prompt options.</param>
59+
/// <param name="isGuidanceEnabled">if set to <c>true</c> is guidance enabled.</param>
60+
/// <returns></returns>
61+
private async Task<PromptEmbeddingsResult> CreateEmbedsOneAsync(IModelOptions model, PromptOptions promptOptions, bool isGuidanceEnabled)
4362
{
4463
// Tokenize Prompt and NegativePrompt
4564
var promptTokens = await DecodeTextAsIntAsync(model, promptOptions.Prompt);
@@ -50,31 +69,74 @@ public async Task<PromptEmbeddingsResult> CreatePromptAsync(IModelOptions model,
5069
var promptEmbeddings = await GenerateEmbedsAsync(model, promptTokens, maxPromptTokenCount);
5170
var negativePromptEmbeddings = await GenerateEmbedsAsync(model, negativePromptTokens, maxPromptTokenCount);
5271

53-
if (model.IsDualTokenizer)
54-
{
55-
/// Tokenize Prompt and NegativePrompt with Tokenizer2
56-
var dualPromptTokens = await DecodeTextAsLongAsync(model, promptOptions.Prompt);
57-
var dualNegativePromptTokens = await DecodeTextAsLongAsync(model, promptOptions.NegativePrompt);
72+
if (isGuidanceEnabled)
73+
return new PromptEmbeddingsResult(negativePromptEmbeddings.Concatenate(promptEmbeddings));
5874

59-
// Generate embeds for tokens
60-
var dualPromptEmbeddings = await GenerateEmbedsAsync(model, dualPromptTokens, maxPromptTokenCount);
61-
var dualNegativePromptEmbeddings = await GenerateEmbedsAsync(model, dualNegativePromptTokens, maxPromptTokenCount);
75+
return new PromptEmbeddingsResult(promptEmbeddings);
76+
}
6277

63-
var dualPrompt = promptEmbeddings.Concatenate(dualPromptEmbeddings.PromptEmbeds, 2);
64-
var dualNegativePrompt = negativePromptEmbeddings.Concatenate(dualNegativePromptEmbeddings.PromptEmbeds, 2);
65-
var pooledPromptEmbeds = dualPromptEmbeddings.PooledPromptEmbeds;
66-
var pooledNegativePromptEmbeds = dualNegativePromptEmbeddings.PooledPromptEmbeds;
78+
/// <summary>
79+
/// Creates the embeds using Tokenizer2 and TextEncoder2
80+
/// </summary>
81+
/// <param name="model">The model.</param>
82+
/// <param name="promptOptions">The prompt options.</param>
83+
/// <param name="isGuidanceEnabled">if set to <c>true</c> is guidance enabled.</param>
84+
/// <returns></returns>
85+
private async Task<PromptEmbeddingsResult> CreateEmbedsTwoAsync(IModelOptions model, PromptOptions promptOptions, bool isGuidanceEnabled)
86+
{
87+
/// Tokenize Prompt and NegativePrompt with Tokenizer2
88+
var promptTokens = await DecodeTextAsLongAsync(model, promptOptions.Prompt);
89+
var negativePromptTokens = await DecodeTextAsLongAsync(model, promptOptions.NegativePrompt);
90+
var maxPromptTokenCount = Math.Max(promptTokens.Length, negativePromptTokens.Length);
6791

68-
if (isGuidanceEnabled)
69-
return new PromptEmbeddingsResult(dualNegativePrompt.Concatenate(dualPrompt), pooledNegativePromptEmbeds.Concatenate(pooledPromptEmbeds));
92+
// Generate embeds for tokens
93+
var promptEmbeddings = await GenerateEmbedsAsync(model, promptTokens, maxPromptTokenCount);
94+
var negativePromptEmbeddings = await GenerateEmbedsAsync(model, negativePromptTokens, maxPromptTokenCount);
7095

71-
return new PromptEmbeddingsResult(dualPrompt, pooledPromptEmbeds);
72-
}
96+
if (isGuidanceEnabled)
97+
return new PromptEmbeddingsResult(
98+
negativePromptEmbeddings.PromptEmbeds.Concatenate(promptEmbeddings.PromptEmbeds),
99+
negativePromptEmbeddings.PooledPromptEmbeds.Concatenate(promptEmbeddings.PooledPromptEmbeds));
100+
101+
return new PromptEmbeddingsResult(promptEmbeddings.PromptEmbeds, promptEmbeddings.PooledPromptEmbeds);
102+
}
103+
104+
105+
/// <summary>
106+
/// Creates the embeds using Tokenizer, Tokenizer2, TextEncoder and TextEncoder2
107+
/// </summary>
108+
/// <param name="model">The model.</param>
109+
/// <param name="promptOptions">The prompt options.</param>
110+
/// <param name="isGuidanceEnabled">if set to <c>true</c> is guidance enabled.</param>
111+
/// <returns></returns>
112+
private async Task<PromptEmbeddingsResult> CreateEmbedsBothAsync(IModelOptions model, PromptOptions promptOptions, bool isGuidanceEnabled)
113+
{
114+
// Tokenize Prompt and NegativePrompt
115+
var promptTokens = await DecodeTextAsIntAsync(model, promptOptions.Prompt);
116+
var negativePromptTokens = await DecodeTextAsIntAsync(model, promptOptions.NegativePrompt);
117+
var maxPromptTokenCount = Math.Max(promptTokens.Length, negativePromptTokens.Length);
118+
119+
// Generate embeds for tokens
120+
var promptEmbeddings = await GenerateEmbedsAsync(model, promptTokens, maxPromptTokenCount);
121+
var negativePromptEmbeddings = await GenerateEmbedsAsync(model, negativePromptTokens, maxPromptTokenCount);
122+
123+
/// Tokenize Prompt and NegativePrompt with Tokenizer2
124+
var dualPromptTokens = await DecodeTextAsLongAsync(model, promptOptions.Prompt);
125+
var dualNegativePromptTokens = await DecodeTextAsLongAsync(model, promptOptions.NegativePrompt);
126+
127+
// Generate embeds for tokens
128+
var dualPromptEmbeddings = await GenerateEmbedsAsync(model, dualPromptTokens, maxPromptTokenCount);
129+
var dualNegativePromptEmbeddings = await GenerateEmbedsAsync(model, dualNegativePromptTokens, maxPromptTokenCount);
130+
131+
var dualPrompt = promptEmbeddings.Concatenate(dualPromptEmbeddings.PromptEmbeds, 2);
132+
var dualNegativePrompt = negativePromptEmbeddings.Concatenate(dualNegativePromptEmbeddings.PromptEmbeds, 2);
133+
var pooledPromptEmbeds = dualPromptEmbeddings.PooledPromptEmbeds;
134+
var pooledNegativePromptEmbeds = dualNegativePromptEmbeddings.PooledPromptEmbeds;
73135

74136
if (isGuidanceEnabled)
75-
return new PromptEmbeddingsResult(negativePromptEmbeddings.Concatenate(promptEmbeddings));
137+
return new PromptEmbeddingsResult(dualNegativePrompt.Concatenate(dualPrompt), pooledNegativePromptEmbeds.Concatenate(pooledPromptEmbeds));
76138

77-
return new PromptEmbeddingsResult(promptEmbeddings);
139+
return new PromptEmbeddingsResult(dualPrompt, pooledPromptEmbeds);
78140
}
79141

80142

@@ -138,7 +200,7 @@ private Task<long[]> DecodeTextAsLongAsync(IModelOptions model, string inputText
138200
private async Task<float[]> EncodeTokensAsync(IModelOptions model, int[] tokenizedInput)
139201
{
140202
var inputDim = new[] { 1, tokenizedInput.Length };
141-
var outputDim = new[] { 1, tokenizedInput.Length, model.EmbeddingsLength };
203+
var outputDim = new[] { 1, tokenizedInput.Length, model.TokenizerLength };
142204
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.TextEncoder);
143205
var inputTensor = new DenseTensor<int>(tokenizedInput, inputDim);
144206
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
@@ -164,8 +226,8 @@ private async Task<float[]> EncodeTokensAsync(IModelOptions model, int[] tokeniz
164226
private async Task<EncoderResult> EncodeTokensAsync(IModelOptions model, long[] tokenizedInput)
165227
{
166228
var inputDim = new[] { 1, tokenizedInput.Length };
167-
var promptOutputDim = new[] { 1, tokenizedInput.Length, model.DualEmbeddingsLength };
168-
var pooledOutputDim = new[] { 1, model.DualEmbeddingsLength };
229+
var promptOutputDim = new[] { 1, tokenizedInput.Length, model.Tokenizer2Length };
230+
var pooledOutputDim = new[] { 1, model.Tokenizer2Length };
169231
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.TextEncoder2);
170232
var inputTensor = new DenseTensor<long>(tokenizedInput, inputDim);
171233
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
@@ -206,12 +268,12 @@ private async Task<EmbedsResult> GenerateEmbedsAsync(IModelOptions model, long[]
206268
pooledEmbeds.AddRange(result.PooledPromptEmbeds);
207269
}
208270

209-
var embeddingsDim = new[] { 1, embeddings.Count / model.DualEmbeddingsLength, model.DualEmbeddingsLength };
271+
var embeddingsDim = new[] { 1, embeddings.Count / model.Tokenizer2Length, model.Tokenizer2Length };
210272
var promptTensor = TensorHelper.CreateTensor(embeddings.ToArray(), embeddingsDim);
211273

212274
//TODO: Pooled embeds do not support more than 77 tokens, just grab first set
213-
var pooledDim = new[] { 1, model.DualEmbeddingsLength };
214-
var pooledTensor = TensorHelper.CreateTensor(pooledEmbeds.Take(model.DualEmbeddingsLength).ToArray(), pooledDim);
275+
var pooledDim = new[] { 1, model.Tokenizer2Length };
276+
var pooledTensor = TensorHelper.CreateTensor(pooledEmbeds.Take(model.Tokenizer2Length).ToArray(), pooledDim);
215277
return new EmbedsResult(promptTensor, pooledTensor);
216278
}
217279

@@ -236,7 +298,7 @@ private async Task<DenseTensor<float>> GenerateEmbedsAsync(IModelOptions model,
236298
embeddings.AddRange(await EncodeTokensAsync(model, tokens.ToArray()));
237299
}
238300

239-
var dim = new[] { 1, embeddings.Count / model.EmbeddingsLength, model.EmbeddingsLength };
301+
var dim = new[] { 1, embeddings.Count / model.TokenizerLength, model.TokenizerLength };
240302
return TensorHelper.CreateTensor(embeddings.ToArray(), dim);
241303
}
242304

OnnxStack.UI/Models/ModelConfigTemplate.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ public class ModelConfigTemplate
1515
public int PadTokenId { get; set; }
1616
public int BlankTokenId { get; set; }
1717
public int TokenizerLimit { get; set; }
18-
public bool IsDualTokenizer { get; set; }
19-
public int EmbeddingsLength { get; set; }
20-
public int DualEmbeddingsLength { get; set; }
18+
public TokenizerType TokenizerType { get; set; }
19+
public int TokenizerLength { get; set; }
20+
public int Tokenizer2Length { get; set; }
2121
public float ScaleFactor { get; set; }
2222
public DiffuserPipelineType PipelineType { get; set; }
2323
public List<DiffuserType> Diffusers { get; set; } = new List<DiffuserType>();

0 commit comments

Comments
 (0)