Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 091cb0d

Browse files
committed
SDXL Refiner model support
1 parent 34e3d86 commit 091cb0d

File tree

9 files changed

+49
-12
lines changed

9 files changed

+49
-12
lines changed

OnnxStack.StableDiffusion/Common/IModelOptions.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ public interface IModelOptions : IOnnxModel
1515
int TokenizerLimit { get; set; }
1616
int TokenizerLength { get; set; }
1717
int Tokenizer2Length { get; set; }
18+
ModelType ModelType { get; set; }
1819
TokenizerType TokenizerType { get; set; }
1920
DiffuserPipelineType PipelineType { get; set; }
2021
List<DiffuserType> Diffusers { get; set; }

OnnxStack.StableDiffusion/Config/ModelOptions.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ public class ModelOptions : IModelOptions, IOnnxModelSetConfig
2121
public TokenizerType TokenizerType { get; set; }
2222
public int SampleSize { get; set; } = 512;
2323
public float ScaleFactor { get; set; }
24+
public ModelType ModelType { get; set; }
2425
public DiffuserPipelineType PipelineType { get; set; }
2526
public List<DiffuserType> Diffusers { get; set; } = new List<DiffuserType>();
2627

OnnxStack.StableDiffusion/Config/SchedulerOptions.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ public record SchedulerOptions
8787

8888
public int OriginalInferenceSteps { get; set; } = 50;
8989

90+
public float AestheticScore { get; set; } = 6f;
91+
public float AestheticNegativeScore { get; set; } = 2.5f;
92+
9093
public bool IsKarrasScheduler
9194
{
9295
get

OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/InpaintLegacyDiffuser.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
7171
// Get Model metadata
7272
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet);
7373

74+
// Get Time ids
75+
var addTimeIds = GetAddTimeIds(modelOptions, schedulerOptions, performGuidance);
76+
7477
// Loop though the timesteps
7578
var step = 0;
7679
foreach (var timestep in timesteps)
@@ -83,7 +86,6 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
8386
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
8487
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
8588
var timestepTensor = CreateTimestepTensor(timestep);
86-
var addTimeIds = GetAddTimeIds(schedulerOptions, performGuidance);
8789

8890
var outputChannels = performGuidance ? 2 : 1;
8991
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);

OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/StableDiffusionXLDiffuser.cs

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
5959
// Get Model metadata
6060
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet);
6161

62+
// Get Time ids
63+
var addTimeIds = GetAddTimeIds(modelOptions, schedulerOptions, performGuidance);
64+
6265
// Loop though the timesteps
6366
var step = 0;
6467
foreach (var timestep in timesteps)
@@ -71,7 +74,6 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
7174
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
7275
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
7376
var timestepTensor = CreateTimestepTensor(timestep);
74-
var addTimeIds = GetAddTimeIds(schedulerOptions, performGuidance);
7577

7678
var outputChannels = performGuidance ? 2 : 1;
7779
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
@@ -113,19 +115,27 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
113115
/// </summary>
114116
/// <param name="schedulerOptions">The scheduler options.</param>
115117
/// <returns></returns>
116-
protected DenseTensor<float> GetAddTimeIds(SchedulerOptions schedulerOptions, bool performGuidance)
118+
protected DenseTensor<float> GetAddTimeIds(IModelOptions model, SchedulerOptions schedulerOptions, bool performGuidance)
117119
{
118-
var addTimeIds = new float[]
120+
float[] result;
121+
if (model.ModelType == ModelType.Refiner)
119122
{
120-
schedulerOptions.Height, schedulerOptions.Width, //original_size
121-
0, 0, //crops_coords_top_left
122-
schedulerOptions.Height, schedulerOptions.Width //negative_target_size
123-
};
124-
var result = TensorHelper.CreateTensor(addTimeIds, new[] { 1, addTimeIds.Length });
125-
if (performGuidance)
126-
return result.Repeat(2);
123+
//original_size + crops_coords_top_left + aesthetic_score
124+
//original_size + crops_coords_top_left + negative_aesthetic_score
125+
result = !performGuidance
126+
? new float[] { schedulerOptions.Height, schedulerOptions.Width, 0, 0, schedulerOptions.AestheticScore }
127+
: new float[] { schedulerOptions.Height, schedulerOptions.Width, 0, 0, schedulerOptions.AestheticNegativeScore, schedulerOptions.Height, schedulerOptions.Width, 0, 0, schedulerOptions.AestheticScore };
128+
}
129+
else
130+
{
131+
//original_size + crops_coords_top_left + target_size
132+
//original_size + crops_coords_top_left + negative_target_size
133+
result = !performGuidance
134+
? new float[] { schedulerOptions.Height, schedulerOptions.Width, 0, 0, schedulerOptions.Height, schedulerOptions.Width }
135+
: new float[] { schedulerOptions.Height, schedulerOptions.Width, 0, 0, schedulerOptions.Height, schedulerOptions.Width, schedulerOptions.Height, schedulerOptions.Width, 0, 0, schedulerOptions.Height, schedulerOptions.Width };
136+
}
127137

