diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 3d6e9e06e..cd36f7a2d 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1,4 +1,5 @@ import argparse +import json import os from dataclasses import asdict, dataclass, field from pathlib import Path @@ -599,6 +600,11 @@ class SGLangConfig: # The interval (in decoding iterations) to log throughput # and update prometheus metrics decode_log_interval: int = 1 + # Extra loader arguments + # NOTE: These arguments will be parsed into a dict json-string + # and passed as `model_loader_extra_config` to SGLang. + enable_multithread_load: bool = False + enable_fast_load: bool = False # Use staticmethod to make OmegaConf happy. @staticmethod @@ -649,6 +655,19 @@ def build_args( ): # Map "all-linear" to "all" args: Dict = conf_as_dict(sglang_config) + if sglang_config.enable_multithread_load or sglang_config.enable_fast_load: + assert pkg_version.is_version_equal( + "sglang", "0.5.2" + ), f"Customized model loading requires exact SGLang version 0.5.2" + model_loader_extra_config = dict( + enable_multithread_load=sglang_config.enable_multithread_load, + enable_fast_load=sglang_config.enable_fast_load, + ) + args.pop("enable_multithread_load", None) + args.pop("enable_fast_load", None) + args["model_loader_extra_config"] = json.dumps( + model_loader_extra_config, separators=(",", ":") + ) # Map "all-linear" to "all" if "lora_target_modules" in args and args["lora_target_modules"]: args["lora_target_modules"] = [ diff --git a/areal/experimental/megatron_engine.py b/areal/experimental/megatron_engine.py index 4f7842258..5b14d4f7a 100644 --- a/areal/experimental/megatron_engine.py +++ b/areal/experimental/megatron_engine.py @@ -612,8 +612,6 @@ def _update_weights_from_disk(self, meta: WeightUpdateMeta): # dist.barrier() are called when _save_model_to_hf finished if dist.get_rank() == 0: - fut.result() - update_name = names.update_weights_from_disk( self.config.experiment_name, self.config.trial_name, @@ -623,6 +621,8 @@ def _update_weights_from_disk(self, meta: WeightUpdateMeta): update_name, str(datetime.now().timestamp()), keepalive_ttl=120 ) + fut.result() + dist.barrier(device_ids=[self.device.index]) current_platform.synchronize() @@ -642,7 +642,10 @@ def connect_engine(self, engine: InferenceEngine, meta: WeightUpdateMeta): ) self.rollout_engine = engine - if not self.weight_update_group_initialized: + if ( + meta.type == current_platform.communication_backend + and not self.weight_update_group_initialized + ): self._init_weight_update_from_distributed(meta) self.weight_update_group_initialized = True diff --git a/areal/launcher/sglang_server.py b/areal/launcher/sglang_server.py index f88d0549e..56fef2eaf 100644 --- a/areal/launcher/sglang_server.py +++ b/areal/launcher/sglang_server.py @@ -22,7 +22,7 @@ ) from areal.platforms import current_platform from areal.utils import logging, name_resolve, names -from areal.utils.launcher import TRITON_CACHE_PATH +from areal.utils.launcher import TRITON_CACHE_PATH, apply_sglang_patch from areal.utils.network import find_free_ports, gethostip logger = logging.getLogger("SGLangServer Wrapper") @@ -130,6 +130,9 @@ def __init__( self.server_process = None self.n_gpus_per_node = n_gpus_per_node + if self.config.enable_fast_load or self.config.enable_multithread_load: + apply_sglang_patch() + def run(self): gpus_per_server = self.allocation_mode.gen_instance_size cross_nodes = False diff --git a/areal/utils/launcher.py b/areal/utils/launcher.py index f111e4ca5..6125030fe 100644 --- a/areal/utils/launcher.py +++ b/areal/utils/launcher.py @@ -3,11 +3,15 @@ import getpass import os import pathlib +import shutil +import subprocess +import sys import time +from pathlib import Path from typing import Dict, Optional from areal.api.alloc_mode import AllocationMode, AllocationType -from areal.utils import logging, name_resolve, names +from areal.utils import logging, name_resolve, names, pkg_version logger = logging.getLogger("Launcher Utils") @@ -154,3 +158,66 @@ def validate_config_for_distributed_launcher(config): assert ( allocation_mode.gen.tp_size <= config.cluster.n_gpus_per_node ), "Currently only support vLLM TP size less <= #GPUs per node." + + +def apply_sglang_patch(): + p = Path(os.path.dirname(__file__)) + patch_path = str( + p.parent.parent + / "patch" + / "sglang" + / f"v{pkg_version.get_version('sglang')}.patch" + ) + target_path = None + sglang_meta = subprocess.check_output( + [sys.executable, "-m", "pip", "show", "sglang"] + ).decode("utf-8") + # Prioritize editable install location, since pip show lists both locations + # if installed in editable mode. + for line in sglang_meta.split("\n"): + line = line.strip() + if line.startswith("Editable project location: "): + target_path = str(Path(line.split(": ")[1]) / "sglang") + break + else: + for line in sglang_meta.split("\n"): + line = line.strip() + if line.startswith("Location: "): + target_path = str(Path(line.split(": ")[1]) / "sglang") + break + + if not target_path or not os.path.exists(target_path): + raise RuntimeError("Could not determine the installation path of SGLang.") + + patch_binary = shutil.which("patch") + if not patch_binary: + raise RuntimeError( + "Could not locate the `patch` command; SGLang patch application failed." + ) + result = subprocess.run( + [patch_binary, "-p1", "-N", "-i", patch_path], + cwd=target_path, + capture_output=True, + text=True, + ) + + output = (result.stdout or "") + (result.stderr or "") + if result.returncode == 0: + logger.info(f"Applied SGLang patch {patch_path} to {target_path}") + elif ( + "Reversed (or previously applied) patch detected" in output + or "Skipping patch." in output + ): + logger.warning( + f"SGLang patch {patch_path} appears to be already applied for {target_path}." + ) + else: + logger.error( + "Failed to apply SGLang patch %s to %s. Output:\n%s", + patch_path, + target_path, + output.strip(), + ) + raise RuntimeError( + f"SGLang patch {patch_path} failed with exit code {result.returncode}." + ) diff --git a/areal/utils/pkg_version.py b/areal/utils/pkg_version.py index c6ab125fb..48216ef79 100644 --- a/areal/utils/pkg_version.py +++ b/areal/utils/pkg_version.py @@ -51,3 +51,15 @@ def is_version_less(package_name: str, target_version: str) -> bool: """ installed_version = get_version(package_name) return compare_versions(installed_version, target_version) < 0 + + +def is_version_equal(package_name: str, target_version: str) -> bool: + """ + Check if the installed version of a package is equal to the target version. + + :param package_name: Name of the package. + :param target_version: Target version to compare against. + :return: True if the installed version is equal to the target version, False otherwise. + """ + installed_version = get_version(package_name) + return compare_versions(installed_version, target_version) == 0 diff --git a/docs/cli_reference.md b/docs/cli_reference.md index 7410a0d9b..39293808f 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -527,6 +527,8 @@ https://github.com/sgl-project/sglang for detailed documentation. | `show_time_cost` | boolean | `False` | - | | `enable_metrics` | boolean | `True` | - | | `decode_log_interval` | integer | `1` | - | +| `enable_multithread_load` | boolean | `False` | - | +| `enable_fast_load` | boolean | `False` | - | (section-v-llm)= diff --git a/patch/sglang/v0.5.2.patch b/patch/sglang/v0.5.2.patch new file mode 100644 index 000000000..65986c635 --- /dev/null +++ b/patch/sglang/v0.5.2.patch @@ -0,0 +1,538 @@ +diff -ruN ./srt/model_executor/model_runner.py ../../../my-sglang/python/sglang/srt/model_executor/model_runner.py +--- ./srt/model_executor/model_runner.py 2025-10-13 13:20:52.071417615 +0800 ++++ ../../../my-sglang/python/sglang/srt/model_executor/model_runner.py 2025-10-11 15:06:03.052603197 +0800 +@@ -258,6 +258,7 @@ + self._model_update_group = {} + + def initialize(self, min_per_gpu_memory: float): ++ logger.warning("SGLang v0.5.2 is patched with customized weight loading.") + server_args = self.server_args + + self.memory_saver_adapter = TorchMemorySaverAdapter.create( +@@ -823,8 +824,11 @@ + + target_device = torch.device(self.device) + self.model_config.model_path = model_path +- load_config = LoadConfig(load_format=load_format) +- ++ load_config = LoadConfig( ++ load_format=load_format, ++ # XXX: This should be in function args, passed in by requests ++ model_loader_extra_config=self.server_args.model_loader_extra_config, ++ ) + # Only support DefaultModelLoader for now + loader = get_model_loader(load_config) + if not isinstance(loader, DefaultModelLoader): +@@ -838,7 +842,9 @@ + return iter + + def model_load_weights(model, iter): +- DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device) ++ DefaultModelLoader.load_weights_and_postprocess( ++ model, iter, target_device, load_config=load_config ++ ) + return model + + with set_default_torch_dtype(self.model_config.dtype): +diff -ruN ./srt/model_loader/loader.py ../../../my-sglang/python/sglang/srt/model_loader/loader.py +--- ./srt/model_loader/loader.py 2025-10-13 13:20:52.071417615 +0800 ++++ ../../../my-sglang/python/sglang/srt/model_loader/loader.py 2025-10-11 15:03:31.201989298 +0800 +@@ -278,7 +278,7 @@ + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + extra_config = load_config.model_loader_extra_config +- allowed_keys = {"enable_multithread_load", "num_threads"} ++ allowed_keys = {"enable_multithread_load", "num_threads", "enable_fast_load"} + unexpected_keys = set(extra_config.keys()) - allowed_keys + + if unexpected_keys: +@@ -399,6 +399,9 @@ + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + extra_config = self.load_config.model_loader_extra_config ++ if extra_config.get("enable_fast_load"): ++ return source.model_or_path ++ + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + source.model_or_path, source.revision, source.fall_back_to_pt + ) +@@ -441,8 +444,6 @@ + ) + else: + weights_iterator = pt_weights_iterator(hf_weights_files) +- +- # Apply the prefix. + return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) + + def _get_all_weights( +@@ -450,7 +451,6 @@ + model_config: ModelConfig, + model: nn.Module, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: +- + primary_weights = DefaultModelLoader.Source.init_new(model_config, model) + yield from self._get_weights_iterator(primary_weights) + +@@ -479,15 +479,23 @@ + self.load_config, + ) + ++ extra_config = self.load_config.model_loader_extra_config ++ if extra_config.get("enable_fast_load"): ++ weights_iter_or_path = model_config.model_path ++ else: ++ weights_iter_or_path = self._get_all_weights(model_config, model) + self.load_weights_and_postprocess( +- model, self._get_all_weights(model_config, model), target_device ++ model, weights_iter_or_path, target_device, load_config=self.load_config + ) +- + return model.eval() + + @staticmethod +- def load_weights_and_postprocess(model, weights, target_device): +- model.load_weights(weights) ++ def load_weights_and_postprocess(model, weights, target_device, load_config=None): ++ extra_config = load_config.model_loader_extra_config ++ if extra_config.get("enable_fast_load"): ++ model.load_weights_from_path(weights) ++ else: ++ model.load_weights(weights) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) +diff -ruN ./srt/models/qwen3_moe.py ../../../my-sglang/python/sglang/srt/models/qwen3_moe.py +--- ./srt/models/qwen3_moe.py 2025-10-13 13:20:52.077417877 +0800 ++++ ../../../my-sglang/python/sglang/srt/models/qwen3_moe.py 2025-10-11 15:03:31.207989560 +0800 +@@ -770,6 +770,130 @@ + else: + self.model.layers_to_capture = [val + 1 for val in layer_ids] + ++ @property ++ def stacked_params_mapping(self) -> List[Tuple[str, str, str]]: ++ return [ ++ # (param_name, shard_name, shard_id) ++ ("qkv_proj", "q_proj", "q"), ++ ("qkv_proj", "k_proj", "k"), ++ ("qkv_proj", "v_proj", "v"), ++ ("gate_up_proj", "gate_proj", 0), ++ ("gate_up_proj", "up_proj", 1), ++ ] ++ ++ @property ++ def expert_params_mapping(self) -> List[Tuple[str, str, int, str]]: ++ return get_moe_impl_class().make_expert_params_mapping( ++ ckpt_gate_proj_name="gate_proj", ++ ckpt_down_proj_name="down_proj", ++ ckpt_up_proj_name="up_proj", ++ num_experts=self.config.num_experts, ++ ) ++ ++ def _load_weights_with_worker( ++ self, ++ params: Dict[str, torch.nn.Parameter], ++ local_names: List[str], ++ filenames: List[str], ++ weight_path: str, ++ ): ++ import os ++ ++ from safetensors import safe_open ++ ++ from sglang.srt.model_loader.weight_utils import default_weight_loader ++ ++ all_slices = {} ++ for filename in filenames: ++ safetensor_file = os.path.join(weight_path, filename) ++ with safe_open(safetensor_file, framework="pt", device="cpu") as f: ++ for name in f.keys(): ++ # all_slices[name] = f.get_slice(name) ++ all_slices[name] = f.get_tensor(name) ++ ++ for local_name in local_names: ++ # Skip loading extra bias for GPTQ models. ++ if local_name.endswith(".bias") and local_name not in params: ++ continue ++ # Handle special cases ++ if "rotary_emb.inv_freq" in local_name or "projector" in local_name: ++ continue ++ if ( ++ "rotary_emb.cos_cached" in local_name ++ or "rotary_emb.sin_cached" in local_name ++ ): ++ # Models trained using ColossalAI may include these tensors in ++ # the checkpoint. Skip them. ++ continue ++ if local_name.startswith("model.vision_tower") and local_name not in params: ++ continue ++ ++ param = params[local_name] ++ # Handle weight tying ++ if self.config.tie_word_embeddings and "lm_head.weight" in local_name: ++ if self.pp_group.world_size > 1 and self.pp_group.is_last_rank: ++ local_name = "model.embed_tokens.weight" ++ ++ loaded = False ++ for param_name, shard_name, shard_id in self.stacked_params_mapping: ++ if param_name not in local_name: ++ continue ++ if "mlp.experts" in local_name: ++ # Skip experts here, handled below ++ continue ++ # If local_name weight is sharded into multiple keys ++ weight_loader = param.weight_loader ++ slice_name = local_name.replace(param_name, shard_name) ++ loaded_weight = all_slices[slice_name] ++ weight_loader(param, loaded_weight, shard_id) ++ loaded = True ++ ++ for ( ++ param_name, ++ shard_name, ++ expert_id, ++ shard_id, ++ ) in self.expert_params_mapping: ++ if param_name not in local_name: ++ continue ++ # If local_name weight is sharded into multiple keys ++ weight_loader = param.weight_loader ++ slice_name = local_name.replace(param_name, shard_name) ++ loaded_weight = all_slices[slice_name] ++ weight_loader( ++ param, ++ loaded_weight, ++ local_name, ++ shard_id=shard_id, ++ expert_id=expert_id, ++ ) ++ loaded = True ++ ++ if not loaded: ++ # If local_name weight is not sharded ++ if local_name in all_slices: ++ loaded_weight = all_slices[local_name] ++ weight_loader = getattr( ++ param, "weight_loader", default_weight_loader ++ ) ++ weight_loader(param, loaded_weight) ++ else: ++ raise KeyError( ++ f"Cannot find weight {local_name} in the loaded slices." ++ ) ++ ++ def load_weights_from_path(self, path: str): ++ # Customized weights loading from a given path of huggingface model ++ from sglang.srt.models.utils.load import load_weights_with_hf_path_fast ++ ++ load_weights_with_hf_path_fast( ++ model=self, ++ weight_path=path, ++ load_weights_with_worker_fn=self._load_weights_with_worker, ++ stacked_params_mapping=self.stacked_params_mapping, ++ expert_params_mapping=self.expert_params_mapping, ++ ) ++ + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) +diff -ruN ./srt/models/qwen3.py ../../../my-sglang/python/sglang/srt/models/qwen3.py +--- ./srt/models/qwen3.py 2025-10-13 13:20:52.076417833 +0800 ++++ ../../../my-sglang/python/sglang/srt/models/qwen3.py 2025-10-11 15:03:31.207989560 +0800 +@@ -418,6 +418,97 @@ + def end_layer(self): + return self.model.end_layer + ++ @property ++ def stacked_params_mapping(self) -> List[Tuple[str, str, str]]: ++ return [ ++ # (param_name, shard_name, shard_id) ++ ("qkv_proj", "q_proj", "q"), ++ ("qkv_proj", "k_proj", "k"), ++ ("qkv_proj", "v_proj", "v"), ++ ("gate_up_proj", "gate_proj", 0), ++ ("gate_up_proj", "up_proj", 1), ++ ] ++ ++ def _load_weights_with_worker( ++ self, ++ params: Dict[str, torch.nn.Parameter], ++ local_names: List[str], ++ filenames: List[str], ++ weight_path: str, ++ ): ++ import os ++ ++ from safetensors import safe_open ++ ++ from sglang.srt.model_loader.weight_utils import default_weight_loader ++ ++ all_slices = {} ++ for filename in filenames: ++ safetensor_file = os.path.join(weight_path, filename) ++ with safe_open(safetensor_file, framework="pt", device="cpu") as f: ++ for name in f.keys(): ++ # all_slices[name] = f.get_slice(name) ++ all_slices[name] = f.get_tensor(name) ++ ++ for local_name in local_names: ++ # Skip loading extra bias for GPTQ models. ++ if local_name.endswith(".bias") and local_name not in params: ++ continue ++ # Handle special cases ++ if "rotary_emb.inv_freq" in local_name or "projector" in local_name: ++ continue ++ if ( ++ "rotary_emb.cos_cached" in local_name ++ or "rotary_emb.sin_cached" in local_name ++ ): ++ # Models trained using ColossalAI may include these tensors in ++ # the checkpoint. Skip them. ++ continue ++ if local_name.startswith("model.vision_tower") and local_name not in params: ++ continue ++ ++ param = params[local_name] ++ # Handle weight tying ++ if self.config.tie_word_embeddings and "lm_head.weight" in local_name: ++ if self.pp_group.world_size > 1 and self.pp_group.is_last_rank: ++ local_name = "model.embed_tokens.weight" ++ ++ loaded = False ++ for param_name, shard_name, shard_id in self.stacked_params_mapping: ++ if param_name not in local_name: ++ continue ++ # If local_name weight is sharded into multiple keys ++ weight_loader = param.weight_loader ++ slice_name = local_name.replace(param_name, shard_name) ++ loaded_weight = all_slices[slice_name] ++ weight_loader(param, loaded_weight, shard_id) ++ loaded = True ++ ++ if not loaded: ++ # If local_name weight is not sharded ++ if local_name in all_slices: ++ loaded_weight = all_slices[local_name] ++ weight_loader = getattr( ++ param, "weight_loader", default_weight_loader ++ ) ++ weight_loader(param, loaded_weight) ++ else: ++ raise KeyError( ++ f"Cannot find weight {local_name} in the loaded slices." ++ ) ++ ++ def load_weights_from_path(self, path: str): ++ # Customized weights loading from a given path of huggingface model ++ from sglang.srt.models.utils.load import load_weights_with_hf_path_fast ++ ++ load_weights_with_hf_path_fast( ++ model=self, ++ weight_path=path, ++ load_weights_with_worker_fn=self._load_weights_with_worker, ++ stacked_params_mapping=self.stacked_params_mapping, ++ tie_word_embeddings=self.config.tie_word_embeddings, ++ ) ++ + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) +@@ -468,11 +559,11 @@ + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue +- name = name.replace(weight_name, param_name) ++ _name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. +- if name.endswith(".bias") and name not in params_dict: ++ if _name.endswith(".bias") and _name not in params_dict: + continue +- param = params_dict[name] ++ param = params_dict[_name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break +diff -ruN ./srt/models/utils/load.py ../../../my-sglang/python/sglang/srt/models/utils/load.py +--- ./srt/models/utils/load.py 1970-01-01 08:00:00.000000000 +0800 ++++ ../../../my-sglang/python/sglang/srt/models/utils/load.py 2025-10-11 15:01:36.754004483 +0800 +@@ -0,0 +1,170 @@ ++import json ++import os ++from collections import defaultdict ++from concurrent.futures import ThreadPoolExecutor ++from glob import glob ++from typing import Callable, Dict, List, Tuple ++ ++import torch ++from safetensors import safe_open ++from transformers.utils.hub import cached_file ++ ++from sglang.srt.model_loader.weight_utils import default_weight_loader ++ ++ ++def get_actual_hf_path(weight_path: str): ++ return os.path.dirname(cached_file(weight_path, "config.json")) ++ ++ ++def make_filename_bins( ++ local_to_file_map: Dict[str, List[str]], ++) -> Tuple[List[List[str]], List[List[str]]]: ++ # Allocate local weight name into bins, where each bin access independent files ++ # Then we can use multiple threads to concurrently load each bin's parameters. ++ # This function has a complexity of O(F + L²) ++ # where F = total number of files, L = number of local names ++ if not local_to_file_map: ++ return [], [] ++ ++ local_names = list(local_to_file_map.keys()) ++ n = len(local_names) ++ ++ # Convert file lists to sets for O(1) lookups and create file-to-locals mapping ++ local_to_files = {name: set(local_to_file_map[name]) for name in local_names} ++ file_to_locals = defaultdict(set) ++ for local_name, files in local_to_files.items(): ++ for file in files: ++ file_to_locals[file].add(local_name) ++ ++ # Union-Find with path compression and union by rank ++ parent = list(range(n)) ++ rank = [0] * n ++ ++ def find(x): ++ if parent[x] != x: ++ parent[x] = find(parent[x]) # Path compression ++ return parent[x] ++ ++ def union(x, y): ++ root_x, root_y = find(x), find(y) ++ if root_x == root_y: ++ return ++ ++ # Union by rank ++ if rank[root_x] < rank[root_y]: ++ root_x, root_y = root_y, root_x ++ parent[root_y] = root_x ++ if rank[root_x] == rank[root_y]: ++ rank[root_x] += 1 ++ ++ # Create name-to-index mapping for O(1) lookups ++ name_to_idx = {name: i for i, name in enumerate(local_names)} ++ ++ # Union locals that share files - O(F) where F is total number of files ++ for locals_sharing_file in file_to_locals.values(): ++ if len(locals_sharing_file) > 1: ++ locals_list = list(locals_sharing_file) ++ first_idx = name_to_idx[locals_list[0]] ++ for local_name in locals_list[1:]: ++ union(first_idx, name_to_idx[local_name]) ++ ++ # Group by root - O(L) ++ root_to_group = defaultdict(list) ++ for i, name in enumerate(local_names): ++ root_to_group[find(i)].append(name) ++ ++ # Build result groups - O(L + F) ++ grouped_local_names = [] ++ grouped_filenames = [] ++ ++ for group in root_to_group.values(): ++ grouped_local_names.append(group) ++ # Use set union to merge files from all locals in group ++ all_files = set() ++ for local_name in group: ++ all_files.update(local_to_files[local_name]) ++ grouped_filenames.append(list(all_files)) ++ ++ return grouped_local_names, grouped_filenames ++ ++ ++def load_weights_with_hf_path_fast( ++ model: torch.nn.Module, ++ weight_path: str, ++ load_weights_with_worker_fn: Callable, ++ stacked_params_mapping: List[Tuple[str, str, str]] | None = None, ++ expert_params_mapping: List[Tuple[str, str, str]] | None = None, ++ tie_word_embeddings: bool = False, ++ max_workers: int = None, ++): ++ if not os.path.exists(weight_path): ++ weight_path = get_actual_hf_path(weight_path) ++ index_file = os.path.join(weight_path, "model.safetensors.index.json") ++ index = {} ++ if os.path.exists(index_file): ++ with open(index_file, "r") as f: ++ index = json.load(f)["weight_map"] ++ else: ++ # Search all safetensors files ++ safetensor_files = glob(os.path.join(weight_path, "*.safetensors")) ++ # If there are safetensors files ++ if safetensor_files: ++ # Iterate through each safetensors file ++ for safetensor_file in safetensor_files: ++ with safe_open(safetensor_file, framework="pt", device="cpu") as f: ++ for k in f.keys(): ++ index[k] = safetensor_file ++ else: ++ raise FileNotFoundError("No safetensors found in the model path to load.") ++ ++ params = dict(model.named_parameters()) ++ local_names = list(params.keys()) ++ ++ worker_args = [] ++ ++ # local name -> set of filenames that contains the weight ++ local_to_file_map = defaultdict(set) ++ # model.layers.31.mlp.experts ++ for local_name in local_names: ++ hf_names = [] ++ if "mlp.experts" not in local_name and stacked_params_mapping is not None: ++ for param_name, shard_name, _ in stacked_params_mapping: ++ if param_name in local_name: ++ hf_names.append(local_name.replace(param_name, shard_name)) ++ if expert_params_mapping is not None: ++ for param_name, shard_name, _, _ in expert_params_mapping: ++ if param_name in local_name: ++ hf_names.append(local_name.replace(param_name, shard_name)) ++ if tie_word_embeddings and "lm_head.weight" in local_name: ++ hf_names.append("model.embed_tokens.weight") ++ if len(hf_names) == 0: ++ hf_names.append(local_name) ++ for name in hf_names: ++ filename = index[name] ++ if filename not in local_to_file_map[local_name]: ++ local_to_file_map[local_name].add(filename) ++ ++ grouped_local_names, grouped_filenames = make_filename_bins(local_to_file_map) ++ ++ if max_workers is None: ++ # assume all GPUs are used by SGLang servers ++ max_workers = min(8, max(1, os.cpu_count() // torch.cuda.device_count())) ++ ++ for local_names, filenames in zip(grouped_local_names, grouped_filenames): ++ worker_args.append( ++ dict( ++ params=params, ++ local_names=local_names, ++ filenames=filenames, ++ weight_path=weight_path, ++ ) ++ ) ++ ++ max_workers = min(max_workers, len(worker_args)) ++ with ThreadPoolExecutor(max_workers=max_workers) as executor: ++ results = executor.map( ++ lambda kwargs: load_weights_with_worker_fn(**kwargs), worker_args ++ ) ++ # Consume all results to make result all tasks complete ++ for _ in results: ++ pass +diff -ruN ./srt/utils.py ../../../my-sglang/python/sglang/srt/utils.py +--- ./srt/utils.py 2025-10-13 13:20:52.081418051 +0800 ++++ ../../../my-sglang/python/sglang/srt/utils.py 2025-10-11 15:03:52.182903128 +0800 +@@ -21,6 +21,7 @@ + import dataclasses + import functools + import importlib ++import inspect + import io + import ipaddress + import itertools