Skip to content

Commit 44d1c68

Browse files
authored
Replicate results of run commands instead of verbatim (#157)
* Replicate results of run commands instead of verbatim * Remove leftover ReplicateVerbatim * Add --use-slaves to test invocation * Fix rebase leftover * Bump Redis version in ramp file
1 parent 30384e0 commit 44d1c68

File tree

6 files changed

+96
-49
lines changed

6 files changed

+96
-49
lines changed

opt/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ ifneq ($(NO_LFS),1)
134134
endif
135135
$(SHOW)set -e ;\
136136
cd $(ROOT)/test ;\
137-
python3 -m RLTest $(TEST_ARGS) --test basic_tests.py --module $(INSTALL_DIR)/redisai.so ;\
137+
python3 -m RLTest $(TEST_ARGS) --test basic_tests.py --module $(INSTALL_DIR)/redisai.so --use-slaves ;\
138138
python3 -m RLTest $(TEST_ARGS) --test double-panda.py --module $(INSTALL_DIR)/redisai.so
139139

140140
#----------------------------------------------------------------------------------------------

ramp.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ description: Serving tensors and executing deep learning graphs
55
homepage: https://oss.redislabs.com/redisai/
66
license: GNU Affero General Public License v3.0
77
command_line_args: ""
8-
min_redis_version: "5.0"
8+
min_redis_version: "5.0.7"
99
min_redis_pack_version: "5.4"
1010
capabilities:
1111
- types

src/redisai.c

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ int RedisAI_TensorGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
404404
}
405405
}
406406
else {
407-
assert(0);
407+
RedisModule_ReplyWithError(ctx, "ERR unsupported dtype");
408408
}
409409

410410
RedisModule_ReplyWithArray(ctx, ndims);
@@ -609,8 +609,6 @@ int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
609609

610610
RedisModule_ReplyWithSimpleString(ctx, "OK");
611611

612-
RedisModule_ReplicateVerbatim(ctx);
613-
614612
return REDISMODULE_OK;
615613
}
616614

@@ -772,6 +770,29 @@ void RedisAI_Disconnected(RedisModuleCtx *ctx, RedisModuleBlockedClient *bc) {
772770
RedisModule_Log(ctx, "warning", "Blocked client %p disconnected!", (void*)bc);
773771
}
774772

773+
void RedisAI_ReplicateTensorSet(RedisModuleCtx *ctx, RedisModuleString *key, RAI_Tensor *t) {
774+
long long ndims = RAI_TensorNumDims(t);
775+
776+
char *dtypestr = NULL;
777+
Tensor_DataTypeStr(RAI_TensorDataType(t), &dtypestr);
778+
779+
assert(dtypestr);
780+
781+
char *data = RAI_TensorData(t);
782+
long long size = RAI_TensorByteSize(t);
783+
784+
RedisModuleString* dims[ndims];
785+
786+
for (long long i=0; i<ndims; i++) {
787+
dims[i] = RedisModule_CreateStringFromLongLong(ctx, RAI_TensorDim(t, i));
788+
}
789+
790+
RedisModule_Replicate(ctx, "AI.TENSORSET", "scvcb", key, dtypestr,
791+
dims, ndims, "BLOB", data, size);
792+
793+
RedisModule_Free(dtypestr);
794+
}
795+
775796
int RedisAI_Run_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
776797
REDISMODULE_NOT_USED(argv);
777798
REDISMODULE_NOT_USED(argc);
@@ -817,6 +838,10 @@ int RedisAI_Run_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
817838
RedisModule_ModuleTypeSetValue(outkey, RedisAI_TensorType, RAI_TensorGetShallowCopy(t));
818839
}
819840
RedisModule_CloseKey(outkey);
841+
842+
if (t) {
843+
RedisAI_ReplicateTensorSet(ctx, rinfo->outkeys[i], t);
844+
}
820845
}
821846

822847
// FIXME This crashes Redis, we need to investigate.
@@ -981,7 +1006,8 @@ int RedisAI_ModelRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
9811006
// RedisAI_RunSession(rinfo);
9821007
// RedisAI_FreeRunInfo(ctx, rinfo);
9831008
// return RedisModule_ReplyWithSimpleString(ctx, "foo");
984-
RedisModule_ReplicateVerbatim(ctx);
1009+
1010+
// RedisModule_ReplicateVerbatim(ctx);
9851011

9861012
return REDISMODULE_OK;
9871013
}

src/tensor.c

Lines changed: 25 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,18 @@ static size_t Tensor_DataTypeSize(DLDataType dtype) {
4646
return dtype.bits / 8;
4747
}
4848

