1
- using Microsoft . ML . OnnxRuntime ;
2
- using Microsoft . ML . OnnxRuntime . Tensors ;
1
+ using Microsoft . ML . OnnxRuntime . Tensors ;
3
2
using OnnxStack . Core ;
4
3
using OnnxStack . Core . Config ;
5
4
using OnnxStack . Core . Model ;
6
5
using OnnxStack . Core . Services ;
7
6
using OnnxStack . StableDiffusion . Common ;
8
7
using OnnxStack . StableDiffusion . Config ;
8
+ using OnnxStack . StableDiffusion . Enums ;
9
9
using OnnxStack . StableDiffusion . Helpers ;
10
10
using System ;
11
11
using System . Collections . Generic ;
@@ -40,6 +40,25 @@ public record EmbedsResult(DenseTensor<float> PromptEmbeds, DenseTensor<float> P
40
40
/// <param name="negativePrompt">The negative prompt.</param>
41
41
/// <returns>Tensor containing all text embeds generated from the prompt and negative prompt</returns>
42
42
public async Task < PromptEmbeddingsResult > CreatePromptAsync ( IModelOptions model , PromptOptions promptOptions , bool isGuidanceEnabled )
43
+ {
44
+ return model . TokenizerType switch
45
+ {
46
+ TokenizerType . One => await CreateEmbedsOneAsync ( model , promptOptions , isGuidanceEnabled ) ,
47
+ TokenizerType . Two => await CreateEmbedsTwoAsync ( model , promptOptions , isGuidanceEnabled ) ,
48
+ TokenizerType . Both => await CreateEmbedsBothAsync ( model , promptOptions , isGuidanceEnabled ) ,
49
+ _ => throw new ArgumentException ( "TokenizerType is not set" )
50
+ } ;
51
+ }
52
+
53
+
54
+ /// <summary>
55
+ /// Creates the embeds using Tokenizer and TextEncoder
56
+ /// </summary>
57
+ /// <param name="model">The model.</param>
58
+ /// <param name="promptOptions">The prompt options.</param>
59
+ /// <param name="isGuidanceEnabled">if set to <c>true</c> is guidance enabled.</param>
60
+ /// <returns></returns>
61
+ private async Task < PromptEmbeddingsResult > CreateEmbedsOneAsync ( IModelOptions model , PromptOptions promptOptions , bool isGuidanceEnabled )
43
62
{
44
63
// Tokenize Prompt and NegativePrompt
45
64
var promptTokens = await DecodeTextAsIntAsync ( model , promptOptions . Prompt ) ;
@@ -50,31 +69,74 @@ public async Task<PromptEmbeddingsResult> CreatePromptAsync(IModelOptions model,
50
69
var promptEmbeddings = await GenerateEmbedsAsync ( model , promptTokens , maxPromptTokenCount ) ;
51
70
var negativePromptEmbeddings = await GenerateEmbedsAsync ( model , negativePromptTokens , maxPromptTokenCount ) ;
52
71
53
- if ( model . IsDualTokenizer )
54
- {
55
- /// Tokenize Prompt and NegativePrompt with Tokenizer2
56
- var dualPromptTokens = await DecodeTextAsLongAsync ( model , promptOptions . Prompt ) ;
57
- var dualNegativePromptTokens = await DecodeTextAsLongAsync ( model , promptOptions . NegativePrompt ) ;
72
+ if ( isGuidanceEnabled )
73
+ return new PromptEmbeddingsResult ( negativePromptEmbeddings . Concatenate ( promptEmbeddings ) ) ;
58
74
59
- // Generate embeds for tokens
60
- var dualPromptEmbeddings = await GenerateEmbedsAsync ( model , dualPromptTokens , maxPromptTokenCount ) ;
61
- var dualNegativePromptEmbeddings = await GenerateEmbedsAsync ( model , dualNegativePromptTokens , maxPromptTokenCount ) ;
75
+ return new PromptEmbeddingsResult ( promptEmbeddings ) ;
76
+ }
62
77
63
- var dualPrompt = promptEmbeddings . Concatenate ( dualPromptEmbeddings . PromptEmbeds , 2 ) ;
64
- var dualNegativePrompt = negativePromptEmbeddings . Concatenate ( dualNegativePromptEmbeddings . PromptEmbeds , 2 ) ;
65
- var pooledPromptEmbeds = dualPromptEmbeddings . PooledPromptEmbeds ;
66
- var pooledNegativePromptEmbeds = dualNegativePromptEmbeddings . PooledPromptEmbeds ;
78
+ /// <summary>
79
+ /// Creates the embeds using Tokenizer2 and TextEncoder2
80
+ /// </summary>
81
+ /// <param name="model">The model.</param>
82
+ /// <param name="promptOptions">The prompt options.</param>
83
+ /// <param name="isGuidanceEnabled">if set to <c>true</c> is guidance enabled.</param>
84
+ /// <returns></returns>
85
+ private async Task < PromptEmbeddingsResult > CreateEmbedsTwoAsync ( IModelOptions model , PromptOptions promptOptions , bool isGuidanceEnabled )
86
+ {
87
+ /// Tokenize Prompt and NegativePrompt with Tokenizer2
88
+ var promptTokens = await DecodeTextAsLongAsync ( model , promptOptions . Prompt ) ;
89
+ var negativePromptTokens = await DecodeTextAsLongAsync ( model , promptOptions . NegativePrompt ) ;
90
+ var maxPromptTokenCount = Math . Max ( promptTokens . Length , negativePromptTokens . Length ) ;
67
91
68
- if ( isGuidanceEnabled )
69
- return new PromptEmbeddingsResult ( dualNegativePrompt . Concatenate ( dualPrompt ) , pooledNegativePromptEmbeds . Concatenate ( pooledPromptEmbeds ) ) ;
92
+ // Generate embeds for tokens
93
+ var promptEmbeddings = await GenerateEmbedsAsync ( model , promptTokens , maxPromptTokenCount ) ;
94
+ var negativePromptEmbeddings = await GenerateEmbedsAsync ( model , negativePromptTokens , maxPromptTokenCount ) ;
70
95
71
- return new PromptEmbeddingsResult ( dualPrompt , pooledPromptEmbeds ) ;
72
- }
96
+ if ( isGuidanceEnabled )
97
+ return new PromptEmbeddingsResult (
98
+ negativePromptEmbeddings . PromptEmbeds . Concatenate ( promptEmbeddings . PromptEmbeds ) ,
99
+ negativePromptEmbeddings . PooledPromptEmbeds . Concatenate ( promptEmbeddings . PooledPromptEmbeds ) ) ;
100
+
101
+ return new PromptEmbeddingsResult ( promptEmbeddings . PromptEmbeds , promptEmbeddings . PooledPromptEmbeds ) ;
102
+ }
103
+
104
+
105
+ /// <summary>
106
+ /// Creates the embeds using Tokenizer, Tokenizer2, TextEncoder and TextEncoder2
107
+ /// </summary>
108
+ /// <param name="model">The model.</param>
109
+ /// <param name="promptOptions">The prompt options.</param>
110
+ /// <param name="isGuidanceEnabled">if set to <c>true</c> is guidance enabled.</param>
111
+ /// <returns></returns>
112
+ private async Task < PromptEmbeddingsResult > CreateEmbedsBothAsync ( IModelOptions model , PromptOptions promptOptions , bool isGuidanceEnabled )
113
+ {
114
+ // Tokenize Prompt and NegativePrompt
115
+ var promptTokens = await DecodeTextAsIntAsync ( model , promptOptions . Prompt ) ;
116
+ var negativePromptTokens = await DecodeTextAsIntAsync ( model , promptOptions . NegativePrompt ) ;
117
+ var maxPromptTokenCount = Math . Max ( promptTokens . Length , negativePromptTokens . Length ) ;
118
+
119
+ // Generate embeds for tokens
120
+ var promptEmbeddings = await GenerateEmbedsAsync ( model , promptTokens , maxPromptTokenCount ) ;
121
+ var negativePromptEmbeddings = await GenerateEmbedsAsync ( model , negativePromptTokens , maxPromptTokenCount ) ;
122
+
123
+ /// Tokenize Prompt and NegativePrompt with Tokenizer2
124
+ var dualPromptTokens = await DecodeTextAsLongAsync ( model , promptOptions . Prompt ) ;
125
+ var dualNegativePromptTokens = await DecodeTextAsLongAsync ( model , promptOptions . NegativePrompt ) ;
126
+
127
+ // Generate embeds for tokens
128
+ var dualPromptEmbeddings = await GenerateEmbedsAsync ( model , dualPromptTokens , maxPromptTokenCount ) ;
129
+ var dualNegativePromptEmbeddings = await GenerateEmbedsAsync ( model , dualNegativePromptTokens , maxPromptTokenCount ) ;
130
+
131
+ var dualPrompt = promptEmbeddings . Concatenate ( dualPromptEmbeddings . PromptEmbeds , 2 ) ;
132
+ var dualNegativePrompt = negativePromptEmbeddings . Concatenate ( dualNegativePromptEmbeddings . PromptEmbeds , 2 ) ;
133
+ var pooledPromptEmbeds = dualPromptEmbeddings . PooledPromptEmbeds ;
134
+ var pooledNegativePromptEmbeds = dualNegativePromptEmbeddings . PooledPromptEmbeds ;
73
135
74
136
if ( isGuidanceEnabled )
75
- return new PromptEmbeddingsResult ( negativePromptEmbeddings . Concatenate ( promptEmbeddings ) ) ;
137
+ return new PromptEmbeddingsResult ( dualNegativePrompt . Concatenate ( dualPrompt ) , pooledNegativePromptEmbeds . Concatenate ( pooledPromptEmbeds ) ) ;
76
138
77
- return new PromptEmbeddingsResult ( promptEmbeddings ) ;
139
+ return new PromptEmbeddingsResult ( dualPrompt , pooledPromptEmbeds ) ;
78
140
}
79
141
80
142
@@ -138,7 +200,7 @@ private Task<long[]> DecodeTextAsLongAsync(IModelOptions model, string inputText
138
200
private async Task < float [ ] > EncodeTokensAsync ( IModelOptions model , int [ ] tokenizedInput )
139
201
{
140
202
var inputDim = new [ ] { 1 , tokenizedInput . Length } ;
141
- var outputDim = new [ ] { 1 , tokenizedInput . Length , model . EmbeddingsLength } ;
203
+ var outputDim = new [ ] { 1 , tokenizedInput . Length , model . TokenizerLength } ;
142
204
var metadata = _onnxModelService . GetModelMetadata ( model , OnnxModelType . TextEncoder ) ;
143
205
var inputTensor = new DenseTensor < int > ( tokenizedInput , inputDim ) ;
144
206
using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
@@ -164,8 +226,8 @@ private async Task<float[]> EncodeTokensAsync(IModelOptions model, int[] tokeniz
164
226
private async Task < EncoderResult > EncodeTokensAsync ( IModelOptions model , long [ ] tokenizedInput )
165
227
{
166
228
var inputDim = new [ ] { 1 , tokenizedInput . Length } ;
167
- var promptOutputDim = new [ ] { 1 , tokenizedInput . Length , model . DualEmbeddingsLength } ;
168
- var pooledOutputDim = new [ ] { 1 , model . DualEmbeddingsLength } ;
229
+ var promptOutputDim = new [ ] { 1 , tokenizedInput . Length , model . Tokenizer2Length } ;
230
+ var pooledOutputDim = new [ ] { 1 , model . Tokenizer2Length } ;
169
231
var metadata = _onnxModelService . GetModelMetadata ( model , OnnxModelType . TextEncoder2 ) ;
170
232
var inputTensor = new DenseTensor < long > ( tokenizedInput , inputDim ) ;
171
233
using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
@@ -206,12 +268,12 @@ private async Task<EmbedsResult> GenerateEmbedsAsync(IModelOptions model, long[]
206
268
pooledEmbeds . AddRange ( result . PooledPromptEmbeds ) ;
207
269
}
208
270
209
- var embeddingsDim = new [ ] { 1 , embeddings . Count / model . DualEmbeddingsLength , model . DualEmbeddingsLength } ;
271
+ var embeddingsDim = new [ ] { 1 , embeddings . Count / model . Tokenizer2Length , model . Tokenizer2Length } ;
210
272
var promptTensor = TensorHelper . CreateTensor ( embeddings . ToArray ( ) , embeddingsDim ) ;
211
273
212
274
//TODO: Pooled embeds do not support more than 77 tokens, just grab first set
213
- var pooledDim = new [ ] { 1 , model . DualEmbeddingsLength } ;
214
- var pooledTensor = TensorHelper . CreateTensor ( pooledEmbeds . Take ( model . DualEmbeddingsLength ) . ToArray ( ) , pooledDim ) ;
275
+ var pooledDim = new [ ] { 1 , model . Tokenizer2Length } ;
276
+ var pooledTensor = TensorHelper . CreateTensor ( pooledEmbeds . Take ( model . Tokenizer2Length ) . ToArray ( ) , pooledDim ) ;
215
277
return new EmbedsResult ( promptTensor , pooledTensor ) ;
216
278
}
217
279
@@ -236,7 +298,7 @@ private async Task<DenseTensor<float>> GenerateEmbedsAsync(IModelOptions model,
236
298
embeddings . AddRange ( await EncodeTokensAsync ( model , tokens . ToArray ( ) ) ) ;
237
299
}
238
300
239
- var dim = new [ ] { 1 , embeddings . Count / model . EmbeddingsLength , model . EmbeddingsLength } ;
301
+ var dim = new [ ] { 1 , embeddings . Count / model . TokenizerLength , model . TokenizerLength } ;
240
302
return TensorHelper . CreateTensor ( embeddings . ToArray ( ) , dim ) ;
241
303
}
242
304
0 commit comments