From a129597180d5517ca1435693282a857709a8749f Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Thu, 14 Nov 2024 10:36:51 +0000 Subject: [PATCH 1/2] Remove CUDA-only restriction for multi-tensor model updates in optimizer --- oneflow/core/job_rewriter/multi_tensor_model_update.cpp | 6 +++--- python/oneflow/nn/optimizer/adamw.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/oneflow/core/job_rewriter/multi_tensor_model_update.cpp b/oneflow/core/job_rewriter/multi_tensor_model_update.cpp index a2d4e83ad83..1fd73e66359 100644 --- a/oneflow/core/job_rewriter/multi_tensor_model_update.cpp +++ b/oneflow/core/job_rewriter/multi_tensor_model_update.cpp @@ -232,9 +232,9 @@ Maybe MultiTensorModelUpdatePass::Apply(const OpGraph& op_graph, const user_op::UserOpConfWrapper model_update_user_conf( find_model_update_update_node->op().op_conf()); // Multi tensor update pass only support for CUDA currently. - if (find_model_update_update_node->parallel_desc().device_type() != DeviceType::kCUDA) { - continue; - } + // if (find_model_update_update_node->parallel_desc().device_type() != DeviceType::kCUDA) { + // continue; + // } // Multi tensor update pass only support Data Parallel. bool if_data_parallel = true; diff --git a/python/oneflow/nn/optimizer/adamw.py b/python/oneflow/nn/optimizer/adamw.py index 10ed9e12640..aec07649d8e 100644 --- a/python/oneflow/nn/optimizer/adamw.py +++ b/python/oneflow/nn/optimizer/adamw.py @@ -163,9 +163,9 @@ def __init__( warnings.warn("Fused Adamw is not supported when amsgrad=True.") param_group["fused"] = False - if param_group["fused"] and not param.is_cuda: - warnings.warn("Fused Adamw only support cuda parameters.") - param_group["fused"] = False + # if param_group["fused"] and not param.is_cuda: + # warnings.warn("Fused Adamw only support cuda parameters.") + # param_group["fused"] = False self._op_with_amsgrad = ( flow.stateful_op("adam_update") From 434204da912b0e6c9aa17c04109038c6656dc987 Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Thu, 14 Nov 2024 10:42:39 +0000 Subject: [PATCH 2/2] rm lines --- oneflow/core/job_rewriter/multi_tensor_model_update.cpp | 4 ---- python/oneflow/nn/optimizer/adamw.py | 4 ---- 2 files changed, 8 deletions(-) diff --git a/oneflow/core/job_rewriter/multi_tensor_model_update.cpp b/oneflow/core/job_rewriter/multi_tensor_model_update.cpp index 1fd73e66359..334b0860744 100644 --- a/oneflow/core/job_rewriter/multi_tensor_model_update.cpp +++ b/oneflow/core/job_rewriter/multi_tensor_model_update.cpp @@ -231,10 +231,6 @@ Maybe MultiTensorModelUpdatePass::Apply(const OpGraph& op_graph, } const user_op::UserOpConfWrapper model_update_user_conf( find_model_update_update_node->op().op_conf()); - // Multi tensor update pass only support for CUDA currently. - // if (find_model_update_update_node->parallel_desc().device_type() != DeviceType::kCUDA) { - // continue; - // } // Multi tensor update pass only support Data Parallel. bool if_data_parallel = true; diff --git a/python/oneflow/nn/optimizer/adamw.py b/python/oneflow/nn/optimizer/adamw.py index aec07649d8e..17e650598f7 100644 --- a/python/oneflow/nn/optimizer/adamw.py +++ b/python/oneflow/nn/optimizer/adamw.py @@ -163,10 +163,6 @@ def __init__( warnings.warn("Fused Adamw is not supported when amsgrad=True.") param_group["fused"] = False - # if param_group["fused"] and not param.is_cuda: - # warnings.warn("Fused Adamw only support cuda parameters.") - # param_group["fused"] = False - self._op_with_amsgrad = ( flow.stateful_op("adam_update") .Input("model")