Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit f335cd1

Browse files
committed
Simplify OnnxSession configuration
1 parent 1274349 commit f335cd1

File tree

14 files changed

+121
-120
lines changed

14 files changed

+121
-120
lines changed

OnnxStack.Console/appsettings.json

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,66 +10,42 @@
1010
"OnnxModelSets": [
1111
{
1212
"Name": "StableDiffusion 1.5",
13+
"IsEnabled": true,
1314
"PadTokenId": 49407,
1415
"BlankTokenId": 49407,
1516
"InputTokenLimit": 512,
1617
"TokenizerLimit": 77,
1718
"EmbeddingsLength": 768,
1819
"ScaleFactor": 0.18215,
20+
"DeviceId": 0,
21+
"InterOpNumThreads": 0,
22+
"IntraOpNumThreads": 0,
23+
"ExecutionMode": "ORT_PARALLEL",
24+
"ExecutionProvider": "DirectML",
1925
"ModelConfigurations": [
2026
{
2127
"Type": "Unet",
22-
"DeviceId": 0,
23-
"InterOpNumThreads": 0,
24-
"IntraOpNumThreads": 0,
25-
"ExecutionMode": "ORT_PARALLEL",
26-
"ExecutionProvider": "DirectML",
2728
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\unet\\model.onnx"
2829
},
2930
{
3031
"Type": "Tokenizer",
31-
"DeviceId": 0,
32-
"InterOpNumThreads": 0,
33-
"IntraOpNumThreads": 0,
34-
"ExecutionMode": "ORT_PARALLEL",
35-
"ExecutionProvider": "DirectML",
3632
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\cliptokenizer.onnx"
3733
},
3834
{
3935
"Type": "TextEncoder",
40-
"DeviceId": 0,
41-
"InterOpNumThreads": 0,
42-
"IntraOpNumThreads": 0,
43-
"ExecutionMode": "ORT_PARALLEL",
44-
"ExecutionProvider": "DirectML",
4536
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\text_encoder\\model.onnx"
4637
},
4738
{
4839
"Type": "VaeEncoder",
49-
"DeviceId": 0,
50-
"InterOpNumThreads": 0,
51-
"IntraOpNumThreads": 0,
52-
"ExecutionMode": "ORT_PARALLEL",
53-
"ExecutionProvider": "DirectML",
5440
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\vae_encoder\\model.onnx"
5541
},
5642
{
5743
"Type": "VaeDecoder",
58-
"DeviceId": 0,
59-
"InterOpNumThreads": 0,
60-
"IntraOpNumThreads": 0,
61-
"ExecutionMode": "ORT_PARALLEL",
62-
"ExecutionProvider": "DirectML",
6344
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\vae_decoder\\model.onnx"
6445
},
6546
{
47+
"IsEnabled": false,
6648
"Type": "SafetyChecker",
67-
"IsDisabled": true,
68-
"DeviceId": 0,
69-
"InterOpNumThreads": 0,
70-
"IntraOpNumThreads": 0,
71-
"ExecutionMode": "ORT_PARALLEL",
72-
"ExecutionProvider": "DirectML",
7349
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\safety_checker\\model.onnx"
7450
}
7551
]
Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1-
using System.Collections.Generic;
1+
using Microsoft.ML.OnnxRuntime;
2+
using System.Collections.Generic;
23

34
namespace OnnxStack.Core.Config
45
{
56
public interface IOnnxModelSetConfig : IOnnxModel
67
{
8+
public int DeviceId { get; set; }
9+
public string OnnxModelPath { get; set; }
10+
public int InterOpNumThreads { get; set; }
11+
public int IntraOpNumThreads { get; set; }
12+
public ExecutionMode ExecutionMode { get; set; }
13+
public ExecutionProvider ExecutionProvider { get; set; }
714
List<OnnxModelSessionConfig> ModelConfigurations { get; set; }
815
}
916
}

OnnxStack.Core/Config/OnnxModelSessionConfig.cs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ namespace OnnxStack.Core.Config
55
public class OnnxModelSessionConfig
66
{
77
public OnnxModelType Type { get; set; }
8-
public bool IsDisabled { get; set; }
9-
public int DeviceId { get; set; }
108
public string OnnxModelPath { get; set; }
11-
public int InterOpNumThreads { get; set; }
12-
public int IntraOpNumThreads { get; set; }
13-
public ExecutionMode ExecutionMode { get; set; }
14-
public ExecutionProvider ExecutionProvider { get; set; }
9+
10+
public bool? IsEnabled { get; set; }
11+
public int? DeviceId { get; set; }
12+
public int? InterOpNumThreads { get; set; }
13+
public int? IntraOpNumThreads { get; set; }
14+
public ExecutionMode? ExecutionMode { get; set; }
15+
public ExecutionProvider? ExecutionProvider { get; set; }
16+
1517
}
1618
}
Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1-
using System.Collections.Generic;
1+
using Microsoft.ML.OnnxRuntime;
2+
using System.Collections.Generic;
23

34
namespace OnnxStack.Core.Config
45
{
56
public class OnnxModelSetConfig : IOnnxModelSetConfig
67
{
78
public string Name { get; set; }
9+
public bool IsEnabled { get; set; }
10+
public int DeviceId { get; set; }
11+
public string OnnxModelPath { get; set; }
12+
public int InterOpNumThreads { get; set; }
13+
public int IntraOpNumThreads { get; set; }
14+
public ExecutionMode ExecutionMode { get; set; }
15+
public ExecutionProvider ExecutionProvider { get; set; }
816
public List<OnnxModelSessionConfig> ModelConfigurations { get; set; }
917
}
1018
}

OnnxStack.Core/Config/OnnxStackConfig.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using OnnxStack.Common.Config;
22
using System.Collections.Generic;
3+
using System.Linq;
34

45
namespace OnnxStack.Core.Config
56
{
@@ -9,6 +10,13 @@ public class OnnxStackConfig : IConfigSection
910

1011
public void Initialize()
1112
{
13+
if (OnnxModelSets.IsNullOrEmpty())
14+
return;
15+
16+
foreach (var modelSet in OnnxModelSets)
17+
{
18+
modelSet.ApplyConfigurationOverrides();
19+
}
1220
}
1321
}
1422
}

