Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion impl/ascend/ascend_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ class AscendTensor final {
return tensor_;
}

bool isSame(const AscendTensor& t) const { return this->tensor_ == t.tensor_; }
bool isSame(const AscendTensor& t) const { return this->tensorHandle() == t.tensorHandle(); }

int64_t getAclMemBufferSize() const;
std::vector<int64_t> getAclMemShape() const;
Expand Down
39 changes: 39 additions & 0 deletions impl/ascend/common/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iostream>
#include <numeric>
#include <string>
#include <type_traits>
Expand Down Expand Up @@ -182,6 +183,44 @@ diopiError_t reshape(diopiContextHandle_t ctx, const AscendTensor& src, AscendTe
return diopiSuccess;
}

AscendTensor reshape(diopiContextHandle_t ctx, const AscendTensor& src, const std::vector<int64_t>& shape) {
std::cout << "come into reshape......" << std::endl;
ASCEND_CHECK_ABORT(src.defined(), "input tensor is nullptr.");

if (src.shape() == shape) {
std::cout << "shape is same, return src tensor pointer=" << src.tensorHandle() << std::endl;
return src;
}

int64_t expectedSize = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
ASCEND_CHECK_THROW(expectedSize == src.numel(), "reshape size not match, expect %ld, but got %ld", expectedSize, src.numel());

{
diopiTensorHandle_t out = nullptr;
diopiSize_t outShape{shape.data(), static_cast<int64_t>(shape.size())};
diopiRequireTensor(ctx, &out, &outShape, nullptr, src.dtype(), diopi_device);
AscendTensor outAt(out), tmp(src);
tmp.view(shape);
DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceCopy, ctx, outAt, tmp);
return outAt;
}

AscendTensor result;
if (src.isContiguous()) {
result = src;
result.view(shape);
std::cout << "isContiguous reshape tensor pointer=" << result.tensorHandle() << std::endl;
return result;
}

std::cout << "not isContiguous, make tensor......" << std::endl;
makeTensor(ctx, result, shape, src.dtype());
std::cout << "make tensor pointer=" << result.tensorHandle() << std::endl;
DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceCopy, ctx, result, src);

return AscendTensor(result.tensorHandle());
}

diopiError_t aclAsStridedCore(diopiContextHandle_t ctx, const AscendTensor& src, AscendTensor& dst) {
diopiTensorHandle_t targetObj = const_cast<diopiTensorHandle_t>(static_cast<diopiConstTensorHandle_t>(dst));
AclOpRunner<4, 1>("AsStrided", ctx)
Expand Down
2 changes: 2 additions & 0 deletions impl/ascend/common/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ diopiError_t makeTensorFromScalar(diopiContextHandle_t ctx, AscendTensor& dst, c

diopiError_t reshape(diopiContextHandle_t ctx, const AscendTensor& src, AscendTensor& dst, const std::vector<int64_t>& shape);

AscendTensor reshape(diopiContextHandle_t ctx, const AscendTensor& src, const std::vector<int64_t>& shape);

diopiError_t contiguous(diopiContextHandle_t ctx, const AscendTensor& src, AscendTensor& dst, diopiMemoryFormat_t format = diopiMemoryFormat_t::Contiguous);

diopiError_t castTensor(diopiContextHandle_t ctx, const AscendTensor& src, AscendTensor& dst);
Expand Down
22 changes: 16 additions & 6 deletions impl/ascend/functions/unique.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,14 @@ diopiError_t diopiUnique(diopiContextHandle_t ctx, diopiTensorHandle_t* out, dio
ASCEND_CHECK_ABORT(ret == 0, "get out aclGetViewShape failed");

// fill out tensor
AscendTensor outReshapeAt;
reshape(ctx, outTmpAt, outReshapeAt, {viewDims, viewDims + viewDimNum});
*out = const_cast<diopiTensorHandle_t>(outReshapeAt.tensorHandle());
if (1) {
AscendTensor outReshapeAt;
reshape(ctx, outTmpAt, outReshapeAt, {viewDims, viewDims + viewDimNum});
*out = const_cast<diopiTensorHandle_t>(outReshapeAt.tensorHandle());
} else {
AscendTensor outReshapeAt = reshape(ctx, outTmpAt, {viewDims, viewDims + viewDimNum});
*out = const_cast<diopiTensorHandle_t>(outReshapeAt.tensorHandle());
}

// fill indices tensor
if (returnInverse) {
Expand All @@ -78,9 +83,14 @@ diopiError_t diopiUnique(diopiContextHandle_t ctx, diopiTensorHandle_t* out, dio
int ret2 = aclGetViewShape(std::get<countsTensorIndex>(params), &viewDims, &viewDimNum);
ASCEND_CHECK_ABORT(ret2 == 0, "get count aclGetViewShape failed");

AscendTensor countsReshapeAt;
reshape(ctx, countsTmpAt, countsReshapeAt, {viewDims, viewDims + viewDimNum});
*counts = const_cast<diopiTensorHandle_t>(countsReshapeAt.tensorHandle());
if (1) {
AscendTensor countsReshapeAt;
reshape(ctx, countsTmpAt, countsReshapeAt, {viewDims, viewDims + viewDimNum});
*counts = const_cast<diopiTensorHandle_t>(countsReshapeAt.tensorHandle());
} else {
AscendTensor countsReshapeAt = reshape(ctx, countsTmpAt, {viewDims, viewDims + viewDimNum});
*counts = const_cast<diopiTensorHandle_t>(countsReshapeAt.tensorHandle());
}
}

// delete viewDims pointer
Expand Down