Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 38f60b5

Browse files
committed
Support models with Float timestep tensor input
1 parent e56a7b6 commit 38f60b5

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,18 @@ protected virtual async Task<DenseTensor<float>> DecodeLatents(IModelOptions mod
193193
protected virtual IReadOnlyList<NamedOnnxValue> CreateUnetInputParams(IModelOptions model, DenseTensor<float> inputTensor, DenseTensor<float> promptEmbeddings, DenseTensor<float> guidanceEmbeddings, int timestep)
194194
{
195195
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.Unet);
196+
var inputMetaData = _onnxModelService.GetInputMetadata(model, OnnxModelType.Unet);
197+
198+
// Some models support Long or Float, could be more but fornow just support these 2
199+
var timesepMetaKey = inputNames[1];
200+
var timestepMetaData = inputMetaData[timesepMetaKey];
201+
var timestepNamedOnnxValue = timestepMetaData.ElementDataType == TensorElementType.Int64
202+
? NamedOnnxValue.CreateFromTensor(timesepMetaKey, new DenseTensor<long>(new long[] { timestep }, new int[] { 1 }))
203+
: NamedOnnxValue.CreateFromTensor(timesepMetaKey, new DenseTensor<float>(new float[] { timestep }, new int[] { 1 }));
204+
196205
return CreateInputParameters(
197206
NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor),
198-
NamedOnnxValue.CreateFromTensor(inputNames[1], new DenseTensor<long>(new long[] { timestep }, new int[] { 1 })),
207+
timestepNamedOnnxValue,
199208
NamedOnnxValue.CreateFromTensor(inputNames[2], promptEmbeddings),
200209
NamedOnnxValue.CreateFromTensor(inputNames[3], guidanceEmbeddings));
201210
}

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,18 @@ protected virtual DenseTensor<float> PerformGuidance(DenseTensor<float> noisePre
213213
protected virtual IReadOnlyList<NamedOnnxValue> CreateUnetInputParams(IModelOptions model, DenseTensor<float> inputTensor, DenseTensor<float> promptEmbeddings, int timestep)
214214
{
215215
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.Unet);
216+
var inputMetaData = _onnxModelService.GetInputMetadata(model, OnnxModelType.Unet);
217+
218+
// Some models support Long or Float, could be more but fornow just support these 2
219+
var timesepMetaKey = inputNames[1];
220+
var timestepMetaData = inputMetaData[timesepMetaKey];
221+
var timestepNamedOnnxValue = timestepMetaData.ElementDataType == TensorElementType.Int64
222+
? NamedOnnxValue.CreateFromTensor(timesepMetaKey, new DenseTensor<long>(new long[] { timestep }, new int[] { 1 }))
223+
: NamedOnnxValue.CreateFromTensor(timesepMetaKey, new DenseTensor<float>(new float[] { timestep }, new int[] { 1 }));
224+
216225
return CreateInputParameters(
217226
NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor),
218-
NamedOnnxValue.CreateFromTensor(inputNames[1], new DenseTensor<long>(new long[] { timestep }, new int[] { 1 })),
227+
timestepNamedOnnxValue,
219228
NamedOnnxValue.CreateFromTensor(inputNames[2], promptEmbeddings));
220229
}
221230

0 commit comments

Comments
 (0)