diff --git a/cinn/common/target.cc b/cinn/common/target.cc index a250a6f1c1..b5a28ab1e0 100644 --- a/cinn/common/target.cc +++ b/cinn/common/target.cc @@ -11,13 +11,16 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - -#include "cinn/common/target.h" +#ifdef CINN_WITH_CUDA +#include +#include +#endif #include #include +#include "cinn/common/target.h" #include "cinn/runtime/cinn_runtime.h" #ifdef CINN_WITH_CUDA @@ -54,6 +57,24 @@ int Target::max_num_threads() const { return 1024; } +int Target::get_multi_processor_count() const { + CHECK(arch == Arch::NVGPU) << "The target is not NVGPU! Cannot get multi processor count"; + int num_sm = 0; +#ifdef CINN_WITH_CUDA + cudaDeviceGetAttribute(&num_sm, cudaDeviceAttr::cudaDevAttrMultiProcessorCount, 0); +#endif + return num_sm; +} + +int Target::get_max_threads_per_sm() const { + CHECK(arch == Arch::NVGPU) << "The target is not NVGPU! Cannot get max threads per stream processor"; + int max_thread = 0; +#ifdef CINN_WITH_CUDA + cudaDeviceGetAttribute(&max_thread, cudaDeviceAttr::cudaDevAttrMaxThreadsPerMultiProcessor, 0); +#endif + return max_thread; +} + std::vector Target::get_target_libs() const { return libs; } int Target::get_target_bits() const { diff --git a/cinn/common/target.h b/cinn/common/target.h index 58b20f4b24..97e8e5b70d 100755 --- a/cinn/common/target.h +++ b/cinn/common/target.h @@ -80,6 +80,10 @@ struct Target { int max_num_threads() const; + int get_multi_processor_count() const; + + int get_max_threads_per_sm() const; + int get_target_bits() const; std::vector get_target_libs() const; diff --git a/cinn/frontend/net_builder.cc b/cinn/frontend/net_builder.cc index ad1069a6e5..23993c896d 100644 --- a/cinn/frontend/net_builder.cc +++ b/cinn/frontend/net_builder.cc @@ -117,7 +117,14 @@ Variable NetBuilder::Reduce(const std::string& op_type, const Variable& x, const return Reshape(x, new_shape); } } - return CustomInstr(op_type, {x}, {{"dim", dim}, {"keep_dim", keep_dim}}).front(); + // Convert the negative dim to a positive number + std::vector reduce_dim(dim.begin(), dim.end()); + for (int i = 0; i < dim.size(); i++) { + if (reduce_dim[i] < 0) { + reduce_dim[i] = x->shape.size() + reduce_dim[i]; + } + } + return CustomInstr(op_type, {x}, {{"dim", reduce_dim}, {"keep_dim", keep_dim}}).front(); } #define NETBUILDER_UNARY_OP_DEF(func_name__, op_type__) \ diff --git a/cinn/hlir/framework/op_lowering_test.cc b/cinn/hlir/framework/op_lowering_test.cc index 747dcf2e37..55a5a2fd44 100644 --- a/cinn/hlir/framework/op_lowering_test.cc +++ b/cinn/hlir/framework/op_lowering_test.cc @@ -1194,6 +1194,75 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_21) { } */ +TEST(OpFusionPass, Block_Reduce_Fuse_Broadcast) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + int h = warp_reduce_threshold - 10; + int w = 256; + NetBuilder net_builder("Block_Reduce_Fuse_Broadcast"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.ReduceSum(A, {1}, true); + auto C = net_builder.BroadcastTo(B, {h, w}, {0, 1}); + } + + Compile(net_builder); +} + +TEST(OpFusionPass, Block_Reduce_Fuse_Elementwise) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + int h = warp_reduce_threshold - 10; + int w = 256; + NetBuilder net_builder("Block_Reduce_Fuse_Elementwise"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h}, "B"); + auto C = net_builder.ReduceSum(A, {1}, true); + auto D = net_builder.Add(B, C); + } + + Compile(net_builder); +} +TEST(OpFusionPass, Warp_Reduce_Fuse_Broadcast) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + int h = warp_reduce_threshold + 10; + int w = 256; + NetBuilder net_builder("Warp_Reduce_Fuse_Broadcast"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.ReduceSum(A, {1}, true); + auto C = net_builder.BroadcastTo(B, {h, w}, {0, 1}); + } + + Compile(net_builder); +} + +TEST(OpFusionPass, Warp_Reduce_Fuse_Elementwise) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + int h = warp_reduce_threshold + 10; + int w = 256; + NetBuilder net_builder("Warp_Reduce_Fuse_Elementwise"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h}, "B"); + auto C = net_builder.ReduceSum(A, {1}, true); + auto D = net_builder.Add(B, C); + } + + Compile(net_builder); +} + } // namespace framework } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/framework/op_lowering_util.cc b/cinn/hlir/framework/op_lowering_util.cc index ba993a219f..f5e0fb9ce8 100644 --- a/cinn/hlir/framework/op_lowering_util.cc +++ b/cinn/hlir/framework/op_lowering_util.cc @@ -631,10 +631,25 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, const std::vector& inshape, const std::vector& axes, const common::Target& target) { + // If the number of current device SM is smaller than the number of SM + // required by Warp Reduce, the performance of Warp Reduce is better. + // Otherwise, use Block Reduce. + auto max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); + int need_reduce_last_count = 1; + for (int i = 0; i < inshape.size(); i++) { + if (find(axes.begin(), axes.end(), i) == axes.end()) { + need_reduce_last_count *= inshape[i]; + } + } + int warp_reduce_need_sm_count = ceil((need_reduce_last_count * 32) / float(target.get_max_threads_per_sm())); + // Set Num_max_threads to 32 is Warp Reduce + if (target.get_multi_processor_count() < warp_reduce_need_sm_count) { + max_num_threads = 32; + } // find first reduce and second reduce axis. - int lane = 1; - int index = static_cast(axes.size()) - 1; - auto max_num_threads = target.max_num_threads(); + int lane = 1; + int index = static_cast(axes.size()) - 1; + for (; index >= 0; --index) { if (index + 1 < axes.size() && axes[index] != axes[index + 1] - 1) { break; diff --git a/cinn/hlir/op/reduction_test.cc b/cinn/hlir/op/reduction_test.cc index abfb767511..9b214f467f 100644 --- a/cinn/hlir/op/reduction_test.cc +++ b/cinn/hlir/op/reduction_test.cc @@ -509,6 +509,53 @@ TEST(Operator, Operator_Reduction_Case_11) { GenReduceCode(shape, dim, "Operator_Reduction_Case_11"); } +TEST(Operator, Operator_Reduction_Case_Warp_Reduce) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + + std::vector shape = {warp_reduce_threshold + 10, 256}; + std::vector dim = {1}; + + auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Warp_Reduce"); + CHECK(res.second.find("threadIdx.x < 32") != std::string::npos); +} + +TEST(Operator, Operator_Reduction_Case_Block_Reduce) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + + std::vector shape = {warp_reduce_threshold - 10, 33}; + std::vector dim = {1}; + + auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Block_Reduce"); + CHECK(res.second.find("threadIdx.x < 32") == std::string::npos); +} + +TEST(Operator, Operator_Reduction_Case_Warp_Reduce_Case_1) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + + std::vector shape = {(warp_reduce_threshold + 32) / 2, 2, 10, 256}; + std::vector dim = {2, 3}; + + auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Warp_Reduce_Case_1"); + CHECK(res.second.find("threadIdx.x < 32") != std::string::npos); +} + +TEST(Operator, Operator_Reduction_Case_Block_Reduce_Case_1) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + + std::vector shape = {(warp_reduce_threshold - 32) / 2, 2, 10, 33}; + std::vector dim = {2, 3}; + + auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Block_Reduce_Case_2"); + CHECK(res.second.find("threadIdx.x < 32") == std::string::npos); +} } // namespace framework } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/pe/reduction.cc b/cinn/hlir/pe/reduction.cc index 80ccb6ef15..9a1519620a 100644 --- a/cinn/hlir/pe/reduction.cc +++ b/cinn/hlir/pe/reduction.cc @@ -665,10 +665,25 @@ std::vector TwoStepBlockReduceInternal(const ir::Tensor& A, BlockReduceFunc block_reduce_func, ir::Expr initial) { CHECK(!WithoutLastDimInReduce(A->shape, axes)) << "Can't find last axis in reduce!"; + // If the number of current device SM is smaller than the number of SM + // required by Warp Reduce, the performance of Warp Reduce is better. + // Otherwise, use Block Reduce. + auto max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); + int need_reduce_last_count = 1; + for (int i = 0; i < A->shape.size(); i++) { + if (find(axes.begin(), axes.end(), i) == axes.end()) { + need_reduce_last_count *= A->shape[i].as_int32(); + } + } + int warp_reduce_need_sm_count = + ceil((need_reduce_last_count * 32) / float(common::DefaultNVGPUTarget().get_max_threads_per_sm())); + // Set Num_max_threads to 32 is Warp Reduce + if (common::DefaultNVGPUTarget().get_multi_processor_count() < warp_reduce_need_sm_count) { + max_num_threads = 32; + } - int lane = A->shape[axes.back()].as_int32(); - int index = static_cast(axes.size()) - 2; - auto max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); + int lane = A->shape[axes.back()].as_int32(); + int index = static_cast(axes.size()) - 2; for (; index >= 0; --index) { if (lane >= max_num_threads / 2) { break; diff --git a/cinn/optim/transform_gpu_forloop.cc b/cinn/optim/transform_gpu_forloop.cc index 47ba3b50ee..1b11659dab 100644 --- a/cinn/optim/transform_gpu_forloop.cc +++ b/cinn/optim/transform_gpu_forloop.cc @@ -474,12 +474,28 @@ class ResizeBufferSizeVisitor : public ir::IRMutator<> { auto &shape = store_tensor->shape; auto &buffer = store_tensor->buffer->shape; + int cnt = buffer.size() - indices.size(); + // for(int i =0 ;i::Visit(op, expr); } @@ -494,7 +510,32 @@ class ResizeBufferSizeVisitor : public ir::IRMutator<> { return; } + VLOG(-1) << load->tensor; + VLOG(-1) << load->tensor.as_tensor_ref()->shape; + VLOG(-1) << load->tensor.as_tensor_ref()->buffer->shape; + VLOG(-1) << load->indices.size(); + bool tmp_flag = false; + auto org_shape = load->tensor.as_tensor_ref()->shape; + if (load->tensor.as_tensor_ref()->shape != load->tensor.as_tensor_ref()->buffer->shape) { + tmp_flag = true; + } + load->tensor.as_tensor_ref()->shape = load->tensor.as_tensor_ref()->buffer->shape; + + if (tmp_flag) { + int cnt = load->indices.size() - load->tensor.as_tensor_ref()->shape.size(); + for (int i = 0; i < cnt; i++) { + auto &xx = load->tensor.as_tensor_ref()->shape; + xx.insert(xx.begin(), Expr(1)); + auto &yy = load->tensor.as_tensor_ref()->buffer->shape; + yy.insert(yy.begin(), Expr(1)); + } + } + VLOG(-1) << load->indices.size(); + for (auto t : load->indices) VLOG(-1) << t; + VLOG(-1) << load->tensor.as_tensor_ref()->shape; + VLOG(-1) << load->tensor.as_tensor_ref()->buffer->shape; + VLOG(-1) << load->tensor; ir::IRMutator<>::Visit(op, expr); } diff --git a/test_group_62.py b/test_group_62.py new file mode 100644 index 0000000000..d79070d6b6 --- /dev/null +++ b/test_group_62.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2023 CINN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Please set "export PYTHONPATH=${CINN_ROOT}/build/python:${PYTHONPATH}" first + +import unittest +import numpy as np +from cinn.frontend import NetBuilder +from cinn.common import DefaultNVGPUTarget + + +def random(shape, dtype="float32", low=0.0, high=1.0): + assert bool(shape), "Shape should not empty!" + assert -1 not in shape, "Shape should not -1!" + if dtype in ["float16", "float32", "float64"]: + return np.random.uniform(low, high, shape).astype(dtype) + elif dtype == "bool": + return np.random.choice(a=[False, True], size=shape).astype(dtype) + elif dtype in ["int8", "uint8", "int32", "int64"]: + return np.random.randint(low, high, shape).astype(dtype) + else: + raise Exception("Not supported yet.") + + +class TestGroup(unittest.TestCase): + def test_group(self): + builder = NetBuilder("group_test") + + var_1545 = builder.create_input( + type="float32", shape=[128, 12, 128, 128], id_hint="var_1545") + eager_in_tmp_2 = builder.create_input( + type="float32", shape=[128, 1, 1, 128], id_hint="eager_in_tmp_2") + + var_3713 = builder.broadcast_to( + eager_in_tmp_2, + broadcast_axes=[0, 1, 2, 3], + out_shape=[128, 12, 128, 128]) + var_1547 = builder.elementwise_add(var_1545, var_3713, axis=-1) + var_1549 = builder.reduce_max(var_1547, dim=[3], keep_dim=True) + var_5993 = builder.broadcast_to( + var_1549, + broadcast_axes=[0, 1, 2, 3], + out_shape=[128, 12, 128, 128]) + var_1551 = builder.subtract(var_1547, var_5993, axis=-1) + var_1553 = builder.exp(var_1551) + + feed_list = [var_1545, eager_in_tmp_2] + fetch_list = [var_1553] + + prog = builder.build() + + feed_data = [ + random(shape=var.shape(), dtype=var.type()) for var in feed_list + ] + result = prog.build_and_get_output(DefaultNVGPUTarget(), feed_list, + feed_data, fetch_list) + + # result = [res.numpy(DefaultNVGPUTarget()) for res in result] + # for i in range(len(result)): + # info_str = fetch_list[i].name() + # info_str += ", shape=" + str(result[i].shape) + # info_str += ", dtype=" + str(result[i].dtype) + ":\n" + # info_str += str(result[i]) + # print(info_str) + + +if __name__ == "__main__": + import os + PID = os.getpid() + print('Program pid:', PID) + print('Pause here to enter DBG') + # input("read") + unittest.main()