Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 98fc358

Browse files
committed
Tidy up some rounding issues
1 parent 87eba95 commit 98fc358

File tree

4 files changed

+61
-42
lines changed

4 files changed

+61
-42
lines changed

OnnxStack.StableDiffusion/Extensions.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
using Microsoft.ML.OnnxRuntime;
2+
using NumSharp;
23
using OnnxStack.StableDiffusion.Config;
34
using OnnxStack.StableDiffusion.Enums;
5+
46
using System;
57
using System.Linq;
8+
using System.Numerics;
9+
using System.Threading.Tasks;
610

711
namespace OnnxStack.StableDiffusion
812
{
@@ -114,5 +118,6 @@ public static SchedulerType[] GetSchedulerTypes(this DiffuserPipelineType pipeli
114118
_ => default
115119
};
116120
}
121+
117122
}
118123
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using System;
2+
using System.Linq;
3+
4+
namespace OnnxStack.StableDiffusion.Helpers
5+
{
6+
internal class ArrayHelpers
7+
{
8+
public static float[] Linspace(float start, float end, int partitions, bool round = false)
9+
{
10+
var result = Enumerable.Range(0, partitions)
11+
.Select(idx => idx != partitions ? start + (end - start) / (partitions - 1) * idx : end);
12+
return !round
13+
? result.ToArray()
14+
: result.Select(x => MathF.Round(x)).ToArray();
15+
}
16+
17+
}
18+
}

OnnxStack.StableDiffusion/Schedulers/SchedulerBase.cs

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
2-
using NumSharp;
2+
using OnnxStack.Core;
33
using OnnxStack.StableDiffusion.Common;
44
using OnnxStack.StableDiffusion.Config;
55
using OnnxStack.StableDiffusion.Enums;
@@ -119,15 +119,13 @@ protected virtual float[] GetBetaSchedule()
119119
}
120120
else if (Options.BetaSchedule == BetaScheduleType.Linear)
121121
{
122-
betas = np.linspace(Options.BetaStart, Options.BetaEnd, Options.TrainTimesteps).ToArray<float>();
122+
betas = ArrayHelpers.Linspace(Options.BetaStart, Options.BetaEnd, Options.TrainTimesteps);
123123
}
124124
else if (Options.BetaSchedule == BetaScheduleType.ScaledLinear)
125125
{
126-
var start = (float)Math.Sqrt(Options.BetaStart);
127-
var end = (float)Math.Sqrt(Options.BetaEnd);
128-
betas = np.linspace(start, end, Options.TrainTimesteps)
129-
.ToArray<float>()
130-
.Select(x => x * x);
126+
var start = MathF.Sqrt(Options.BetaStart);
127+
var end = MathF.Sqrt(Options.BetaEnd);
128+
betas = ArrayHelpers.Linspace(start, end, Options.TrainTimesteps).Select(x => x * x);
131129
}
132130
else if (Options.BetaSchedule == BetaScheduleType.SquaredCosCapV2)
133131
{
@@ -136,9 +134,9 @@ protected virtual float[] GetBetaSchedule()
136134
else if (Options.BetaSchedule == BetaScheduleType.Sigmoid)
137135
{
138136
var mul = Options.BetaEnd - Options.BetaStart;
139-
var betaSig = np.linspace(-6f, 6f, Options.TrainTimesteps).ToArray<float>();
137+
var betaSig = ArrayHelpers.Linspace(-6f, 6f, Options.TrainTimesteps);
140138
var sigmoidBetas = betaSig
141-
.Select(beta => 1.0f / (1.0f + (float)Math.Exp(-beta)))
139+
.Select(beta => 1.0f / (1.0f + MathF.Exp(-beta)))
142140
.ToArray();
143141
betas = sigmoidBetas
144142
.Select(x => (x * mul) + Options.BetaStart)
@@ -158,7 +156,7 @@ protected virtual float GetInitNoiseSigma(float[] sigmas)
158156
var maxSigma = sigmas.Max();
159157
return Options.TimestepSpacing == TimestepSpacingType.Linspace
160158
|| Options.TimestepSpacing == TimestepSpacingType.Trailing
161-
? maxSigma : (float)Math.Sqrt(maxSigma * maxSigma + 1);
159+
? maxSigma : MathF.Sqrt(maxSigma * maxSigma + 1);
162160
}
163161

164162

@@ -168,28 +166,27 @@ protected virtual float GetInitNoiseSigma(float[] sigmas)
168166
/// <returns></returns>
169167
protected virtual float[] GetTimesteps()
170168
{
171-
NDArray timestepsArray = null;
172169
if (Options.TimestepSpacing == TimestepSpacingType.Linspace)
173170
{
174-
timestepsArray = np.linspace(0, Options.TrainTimesteps - 1, Options.InferenceSteps);
175-
timestepsArray = np.around(timestepsArray)["::1"];
171+
return ArrayHelpers.Linspace(0, Options.TrainTimesteps - 1, Options.InferenceSteps, true);
176172
}
177173
else if (Options.TimestepSpacing == TimestepSpacingType.Leading)
178174
{
179175
var stepRatio = Options.TrainTimesteps / Options.InferenceSteps;
180-
timestepsArray = np.arange(0, (float)Options.InferenceSteps) * stepRatio;
181-
timestepsArray = np.around(timestepsArray)["::1"];
182-
timestepsArray += Options.StepsOffset;
176+
return Enumerable.Range(0, Options.InferenceSteps)
177+
.Select(x => MathF.Round((float)x * stepRatio) + Options.StepsOffset)
178+
.ToArray();
183179
}
184180
else if (Options.TimestepSpacing == TimestepSpacingType.Trailing)
185181
{
186182
var stepRatio = Options.TrainTimesteps / (Options.InferenceSteps - 1);
187-
timestepsArray = np.arange((float)Options.TrainTimesteps, 0, -stepRatio)["::-1"];
188-
timestepsArray = np.around(timestepsArray);
189-
timestepsArray -= 1;
183+
return Enumerable.Range(0, Options.TrainTimesteps)
184+
.Where((number, index) => index % stepRatio == 0)
185+
.Select(x => (float)x)
186+
.ToArray();
190187
}
191188

192-
return timestepsArray.ToArray<float>();
189+
throw new NotImplementedException();
193190
}
194191

195192

@@ -209,7 +206,7 @@ protected virtual DenseTensor<float> GetPredictedSample(DenseTensor<float> model
209206
}
210207
else if (Options.PredictionType == PredictionType.VariablePrediction)
211208
{
212-
var sigmaSqrt = (float)Math.Sqrt(sigma * sigma + 1);
209+
var sigmaSqrt = MathF.Sqrt(sigma * sigma + 1);
213210
predOriginalSample = sample.DivideTensorByFloat(sigmaSqrt)
214211
.AddTensors(modelOutput.MultiplyTensorByFloat(-sigma / sigmaSqrt));
215212
}
@@ -253,11 +250,11 @@ protected float[] GetBetasForAlphaBar()
253250
Func<float, float> alphaBarFn = null;
254251
if (_options.AlphaTransformType == AlphaTransformType.Cosine)
255252
{
256-
alphaBarFn = t => (float)Math.Pow(Math.Cos((t + 0.008f) / 1.008f * Math.PI / 2.0f), 2.0f);
253+
alphaBarFn = t => MathF.Pow(MathF.Cos((t + 0.008f) / 1.008f * MathF.PI / 2.0f), 2.0f);
257254
}
258255
else if (_options.AlphaTransformType == AlphaTransformType.Exponential)
259256
{
260-
alphaBarFn = t => (float)Math.Exp(t * -12.0f);
257+
alphaBarFn = t => MathF.Exp(t * -12.0f);
261258
}
262259

263260
return Enumerable
@@ -266,7 +263,7 @@ protected float[] GetBetasForAlphaBar()
266263
{
267264
var t1 = (float)i / _options.TrainTimesteps;
268265
var t2 = (float)(i + 1) / _options.TrainTimesteps;
269-
return Math.Min(1f - alphaBarFn(t2) / alphaBarFn(t1), _options.MaximumBeta);
266+
return MathF.Min(1f - alphaBarFn(t2) / alphaBarFn(t1), _options.MaximumBeta);
270267
}).ToArray();
271268
}
272269

