Skip to content

Commit 4f4aadc

Browse files
authored
Update TensorSet and TensorGet in DAG upon execution (#811)
* Set status of dag tensorSet op at parse time (when it is executed). Also, set tensorGet op status at reply time. * Refactor validation of DAGCommand enum upon dag op execution
1 parent 5f51f1d commit 4f4aadc

File tree

5 files changed

+25
-43
lines changed

5 files changed

+25
-43
lines changed

src/execution/DAG/dag.c

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -480,31 +480,17 @@ void Dag_SetTensorInGlobalCtx(RedisAI_RunInfo *rinfo, size_t index, RAI_Tensor *
480480
void RedisAI_DagRunSessionStep(RedisAI_RunInfo *rinfo, const char *devicestr) {
481481
RAI_DagOp *currentOp = RedisAI_DagCurrentOp(rinfo);
482482

483-
switch (currentOp->commandType) {
484-
case REDISAI_DAG_CMD_TENSORSET: {
485-
// TENSORSET op is done in parsing stage (consider removing it from dag ops).
486-
currentOp->result = REDISMODULE_OK;
487-
break;
488-
}
489-
case REDISAI_DAG_CMD_TENSORGET: {
490-
// TENSORSET op is done when we finish (consider removing it from dag ops).
491-
currentOp->result = REDISMODULE_OK;
492-
break;
493-
}
494-
case REDISAI_DAG_CMD_MODELRUN: {
483+
// Verify that the op type belongs to the DAGCommand enum.
484+
VALIDATE_DAG_COMMAND(currentOp->commandType)
485+
486+
if (currentOp->commandType == REDISAI_DAG_CMD_MODELRUN) {
495487
RedisAI_DagRunSession_ModelRun_Step(rinfo, currentOp);
496-
break;
497-
}
498-
case REDISAI_DAG_CMD_SCRIPTRUN: {
488+
} else if (currentOp->commandType == REDISAI_DAG_CMD_SCRIPTRUN) {
499489
RedisAI_DagRunSession_ScriptRun_Step(rinfo, currentOp);
500-
break;
501-
}
502-
default: {
503-
/* unsupported DAG's command */
504-
RAI_SetError(currentOp->err, RAI_EDAGRUN, "ERR unsupported command within DAG");
505-
currentOp->result = REDISMODULE_ERR;
506-
break;
507-
}
490+
} else {
491+
// do nothing for tensorset (executed on parsing) and for tensorget (done
492+
// on dag reply).
493+
return;
508494
}
509495

510496
if (currentOp->result != REDISMODULE_OK) {
@@ -568,26 +554,21 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
568554

569555
for (size_t i = 0; i < n_dagOps; i++) {
570556
RAI_DagOp *currentOp = rinfo->dagOps[i];
557+
571558
switch (currentOp->commandType) {
572559
case REDISAI_DAG_CMD_TENSORSET: {
573560
rinfo->dagReplyLength++;
574-
if (currentOp->result == REDISMODULE_ERR) {
575-
RedisModule_ReplyWithError(ctx, currentOp->err->detail_oneline);
576-
dag_error = 1;
577-
} else if (currentOp->result == -1) {
578-
RedisModule_ReplyWithSimpleString(ctx, "NA");
579-
} else {
580-
RedisModule_ReplyWithSimpleString(ctx, "OK");
581-
}
561+
RedisModule_Assert(currentOp->result == REDISMODULE_OK);
562+
RedisModule_ReplyWithSimpleString(ctx, "OK");
582563
break;
583564
}
584565

585566
case REDISAI_DAG_CMD_TENSORGET: {
586567
rinfo->dagReplyLength++;
587-
if (currentOp->result == -1) {
568+
RAI_Tensor *t = Dag_GetTensorFromGlobalCtx(rinfo, currentOp->inkeys_indices[0]);
569+
if (t == NULL) {
588570
RedisModule_ReplyWithSimpleString(ctx, "NA");
589571
} else {
590-
RAI_Tensor *t = Dag_GetTensorFromGlobalCtx(rinfo, currentOp->inkeys_indices[0]);
591572
ReplyWithTensor(ctx, currentOp->fmt, t);
592573
}
593574
break;
@@ -636,8 +617,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
636617
break;
637618
}
638619
default:
639-
/* no-op */
640-
break;
620+
RedisModule_Assert(false && "Dag reply - invalid op");
641621
}
642622
}
643623
if (dag_error) {

src/execution/DAG/dag_op.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ typedef enum DAGCommand {
1313
REDISAI_DAG_CMD_SCRIPTRUN
1414
} DAGCommand;
1515

16+
#define VALIDATE_DAG_COMMAND(cmd) \
17+
RedisModule_Assert(cmd >= REDISAI_DAG_CMD_TENSORSET && cmd <= REDISAI_DAG_CMD_SCRIPTRUN);
18+
1619
typedef struct RAI_DagOp {
1720
DAGCommand commandType;
1821
RedisModuleString *runkey;

src/execution/parsing/dag_parser.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ int ParseDAGExecuteOps(RedisAI_RunInfo *rinfo, RAI_DagOp **ops, bool ro) {
190190
rinfo->err) == -1) {
191191
return REDISMODULE_ERR;
192192
}
193+
currentOp->result = REDISMODULE_OK;
193194
continue;
194195
}
195196
if (!strcasecmp(arg_string, "AI.MODELEXECUTE")) {

src/execution/parsing/deprecated.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,8 +599,10 @@ int ParseDAGRunOps(RedisAI_RunInfo *rinfo, RAI_DagOp **ops) {
599599
RAI_HoldString(currentOp->argv[1]);
600600
currentOp->outkeys = array_append(currentOp->outkeys, currentOp->argv[1]);
601601
if (RAI_parseTensorSetArgs(currentOp->argv, currentOp->argc, &currentOp->outTensor, 0,
602-
rinfo->err) == -1)
602+
rinfo->err) == -1) {
603603
goto cleanup;
604+
}
605+
currentOp->result = REDISMODULE_OK;
604606
continue;
605607
}
606608
if (!strcasecmp(arg_string, "AI.MODELRUN")) {

tests/flow/tests_sanitizer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,11 @@ def test_sanitizer_dagrun_mobilenet_v1(env):
3131
class_key = 'output{s}'
3232

3333
ret = con.execute_command(
34-
'AI.DAGRUN', '|>',
34+
'AI.DAGEXECUTE', 'ROUTING', '{s}', '|>',
3535
'AI.TENSORSET', image_key, 'FLOAT', 1, 224, 224, 3, 'BLOB', img.tobytes(),
3636
'|>',
37-
'AI.MODELRUN', model_name,
38-
'INPUTS', image_key,
39-
'OUTPUTS', class_key,
40-
'|>',
41-
'AI.TENSORGET', class_key, 'blob'
42-
)
37+
'AI.MODELEXECUTE', model_name, 'INPUTS', 1, image_key, 'OUTPUTS', 1, class_key,
38+
'|>', 'AI.TENSORGET', class_key, 'blob')
4339
env.assertEqual([b'OK', b'OK'], ret[:2])
4440
env.assertEqual(1001.0, len(ret[2])/4)
4541

0 commit comments

Comments
 (0)