128-
return result;
138+
return TensorHelper.CreateTensor(result, new[] { 1, result.Length });
129139
}
130140

131141

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
namespace OnnxStack.StableDiffusion.Enums
2+
{
3+
public enum ModelType
4+
{
5+
Base = 0,
6+
Refiner = 1
7+
}
8+
}

OnnxStack.UI/Models/ModelConfigTemplate.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ public class ModelConfigTemplate
2323
public List<DiffuserType> Diffusers { get; set; } = new List<DiffuserType>();
2424
public List<string> ModelFiles { get; set; } = new List<string>();
2525
public List<string> Images { get; set; } = new List<string>();
26+
public ModelType ModelType { get; set; }
2627
}
2728

2829
}

OnnxStack.UI/Models/ModelSetViewModel.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ public class ModelSetViewModel : INotifyPropertyChanged
3939
private int _dualEmbeddingsLength;
4040
private TokenizerType _tokenizerType;
4141
private int _sampleSize;
42+
private ModelType _modelType;
4243

4344
public string Name
4445
{
@@ -215,6 +216,11 @@ public bool HasChanged
215216
set { _hasChanged = value; NotifyPropertyChanged(); }
216217
}
217218

219+
public ModelType ModelType
220+
{
221+
get { return _modelType; }
222+
set { _modelType = value; NotifyPropertyChanged(); }
223+
}
218224

219225
public IEnumerable<DiffuserType> GetDiffusers()
220226
{

OnnxStack.UI/Views/ModelView.xaml.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,7 @@ private ModelSetViewModel CreateViewModel(ModelConfigTemplate modelTemplate)
944944
ScaleFactor = modelTemplate.ScaleFactor,
945945
TokenizerLimit = modelTemplate.TokenizerLimit,
946946
SampleSize = modelTemplate.SampleSize,
947+
ModelType = modelTemplate.ModelType,
947948
PipelineType = modelTemplate.PipelineType,
948949
EnableTextToImage = modelTemplate.Diffusers.Contains(DiffuserType.TextToImage),
949950
EnableImageToImage = modelTemplate.Diffusers.Contains(DiffuserType.ImageToImage),
@@ -969,6 +970,7 @@ private ModelSetViewModel CreateViewModel(ModelConfigTemplate modelTemplate)
969970
ModelFiles = modelTemplate.ModelFiles.ToList(),
970971
Repository = modelTemplate.Repository,
971972
SampleSize = modelTemplate.SampleSize,
973+
ModelType = modelTemplate.ModelType,
972974
Status = ModelTemplateStatus.Installed
973975
}
974976
};
@@ -1001,6 +1003,7 @@ private ModelSetViewModel CreateViewModel(ModelOptions modelOptions)
10011003
PadTokenId = modelOptions.PadTokenId,
10021004
ScaleFactor = modelOptions.ScaleFactor,
10031005
SampleSize = modelOptions.SampleSize,
1006+
ModelType = modelOptions.ModelType,
10041007
TokenizerLimit = modelOptions.TokenizerLimit,
10051008
PipelineType = modelOptions.PipelineType,
10061009
EnableTextToImage = modelOptions.Diffusers.Contains(DiffuserType.TextToImage),
@@ -1036,6 +1039,7 @@ private ModelSetViewModel CreateViewModel(ModelOptions modelOptions)
10361039
TokenizerLength = modelOptions.TokenizerLength,
10371040
Tokenizer2Length = modelOptions.Tokenizer2Length,
10381041
SampleSize = modelOptions.SampleSize,
1042+
ModelType = modelOptions.ModelType,
10391043
Description = "",
10401044
Diffusers = modelOptions.Diffusers,
10411045
ImageIcon = "",
@@ -1074,6 +1078,7 @@ private ModelOptions CreateModelOptions(ModelSetViewModel editModel)
10741078
PipelineType = editModel.PipelineType,
10751079
Diffusers = new List<DiffuserType>(editModel.GetDiffusers()),
10761080
SampleSize = editModel.SampleSize,
1081+
ModelType = editModel.ModelType,
10771082
ModelConfigurations = new List<OnnxModelSessionConfig>(editModel.ModelFiles.Select(x => new OnnxModelSessionConfig
10781083
{
10791084
Type = x.Type,

0 commit comments

Comments
 (0)