@@ -298,13 +295,13 @@ protected float[] Interpolate(float[] timesteps, float[] range, float[] sigmas)
298295
// If timesteps[i] is less than the first element in range, use the first value in sigmas
299296
else if (index == -1)
300297
{
301-
result[i] = sigmas[sigmas.Length - 1];
298+
result[i] = sigmas[0];
302299
}
303300

304301
// If timesteps[i] is greater than the last element in range, use the last value in sigmas
305-
else if (index == -range.Length - 1)
302+
else if (index > range.Length - 1)
306303
{
307-
result[i] = sigmas[0];
304+
result[i] = sigmas[sigmas.Length - 1];
308305
}
309306

310307
// Otherwise, interpolate linearly between two adjacent values in sigmas
@@ -340,14 +337,14 @@ protected float[] ConvertToKarras(float[] inSigmas)
340337
.ToArray();
341338

342339
// Calculate the inverse of sigmaMin and sigmaMax raised to the power of 1/rho
343-
float minInvRho = (float)Math.Pow(sigmaMin, 1.0 / rho);
344-
float maxInvRho = (float)Math.Pow(sigmaMax, 1.0 / rho);
340+
float minInvRho = MathF.Pow(sigmaMin, 1.0f / rho);
341+
float maxInvRho = MathF.Pow(sigmaMax, 1.0f / rho);
345342

346343
// Calculate the Karras noise schedule using the formula from the paper
347344
float[] sigmas = new float[_options.InferenceSteps];
348345
for (int i = 0; i < _options.InferenceSteps; i++)
349346
{
350-
sigmas[i] = (float)Math.Pow(maxInvRho + ramp[i] * (minInvRho - maxInvRho), rho);
347+
sigmas[i] = MathF.Pow(maxInvRho + ramp[i] * (minInvRho - maxInvRho), rho);
351348
}
352349

