diff --git a/impl/ascend/ascend_tensor.hpp b/impl/ascend/ascend_tensor.hpp index 87b6d4672..46e78ed00 100644 --- a/impl/ascend/ascend_tensor.hpp +++ b/impl/ascend/ascend_tensor.hpp @@ -233,7 +233,7 @@ class AscendTensor final { return tensor_; } - bool isSame(const AscendTensor& t) const { return this->tensor_ == t.tensor_; } + bool isSame(const AscendTensor& t) const { return this->tensorHandle() == t.tensorHandle(); } int64_t getAclMemBufferSize() const; std::vector getAclMemShape() const; diff --git a/impl/ascend/common/utils.cpp b/impl/ascend/common/utils.cpp index d3f3e0f59..55e610b19 100644 --- a/impl/ascend/common/utils.cpp +++ b/impl/ascend/common/utils.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -182,6 +183,44 @@ diopiError_t reshape(diopiContextHandle_t ctx, const AscendTensor& src, AscendTe return diopiSuccess; } +AscendTensor reshape(diopiContextHandle_t ctx, const AscendTensor& src, const std::vector& shape) { + std::cout << "come into reshape......" << std::endl; + ASCEND_CHECK_ABORT(src.defined(), "input tensor is nullptr."); + + if (src.shape() == shape) { + std::cout << "shape is same, return src tensor pointer=" << src.tensorHandle() << std::endl; + return src; + } + + int64_t expectedSize = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + ASCEND_CHECK_THROW(expectedSize == src.numel(), "reshape size not match, expect %ld, but got %ld", expectedSize, src.numel()); + + { + diopiTensorHandle_t out = nullptr; + diopiSize_t outShape{shape.data(), static_cast(shape.size())}; + diopiRequireTensor(ctx, &out, &outShape, nullptr, src.dtype(), diopi_device); + AscendTensor outAt(out), tmp(src); + tmp.view(shape); + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceCopy, ctx, outAt, tmp); + return outAt; + } + + AscendTensor result; + if (src.isContiguous()) { + result = src; + result.view(shape); + std::cout << "isContiguous reshape tensor pointer=" << result.tensorHandle() << std::endl; + return result; + } + + std::cout << "not isContiguous, make tensor......" << std::endl; + makeTensor(ctx, result, shape, src.dtype()); + std::cout << "make tensor pointer=" << result.tensorHandle() << std::endl; + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceCopy, ctx, result, src); + + return AscendTensor(result.tensorHandle()); +} + diopiError_t aclAsStridedCore(diopiContextHandle_t ctx, const AscendTensor& src, AscendTensor& dst) { diopiTensorHandle_t targetObj = const_cast(static_cast(dst)); AclOpRunner<4, 1>("AsStrided", ctx) diff --git a/impl/ascend/common/utils.hpp b/impl/ascend/common/utils.hpp index 80b1ce056..05314907d 100644 --- a/impl/ascend/common/utils.hpp +++ b/impl/ascend/common/utils.hpp @@ -89,6 +89,8 @@ diopiError_t makeTensorFromScalar(diopiContextHandle_t ctx, AscendTensor& dst, c diopiError_t reshape(diopiContextHandle_t ctx, const AscendTensor& src, AscendTensor& dst, const std::vector& shape); +AscendTensor reshape(diopiContextHandle_t ctx, const AscendTensor& src, const std::vector& shape); + diopiError_t contiguous(diopiContextHandle_t ctx, const AscendTensor& src, AscendTensor& dst, diopiMemoryFormat_t format = diopiMemoryFormat_t::Contiguous); diopiError_t castTensor(diopiContextHandle_t ctx, const AscendTensor& src, AscendTensor& dst); diff --git a/impl/ascend/functions/unique.cpp b/impl/ascend/functions/unique.cpp index dc1e9f9ae..13b80a481 100644 --- a/impl/ascend/functions/unique.cpp +++ b/impl/ascend/functions/unique.cpp @@ -62,9 +62,14 @@ diopiError_t diopiUnique(diopiContextHandle_t ctx, diopiTensorHandle_t* out, dio ASCEND_CHECK_ABORT(ret == 0, "get out aclGetViewShape failed"); // fill out tensor - AscendTensor outReshapeAt; - reshape(ctx, outTmpAt, outReshapeAt, {viewDims, viewDims + viewDimNum}); - *out = const_cast(outReshapeAt.tensorHandle()); + if (1) { + AscendTensor outReshapeAt; + reshape(ctx, outTmpAt, outReshapeAt, {viewDims, viewDims + viewDimNum}); + *out = const_cast(outReshapeAt.tensorHandle()); + } else { + AscendTensor outReshapeAt = reshape(ctx, outTmpAt, {viewDims, viewDims + viewDimNum}); + *out = const_cast(outReshapeAt.tensorHandle()); + } // fill indices tensor if (returnInverse) { @@ -78,9 +83,14 @@ diopiError_t diopiUnique(diopiContextHandle_t ctx, diopiTensorHandle_t* out, dio int ret2 = aclGetViewShape(std::get(params), &viewDims, &viewDimNum); ASCEND_CHECK_ABORT(ret2 == 0, "get count aclGetViewShape failed"); - AscendTensor countsReshapeAt; - reshape(ctx, countsTmpAt, countsReshapeAt, {viewDims, viewDims + viewDimNum}); - *counts = const_cast(countsReshapeAt.tensorHandle()); + if (1) { + AscendTensor countsReshapeAt; + reshape(ctx, countsTmpAt, countsReshapeAt, {viewDims, viewDims + viewDimNum}); + *counts = const_cast(countsReshapeAt.tensorHandle()); + } else { + AscendTensor countsReshapeAt = reshape(ctx, countsTmpAt, {viewDims, viewDims + viewDimNum}); + *counts = const_cast(countsReshapeAt.tensorHandle()); + } } // delete viewDims pointer