|
3 | 3 | import subprocess
|
4 | 4 | import redis
|
5 | 5 | from includes import *
|
| 6 | +from RLTest import Env |
6 | 7 |
|
7 | 8 | '''
|
8 | 9 | python -m RLTest --test tests_onnx.py --module path/to/redisai.so
|
@@ -413,3 +414,40 @@ def tests_onnx_info(env):
|
413 | 414 | ret = con.execute_command('AI.INFO')
|
414 | 415 | env.assertEqual(8, len(ret))
|
415 | 416 | env.assertEqual(b'ONNX version', ret[6])
|
| 417 | + |
| 418 | + |
| 419 | +def test_parallelism(): |
| 420 | + env = Env(moduleArgs='INTRA_OP_PARALLELISM 1 INTER_OP_PARALLELISM 1') |
| 421 | + if not TEST_ONNX: |
| 422 | + env.debugPrint("skipping {} since TEST_ONNX=0".format(sys._getframe().f_code.co_name), force=True) |
| 423 | + return |
| 424 | + |
| 425 | + con = env.getConnection() |
| 426 | + test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') |
| 427 | + model_filename = os.path.join(test_data_path, 'mnist.onnx') |
| 428 | + sample_filename = os.path.join(test_data_path, 'one.raw') |
| 429 | + with open(model_filename, 'rb') as f: |
| 430 | + model_pb = f.read() |
| 431 | + with open(sample_filename, 'rb') as f: |
| 432 | + sample_raw = f.read() |
| 433 | + |
| 434 | + ret = con.execute_command('AI.MODELSET', 'm{1}', 'ONNX', DEVICE, 'BLOB', model_pb) |
| 435 | + env.assertEqual(ret, b'OK') |
| 436 | + con.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw) |
| 437 | + |
| 438 | + con.execute_command('AI.MODELRUN', 'm{1}', 'INPUTS', 'a{1}', 'OUTPUTS', 'b{1}') |
| 439 | + ensureSlaveSynced(con, env) |
| 440 | + values = con.execute_command('AI.TENSORGET', 'b{1}', 'VALUES') |
| 441 | + argmax = max(range(len(values)), key=lambda i: values[i]) |
| 442 | + env.assertEqual(argmax, 1) |
| 443 | + |
| 444 | + load_time_config = {k.split(":")[0]: k.split(":")[1] |
| 445 | + for k in con.execute_command("INFO MODULES").decode().split("#")[3].split()[1:]} |
| 446 | + env.assertEqual(load_time_config["ai_inter_op_parallelism"], "1") |
| 447 | + env.assertEqual(load_time_config["ai_intra_op_parallelism"], "1") |
| 448 | + |
| 449 | + env = Env(moduleArgs='INTRA_OP_PARALLELISM 2 INTER_OP_PARALLELISM 2') |
| 450 | + load_time_config = {k.split(":")[0]: k.split(":")[1] |
| 451 | + for k in con.execute_command("INFO MODULES").decode().split("#")[3].split()[1:]} |
| 452 | + env.assertEqual(load_time_config["ai_inter_op_parallelism"], "2") |
| 453 | + env.assertEqual(load_time_config["ai_intra_op_parallelism"], "2") |
0 commit comments