49-
static void Tensor_DataTypeStr(DLDataType dtype, char **dtypestr) {
49+
void Tensor_DataTypeStr(DLDataType dtype, char **dtypestr) {
5050
*dtypestr = RedisModule_Calloc(8, sizeof(char));
5151
if (dtype.code == kDLFloat) {
5252
if (dtype.bits == 32) {
53-
strcpy(*dtypestr, "FLOAT32");
53+
strcpy(*dtypestr, "FLOAT");
5454
}
5555
else if (dtype.bits == 64) {
56-
strcpy(*dtypestr, "FLOAT64");
56+
strcpy(*dtypestr, "DOUBLE");
57+
}
58+
else {
59+
RedisModule_Free(*dtypestr);
60+
*dtypestr = NULL;
5761
}
5862
}
5963
else if (dtype.code == kDLInt) {
@@ -69,6 +73,10 @@ static void Tensor_DataTypeStr(DLDataType dtype, char **dtypestr) {
6973
else if (dtype.bits == 64) {
7074
strcpy(*dtypestr, "INT64");
7175
}
76+
else {
77+
RedisModule_Free(*dtypestr);
78+
*dtypestr = NULL;
79+
}
7280
}
7381
else if (dtype.code == kDLUInt) {
7482
if (dtype.bits == 8) {
@@ -77,6 +85,10 @@ static void Tensor_DataTypeStr(DLDataType dtype, char **dtypestr) {
7785
else if (dtype.bits == 16) {
7886
strcpy(*dtypestr, "UINT16");
7987
}
88+
else {
89+
RedisModule_Free(*dtypestr);
90+
*dtypestr = NULL;
91+
}
8092
}
8193
}
8294

@@ -175,51 +187,21 @@ static void RAI_Tensor_AofRewrite(RedisModuleIO *aof, RedisModuleString *key, vo
175187
RAI_Tensor *tensor = (RAI_Tensor*)value;
176188

177189
char *dtypestr = NULL;
178-
179190
Tensor_DataTypeStr(RAI_TensorDataType(tensor), &dtypestr);
180191

181-
int64_t* shape = tensor->tensor.dl_tensor.shape;
182-
char* data = RAI_TensorData(tensor);
183-
size_t size = RAI_TensorByteSize(tensor);
192+
char *data = RAI_TensorData(tensor);
193+
long long size = RAI_TensorByteSize(tensor);
194+
195+
long long ndims = RAI_TensorNumDims(tensor);
196+
197+
RedisModuleString* dims[ndims];
184198

185-
// We switch over the dimensions of the tensor up to 7
186-
// The reason is that we don't have a way to pass a vector of long long to RedisModule_EmitAOF,
187-
// there's no format for it. Vector of strings is supported (format 'v').
188-
// This might change in the future, but it needs to change in redis/src/module.c
189-
190-
switch (RAI_TensorNumDims(tensor)) {
191-
case 1:
192-
RedisModule_EmitAOF(aof, "AI.TENSORSET", "sllcb",
193-
key, dtypestr, RAI_SPLICE_SHAPE_1(shape), "BLOB", data, size);
194-
break;
195-
case 2:
196-
RedisModule_EmitAOF(aof, "AI.TENSORSET", "slllcb",
197-
key, dtypestr, RAI_SPLICE_SHAPE_2(shape), "BLOB", data, size);
198-
break;
199-
case 3:
200-
RedisModule_EmitAOF(aof, "AI.TENSORSET", "sllllcb",
201-
key, dtypestr, RAI_SPLICE_SHAPE_3(shape), "BLOB", data, size);
202-
break;
203-
case 4:
204-
RedisModule_EmitAOF(aof, "AI.TENSORSET", "slllllcb",
205-
key, dtypestr, RAI_SPLICE_SHAPE_4(shape), "BLOB", data, size);
206-
break;
207-
case 5:
208-
RedisModule_EmitAOF(aof, "AI.TENSORSET", "sllllllcb",
209-
key, dtypestr, RAI_SPLICE_SHAPE_5(shape), "BLOB", data, size);
210-
break;
211-
case 6:
212-
RedisModule_EmitAOF(aof, "AI.TENSORSET", "slllllllcb",
213-
key, dtypestr, RAI_SPLICE_SHAPE_6(shape), "BLOB", data, size);
214-
break;
215-
case 7:
216-
RedisModule_EmitAOF(aof, "AI.TENSORSET", "sllllllllcb",
217-
key, dtypestr, RAI_SPLICE_SHAPE_7(shape), "BLOB", data, size);
218-
break;
219-
default:
220-
printf("ERR: AOF serialization supports tensors of dimension up to 7\n");
199+
for (long long i=0; i<ndims; i++) {
200+
dims[i] = RedisModule_CreateStringFromLongLong(RedisModule_GetContextFromIO(aof), RAI_TensorDim(tensor, i));
221201
}
222202

203+
RedisModule_EmitAOF(aof, "AI.TENSORSET", "scvcb", key, dtypestr, dims, ndims, "BLOB", data, size);
204+
223205
RedisModule_Free(dtypestr);
224206
}
225207

src/tensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ RAI_Tensor* RAI_TensorCreateFromDLTensor(DLManagedTensor* dl_tensor);
1414
size_t RAI_TensorLength(RAI_Tensor* t);
1515
size_t RAI_TensorGetDataSize(const char* dataTypeStr);
1616
DLDataType RAI_TensorDataType(RAI_Tensor* t);
17+
void Tensor_DataTypeStr(DLDataType dtype, char **dtypestr);
1718
void RAI_TensorFree(RAI_Tensor* t);
1819
int RAI_TensorSetData(RAI_Tensor* t, const char* data, size_t len);
1920
int RAI_TensorSetValueFromLongLong(RAI_Tensor* t, long long i, long long val);

test/basic_tests.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def test_set_tensor(env):
102102
exception = e
103103
env.assertEqual(type(exception), redis.exceptions.ResponseError)
104104

105+
time.sleep(0.1)
106+
105107
for _ in con.reloadingIterator():
106108
env.assertExists('x')
107109

@@ -243,6 +245,12 @@ def test_run_tf_model(env):
243245
values = tensor[-1]
244246
con.assertEqual(values, [b'4', b'9', b'4', b'9'])
245247

248+
if env.useSlaves:
249+
con2 = env.getSlaveConnection()
250+
time.sleep(0.1)
251+
tensor2 = con2.execute_command('AI.TENSORGET', 'c', 'VALUES')
252+
con.assertEqual(tensor2, tensor)
253+
246254
for _ in con.reloadingIterator():
247255
env.assertExists('m')
248256
env.assertExists('a')
@@ -344,6 +352,12 @@ def test_run_torch_model(env):
344352
values = tensor[-1]
345353
con.assertEqual(values, [b'4', b'6', b'4', b'6'])
346354

355+
if env.useSlaves:
356+
con2 = env.getSlaveConnection()
357+
time.sleep(0.1)
358+
tensor2 = con2.execute_command('AI.TENSORGET', 'c', 'VALUES')
359+
con.assertEqual(tensor2, tensor)
360+
347361
for _ in con.reloadingIterator():
348362
env.assertExists('m')
349363
env.assertExists('a')
@@ -450,6 +464,12 @@ def test_run_onnx_model(env):
450464

451465
env.assertEqual(argmax, 1)
452466

467+
if env.useSlaves:
468+
con2 = env.getSlaveConnection()
469+
time.sleep(0.1)
470+
tensor2 = con2.execute_command('AI.TENSORGET', 'b', 'VALUES')
471+
con.assertEqual(tensor2, tensor)
472+
453473
for _ in con.reloadingIterator():
454474
env.assertExists('m')
455475
env.assertExists('a')
@@ -490,6 +510,14 @@ def test_run_onnxml_model(env):
490510
env.assertEqual(float(linear_out[2][0]), -0.090524077415466309)
491511
env.assertEqual(logreg_out[2][0], 0)
492512

513+
if env.useSlaves:
514+
con2 = env.getSlaveConnection()
515+
time.sleep(0.1)
516+
linear_out2 = con2.execute_command('AI.TENSORGET', 'linear_out', 'VALUES')
517+
logreg_out2 = con2.execute_command('AI.TENSORGET', 'logreg_out', 'VALUES')
518+
env.assertEqual(linear_out, linear_out2)
519+
env.assertEqual(logreg_out, logreg_out2)
520+
493521
for _ in con.reloadingIterator():
494522
env.assertExists('linear')
495523
env.assertExists('logreg')
@@ -743,6 +771,8 @@ def test_set_correct_script(env):
743771

744772
env.execute_command('AI.SCRIPTSET', 'ket', 'CPU', script)
745773

774+
time.sleep(0.1)
775+
746776
for _ in env.reloadingIterator():
747777
env.assertExists('ket')
748778

@@ -807,6 +837,14 @@ def test_run_script(env):
807837
values = tensor[-1]
808838
env.assertEqual(values, [b'4', b'6', b'4', b'6'])
809839

840+
time.sleep(0.1)
841+
842+
if env.useSlaves:
843+
con2 = env.getSlaveConnection()
844+
time.sleep(0.1)
845+
tensor2 = con2.execute_command('AI.TENSORGET', 'c', 'VALUES')
846+
env.assertEqual(tensor2, tensor)
847+
810848
for _ in env.reloadingIterator():
811849
env.assertExists('ket')
812850
env.assertExists('a')

0 commit comments

Comments
 (0)