OnnxStack.Core/Extensions.cs

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
using Microsoft.ML.OnnxRuntime;
22
using OnnxStack.Core.Config;
33
using System;
4+
using System.Collections.Concurrent;
45
using System.Collections.Generic;
56
using System.Linq;
7+
using System.Xml.Linq;
68

79
namespace OnnxStack.Core
810
{
@@ -12,16 +14,16 @@ public static SessionOptions GetSessionOptions(this OnnxModelSessionConfig confi
1214
{
1315
var sessionOptions = new SessionOptions
1416
{
15-
ExecutionMode = configuration.ExecutionMode,
16-
InterOpNumThreads = configuration.InterOpNumThreads,
17-
IntraOpNumThreads = configuration.InterOpNumThreads
17+
ExecutionMode = configuration.ExecutionMode.Value,
18+
InterOpNumThreads = configuration.InterOpNumThreads.Value,
19+
IntraOpNumThreads = configuration.IntraOpNumThreads.Value
1820
};
1921
switch (configuration.ExecutionProvider)
2022
{
2123
case ExecutionProvider.DirectML:
2224
sessionOptions.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL;
2325
sessionOptions.EnableMemoryPattern = false;
24-
sessionOptions.AppendExecutionProvider_DML(configuration.DeviceId);
26+
sessionOptions.AppendExecutionProvider_DML(configuration.DeviceId.Value);
2527
sessionOptions.AppendExecutionProvider_CPU();
2628
return sessionOptions;
2729
case ExecutionProvider.Cpu:
@@ -30,7 +32,7 @@ public static SessionOptions GetSessionOptions(this OnnxModelSessionConfig confi
3032
default:
3133
case ExecutionProvider.Cuda:
3234
sessionOptions.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL;
33-
sessionOptions.AppendExecutionProvider_CUDA(configuration.DeviceId);
35+
sessionOptions.AppendExecutionProvider_CUDA(configuration.DeviceId.Value);
3436
sessionOptions.AppendExecutionProvider_CPU();
3537
return sessionOptions;
3638
case ExecutionProvider.CoreML:
@@ -41,6 +43,25 @@ public static SessionOptions GetSessionOptions(this OnnxModelSessionConfig confi
4143
}
4244
}
4345

46+
/// <summary>
47+
/// Applies the configuration overrides.
48+
/// </summary>
49+
public static void ApplyConfigurationOverrides(this IOnnxModelSetConfig innxModelSetConfig)
50+
{
51+
if (innxModelSetConfig.ModelConfigurations.IsNullOrEmpty())
52+
return;
53+
54+
foreach (var modelConfig in innxModelSetConfig.ModelConfigurations)
55+
{
56+
modelConfig.IsEnabled = modelConfig.IsEnabled != false;
57+
modelConfig.DeviceId ??= innxModelSetConfig.DeviceId;
58+
modelConfig.ExecutionMode ??= innxModelSetConfig.ExecutionMode;
59+
modelConfig.InterOpNumThreads ??= innxModelSetConfig.InterOpNumThreads;
60+
modelConfig.IntraOpNumThreads ??= innxModelSetConfig.IntraOpNumThreads;
61+
modelConfig.ExecutionProvider ??= innxModelSetConfig.ExecutionProvider;
62+
}
63+
}
64+
4465
/// <summary>
4566
/// Determines whether the the source sequence is null or empty
4667
/// </summary>
@@ -139,5 +160,16 @@ public static int IndexOf<T>(this IReadOnlyList<T> list, T item) where T : IEqua
139160
}
140161

141162

163+
/// <summary>
164+
/// Converts to source IEnumerable to a ConcurrentDictionary.
165+
/// </summary>
166+
/// <param name="source">The source.</param>
167+
/// <param name="keySelector">The key selector.</param>
168+
/// <param name="elementSelector">The element selector.</param>
169+
/// <returns></returns>
170+
public static ConcurrentDictionary<T, U> ToConcurrentDictionary<S, T, U>(this IEnumerable<S> source, Func<S, T> keySelector, Func<S, U> elementSelector) where T : notnull
171+
{
172+
return new ConcurrentDictionary<T, U>(source.ToDictionary(keySelector, elementSelector));
173+
}
142174
}
143175
}

OnnxStack.Core/Model/OnnxModelSet.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public OnnxModelSet(IOnnxModelSetConfig configuration)
2121
_configuration = configuration;
2222
_prePackedWeightsContainer = new PrePackedWeightsContainer();
2323
_modelSessions = configuration.ModelConfigurations
24-
.Where(x => !x.IsDisabled)
24+
.Where(x => x.IsEnabled == true)
2525
.ToImmutableDictionary(
2626
modelConfig => modelConfig.Type,
2727
modelConfig => new OnnxModelSession(modelConfig, _prePackedWeightsContainer));

OnnxStack.Core/Services/OnnxModelService.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public OnnxModelService(OnnxStackConfig configuration)
2727
{
2828
_configuration = configuration;
2929
_onnxModelSets = new ConcurrentDictionary<string, OnnxModelSet>();
30-
_onnxModelSetConfigs = new ConcurrentDictionary<string, OnnxModelSetConfig>(configuration.OnnxModelSets.ToDictionary(x => x.Name, x => x));
30+
_onnxModelSetConfigs = _configuration.OnnxModelSets.ToConcurrentDictionary(x => x.Name, x => x);
3131
}
3232

3333

@@ -253,10 +253,12 @@ private OnnxModelSet LoadModelSet(IOnnxModel model)
253253
if (_onnxModelSets.ContainsKey(model.Name))
254254
return _onnxModelSets[model.Name];
255255

256-
if (!_onnxModelSetConfigs.ContainsKey(model.Name))
256+
if (!_onnxModelSetConfigs.TryGetValue(model.Name, out var modelSetConfig))
257257
throw new Exception($"Model {model.Name} not found in configuration");
258258

259-
var modelSetConfig = _onnxModelSetConfigs[model.Name];
259+
if (!modelSetConfig.IsEnabled)
260+
throw new Exception($"Model {model.Name} is not enabled");
261+
260262
var modelSet = new OnnxModelSet(modelSetConfig);
261263
_onnxModelSets.TryAdd(model.Name, modelSet);
262264
return modelSet;

OnnxStack.StableDiffusion/Common/IModelOptions.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ namespace OnnxStack.StableDiffusion.Common
55
{
66
public interface IModelOptions : IOnnxModel
77
{
8-
int BlankTokenId { get; set; }
9-
ImmutableArray<int> BlankTokenValueArray { get; set; }
10-
int EmbeddingsLength { get; set; }
11-
int InputTokenLimit { get; set; }
128
int PadTokenId { get; set; }
9+
int BlankTokenId { get; set; }
1310
float ScaleFactor { get; set; }
1411
int TokenizerLimit { get; set; }
12+
int InputTokenLimit { get; set; }
13+
int EmbeddingsLength { get; set; }
14+
ImmutableArray<int> BlankTokenValueArray { get; set; }
1515
}
1616
}
Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
using OnnxStack.StableDiffusion.Common;
1+
using Microsoft.ML.OnnxRuntime;
2+
using OnnxStack.Core.Config;
3+
using OnnxStack.StableDiffusion.Common;
4+
using System.Collections.Generic;
25
using System.Collections.Immutable;
36

47
namespace OnnxStack.StableDiffusion.Config
58
{
6-
public class ModelOptions : IModelOptions
9+
public class ModelOptions : IModelOptions, IOnnxModelSetConfig
710
{
811
public string Name { get; set; }
912
public int PadTokenId { get; set; }
@@ -13,5 +16,13 @@ public class ModelOptions : IModelOptions
1316
public int EmbeddingsLength { get; set; }
1417
public float ScaleFactor { get; set; }
1518
public ImmutableArray<int> BlankTokenValueArray { get; set; }
19+
20+
public int DeviceId { get; set; }
21+
public string OnnxModelPath { get; set; }
22+
public int InterOpNumThreads { get; set; }
23+
public int IntraOpNumThreads { get; set; }
24+
public ExecutionMode ExecutionMode { get; set; }
25+
public ExecutionProvider ExecutionProvider { get; set; }
26+
public List<OnnxModelSessionConfig> ModelConfigurations { get; set; }
1627
}
1728
}

0 commit comments

Comments
 (0)