Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 850b32a

Browse files
committed
Fix previous timestep selection
1 parent 6387cce commit 850b32a

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

OnnxStack.StableDiffusion/Schedulers/LatentConsistency/LCMScheduler.cs

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
2+
using OnnxStack.Core;
23
using OnnxStack.StableDiffusion.Config;
34
using OnnxStack.StableDiffusion.Enums;
45
using OnnxStack.StableDiffusion.Helpers;
@@ -109,13 +110,23 @@ public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int tim
109110
int currentTimestep = timestep;
110111

111112
// 1. get previous step value
112-
int previousTimestep = GetPreviousTimestep(currentTimestep);
113+
int prevIndex = Timesteps.IndexOf(currentTimestep) + 1;
114+
int previousTimestep = prevIndex < Timesteps.Count
115+
? Timesteps[prevIndex]
116+
: currentTimestep;
113117

114118
//# 2. compute alphas, betas
115119
float alphaProdT = _alphasCumProd[currentTimestep];
116-
float alphaProdTPrev = previousTimestep >= 0 ? _alphasCumProd[previousTimestep] : _finalAlphaCumprod;
120+
float alphaProdTPrev = previousTimestep >= 0
121+
? _alphasCumProd[previousTimestep]
122+
: _finalAlphaCumprod;
117123
float betaProdT = 1f - alphaProdT;
118124
float betaProdTPrev = 1f - alphaProdTPrev;
125+
float alphaSqrt = MathF.Sqrt(alphaProdT);
126+
float betaSqrt = MathF.Sqrt(betaProdT);
127+
float betaProdTPrevSqrt = MathF.Sqrt(betaProdTPrev);
128+
float alphaProdTPrevSqrt = MathF.Sqrt(alphaProdTPrev);
129+
119130

120131
// 3.Get scalings for boundary conditions
121132
(float cSkip, float cOut) = GetBoundaryConditionScalings(currentTimestep);
@@ -125,17 +136,16 @@ public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int tim
125136
DenseTensor<float> predOriginalSample = null;
126137
if (Options.PredictionType == PredictionType.Epsilon)
127138
{
128-
var sampleBeta = sample.SubtractTensors(modelOutput.MultipleTensorByFloat((float)Math.Sqrt(betaProdT)));
129-
predOriginalSample = sampleBeta.DivideTensorByFloat((float)Math.Sqrt(alphaProdT));
139+
predOriginalSample = sample
140+
.SubtractTensors(modelOutput.MultipleTensorByFloat(betaSqrt))
141+
.DivideTensorByFloat(alphaSqrt);
130142
}
131143
else if (Options.PredictionType == PredictionType.Sample)
132144
{
133145
predOriginalSample = modelOutput;
134146
}
135147
else if (Options.PredictionType == PredictionType.VariablePrediction)
136148
{
137-
var alphaSqrt = (float)Math.Sqrt(alphaProdT);
138-
var betaSqrt = (float)Math.Sqrt(betaProdT);
139149
predOriginalSample = sample
140150
.MultipleTensorByFloat(alphaSqrt)
141151
.SubtractTensors(modelOutput.MultipleTensorByFloat(betaSqrt));
@@ -155,8 +165,8 @@ public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int tim
155165
//# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
156166
var prevSample = Timesteps.Count > 1
157167
? CreateRandomSample(modelOutput.Dimensions)
158-
.MultipleTensorByFloat(MathF.Sqrt(betaProdTPrev))
159-
.AddTensors(denoised.MultipleTensorByFloat(MathF.Sqrt(alphaProdTPrev)))
168+
.MultipleTensorByFloat(betaProdTPrevSqrt)
169+
.AddTensors(denoised.MultipleTensorByFloat(alphaProdTPrevSqrt))
160170
: denoised;
161171

162172
return new SchedulerStepResult(prevSample, denoised);
@@ -175,8 +185,8 @@ public override DenseTensor<float> AddNoise(DenseTensor<float> originalSamples,
175185
// Ref: https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L456
176186
int timestep = timesteps[0];
177187
float alphaProd = _alphasCumProd[timestep];
178-
float sqrtAlpha = (float)Math.Sqrt(alphaProd);
179-
float sqrtOneMinusAlpha = (float)Math.Sqrt(1.0f - alphaProd);
188+
float sqrtAlpha = MathF.Sqrt(alphaProd);
189+
float sqrtOneMinusAlpha = MathF.Sqrt(1.0f - alphaProd);
180190

181191
return noise
182192
.MultipleTensorByFloat(sqrtOneMinusAlpha)

0 commit comments

Comments
 (0)