1717
1818import torch
1919
20- from ._quant_common .quant_config import local_rank , world_size , HpDtype
20+ from ._quant_common .quant_config import HpDtype
2121from ._core .quant_dequant import QuantDequantBase
2222from ._core .scale_handler import update_state_dict_method , ScaleFormat
2323from ._core .quantized_func_wrappers import (
2626 get_quantized_func_wrapper ,
2727 OP_TYPE ,
2828)
29+ from .prepare_quant .prepare_model import get_world_size , get_local_rank
2930from .utils .logger import logger
3031from neural_compressor .common import options
3132from neural_compressor .torch .utils import (
@@ -75,7 +76,7 @@ def save_rank_model(model, folder_prefix="", **kwargs):
7576 """Save state_dict for model from each rank."""
7677 # workaround for [SW-199005] [HQT] casted fp8 tensor cannot get data pointer
7778 cur_accelerator .synchronize ()
78- save_directory = add_rank_suffix (folder_prefix , local_rank , world_size )
79+ save_directory = add_rank_suffix (folder_prefix , get_local_rank (), get_world_size () )
7980 os .makedirs (save_directory , exist_ok = True )
8081 safe_serialization = kwargs .get ("safe_serialization" , True )
8182 max_shard_size = kwargs .get ("max_shard_size" , f"{ MAX_FILE_SIZE } GB" )
@@ -96,6 +97,8 @@ def gather_state_dict(folder_prefix, file_name, tp_mod_list=[]):
9697 """Gather state_dict from files saved by each rank."""
9798 from safetensors .torch import load_file as safe_load_file
9899
100+ world_size = get_world_size ()
101+
99102 def _is_in_list (name , tp_mod_list ):
100103 for tp_name in tp_mod_list :
101104 if tp_name in name :
@@ -122,6 +125,7 @@ def _is_in_list(name, tp_mod_list):
122125
123126def clean_rank_files (folder_prefix , file_name = None ):
124127 """Clean files saved by each rank after gathering."""
128+ world_size = get_world_size ()
125129 for i in range (world_size ): # TODO: assuming tp_size == world_size
126130 folder_name = add_rank_suffix (folder_prefix , i , world_size )
127131 if file_name is None :
@@ -375,6 +379,8 @@ def save(model, checkpoint_dir="saved_results", format="huggingface", **kwargs):
375379 checkpoint_dir (str, optional): path to checkpoint. Defaults to "saved_results".
376380 format (str, optional): defaults to 'huggingface'.
377381 """
382+ world_size = get_world_size ()
383+ local_rank = get_local_rank ()
378384 format = get_enum_from_format (format )
379385 model = process_model_for_scalar_scale (model )
380386 if world_size > 1 :
@@ -455,6 +461,7 @@ def load_empty_raw_model(model_name_or_path, **kwargs):
455461 if model is None :
456462 with init_empty_weights (include_buffers = False ):
457463 model = transformers .AutoModelForCausalLM .from_config (config , torch_dtype = hp_dtype )
464+ world_size = get_world_size ()
458465 if world_size > 1 :
459466 import deepspeed
460467 from neural_compressor .torch .utils import get_non_persistent_buffers , load_non_persistent_buffers
@@ -604,8 +611,7 @@ def load(model_name_or_path, format="huggingface", device="hpu", **kwargs):
604611 FP8 model.
605612 """
606613 format = get_enum_from_format (format )
607- global world_size
608- world_size = kwargs .get ("world_size" , world_size )
614+ world_size = kwargs .get ("world_size" , get_world_size ())
609615 assert format == SaveLoadFormat .HUGGINGFACE , "Currently, only huggingface models are supported."
610616 assert device in ["hpu" , "cpu" ], "Currently, only hpu & cpu device is supported for FP8 model."
611617
@@ -781,7 +787,7 @@ def load_scale_params(model, new_scale_params):
781787 param .data = new_scale
782788
783789
784- def get_new_rank_state_dict (all_rank_state_dict , model , world_size = world_size , local_rank = local_rank ):
790+ def get_new_rank_state_dict (all_rank_state_dict , model , world_size = get_world_size () , local_rank = get_local_rank () ):
785791 """Get new rank state_dict for world_size.
786792
787793 Args:
0 commit comments