353350
// Return the resulting noise schedule
@@ -366,7 +363,7 @@ protected float[] SigmaToTimestep(float[] sigmas, float[] logSigmas)
366363
var timesteps = new float[sigmas.Length];
367364
for (int i = 0; i < sigmas.Length; i++)
368365
{
369-
float logSigma = (float)Math.Log(sigmas[i]);
366+
float logSigma = MathF.Log(sigmas[i]);
370367
float[] dists = new float[logSigmas.Length];
371368

372369
for (int j = 0; j < logSigmas.Length; j++)

OnnxStack.StableDiffusion/Schedulers/StableDiffusion/LMSScheduler.cs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
using NumSharp;
44
using OnnxStack.Core;
55
using OnnxStack.StableDiffusion.Config;
6-
using OnnxStack.StableDiffusion.Enums;
76
using OnnxStack.StableDiffusion.Helpers;
87
using System;
98
using System.Collections.Generic;
@@ -43,7 +42,7 @@ protected override void Initialize()
4342
var cumulativeProduct = alphas.Select((alpha, i) => alphas.Take(i + 1).Aggregate((a, b) => a * b));
4443

4544
_sigmas = cumulativeProduct
46-
.Select(alpha_prod => (float)Math.Sqrt((1 - alpha_prod) / alpha_prod))
45+
.Select(alpha_prod => MathF.Sqrt((1 - alpha_prod) / alpha_prod))
4746
.ToArray();
4847

4948
var initNoiseSigma = GetInitNoiseSigma(_sigmas);
@@ -57,12 +56,13 @@ protected override void Initialize()
5756
/// <returns></returns>
5857
protected override int[] SetTimesteps()
5958
{
60-
var sigmas = _sigmas.ToArray();
6159
var timesteps = GetTimesteps();
62-
var log_sigmas = np.log(sigmas).ToArray<float>();
63-
var range = np.arange(0, (float)_sigmas.Length).ToArray<float>();
64-
sigmas = Interpolate(timesteps, range, _sigmas);
60+
var log_sigmas = _sigmas.Select(x => MathF.Log(x)).ToArray();
61+
var range = Enumerable.Range(0, _sigmas.Length)
62+
.Select(x => (float)x)
63+
.ToArray();
6564

65+
var sigmas = Interpolate(timesteps, range, _sigmas);
6666
if (Options.UseKarrasSigmas)
6767
{
6868
sigmas = ConvertToKarras(sigmas);
@@ -72,7 +72,6 @@ protected override int[] SetTimesteps()
7272
_sigmas = sigmas
7373
.Append(0.000f)
7474
.ToArray();
75-
7675
return timesteps.Select(x => (int)x)
7776
.OrderByDescending(x => x)
7877
.ToArray();
@@ -92,7 +91,7 @@ public override DenseTensor<float> ScaleInput(DenseTensor<float> sample, int tim
9291

9392
// Get sigma at stepIndex
9493
var sigma = _sigmas[stepIndex];
95-
sigma = (float)Math.Sqrt(Math.Pow(sigma, 2) + 1);
94+
sigma = MathF.Sqrt(MathF.Pow(sigma, 2f) + 1f);
9695

9796
// Divide sample tensor shape {2,4,(H/8),(W/8)} by sigma
9897
return sample.DivideTensorByFloat(sigma);
@@ -144,7 +143,7 @@ public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int tim
144143
{
145144
// Multiply to coeff by each derivatives to create the new tensors
146145
var (lmsCoeff, derivative) = lmsCoeffsAndDerivatives[i];
147-
lmsDerProduct[i] = derivative.MultiplyTensorByFloat((float)lmsCoeff);
146+
lmsDerProduct[i] = derivative.MultiplyTensorByFloat(lmsCoeff);
148147
}
149148

150149
// Add the sumed tensor to the sample
@@ -180,7 +179,7 @@ public override DenseTensor<float> AddNoise(DenseTensor<float> originalSamples,
180179
/// <param name="t">The t.</param>
181180
/// <param name="currentOrder">The current order.</param>
182181
/// <returns></returns>
183-
private double GetLmsCoefficient(int order, int t, int currentOrder)
182+
private float GetLmsCoefficient(int order, int t, int currentOrder)
184183
{
185184
// python line 135 of scheduling_lms_discrete.py
186185
// Compute a linear multistep coefficient.
@@ -197,7 +196,7 @@ double LmsDerivative(double tau)
197196
}
198197
return prod;
199198
}
200-
return Integrate.OnClosedInterval(LmsDerivative, _sigmas[t], _sigmas[t + 1], 1e-4);
199+
return (float)Integrate.OnClosedInterval(LmsDerivative, _sigmas[t], _sigmas[t + 1], 1e-4);
201200
}
202201

203202
protected override void Dispose(bool disposing)

0 commit comments

Comments
 (0)