From bc67ecbe5d03d40e88897e3bc7fc5301689c59de Mon Sep 17 00:00:00 2001 From: Jianhua Zheng Date: Wed, 18 Dec 2024 10:37:37 +0800 Subject: [PATCH] fix autocast to support global tensor --- oneflow/core/framework/autocast.cpp | 2 +- oneflow/core/functional/tensor_processor.cpp | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/oneflow/core/framework/autocast.cpp b/oneflow/core/framework/autocast.cpp index f3ce8320183..f5ee0505b0d 100644 --- a/oneflow/core/framework/autocast.cpp +++ b/oneflow/core/framework/autocast.cpp @@ -94,7 +94,7 @@ Maybe cached_cast(const std::shared_ptr& tensor, Symbo && cast_type == get_lower_precision_fp_from_device_type(device_type) && tensor->dtype()->data_type() == DataType::kFloat && tensor->is_leaf() && !tensor->is_view()); - if (use_cache) { + if (use_cache && tensor->is_local()) { auto it = cached_casts()->find( std::make_pair(JUST(tensor->mut_eager_local_tensor_impl()), cast_type->data_type())); if (it == cached_casts()->end() || it->second.first.lock() == nullptr) { diff --git a/oneflow/core/functional/tensor_processor.cpp b/oneflow/core/functional/tensor_processor.cpp index b13ebc57507..079ed2f24d1 100644 --- a/oneflow/core/functional/tensor_processor.cpp +++ b/oneflow/core/functional/tensor_processor.cpp @@ -225,8 +225,10 @@ Maybe TensorAutoCastProcessor::Apply() { for (int i = 0; i < inputs_.size(); ++i) { if (args_eligible[i] && JUST(IsDeviceType(inputs_[i], autocast_device_type)) && inputs_[i]->dtype()->is_floating_point() && inputs_[i]->dtype() != autocast_dtype) { - autocast_inputs_[i] = JUST(autocast::cached_cast(inputs_[i], autocast_dtype, - JUST(inputs_[i]->device())->enum_type())); + auto device_type = inputs_[i]->is_local() + ? JUST(inputs_[i]->device())->enum_type() + : JUST(inputs_[i]->parallel_desc())->device_type(); + autocast_inputs_[i] = JUST(autocast::cached_cast(inputs_[i], autocast_dtype, device_type)); } else { autocast_inputs_[i] = inputs_[i]; }