Skip to content

Commit c0c2f2d

Browse files
Add LeakyReluFusion transformation (openvinotoolkit#6816)
1 parent 518ec79 commit c0c2f2d

File tree

4 files changed

+188
-0
lines changed

4 files changed

+188
-0
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (C) 2018-2021 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include <vector>
8+
#include <memory>
9+
10+
#include <transformations_visibility.hpp>
11+
12+
#include <ngraph/pass/graph_rewrite.hpp>
13+
14+
namespace ngraph {
15+
namespace pass {
16+
17+
class TRANSFORMATIONS_API LeakyReluFusion;
18+
19+
} // namespace pass
20+
} // namespace ngraph
21+
22+
/**
23+
* @ingroup ie_transformation_common_api
24+
* @brief LeakyReluFusion transformation replaces following graph:
25+
* Multiply->Maximum to LeakyRelu
26+
*/
27+
28+
class ngraph::pass::LeakyReluFusion: public ngraph::pass::MatcherPass {
29+
public:
30+
NGRAPH_RTTI_DECLARATION;
31+
LeakyReluFusion();
32+
};

inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "transformations/common_optimizations/swish_fusion.hpp"
2222
#include "transformations/common_optimizations/normalize_l2_fusion.hpp"
2323
#include "transformations/common_optimizations/pull_transpose_through_fq.hpp"
24+
#include "transformations/common_optimizations/leaky_relu_fusion.hpp"
2425
#include "transformations/common_optimizations/lin_op_sequence_fusion.hpp"
2526
#include "transformations/common_optimizations/remove_filtering_boxes_by_size.hpp"
2627
#include "transformations/common_optimizations/hsigmoid_fusion.hpp"
@@ -133,6 +134,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
133134
common_fusions->add_matcher<ngraph::pass::DilatedConvolutionConverter>();
134135
common_fusions->add_matcher<ngraph::pass::GeluFusion>();
135136
common_fusions->add_matcher<ngraph::pass::TransposeToReshape>();
137+
common_fusions->add_matcher<ngraph::pass::LeakyReluFusion>();
136138
common_fusions->set_name("ngraph::pass::CommonFusions");
137139

138140
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Copyright (C) 2018-2021 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "transformations/common_optimizations/leaky_relu_fusion.hpp"
6+
#include "transformations/utils/utils.hpp"
7+
8+
#include <memory>
9+
#include <vector>
10+
11+
#include <ngraph/opsets/opset8.hpp>
12+
#include <ngraph/rt_info.hpp>
13+
#include <ngraph/pattern/op/wrap_type.hpp>
14+
#include "itt.hpp"
15+
16+
17+
NGRAPH_RTTI_DEFINITION(ngraph::pass::LeakyReluFusion, "LeakyReluFusion", 0);
18+
19+
ngraph::pass::LeakyReluFusion::LeakyReluFusion() {
20+
MATCHER_SCOPE(LeakyReluFusion);
21+
auto data_pattern = ngraph::pattern::any_input();
22+
auto alpha_pattern = ngraph::pattern::any_input(pattern::has_static_shape());
23+
auto multiply_pattern = ngraph::pattern::wrap_type<opset8::Multiply>({data_pattern, alpha_pattern}, pattern::consumers_count(1));
24+
auto max_pattern = ngraph::pattern::wrap_type<opset8::Maximum>({data_pattern, multiply_pattern});
25+
26+
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
27+
auto pattern_map = m.get_pattern_value_map();
28+
auto data = pattern_map.at(data_pattern);
29+
const auto & original_alpha_pattern = pattern_map.at(alpha_pattern);
30+
31+
if (shape_size(original_alpha_pattern.get_shape()) != 1)
32+
return false;
33+
34+
auto leaky_relu = register_new_node<ngraph::opset8::PRelu>(data, original_alpha_pattern);
35+
auto maximum = pattern_map.at(max_pattern);
36+
leaky_relu->set_friendly_name(maximum.get_node()->get_friendly_name());
37+
38+
copy_runtime_info({
39+
pattern_map.at(multiply_pattern).get_node_shared_ptr(),
40+
maximum.get_node_shared_ptr()
41+
},
42+
leaky_relu);
43+
replace_node(maximum.get_node_shared_ptr(), leaky_relu);
44+
45+
return true;
46+
};
47+
48+
auto m = std::make_shared<ngraph::pattern::Matcher>(max_pattern, matcher_name);
49+
this->register_matcher(m, callback);
50+
}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// Copyright (C) 2018-2021 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include <gtest/gtest.h>
6+
7+
#include <string>
8+
#include <memory>
9+
#include <queue>
10+
11+
#include <ngraph/function.hpp>
12+
#include <ngraph/opsets/opset8.hpp>
13+
#include <transformations/common_optimizations/leaky_relu_fusion.hpp>
14+
#include <transformations/init_node_info.hpp>
15+
#include <transformations/utils/utils.hpp>
16+
#include <ngraph/pass/manager.hpp>
17+
#include <ngraph/pass/constant_folding.hpp>
18+
19+
#include "common_test_utils/ngraph_test_utils.hpp"
20+
21+
22+
using namespace testing;
23+
using namespace ngraph;
24+
25+
TEST(TransformationTests, LeakyReluFusionConstant) {
26+
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
27+
{
28+
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{2, 2});
29+
auto alpha = opset8::Constant::create(element::f32, Shape{1}, {0.1});
30+
auto multiply = std::make_shared<opset8::Multiply>(data, alpha);
31+
auto max = std::make_shared<opset8::Maximum>(data, multiply);
32+
f = std::make_shared<Function>(NodeVector{max}, ParameterVector{data});
33+
34+
pass::Manager m;
35+
m.register_pass<pass::InitNodeInfo>();
36+
m.register_pass<pass::LeakyReluFusion>();
37+
m.run_passes(f);
38+
ASSERT_NO_THROW(check_rt_info(f));
39+
}
40+
41+
{
42+
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 2});
43+
auto alpha = opset8::Constant::create(element::f32, Shape{1}, {0.1});
44+
auto leaky_relu = std::make_shared<opset8::PRelu>(data, alpha);
45+
f_ref = std::make_shared<Function>(NodeVector{leaky_relu}, ParameterVector{data});
46+
}
47+
48+
auto res = compare_functions(f, f_ref);
49+
ASSERT_TRUE(res.first) << res.second;
50+
}
51+
52+
TEST(TransformationTests, LeakyReluFusionScalar) {
53+
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
54+
{
55+
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{2, 2});
56+
auto alpha = opset8::Constant::create(element::f32, Shape{}, {0.1});
57+
auto multiply = std::make_shared<opset8::Multiply>(data, alpha);
58+
auto max = std::make_shared<opset8::Maximum>(data, multiply);
59+
f = std::make_shared<Function>(NodeVector{max}, ParameterVector{data});
60+
61+
pass::Manager m;
62+
m.register_pass<pass::InitNodeInfo>();
63+
m.register_pass<pass::LeakyReluFusion>();
64+
m.run_passes(f);
65+
ASSERT_NO_THROW(check_rt_info(f));
66+
}
67+
68+
{
69+
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 2});
70+
auto alpha = opset8::Constant::create(element::f32, Shape{}, {0.1});
71+
auto leaky_relu = std::make_shared<opset8::PRelu>(data, alpha);
72+
f_ref = std::make_shared<Function>(NodeVector{leaky_relu}, ParameterVector{data});
73+
}
74+
75+
auto res = compare_functions(f, f_ref);
76+
ASSERT_TRUE(res.first) << res.second;
77+
}
78+
79+
TEST(TransformationTests, LeakyReluFusionParameter) {
80+
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
81+
{
82+
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{2, 2});
83+
auto alpha = std::make_shared<opset8::Parameter>(element::f32, Shape{});
84+
auto multiply = std::make_shared<opset8::Multiply>(data, alpha);
85+
auto max = std::make_shared<opset8::Maximum>(data, multiply);
86+
f = std::make_shared<Function>(NodeVector{max}, ParameterVector{data, alpha});
87+
88+
pass::Manager m;
89+
m.register_pass<pass::InitNodeInfo>();
90+
m.register_pass<pass::LeakyReluFusion>();
91+
m.run_passes(f);
92+
ASSERT_NO_THROW(check_rt_info(f));
93+
}
94+
95+
{
96+
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 2});
97+
auto alpha = std::make_shared<opset8::Parameter>(element::f32, Shape{});
98+
auto leaky_relu = std::make_shared<opset8::PRelu>(data, alpha);
99+
f_ref = std::make_shared<Function>(NodeVector{leaky_relu}, ParameterVector{data, alpha});
100+
}
101+
102+
auto res = compare_functions(f, f_ref);
103+
ASSERT_TRUE(res.first) << res.second;
104+
}

0 commit comments

Comments
 (0)