Skip to content

Commit 08bb9a4

Browse files
author
DvirDukhan
committed
Merge pull request #593 from RedisAI/onnx_support_threads
Support threading in onnx config
1 parent b3f38ed commit 08bb9a4

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

src/backends/onnxruntime.c

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,6 @@ OrtValue *RAI_OrtValueFromTensors(RAI_Tensor **ts, size_t count, RAI_Error *erro
9191
return NULL;
9292
}
9393

94-
if (count == 0) {
95-
return NULL;
96-
}
97-
9894
size_t batch_size = 0;
9995
size_t batch_byte_size = 0;
10096

@@ -328,6 +324,9 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
328324
ort->ReleaseSessionOptions(session_options);
329325
goto error;
330326
}
327+
ort->SetIntraOpNumThreads(session_options, (int)opts.backends_intra_op_parallelism);
328+
ort->SetInterOpNumThreads(session_options, (int)opts.backends_inter_op_parallelism);
329+
331330
// TODO: we will need to propose a more dynamic way to request a specific provider,
332331
// e.g. given the name, in ONNXRuntime
333332
#if RAI_ONNXRUNTIME_USE_CUDA
@@ -344,7 +343,6 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
344343
#endif
345344

346345
OrtSession *session;
347-
348346
status = ort->CreateSessionFromArray(env, modeldef, modellen, session_options, &session);
349347

350348
ort->ReleaseSessionOptions(session_options);

tests/flow/tests_onnx.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import subprocess
44
import redis
55
from includes import *
6+
from RLTest import Env
67

78
'''
89
python -m RLTest --test tests_onnx.py --module path/to/redisai.so
@@ -413,3 +414,40 @@ def tests_onnx_info(env):
413414
ret = con.execute_command('AI.INFO')
414415
env.assertEqual(8, len(ret))
415416
env.assertEqual(b'ONNX version', ret[6])
417+
418+
419+
def test_parallelism():
420+
env = Env(moduleArgs='INTRA_OP_PARALLELISM 1 INTER_OP_PARALLELISM 1')
421+
if not TEST_ONNX:
422+
env.debugPrint("skipping {} since TEST_ONNX=0".format(sys._getframe().f_code.co_name), force=True)
423+
return
424+
425+
con = env.getConnection()
426+
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
427+
model_filename = os.path.join(test_data_path, 'mnist.onnx')
428+
sample_filename = os.path.join(test_data_path, 'one.raw')
429+
with open(model_filename, 'rb') as f:
430+
model_pb = f.read()
431+
with open(sample_filename, 'rb') as f:
432+
sample_raw = f.read()
433+
434+
ret = con.execute_command('AI.MODELSET', 'm{1}', 'ONNX', DEVICE, 'BLOB', model_pb)
435+
env.assertEqual(ret, b'OK')
436+
con.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw)
437+
438+
con.execute_command('AI.MODELRUN', 'm{1}', 'INPUTS', 'a{1}', 'OUTPUTS', 'b{1}')
439+
ensureSlaveSynced(con, env)
440+
values = con.execute_command('AI.TENSORGET', 'b{1}', 'VALUES')
441+
argmax = max(range(len(values)), key=lambda i: values[i])
442+
env.assertEqual(argmax, 1)
443+
444+
load_time_config = {k.split(":")[0]: k.split(":")[1]
445+
for k in con.execute_command("INFO MODULES").decode().split("#")[3].split()[1:]}
446+
env.assertEqual(load_time_config["ai_inter_op_parallelism"], "1")
447+
env.assertEqual(load_time_config["ai_intra_op_parallelism"], "1")
448+
449+
env = Env(moduleArgs='INTRA_OP_PARALLELISM 2 INTER_OP_PARALLELISM 2')
450+
load_time_config = {k.split(":")[0]: k.split(":")[1]
451+
for k in con.execute_command("INFO MODULES").decode().split("#")[3].split()[1:]}
452+
env.assertEqual(load_time_config["ai_inter_op_parallelism"], "2")
453+
env.assertEqual(load_time_config["ai_intra_op_parallelism"], "2")

0 commit comments

Comments
 (0)