Skip to content

Commit f8f395b

Browse files
authored
[ascend]optimize: remove sync 2 (DeepLink-org#898)
* feat: remove sync
1 parent 3d84367 commit f8f395b

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

impl/ascend_npu/torch_npu/csrc/CopyKernel.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,10 @@ at::Tensor& NPUNativeFunctions::copy_(at::Tensor& self, const at::Tensor& src, b
385385
internal_set_names_inplace(self, names);
386386
}
387387

388+
// Param `non_blocking`: if True and this copy is between CPU and GPU,
389+
// the copy may occur asynchronously with respect to the host.
390+
// For other cases, this argument has no effect.
391+
// https://pytorch.org/docs/stable/generated/torch.Tensor.copy_.html
388392
if (at_npu::key::isDeviceTensor(self)) {
389393
if (at_npu::key::isDeviceTensor(src)) {
390394
copy_d2d(self, src, non_blocking);
@@ -396,9 +400,6 @@ at::Tensor& NPUNativeFunctions::copy_(at::Tensor& self, const at::Tensor& src, b
396400
copy_d2h(self, src, non_blocking);
397401
}
398402
}
399-
if (!non_blocking) {
400-
c10_npu::getCurrentNPUStream().synchronize();
401-
}
402403
return self;
403404
}
404405

impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2730,10 +2730,7 @@ NPUStream getCurrentNPUStream(c10::DeviceIndex device_index) {
27302730

27312731
NPUStream getCurrentSecondaryStream(c10::DeviceIndex device_index) { return getCurrentNPUStream(device_index); }
27322732

2733-
void NPUStream::synchronize() const {
2734-
NPU_CHECK_ERROR(aclrtSynchronizeStream(aclStream_));
2735-
NPU_CHECK_ERROR(aclrtSynchronizeDevice());
2736-
}
2733+
void NPUStream::synchronize() const { NPU_CHECK_ERROR(aclrtSynchronizeStream(aclStream_)); }
27372734

27382735
aclError queue::LaunchAsyncCopyTask(void* dst, size_t dstLen, void* src, size_t srcLen, aclrtMemcpyKind kind) {
27392736
c10_npu::NPUStream stream = c10_npu::getCurrentNPUStream();

0 commit comments

Comments
 (0)