Skip to content

Commit 777fe79

Browse files
authored
[WebNN EP] Support Sign and CumSum operators (microsoft#22616)
This PR supports Sign and CumSum operators for WebNN EP. @Honry @fdwr PTAL, thanks.
1 parent ac6fe48 commit 777fe79

File tree

7 files changed

+101
-8
lines changed

7 files changed

+101
-8
lines changed

js/web/docs/webnn-operators.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
2525
| Conv | ai.onnx(7-10, 11+) | conv2d ||| Only supports 3-D or 4-D input and 'W' (weight) |
2626
| ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d ||| Only supports 3-D or 4-D input and 'W' (weight). WebNN CPU backend only supports default dilations and group |
2727
| Cos | ai.onnx(7+) | cos ||| |
28+
| CumSum | ai.onnx(11-13, 14+) | cumulativeSum ||| |
2829
| Div | ai.onnx(7-12, 13, 14+) | div ||| |
2930
| DequantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | dequantizeLinear ||| |
3031
| Dropout | ai.onnx(7-9, 10-11, 12, 13-21, 22+) | identity ||| Only supports test mode |
@@ -87,6 +88,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
8788
| ScatterND | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterND ||| Only supports 'reduction' == 'none' |
8889
| Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice ||| |
8990
| Sigmoid | ai.onnx(7-12, 13+) | sigmoid ||| |
91+
| Sign | ai.onnx(9-12, 13+) | sign ||| |
9092
| Softplus | ai.onnx(7+) | softplus ||| |
9193
| Softsign | ai.onnx(7+) | softsign ||| |
9294
| Sin | ai.onnx(7+) | sin ||| |

js/web/test/suite-test-list.jsonc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1699,13 +1699,13 @@
16991699
"test_cos",
17001700
// "test_cosh_example",
17011701
// "test_cosh",
1702-
// "test_cumsum_1d_exclusive",
1703-
// "test_cumsum_1d_reverse_exclusive",
1704-
// "test_cumsum_1d_reverse",
1705-
// "test_cumsum_1d",
1706-
// "test_cumsum_2d_axis_0",
1707-
// "test_cumsum_2d_axis_1",
1708-
// "test_cumsum_2d_negative_axis",
1702+
"test_cumsum_1d_exclusive",
1703+
"test_cumsum_1d_reverse_exclusive",
1704+
"test_cumsum_1d_reverse",
1705+
"test_cumsum_1d",
1706+
"test_cumsum_2d_axis_0",
1707+
"test_cumsum_2d_axis_1",
1708+
"test_cumsum_2d_negative_axis",
17091709
// "test_depthtospace_crd_mode_example",
17101710
// "test_depthtospace_crd_mode",
17111711
// "test_depthtospace_dcr_mode",
@@ -2352,7 +2352,7 @@
23522352
// "test_shrink_soft",
23532353
"test_sigmoid_example",
23542354
"test_sigmoid",
2355-
// "test_sign",
2355+
"test_sign",
23562356
// "test_simple_rnn_batchwise",
23572357
// "test_simple_rnn_defaults",
23582358
// "test_simple_rnn_with_initial_bias",

onnxruntime/core/providers/webnn/builders/helper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
204204
{"ConvInteger", "conv2dInteger"},
205205
{"ConvTranspose", "convTranspose2d"},
206206
{"Cos", "cos"},
207+
{"CumSum", "cumulativeSum"},
207208
{"Div", "div"},
208209
{"DequantizeLinear", "dequantizeLinear"},
209210
{"Dropout", "identity"},
@@ -268,6 +269,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
268269
{"ScatterND", "scatterND"},
269270
{"Shape", "slice"},
270271
{"Sigmoid", "sigmoid"},
272+
{"Sign", "sign"},
271273
{"Softplus", "softplus"},
272274
{"Softsign", "softsign"},
273275
{"Sin", "sin"},
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Copyright (c) Intel Corporation. All rights reserved.
3+
// Licensed under the MIT License.
4+
5+
#include "core/common/safeint.h"
6+
#include "core/framework/tensorprotoutils.h"
7+
#include "core/optimizer/initializer.h"
8+
#include "core/providers/common.h"
9+
#include "core/providers/shared/utils/utils.h"
10+
#include "core/providers/webnn/builders/helper.h"
11+
#include "core/providers/webnn/builders/model_builder.h"
12+
#include "core/providers/webnn/builders/op_builder_factory.h"
13+
14+
#include "base_op_builder.h"
15+
16+
namespace onnxruntime {
17+
namespace webnn {
18+
19+
class CumSumOpBuilder : public BaseOpBuilder {
20+
// Add operator related.
21+
22+
private:
23+
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
24+
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
25+
26+
// Operator support related.
27+
private:
28+
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
29+
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
30+
};
31+
32+
// Add operator related.
33+
Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
34+
const Node& node,
35+
const logging::Logger& logger) const {
36+
const auto& input_defs = node.InputDefs();
37+
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
38+
std::vector<int64_t> input_shape;
39+
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape");
40+
const auto input_rank = input_shape.size();
41+
42+
NodeAttrHelper helper(node);
43+
int64_t axis = helper.Get("axis", 0);
44+
axis = HandleNegativeAxis(axis, input_rank);
45+
46+
const auto exclusive = helper.Get("exclusive", 0);
47+
const auto reverse = helper.Get("reverse", 0);
48+
49+
emscripten::val options = emscripten::val::object();
50+
options.set("exclusive", exclusive == 1);
51+
options.set("reversed", reverse == 1);
52+
options.set("label", node.Name());
53+
54+
emscripten::val output = emscripten::val::object();
55+
output = model_builder.GetBuilder().call<emscripten::val>("cumulativeSum", input, gsl::narrow<uint32_t>(axis), options);
56+
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
57+
return Status::OK();
58+
}
59+
60+
// Operator support related.
61+
bool CumSumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
62+
const Node& node,
63+
WebnnDeviceType /* device_type */,
64+
const logging::Logger& logger) const {
65+
const auto& input_defs = node.InputDefs();
66+
67+
std::vector<int64_t> input_shape;
68+
if (!GetShape(*input_defs[0], input_shape, logger))
69+
return false;
70+
71+
return true;
72+
}
73+
74+
void CreateCumSumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
75+
op_registrations.builders.push_back(std::make_unique<CumSumOpBuilder>());
76+
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
77+
}
78+
79+
} // namespace webnn
80+
} // namespace onnxruntime

onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
5151
output = model_builder.GetBuilder().call<emscripten::val>("neg", input, options);
5252
} else if (op_type == "Reciprocal") {
5353
output = model_builder.GetBuilder().call<emscripten::val>("reciprocal", input, options);
54+
} else if (op_type == "Sign") {
55+
output = model_builder.GetBuilder().call<emscripten::val>("sign", input, options);
5456
} else if (op_type == "Sin") {
5557
output = model_builder.GetBuilder().call<emscripten::val>("sin", input, options);
5658
} else if (op_type == "Sqrt") {
@@ -82,6 +84,7 @@ void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op
8284
"Log",
8385
"Neg",
8486
"Reciprocal",
87+
"Sign",
8588
"Sin",
8689
"Sqrt",
8790
"Tan",

onnxruntime/core/providers/webnn/builders/op_builder_factory.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
2626
CreateUnaryOpBuilder("Log", op_registrations);
2727
CreateUnaryOpBuilder("Neg", op_registrations);
2828
CreateUnaryOpBuilder("Reciprocal", op_registrations);
29+
CreateUnaryOpBuilder("Sign", op_registrations);
2930
CreateUnaryOpBuilder("Sin", op_registrations);
3031
CreateUnaryOpBuilder("Sqrt", op_registrations);
3132
CreateUnaryOpBuilder("Tan", op_registrations);
@@ -80,6 +81,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
8081
CreateConcatOpBuilder("Concat", op_registrations);
8182
}
8283

84+
{ // CumSum
85+
CreateConcatOpBuilder("CumSum", op_registrations);
86+
}
87+
8388
{ // Dropout
8489
CreateDropoutOpBuilder("Dropout", op_registrations);
8590
}

onnxruntime/core/providers/webnn/builders/op_builder_factory.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_
2626
void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
2727
void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
2828
void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
29+
void CreateCumSumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
2930
void CreateDropoutOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3031
void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3132
void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

0 commit comments

Comments
 (0)