From 78fb6144d2edb594b0841520dbc5febdc6a36ce5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=93=E9=9B=B7?= Date: Thu, 11 Sep 2025 17:23:53 +0800 Subject: [PATCH 01/24] add patch to optimize sglang loading --- areal/api/cli_args.py | 23 ++ areal/launcher/local.py | 13 +- areal/utils/launcher.py | 34 ++- areal/utils/pkg_version.py | 12 + patch/sglang/v0.4.9.post2.patch | 497 ++++++++++++++++++++++++++++++++ 5 files changed, 577 insertions(+), 2 deletions(-) create mode 100644 patch/sglang/v0.4.9.post2.patch diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 5a0afacd6..207d0d2d9 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1,4 +1,5 @@ import argparse +import json import os from dataclasses import asdict, dataclass, field from pathlib import Path @@ -422,6 +423,15 @@ class SGLangConfig: # The interval (in decoding iterations) to log throughput # and update prometheus metrics decode_log_interval: int = 1 + # Extra loader arguments + # NOTE: These arguments will be parsed into a dict json-string + # and passed as `model_loader_extra_config` to SGLang. + enable_multithread_load: bool = False + enable_fast_load: bool = False + + @property + def if_apply_sglang_patch(self): + return self.enable_multithread_load or self.enable_fast_load # Use staticmethod to make OmegaConf happy. @staticmethod @@ -472,6 +482,19 @@ def build_args( ): args: Dict = conf_as_dict(sglang_config) + if sglang_config.enable_multithread_load or sglang_config.enable_fast_load: + assert pkg_version.is_version_equal( + "sglang", "0.4.9.post2" + ), f"Customized model loading requires exact SGLang version 0.4.9.post2" + model_loader_extra_config = dict( + enable_multithread_load=sglang_config.enable_multithread_load, + enable_fast_load=sglang_config.enable_fast_load, + ) + args.pop("enable_multithread_load", None) + args.pop("enable_fast_load", None) + args["model_loader_extra_config"] = json.dumps( + model_loader_extra_config, separators=(",", ":") + ) args = dict( host=host, port=port, diff --git a/areal/launcher/local.py b/areal/launcher/local.py index 679b54ff8..1738a1e78 100644 --- a/areal/launcher/local.py +++ b/areal/launcher/local.py @@ -21,7 +21,13 @@ ) from areal.utils import logging, name_resolve, names from areal.utils.device import gpu_count -from areal.utils.launcher import JobException, JobInfo, JobState, get_env_vars +from areal.utils.launcher import ( + JobException, + JobInfo, + JobState, + apply_sglang_patch, + get_env_vars, +) from areal.utils.network import find_free_ports, gethostip from areal.utils.recover import check_if_recover @@ -132,6 +138,8 @@ def submit_array( + cmd[i] ) c = f"{c} 2>&1 | tee -a {self.log_path_of(job_name)}" + # SGLang will somehow remove quotes in the command, so we need to escape the quotes + c = c.replace('"', '\\"') logger.info("Starting local process with command: %s", c) process = subprocess.Popen(c, shell=isinstance(c, str)) self._jobs[f"{job_name}/{offset + i}"] = process @@ -280,6 +288,9 @@ def local_main(config, run_id: int = 0): ports = find_free_ports(alloc_mode.gen.dp_size * 2, port_range=(10000, 50000)) host_ip = gethostip() host = "localhost" if not config.sglang.enable_metrics else host_ip + # Directly apply sglang patch on local node since we do not use sglang_server.py + if config.sglang.if_apply_sglang_patch: + apply_sglang_patch() for i in range(alloc_mode.gen.dp_size): config.sglang.random_seed = base_seed + i cmd = SGLangConfig.build_cmd( diff --git a/areal/utils/launcher.py b/areal/utils/launcher.py index 7178d1530..a778ad442 100644 --- a/areal/utils/launcher.py +++ b/areal/utils/launcher.py @@ -3,11 +3,14 @@ import getpass import os import pathlib +import subprocess +import sys import time +from pathlib import Path from typing import Dict, Optional from areal.api.alloc_mode import AllocationMode, AllocationType -from areal.utils import logging, name_resolve, names +from areal.utils import logging, name_resolve, names, pkg_version logger = logging.getLogger("Launcher Utils") @@ -134,3 +137,32 @@ def validate_config_for_distributed_launcher(config): assert ( allocation_mode.gen.pp_size == 1 ), "Pipeline generation in SGLang is not supported for now." + + +def apply_sglang_patch(): + p = Path(os.path.dirname(__file__)) + patch_path = str( + p.parent.parent + / "patch" + / "sglang" + / f"v{pkg_version.get_version('sglang')}.patch" + ) + + target_path = "" + sglang_meta = subprocess.check_output( + "python3 -m pip show sglang", shell=True + ).decode("ascii") + for line in sglang_meta.split("\n"): + line = line.strip() + if line.startswith("Editable project location: "): + target_path = str(Path(line.split(": ")[1]).parent) + + if target_path: + proc = subprocess.Popen( + ["git", "apply", patch_path], + cwd=target_path, + stderr=sys.stdout, + stdout=sys.stdout, + ) + proc.wait() + logger.info(f"Applied SGLang patch {patch_path} to {target_path}") diff --git a/areal/utils/pkg_version.py b/areal/utils/pkg_version.py index c6ab125fb..48216ef79 100644 --- a/areal/utils/pkg_version.py +++ b/areal/utils/pkg_version.py @@ -51,3 +51,15 @@ def is_version_less(package_name: str, target_version: str) -> bool: """ installed_version = get_version(package_name) return compare_versions(installed_version, target_version) < 0 + + +def is_version_equal(package_name: str, target_version: str) -> bool: + """ + Check if the installed version of a package is equal to the target version. + + :param package_name: Name of the package. + :param target_version: Target version to compare against. + :return: True if the installed version is equal to the target version, False otherwise. + """ + installed_version = get_version(package_name) + return compare_versions(installed_version, target_version) == 0 diff --git a/patch/sglang/v0.4.9.post2.patch b/patch/sglang/v0.4.9.post2.patch new file mode 100644 index 000000000..e9557736b --- /dev/null +++ b/patch/sglang/v0.4.9.post2.patch @@ -0,0 +1,497 @@ +锘縟iff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py +index 051f2b75e..b1ea1a140 100644 +--- a/python/sglang/srt/model_executor/model_runner.py ++++ b/python/sglang/srt/model_executor/model_runner.py +@@ -239,6 +239,7 @@ class ModelRunner: + self._model_update_group = {} + + def initialize(self, min_per_gpu_memory: float): ++ logger.info("SGLang v0.4.9.post2 is patched with customized weight loading.") + server_args = self.server_args + + self.memory_saver_adapter = TorchMemorySaverAdapter.create( +@@ -698,8 +699,11 @@ class ModelRunner: + + target_device = torch.device(self.device) + self.model_config.model_path = model_path +- load_config = LoadConfig(load_format=load_format) +- ++ load_config = LoadConfig( ++ load_format=load_format, ++ # XXX: This should be in function args, passed in by requests ++ model_loader_extra_config=self.server_args.model_loader_extra_config, ++ ) + # Only support DefaultModelLoader for now + loader = get_model_loader(load_config) + if not isinstance(loader, DefaultModelLoader): +@@ -713,7 +717,7 @@ class ModelRunner: + return iter + + def model_load_weights(model, iter): +- DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device) ++ DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device, load_config=load_config) + return model + + with set_default_torch_dtype(self.model_config.dtype): +@@ -733,7 +737,7 @@ class ModelRunner: + iter = get_weight_iter(self.model_config) + self.model = model_load_weights(self.model, iter) + return False, message +- ++ + self.model = model + self.server_args.model_path = model_path + self.server_args.load_format = load_format +diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py +index 733e6df9e..579b91640 100644 +--- a/python/sglang/srt/model_loader/loader.py ++++ b/python/sglang/srt/model_loader/loader.py +@@ -233,7 +233,7 @@ class DefaultModelLoader(BaseModelLoader): + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + extra_config = load_config.model_loader_extra_config +- allowed_keys = {"enable_multithread_load", "num_threads"} ++ allowed_keys = {"enable_multithread_load", "num_threads", "enable_fast_load"} + unexpected_keys = set(extra_config.keys()) - allowed_keys + + if unexpected_keys: +@@ -354,6 +354,9 @@ class DefaultModelLoader(BaseModelLoader): + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + extra_config = self.load_config.model_loader_extra_config ++ if extra_config.get("enable_fast_load"): ++ return source.model_or_path ++ + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + source.model_or_path, source.revision, source.fall_back_to_pt + ) +@@ -396,8 +399,6 @@ class DefaultModelLoader(BaseModelLoader): + ) + else: + weights_iterator = pt_weights_iterator(hf_weights_files) +- +- # Apply the prefix. + return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) + + def _get_all_weights( +@@ -405,7 +406,6 @@ class DefaultModelLoader(BaseModelLoader): + model_config: ModelConfig, + model: nn.Module, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: +- + primary_weights = DefaultModelLoader.Source.init_new(model_config, model) + yield from self._get_weights_iterator(primary_weights) + +@@ -434,15 +434,23 @@ class DefaultModelLoader(BaseModelLoader): + self.load_config, + ) + ++ extra_config = self.load_config.model_loader_extra_config ++ if extra_config.get("enable_fast_load"): ++ weights_iter_or_path = model_config.model_path ++ else: ++ weights_iter_or_path = self._get_all_weights(model_config, model) + self.load_weights_and_postprocess( +- model, self._get_all_weights(model_config, model), target_device ++ model, weights_iter_or_path, target_device, load_config=self.load_config + ) +- + return model.eval() + + @staticmethod +- def load_weights_and_postprocess(model, weights, target_device): +- model.load_weights(weights) ++ def load_weights_and_postprocess(model, weights, target_device, load_config=None): ++ extra_config = load_config.model_loader_extra_config ++ if extra_config.get("enable_fast_load"): ++ model.load_weights_from_path(weights) ++ else: ++ model.load_weights(weights) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) +@@ -455,7 +463,6 @@ class DefaultModelLoader(BaseModelLoader): + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + +- + class LayeredModelLoader(DefaultModelLoader): + """Model loader that loads weights layer by layer so that one can quantize a + layer before loading another to make the peak memory envelope smaller.""" +diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py +index 9c3659839..57276c91f 100644 +--- a/python/sglang/srt/models/qwen3.py ++++ b/python/sglang/srt/models/qwen3.py +@@ -374,7 +374,89 @@ class Qwen3ForCausalLM(nn.Module): + @property + def end_layer(self): + return self.model.end_layer ++ ++ @property ++ def stacked_params_mapping(self) -> List[Tuple[str, str, str]]: ++ return [ ++ # (param_name, shard_name, shard_id) ++ ("qkv_proj", "q_proj", "q"), ++ ("qkv_proj", "k_proj", "k"), ++ ("qkv_proj", "v_proj", "v"), ++ ("gate_up_proj", "gate_proj", 0), ++ ("gate_up_proj", "up_proj", 1), ++ ] ++ ++ def _load_weights_with_worker( ++ self, ++ params: Dict[str, torch.nn.Parameter], ++ local_names: List[str], ++ filenames: List[str], ++ weight_path: str, ++ ): ++ import os ++ from sglang.srt.model_loader.weight_utils import default_weight_loader ++ from safetensors import safe_open ++ all_slices = {} ++ for filename in filenames: ++ safetensor_file = os.path.join(weight_path, filename) ++ with safe_open(safetensor_file, framework="pt", device="cpu") as f: ++ for name in f.keys(): ++ # all_slices[name] = f.get_slice(name) ++ all_slices[name] = f.get_tensor(name) ++ ++ for local_name in local_names: ++ # Skip loading extra bias for GPTQ models. ++ if local_name.endswith(".bias") and local_name not in params: ++ continue ++ # Handle special cases ++ if "rotary_emb.inv_freq" in local_name or "projector" in local_name: ++ continue ++ if "rotary_emb.cos_cached" in local_name or "rotary_emb.sin_cached" in local_name: ++ # Models trained using ColossalAI may include these tensors in ++ # the checkpoint. Skip them. ++ continue ++ if local_name.startswith("model.vision_tower") and local_name not in params: ++ continue ++ ++ param = params[local_name] ++ # Handle weight tying ++ if self.config.tie_word_embeddings and "lm_head.weight" in local_name: ++ if self.pp_group.world_size > 1 and self.pp_group.is_last_rank: ++ local_name = "model.embed_tokens.weight" + ++ loaded = False ++ for param_name, shard_name, shard_id in self.stacked_params_mapping: ++ if param_name not in local_name: ++ continue ++ # If local_name weight is sharded into multiple keys ++ weight_loader = param.weight_loader ++ slice_name = local_name.replace(param_name, shard_name) ++ loaded_weight = all_slices[slice_name] ++ weight_loader(param, loaded_weight, shard_id) ++ loaded = True ++ ++ if not loaded: ++ # If local_name weight is not sharded ++ if local_name in all_slices: ++ loaded_weight = all_slices[local_name] ++ weight_loader = getattr( ++ param, "weight_loader", default_weight_loader ++ ) ++ weight_loader(param, loaded_weight) ++ else: ++ raise KeyError(f"Cannot find weight {local_name} in the loaded slices.") ++ ++ def load_weights_from_path(self, path: str): ++ # Customized weights loading from a given path of huggingface model ++ from sglang.srt.models.utils.load import load_weights_with_hf_path_fast ++ load_weights_with_hf_path_fast( ++ model=self, ++ weight_path=path, ++ load_weights_with_worker_fn=self._load_weights_with_worker, ++ stacked_params_mapping=self.stacked_params_mapping, ++ tie_word_embeddings=self.config.tie_word_embeddings, ++ ) ++ + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) +@@ -422,11 +504,11 @@ class Qwen3ForCausalLM(nn.Module): + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue +- name = name.replace(weight_name, param_name) ++ _name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. +- if name.endswith(".bias") and name not in params_dict: ++ if _name.endswith(".bias") and _name not in params_dict: + continue +- param = params_dict[name] ++ param = params_dict[_name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break +diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py +index 7c7c7551b..14f2ca7f1 100644 +--- a/python/sglang/srt/models/qwen3_moe.py ++++ b/python/sglang/srt/models/qwen3_moe.py +@@ -771,6 +771,116 @@ class Qwen3MoeForCausalLM(nn.Module): + else: + self.model.layers_to_capture = [val + 1 for val in layer_ids] + ++ @property ++ def stacked_params_mapping(self) -> List[Tuple[str, str, str]]: ++ return [ ++ # (param_name, shard_name, shard_id) ++ ("qkv_proj", "q_proj", "q"), ++ ("qkv_proj", "k_proj", "k"), ++ ("qkv_proj", "v_proj", "v"), ++ ("gate_up_proj", "gate_proj", 0), ++ ("gate_up_proj", "up_proj", 1), ++ ] ++ ++ @property ++ def expert_params_mapping(self) -> List[Tuple[str, str, int, str]]: ++ return get_moe_impl_class().make_expert_params_mapping( ++ ckpt_gate_proj_name="gate_proj", ++ ckpt_down_proj_name="down_proj", ++ ckpt_up_proj_name="up_proj", ++ num_experts=self.config.num_experts, ++ ) ++ ++ def _load_weights_with_worker( ++ self, ++ params: Dict[str, torch.nn.Parameter], ++ local_names: List[str], ++ filenames: List[str], ++ weight_path: str, ++ ): ++ import os ++ from sglang.srt.model_loader.weight_utils import default_weight_loader ++ from safetensors import safe_open ++ all_slices = {} ++ for filename in filenames: ++ safetensor_file = os.path.join(weight_path, filename) ++ with safe_open(safetensor_file, framework="pt", device="cpu") as f: ++ for name in f.keys(): ++ # all_slices[name] = f.get_slice(name) ++ all_slices[name] = f.get_tensor(name) ++ ++ for local_name in local_names: ++ # Skip loading extra bias for GPTQ models. ++ if local_name.endswith(".bias") and local_name not in params: ++ continue ++ # Handle special cases ++ if "rotary_emb.inv_freq" in local_name or "projector" in local_name: ++ continue ++ if "rotary_emb.cos_cached" in local_name or "rotary_emb.sin_cached" in local_name: ++ # Models trained using ColossalAI may include these tensors in ++ # the checkpoint. Skip them. ++ continue ++ if local_name.startswith("model.vision_tower") and local_name not in params: ++ continue ++ ++ param = params[local_name] ++ # Handle weight tying ++ if self.config.tie_word_embeddings and "lm_head.weight" in local_name: ++ if self.pp_group.world_size > 1 and self.pp_group.is_last_rank: ++ local_name = "model.embed_tokens.weight" ++ ++ loaded = False ++ for param_name, shard_name, shard_id in self.stacked_params_mapping: ++ if param_name not in local_name: ++ continue ++ if "mlp.experts" in local_name: ++ # Skip experts here, handled below ++ continue ++ # If local_name weight is sharded into multiple keys ++ weight_loader = param.weight_loader ++ slice_name = local_name.replace(param_name, shard_name) ++ loaded_weight = all_slices[slice_name] ++ weight_loader(param, loaded_weight, shard_id) ++ loaded = True ++ ++ for param_name, shard_name, expert_id, shard_id in self.expert_params_mapping: ++ if param_name not in local_name: ++ continue ++ # If local_name weight is sharded into multiple keys ++ weight_loader = param.weight_loader ++ slice_name = local_name.replace(param_name, shard_name) ++ loaded_weight = all_slices[slice_name] ++ weight_loader( ++ param, ++ loaded_weight, ++ local_name, ++ shard_id=shard_id, ++ expert_id=expert_id, ++ ) ++ loaded = True ++ ++ if not loaded: ++ # If local_name weight is not sharded ++ if local_name in all_slices: ++ loaded_weight = all_slices[local_name] ++ weight_loader = getattr( ++ param, "weight_loader", default_weight_loader ++ ) ++ weight_loader(param, loaded_weight) ++ else: ++ raise KeyError(f"Cannot find weight {local_name} in the loaded slices.") ++ ++ def load_weights_from_path(self, path: str): ++ # Customized weights loading from a given path of huggingface model ++ from sglang.srt.models.utils.load import load_weights_with_hf_path_fast ++ load_weights_with_hf_path_fast( ++ model=self, ++ weight_path=path, ++ load_weights_with_worker_fn=self._load_weights_with_worker, ++ stacked_params_mapping=self.stacked_params_mapping, ++ expert_params_mapping=self.expert_params_mapping, ++ ) ++ + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) +diff --git a/python/sglang/srt/models/utils/load.py b/python/sglang/srt/models/utils/load.py +new file mode 100644 +index 000000000..2baa1ba4e +--- /dev/null ++++ b/python/sglang/srt/models/utils/load.py +@@ -0,0 +1,140 @@ ++import os ++import json ++from glob import glob ++from concurrent.futures import ThreadPoolExecutor ++from collections import defaultdict ++from typing import Tuple, List, Dict, Callable ++ ++import torch ++from safetensors import safe_open ++ ++from transformers.utils.hub import cached_file ++ ++from sglang.srt.model_loader.weight_utils import default_weight_loader ++ ++def get_actual_hf_path(weight_path: str): ++ return os.path.dirname(cached_file(weight_path, "config.json")) ++ ++ ++def load_weights_with_hf_path_fast( ++ model: torch.nn.Module, ++ weight_path: str, ++ load_weights_with_worker_fn: Callable, ++ stacked_params_mapping: List[Tuple[str, str, str]] | None = None, ++ expert_params_mapping: List[Tuple[str, str, str]] | None = None, ++ tie_word_embeddings: bool = False, ++ max_workers: int = None, ++): ++ if not os.path.exists(weight_path): ++ weight_path = get_actual_hf_path(weight_path) ++ index_file = os.path.join(weight_path, "model.safetensors.index.json") ++ index = {} ++ if os.path.exists(index_file): ++ with open(index_file, "r") as f: ++ index = json.load(f)["weight_map"] ++ else: ++ # Search all safetensors files ++ safetensor_files = glob(os.path.join(weight_path, "*.safetensors")) ++ # If there are safetensors files ++ if safetensor_files: ++ # Iterate through each safetensors file ++ for safetensor_file in safetensor_files: ++ with safe_open(safetensor_file, framework="pt", device="cpu") as f: ++ for k in f.keys(): ++ index[k] = safetensor_file ++ else: ++ raise FileNotFoundError("No safetensors found in the model path to load.") ++ ++ params = dict(model.named_parameters()) ++ local_names = list(params.keys()) ++ ++ worker_args = [] ++ ++ # local name -> set of filenames that contains the weight ++ local_to_file_map = defaultdict(set) ++ # model.layers.31.mlp.experts ++ for local_name in local_names: ++ hf_names = [] ++ if "mlp.experts" not in local_name and stacked_params_mapping is not None: ++ for param_name, shard_name, _ in stacked_params_mapping: ++ if param_name in local_name: ++ hf_names.append(local_name.replace(param_name, shard_name)) ++ if expert_params_mapping is not None: ++ for param_name, shard_name, _, _ in expert_params_mapping: ++ if param_name in local_name: ++ hf_names.append(local_name.replace(param_name, shard_name)) ++ if tie_word_embeddings and "lm_head.weight" in local_name: ++ hf_names.append("model.embed_tokens.weight") ++ if len(hf_names) == 0: ++ hf_names.append(local_name) ++ for name in hf_names: ++ filename = index[name] ++ if filename not in local_to_file_map[local_name]: ++ local_to_file_map[local_name].add(filename) ++ ++ # Use union find to create local_name groups with no file conflicts ++ parent = {name: name for name in local_names} ++ weight_groups = {name: [name] for name in local_names} ++ file_groups = {name: local_to_file_map[name] for name in local_names} ++ roots = [name for name in local_names] ++ ranks = {name: 0 for name in local_names} ++ def find(x): ++ if parent[x] != x: ++ parent[x] = find(parent[x]) ++ return parent[x] ++ ++ def union(x, y): ++ root_x = find(x) ++ root_y = find(y) ++ if root_x != root_y: ++ if ranks[root_x] > ranks[root_y]: ++ parent[root_y] = root_x ++ roots.remove(root_y) ++ elif ranks[root_x] < ranks[root_y]: ++ parent[root_x] = root_y ++ roots.remove(root_x) ++ else: ++ parent[root_y] = root_x ++ roots.remove(root_y) ++ ranks[root_x] += 1 ++ # Merge file groups ++ file_groups[root_x].update(file_groups[root_y]) ++ file_groups[root_y] = file_groups[root_x] ++ # Merge weight groups ++ weight_groups[root_x].extend(weight_groups[root_y]) ++ weight_groups[root_y] = weight_groups[root_x] ++ return True ++ return False ++ ++ for i, weight1 in enumerate(local_names): ++ for weight2 in local_names[i+1:]: ++ # If two weights share any files, they conflict ++ if any(fn in file_groups[weight1] for fn in file_groups[weight2]): ++ union(weight1, weight2) ++ ++ grouped_local_names = [weight_groups[root] for root in roots] ++ grouped_filenames = [list(file_groups[root]) for root in roots] ++ ++ if max_workers is None: ++ # assume all GPUs are used by SGLang servers ++ max_workers = min(8, max(1, os.cpu_count() // torch.cuda.device_count())) ++ ++ for local_names, filenames in zip(grouped_local_names, grouped_filenames): ++ worker_args.append( ++ dict( ++ params=params, ++ local_names=local_names, ++ filenames=filenames, ++ weight_path=weight_path, ++ ) ++ ) ++ ++ max_workers = min(max_workers, len(worker_args)) ++ with ThreadPoolExecutor(max_workers=max_workers) as executor: ++ results = executor.map( ++ lambda kwargs: load_weights_with_worker_fn(**kwargs), worker_args ++ ) ++ # Consume all results to make result all tasks complete ++ for _ in results: ++ pass ++ +\ No newline at end of file From e05bf7400f897dcf062351bd4113a544941547e4 Mon Sep 17 00:00:00 2001 From: nuzant Date: Fri, 10 Oct 2025 17:05:50 +0800 Subject: [PATCH 02/24] . --- areal/api/cli_args.py | 4 +- ...{v0.4.9.post2.patch => v0.5.1.post3.patch} | 140 +++++++++--------- 2 files changed, 75 insertions(+), 69 deletions(-) rename patch/sglang/{v0.4.9.post2.patch => v0.5.1.post3.patch} (88%) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index b78d9bca3..5283b993e 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -660,8 +660,8 @@ def build_args( args: Dict = conf_as_dict(sglang_config) if sglang_config.enable_multithread_load or sglang_config.enable_fast_load: assert pkg_version.is_version_equal( - "sglang", "0.4.9.post2" - ), f"Customized model loading requires exact SGLang version 0.4.9.post2" + "sglang", "0.5.1.post3" + ), f"Customized model loading requires exact SGLang version 0.5.1.post3" model_loader_extra_config = dict( enable_multithread_load=sglang_config.enable_multithread_load, enable_fast_load=sglang_config.enable_fast_load, diff --git a/patch/sglang/v0.4.9.post2.patch b/patch/sglang/v0.5.1.post3.patch similarity index 88% rename from patch/sglang/v0.4.9.post2.patch rename to patch/sglang/v0.5.1.post3.patch index e9557736b..4ea96fe0e 100644 --- a/patch/sglang/v0.4.9.post2.patch +++ b/patch/sglang/v0.5.1.post3.patch @@ -1,8 +1,8 @@ -锘縟iff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py -index 051f2b75e..b1ea1a140 100644 +diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py +index 8d5b7c715..ba4d52982 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py -@@ -239,6 +239,7 @@ class ModelRunner: +@@ -253,6 +253,7 @@ class ModelRunner: self._model_update_group = {} def initialize(self, min_per_gpu_memory: float): @@ -10,7 +10,7 @@ index 051f2b75e..b1ea1a140 100644 server_args = self.server_args self.memory_saver_adapter = TorchMemorySaverAdapter.create( -@@ -698,8 +699,11 @@ class ModelRunner: +@@ -775,8 +776,11 @@ class ModelRunner: target_device = torch.device(self.device) self.model_config.model_path = model_path @@ -24,29 +24,22 @@ index 051f2b75e..b1ea1a140 100644 # Only support DefaultModelLoader for now loader = get_model_loader(load_config) if not isinstance(loader, DefaultModelLoader): -@@ -713,7 +717,7 @@ class ModelRunner: +@@ -790,7 +794,9 @@ class ModelRunner: return iter def model_load_weights(model, iter): - DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device) -+ DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device, load_config=load_config) ++ DefaultModelLoader.load_weights_and_postprocess( ++ model, iter, target_device, load_config=load_config ++ ) return model with set_default_torch_dtype(self.model_config.dtype): -@@ -733,7 +737,7 @@ class ModelRunner: - iter = get_weight_iter(self.model_config) - self.model = model_load_weights(self.model, iter) - return False, message -- -+ - self.model = model - self.server_args.model_path = model_path - self.server_args.load_format = load_format diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py -index 733e6df9e..579b91640 100644 +index 23d70be44..7e8968743 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py -@@ -233,7 +233,7 @@ class DefaultModelLoader(BaseModelLoader): +@@ -263,7 +263,7 @@ class DefaultModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) extra_config = load_config.model_loader_extra_config @@ -55,17 +48,17 @@ index 733e6df9e..579b91640 100644 unexpected_keys = set(extra_config.keys()) - allowed_keys if unexpected_keys: -@@ -354,6 +354,9 @@ class DefaultModelLoader(BaseModelLoader): +@@ -384,6 +384,9 @@ class DefaultModelLoader(BaseModelLoader): ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" extra_config = self.load_config.model_loader_extra_config + if extra_config.get("enable_fast_load"): + return source.model_or_path -+ ++ hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( source.model_or_path, source.revision, source.fall_back_to_pt ) -@@ -396,8 +399,6 @@ class DefaultModelLoader(BaseModelLoader): +@@ -426,8 +429,6 @@ class DefaultModelLoader(BaseModelLoader): ) else: weights_iterator = pt_weights_iterator(hf_weights_files) @@ -74,7 +67,7 @@ index 733e6df9e..579b91640 100644 return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) def _get_all_weights( -@@ -405,7 +406,6 @@ class DefaultModelLoader(BaseModelLoader): +@@ -435,7 +436,6 @@ class DefaultModelLoader(BaseModelLoader): model_config: ModelConfig, model: nn.Module, ) -> Generator[Tuple[str, torch.Tensor], None, None]: @@ -82,7 +75,7 @@ index 733e6df9e..579b91640 100644 primary_weights = DefaultModelLoader.Source.init_new(model_config, model) yield from self._get_weights_iterator(primary_weights) -@@ -434,15 +434,23 @@ class DefaultModelLoader(BaseModelLoader): +@@ -464,15 +464,23 @@ class DefaultModelLoader(BaseModelLoader): self.load_config, ) @@ -110,23 +103,14 @@ index 733e6df9e..579b91640 100644 for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) -@@ -455,7 +463,6 @@ class DefaultModelLoader(BaseModelLoader): - with device_loading_context(module, target_device): - quant_method.process_weights_after_loading(module) - -- - class LayeredModelLoader(DefaultModelLoader): - """Model loader that loads weights layer by layer so that one can quantize a - layer before loading another to make the peak memory envelope smaller.""" diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py -index 9c3659839..57276c91f 100644 +index 042159a50..399600b63 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py -@@ -374,7 +374,89 @@ class Qwen3ForCausalLM(nn.Module): - @property +@@ -415,6 +415,97 @@ class Qwen3ForCausalLM(nn.Module): def end_layer(self): return self.model.end_layer -+ + + @property + def stacked_params_mapping(self) -> List[Tuple[str, str, str]]: + return [ @@ -137,7 +121,7 @@ index 9c3659839..57276c91f 100644 + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] -+ ++ + def _load_weights_with_worker( + self, + params: Dict[str, torch.nn.Parameter], @@ -146,8 +130,11 @@ index 9c3659839..57276c91f 100644 + weight_path: str, + ): + import os -+ from sglang.srt.model_loader.weight_utils import default_weight_loader ++ + from safetensors import safe_open ++ ++ from sglang.srt.model_loader.weight_utils import default_weight_loader ++ + all_slices = {} + for filename in filenames: + safetensor_file = os.path.join(weight_path, filename) @@ -163,7 +150,10 @@ index 9c3659839..57276c91f 100644 + # Handle special cases + if "rotary_emb.inv_freq" in local_name or "projector" in local_name: + continue -+ if "rotary_emb.cos_cached" in local_name or "rotary_emb.sin_cached" in local_name: ++ if ( ++ "rotary_emb.cos_cached" in local_name ++ or "rotary_emb.sin_cached" in local_name ++ ): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue @@ -175,7 +165,7 @@ index 9c3659839..57276c91f 100644 + if self.config.tie_word_embeddings and "lm_head.weight" in local_name: + if self.pp_group.world_size > 1 and self.pp_group.is_last_rank: + local_name = "model.embed_tokens.weight" - ++ + loaded = False + for param_name, shard_name, shard_id in self.stacked_params_mapping: + if param_name not in local_name: @@ -186,7 +176,7 @@ index 9c3659839..57276c91f 100644 + loaded_weight = all_slices[slice_name] + weight_loader(param, loaded_weight, shard_id) + loaded = True -+ ++ + if not loaded: + # If local_name weight is not sharded + if local_name in all_slices: @@ -196,11 +186,14 @@ index 9c3659839..57276c91f 100644 + ) + weight_loader(param, loaded_weight) + else: -+ raise KeyError(f"Cannot find weight {local_name} in the loaded slices.") -+ ++ raise KeyError( ++ f"Cannot find weight {local_name} in the loaded slices." ++ ) ++ + def load_weights_from_path(self, path: str): + # Customized weights loading from a given path of huggingface model + from sglang.srt.models.utils.load import load_weights_with_hf_path_fast ++ + load_weights_with_hf_path_fast( + model=self, + weight_path=path, @@ -208,11 +201,11 @@ index 9c3659839..57276c91f 100644 + stacked_params_mapping=self.stacked_params_mapping, + tie_word_embeddings=self.config.tie_word_embeddings, + ) -+ ++ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) -@@ -422,11 +504,11 @@ class Qwen3ForCausalLM(nn.Module): +@@ -462,11 +553,11 @@ class Qwen3ForCausalLM(nn.Module): for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -228,10 +221,10 @@ index 9c3659839..57276c91f 100644 weight_loader(param, loaded_weight, shard_id) break diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py -index 7c7c7551b..14f2ca7f1 100644 +index fcb45b947..e354e7a47 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py -@@ -771,6 +771,116 @@ class Qwen3MoeForCausalLM(nn.Module): +@@ -739,6 +739,130 @@ class Qwen3MoeForCausalLM(nn.Module): else: self.model.layers_to_capture = [val + 1 for val in layer_ids] @@ -245,7 +238,7 @@ index 7c7c7551b..14f2ca7f1 100644 + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] -+ ++ + @property + def expert_params_mapping(self) -> List[Tuple[str, str, int, str]]: + return get_moe_impl_class().make_expert_params_mapping( @@ -263,8 +256,11 @@ index 7c7c7551b..14f2ca7f1 100644 + weight_path: str, + ): + import os -+ from sglang.srt.model_loader.weight_utils import default_weight_loader ++ + from safetensors import safe_open ++ ++ from sglang.srt.model_loader.weight_utils import default_weight_loader ++ + all_slices = {} + for filename in filenames: + safetensor_file = os.path.join(weight_path, filename) @@ -280,7 +276,10 @@ index 7c7c7551b..14f2ca7f1 100644 + # Handle special cases + if "rotary_emb.inv_freq" in local_name or "projector" in local_name: + continue -+ if "rotary_emb.cos_cached" in local_name or "rotary_emb.sin_cached" in local_name: ++ if ( ++ "rotary_emb.cos_cached" in local_name ++ or "rotary_emb.sin_cached" in local_name ++ ): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue @@ -306,8 +305,13 @@ index 7c7c7551b..14f2ca7f1 100644 + loaded_weight = all_slices[slice_name] + weight_loader(param, loaded_weight, shard_id) + loaded = True -+ -+ for param_name, shard_name, expert_id, shard_id in self.expert_params_mapping: ++ ++ for ( ++ param_name, ++ shard_name, ++ expert_id, ++ shard_id, ++ ) in self.expert_params_mapping: + if param_name not in local_name: + continue + # If local_name weight is sharded into multiple keys @@ -322,7 +326,7 @@ index 7c7c7551b..14f2ca7f1 100644 + expert_id=expert_id, + ) + loaded = True -+ ++ + if not loaded: + # If local_name weight is not sharded + if local_name in all_slices: @@ -332,11 +336,14 @@ index 7c7c7551b..14f2ca7f1 100644 + ) + weight_loader(param, loaded_weight) + else: -+ raise KeyError(f"Cannot find weight {local_name} in the loaded slices.") -+ ++ raise KeyError( ++ f"Cannot find weight {local_name} in the loaded slices." ++ ) ++ + def load_weights_from_path(self, path: str): + # Customized weights loading from a given path of huggingface model + from sglang.srt.models.utils.load import load_weights_with_hf_path_fast ++ + load_weights_with_hf_path_fast( + model=self, + weight_path=path, @@ -350,31 +357,31 @@ index 7c7c7551b..14f2ca7f1 100644 # (param_name, shard_name, shard_id) diff --git a/python/sglang/srt/models/utils/load.py b/python/sglang/srt/models/utils/load.py new file mode 100644 -index 000000000..2baa1ba4e +index 000000000..f3cb03012 --- /dev/null +++ b/python/sglang/srt/models/utils/load.py @@ -0,0 +1,140 @@ -+import os +import json -+from glob import glob -+from concurrent.futures import ThreadPoolExecutor ++import os +from collections import defaultdict -+from typing import Tuple, List, Dict, Callable ++from concurrent.futures import ThreadPoolExecutor ++from glob import glob ++from typing import Callable, Dict, List, Tuple + +import torch +from safetensors import safe_open -+ +from transformers.utils.hub import cached_file + +from sglang.srt.model_loader.weight_utils import default_weight_loader + ++ +def get_actual_hf_path(weight_path: str): + return os.path.dirname(cached_file(weight_path, "config.json")) -+ ++ + +def load_weights_with_hf_path_fast( -+ model: torch.nn.Module, -+ weight_path: str, ++ model: torch.nn.Module, ++ weight_path: str, + load_weights_with_worker_fn: Callable, + stacked_params_mapping: List[Tuple[str, str, str]] | None = None, + expert_params_mapping: List[Tuple[str, str, str]] | None = None, @@ -434,11 +441,12 @@ index 000000000..2baa1ba4e + file_groups = {name: local_to_file_map[name] for name in local_names} + roots = [name for name in local_names] + ranks = {name: 0 for name in local_names} ++ + def find(x): + if parent[x] != x: + parent[x] = find(parent[x]) + return parent[x] -+ ++ + def union(x, y): + root_x = find(x) + root_y = find(y) @@ -463,11 +471,11 @@ index 000000000..2baa1ba4e + return False + + for i, weight1 in enumerate(local_names): -+ for weight2 in local_names[i+1:]: ++ for weight2 in local_names[i + 1 :]: + # If two weights share any files, they conflict + if any(fn in file_groups[weight1] for fn in file_groups[weight2]): + union(weight1, weight2) -+ ++ + grouped_local_names = [weight_groups[root] for root in roots] + grouped_filenames = [list(file_groups[root]) for root in roots] + @@ -493,5 +501,3 @@ index 000000000..2baa1ba4e + # Consume all results to make result all tasks complete + for _ in results: + pass -+ -\ No newline at end of file From d408ba8f237cc816c464704e8bad7d08c1fcebf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=93=E9=9B=B7?= Date: Fri, 10 Oct 2025 17:22:50 +0800 Subject: [PATCH 03/24] . --- areal/api/cli_args.py | 4 ---- areal/launcher/local.py | 6 ++++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 5283b993e..027561ae5 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -605,10 +605,6 @@ class SGLangConfig: enable_multithread_load: bool = False enable_fast_load: bool = False - @property - def if_apply_sglang_patch(self): - return self.enable_multithread_load or self.enable_fast_load - # Use staticmethod to make OmegaConf happy. @staticmethod def build_cmd( diff --git a/areal/launcher/local.py b/areal/launcher/local.py index fae30db53..7af18f70d 100644 --- a/areal/launcher/local.py +++ b/areal/launcher/local.py @@ -296,8 +296,10 @@ def local_main(config, run_id: int = 0): config.vllm = to_structured_cfg(config.vllm, vLLMConfig) random_seed = config.vllm.seed - if alloc_mode.gen_backend == "sglang" and config.sglang.if_apply_sglang_patch: - apply_sglang_patch() + if alloc_mode.gen_backend == "sglang": + apply_patch = config.launcher.enable_multithread_load or config.launcher.enable_fast_load + if apply_patch: + apply_sglang_patch() backend_spec = { "sglang": { From 5f8e990c5a1d0e46e2f194fe72b322cae2e07355 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=93=E9=9B=B7?= Date: Fri, 10 Oct 2025 17:26:32 +0800 Subject: [PATCH 04/24] . --- areal/launcher/local.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/areal/launcher/local.py b/areal/launcher/local.py index 7af18f70d..a6fb25152 100644 --- a/areal/launcher/local.py +++ b/areal/launcher/local.py @@ -297,7 +297,10 @@ def local_main(config, run_id: int = 0): random_seed = config.vllm.seed if alloc_mode.gen_backend == "sglang": - apply_patch = config.launcher.enable_multithread_load or config.launcher.enable_fast_load + apply_patch = ( + config.sglang.enable_multithread_load + or config.sglang.enable_fast_load + ) if apply_patch: apply_sglang_patch() From 81028146b32c5aea024e590a111d157edc3f75d8 Mon Sep 17 00:00:00 2001 From: nuzant Date: Fri, 10 Oct 2025 19:07:30 +0800 Subject: [PATCH 05/24] . --- patch/sglang/v0.5.1.post3.patch | 161 +++++++++++++++++++++++++++++--- 1 file changed, 147 insertions(+), 14 deletions(-) diff --git a/patch/sglang/v0.5.1.post3.patch b/patch/sglang/v0.5.1.post3.patch index 4ea96fe0e..f6bffac37 100644 --- a/patch/sglang/v0.5.1.post3.patch +++ b/patch/sglang/v0.5.1.post3.patch @@ -1,16 +1,41 @@ diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py -index 8d5b7c715..ba4d52982 100644 +index 8d5b7c715..c83b99191 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py -@@ -253,6 +253,7 @@ class ModelRunner: +@@ -108,6 +108,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter + from sglang.srt.utils import ( + MultiprocessingSerializer, + cpu_has_amx_support, ++ debug_function, + dynamic_import, + enable_show_time_cost, + get_available_gpu_memory, +@@ -253,6 +254,8 @@ class ModelRunner: self._model_update_group = {} def initialize(self, min_per_gpu_memory: float): -+ logger.info("SGLang v0.4.9.post2 is patched with customized weight loading.") ++ logger.info("SGLang v0.5.1.post3 is patched with customized weight loading.") ++ print("[Debug] SGLang v0.5.1.post3 is patched with customized weight loading.") server_args = self.server_args self.memory_saver_adapter = TorchMemorySaverAdapter.create( -@@ -775,8 +776,11 @@ class ModelRunner: +@@ -644,6 +647,7 @@ class ModelRunner: + ) + return min_per_gpu_memory + ++ @debug_function + def load_model(self): + before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id) + logger.info( +@@ -764,6 +768,7 @@ class ModelRunner: + rank=self.tp_rank, + ) + ++ @debug_function + def update_weights_from_disk( + self, model_path: str, load_format: str + ) -> tuple[bool, str]: +@@ -775,8 +780,11 @@ class ModelRunner: target_device = torch.device(self.device) self.model_config.model_path = model_path @@ -24,7 +49,7 @@ index 8d5b7c715..ba4d52982 100644 # Only support DefaultModelLoader for now loader = get_model_loader(load_config) if not isinstance(loader, DefaultModelLoader): -@@ -790,7 +794,9 @@ class ModelRunner: +@@ -790,7 +798,9 @@ class ModelRunner: return iter def model_load_weights(model, iter): @@ -36,10 +61,18 @@ index 8d5b7c715..ba4d52982 100644 with set_default_torch_dtype(self.model_config.dtype): diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py -index 23d70be44..7e8968743 100644 +index 23d70be44..a3dba6687 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py -@@ -263,7 +263,7 @@ class DefaultModelLoader(BaseModelLoader): +@@ -62,6 +62,7 @@ from sglang.srt.model_loader.weight_utils import ( + set_runai_streamer_env, + ) + from sglang.srt.utils import ( ++ debug_function, + get_bool_env_var, + get_device_capability, + is_npu, +@@ -263,7 +264,7 @@ class DefaultModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) extra_config = load_config.model_loader_extra_config @@ -48,7 +81,21 @@ index 23d70be44..7e8968743 100644 unexpected_keys = set(extra_config.keys()) - allowed_keys if unexpected_keys: -@@ -384,6 +384,9 @@ class DefaultModelLoader(BaseModelLoader): +@@ -299,6 +300,7 @@ class DefaultModelLoader(BaseModelLoader): + return model_path + return None + ++ @debug_function + def _prepare_weights( + self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool + ) -> Tuple[str, List[str], bool]: +@@ -379,11 +381,15 @@ class DefaultModelLoader(BaseModelLoader): + + return hf_folder, hf_weights_files, use_safetensors + ++ @debug_function + def _get_weights_iterator( + self, source: "Source" ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" extra_config = self.load_config.model_loader_extra_config @@ -58,7 +105,7 @@ index 23d70be44..7e8968743 100644 hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( source.model_or_path, source.revision, source.fall_back_to_pt ) -@@ -426,8 +429,6 @@ class DefaultModelLoader(BaseModelLoader): +@@ -426,16 +432,14 @@ class DefaultModelLoader(BaseModelLoader): ) else: weights_iterator = pt_weights_iterator(hf_weights_files) @@ -66,8 +113,9 @@ index 23d70be44..7e8968743 100644 - # Apply the prefix. return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) ++ @debug_function def _get_all_weights( -@@ -435,7 +436,6 @@ class DefaultModelLoader(BaseModelLoader): + self, model_config: ModelConfig, model: nn.Module, ) -> Generator[Tuple[str, torch.Tensor], None, None]: @@ -75,7 +123,15 @@ index 23d70be44..7e8968743 100644 primary_weights = DefaultModelLoader.Source.init_new(model_config, model) yield from self._get_weights_iterator(primary_weights) -@@ -464,15 +464,23 @@ class DefaultModelLoader(BaseModelLoader): +@@ -450,6 +454,7 @@ class DefaultModelLoader(BaseModelLoader): + model_config.model_path, model_config.revision, fall_back_to_pt=True + ) + ++ @debug_function + def load_model( + self, + *, +@@ -464,15 +469,24 @@ class DefaultModelLoader(BaseModelLoader): self.load_config, ) @@ -91,6 +147,7 @@ index 23d70be44..7e8968743 100644 - return model.eval() ++ @debug_function @staticmethod - def load_weights_and_postprocess(model, weights, target_device): - model.load_weights(weights) @@ -104,10 +161,19 @@ index 23d70be44..7e8968743 100644 for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py -index 042159a50..399600b63 100644 +index 042159a50..4f8b8a1d7 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py -@@ -415,6 +415,97 @@ class Qwen3ForCausalLM(nn.Module): +@@ -27,7 +27,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe + from sglang.srt.model_loader.weight_utils import default_weight_loader + from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP + from sglang.srt.models.qwen2 import Qwen2Model +-from sglang.srt.utils import add_prefix, is_cuda ++from sglang.srt.utils import add_prefix, debug_function, is_cuda + + Qwen3Config = None + +@@ -415,6 +415,99 @@ class Qwen3ForCausalLM(nn.Module): def end_layer(self): return self.model.end_layer @@ -190,6 +256,7 @@ index 042159a50..399600b63 100644 + f"Cannot find weight {local_name} in the loaded slices." + ) + ++ @debug_function + def load_weights_from_path(self, path: str): + # Customized weights loading from a given path of huggingface model + from sglang.srt.models.utils.load import load_weights_with_hf_path_fast @@ -202,10 +269,11 @@ index 042159a50..399600b63 100644 + tie_word_embeddings=self.config.tie_word_embeddings, + ) + ++ @debug_function def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) -@@ -462,11 +553,11 @@ class Qwen3ForCausalLM(nn.Module): +@@ -462,11 +555,11 @@ class Qwen3ForCausalLM(nn.Module): for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -501,3 +569,68 @@ index 000000000..f3cb03012 + # Consume all results to make result all tasks complete + for _ in results: + pass +diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py +index b5f6626a2..a5ac2a2bd 100644 +--- a/python/sglang/srt/utils.py ++++ b/python/sglang/srt/utils.py +@@ -21,6 +21,7 @@ import ctypes + import dataclasses + import functools + import importlib ++import inspect + import io + import ipaddress + import itertools +@@ -3011,3 +3012,52 @@ def check_cuda_result(raw_output): + raise Exception(f"CUDA error: {err}") + + return results ++ ++ ++def debug_function(func: Callable): ++ """Decorator that logs function entry/exit with contextual information.""" ++ ++ if func is None: ++ raise ValueError("debug_function decorator requires a function") ++ ++ @functools.wraps(func) ++ def wrapper(*args, **kwargs): ++ def _resolve_class_name(): ++ qualname = getattr(func, "__qualname__", "") ++ if qualname: ++ parts = qualname.split(".") ++ if len(parts) > 1 and parts[-2] != "": ++ return parts[-2] ++ if args: ++ instance = args[0] ++ if hasattr(instance, "__class__"): ++ return instance.__class__.__name__ ++ return None ++ ++ def _resolve_filename(): ++ try: ++ source = inspect.getsourcefile(func) or inspect.getfile(func) ++ except (TypeError, OSError): ++ source = None ++ if source is None: ++ return "" ++ return Path(source).name ++ ++ class_name = _resolve_class_name() ++ filename = _resolve_filename() ++ display_name = func.__name__ ++ if class_name: ++ display_name = f"{class_name}.{display_name}" ++ ++ print(f"[Debug] Entering {display_name} (file: {filename})", flush=True) ++ ++ start_time = time.perf_counter() ++ try: ++ return func(*args, **kwargs) ++ finally: ++ elapsed_ms = (time.perf_counter() - start_time) * 1000 ++ print( ++ f"[Debug] Exiting {display_name} (file: {filename}) - elapsed {elapsed_ms:.3f} ms" ++ ) ++ ++ return wrapper From 635fad37a9b47296b2d4426ab1ea5a9f815c5c75 Mon Sep 17 00:00:00 2001 From: nuzant Date: Fri, 10 Oct 2025 19:30:58 +0800 Subject: [PATCH 06/24] . --- patch/sglang/v0.5.1.post3.patch | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/patch/sglang/v0.5.1.post3.patch b/patch/sglang/v0.5.1.post3.patch index f6bffac37..5e3dcfcd8 100644 --- a/patch/sglang/v0.5.1.post3.patch +++ b/patch/sglang/v0.5.1.post3.patch @@ -61,7 +61,7 @@ index 8d5b7c715..c83b99191 100644 with set_default_torch_dtype(self.model_config.dtype): diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py -index 23d70be44..a3dba6687 100644 +index 23d70be44..5c44a9278 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -62,6 +62,7 @@ from sglang.srt.model_loader.weight_utils import ( @@ -131,7 +131,7 @@ index 23d70be44..a3dba6687 100644 def load_model( self, *, -@@ -464,15 +469,24 @@ class DefaultModelLoader(BaseModelLoader): +@@ -464,15 +469,23 @@ class DefaultModelLoader(BaseModelLoader): self.load_config, ) @@ -147,7 +147,6 @@ index 23d70be44..a3dba6687 100644 - return model.eval() -+ @debug_function @staticmethod - def load_weights_and_postprocess(model, weights, target_device): - model.load_weights(weights) From 677bd21d0e3b4910426601ae696ab0c7b5332472 Mon Sep 17 00:00:00 2001 From: nuzant Date: Sat, 11 Oct 2025 11:18:26 +0800 Subject: [PATCH 07/24] filename bins --- patch/sglang/v0.5.1.post3.patch | 120 ++++++++++++++++++++------------ 1 file changed, 75 insertions(+), 45 deletions(-) diff --git a/patch/sglang/v0.5.1.post3.patch b/patch/sglang/v0.5.1.post3.patch index 5e3dcfcd8..eb1c54a9d 100644 --- a/patch/sglang/v0.5.1.post3.patch +++ b/patch/sglang/v0.5.1.post3.patch @@ -424,10 +424,10 @@ index fcb45b947..e354e7a47 100644 # (param_name, shard_name, shard_id) diff --git a/python/sglang/srt/models/utils/load.py b/python/sglang/srt/models/utils/load.py new file mode 100644 -index 000000000..f3cb03012 +index 000000000..035835a0f --- /dev/null +++ b/python/sglang/srt/models/utils/load.py -@@ -0,0 +1,140 @@ +@@ -0,0 +1,170 @@ +import json +import os +from collections import defaultdict @@ -446,6 +446,78 @@ index 000000000..f3cb03012 + return os.path.dirname(cached_file(weight_path, "config.json")) + + ++def make_filename_bins( ++ local_to_file_map: Dict[str, List[str]], ++) -> Tuple[List[List[str]], List[List[str]]]: ++ # Allocate local weight name into bins, where each bin access independent files ++ # Then we can use multiple threads to concurrently load each bin's parameters. ++ # This function has a complexity of O(F + L²) ++ # where F = total number of files, L = number of local names ++ if not local_to_file_map: ++ return [], [] ++ ++ local_names = list(local_to_file_map.keys()) ++ n = len(local_names) ++ ++ # Convert file lists to sets for O(1) lookups and create file-to-locals mapping ++ local_to_files = {name: set(local_to_file_map[name]) for name in local_names} ++ file_to_locals = defaultdict(set) ++ for local_name, files in local_to_files.items(): ++ for file in files: ++ file_to_locals[file].add(local_name) ++ ++ # Union-Find with path compression and union by rank ++ parent = list(range(n)) ++ rank = [0] * n ++ ++ def find(x): ++ if parent[x] != x: ++ parent[x] = find(parent[x]) # Path compression ++ return parent[x] ++ ++ def union(x, y): ++ root_x, root_y = find(x), find(y) ++ if root_x == root_y: ++ return ++ ++ # Union by rank ++ if rank[root_x] < rank[root_y]: ++ root_x, root_y = root_y, root_x ++ parent[root_y] = root_x ++ if rank[root_x] == rank[root_y]: ++ rank[root_x] += 1 ++ ++ # Create name-to-index mapping for O(1) lookups ++ name_to_idx = {name: i for i, name in enumerate(local_names)} ++ ++ # Union locals that share files - O(F) where F is total number of files ++ for locals_sharing_file in file_to_locals.values(): ++ if len(locals_sharing_file) > 1: ++ locals_list = list(locals_sharing_file) ++ first_idx = name_to_idx[locals_list[0]] ++ for local_name in locals_list[1:]: ++ union(first_idx, name_to_idx[local_name]) ++ ++ # Group by root - O(L) ++ root_to_group = defaultdict(list) ++ for i, name in enumerate(local_names): ++ root_to_group[find(i)].append(name) ++ ++ # Build result groups - O(L + F) ++ grouped_local_names = [] ++ grouped_filenames = [] ++ ++ for group in root_to_group.values(): ++ grouped_local_names.append(group) ++ # Use set union to merge files from all locals in group ++ all_files = set() ++ for local_name in group: ++ all_files.update(local_to_files[local_name]) ++ grouped_filenames.append(list(all_files)) ++ ++ return grouped_local_names, grouped_filenames ++ ++ +def load_weights_with_hf_path_fast( + model: torch.nn.Module, + weight_path: str, @@ -502,49 +574,7 @@ index 000000000..f3cb03012 + if filename not in local_to_file_map[local_name]: + local_to_file_map[local_name].add(filename) + -+ # Use union find to create local_name groups with no file conflicts -+ parent = {name: name for name in local_names} -+ weight_groups = {name: [name] for name in local_names} -+ file_groups = {name: local_to_file_map[name] for name in local_names} -+ roots = [name for name in local_names] -+ ranks = {name: 0 for name in local_names} -+ -+ def find(x): -+ if parent[x] != x: -+ parent[x] = find(parent[x]) -+ return parent[x] -+ -+ def union(x, y): -+ root_x = find(x) -+ root_y = find(y) -+ if root_x != root_y: -+ if ranks[root_x] > ranks[root_y]: -+ parent[root_y] = root_x -+ roots.remove(root_y) -+ elif ranks[root_x] < ranks[root_y]: -+ parent[root_x] = root_y -+ roots.remove(root_x) -+ else: -+ parent[root_y] = root_x -+ roots.remove(root_y) -+ ranks[root_x] += 1 -+ # Merge file groups -+ file_groups[root_x].update(file_groups[root_y]) -+ file_groups[root_y] = file_groups[root_x] -+ # Merge weight groups -+ weight_groups[root_x].extend(weight_groups[root_y]) -+ weight_groups[root_y] = weight_groups[root_x] -+ return True -+ return False -+ -+ for i, weight1 in enumerate(local_names): -+ for weight2 in local_names[i + 1 :]: -+ # If two weights share any files, they conflict -+ if any(fn in file_groups[weight1] for fn in file_groups[weight2]): -+ union(weight1, weight2) -+ -+ grouped_local_names = [weight_groups[root] for root in roots] -+ grouped_filenames = [list(file_groups[root]) for root in roots] ++ grouped_local_names, grouped_filenames = make_filename_bins(local_to_file_map) + + if max_workers is None: + # assume all GPUs are used by SGLang servers From 69043aaf76e10f9357d501ff8b0dab68f37b49da Mon Sep 17 00:00:00 2001 From: nuzant Date: Sat, 11 Oct 2025 14:12:44 +0800 Subject: [PATCH 08/24] . --- areal/api/cli_args.py | 4 ---- areal/launcher/local.py | 5 +++-- areal/launcher/sglang_server.py | 5 ++++- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 5283b993e..027561ae5 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -605,10 +605,6 @@ class SGLangConfig: enable_multithread_load: bool = False enable_fast_load: bool = False - @property - def if_apply_sglang_patch(self): - return self.enable_multithread_load or self.enable_fast_load - # Use staticmethod to make OmegaConf happy. @staticmethod def build_cmd( diff --git a/areal/launcher/local.py b/areal/launcher/local.py index fae30db53..05fc83014 100644 --- a/areal/launcher/local.py +++ b/areal/launcher/local.py @@ -296,8 +296,9 @@ def local_main(config, run_id: int = 0): config.vllm = to_structured_cfg(config.vllm, vLLMConfig) random_seed = config.vllm.seed - if alloc_mode.gen_backend == "sglang" and config.sglang.if_apply_sglang_patch: - apply_sglang_patch() + if alloc_mode.gen_backend == "sglang": + if config.sglang.enable_multithread_load or config.sglang.enable_fast_load: + apply_sglang_patch() backend_spec = { "sglang": { diff --git a/areal/launcher/sglang_server.py b/areal/launcher/sglang_server.py index f88d0549e..56fef2eaf 100644 --- a/areal/launcher/sglang_server.py +++ b/areal/launcher/sglang_server.py @@ -22,7 +22,7 @@ ) from areal.platforms import current_platform from areal.utils import logging, name_resolve, names -from areal.utils.launcher import TRITON_CACHE_PATH +from areal.utils.launcher import TRITON_CACHE_PATH, apply_sglang_patch from areal.utils.network import find_free_ports, gethostip logger = logging.getLogger("SGLangServer Wrapper") @@ -130,6 +130,9 @@ def __init__( self.server_process = None self.n_gpus_per_node = n_gpus_per_node + if self.config.enable_fast_load or self.config.enable_multithread_load: + apply_sglang_patch() + def run(self): gpus_per_server = self.allocation_mode.gen_instance_size cross_nodes = False From 74d2580592801f80bb655ebd51aedd77f8884fa0 Mon Sep 17 00:00:00 2001 From: nuzant Date: Sat, 11 Oct 2025 14:39:55 +0800 Subject: [PATCH 09/24] remove debug info --- patch/sglang/v0.5.1.post3.patch | 147 +++----------------------------- 1 file changed, 14 insertions(+), 133 deletions(-) diff --git a/patch/sglang/v0.5.1.post3.patch b/patch/sglang/v0.5.1.post3.patch index eb1c54a9d..80583c626 100644 --- a/patch/sglang/v0.5.1.post3.patch +++ b/patch/sglang/v0.5.1.post3.patch @@ -1,16 +1,8 @@ diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py -index 8d5b7c715..c83b99191 100644 +index 8d5b7c715..af294364a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py -@@ -108,6 +108,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter - from sglang.srt.utils import ( - MultiprocessingSerializer, - cpu_has_amx_support, -+ debug_function, - dynamic_import, - enable_show_time_cost, - get_available_gpu_memory, -@@ -253,6 +254,8 @@ class ModelRunner: +@@ -253,6 +253,8 @@ class ModelRunner: self._model_update_group = {} def initialize(self, min_per_gpu_memory: float): @@ -19,23 +11,7 @@ index 8d5b7c715..c83b99191 100644 server_args = self.server_args self.memory_saver_adapter = TorchMemorySaverAdapter.create( -@@ -644,6 +647,7 @@ class ModelRunner: - ) - return min_per_gpu_memory - -+ @debug_function - def load_model(self): - before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id) - logger.info( -@@ -764,6 +768,7 @@ class ModelRunner: - rank=self.tp_rank, - ) - -+ @debug_function - def update_weights_from_disk( - self, model_path: str, load_format: str - ) -> tuple[bool, str]: -@@ -775,8 +780,11 @@ class ModelRunner: +@@ -775,8 +777,11 @@ class ModelRunner: target_device = torch.device(self.device) self.model_config.model_path = model_path @@ -49,7 +25,7 @@ index 8d5b7c715..c83b99191 100644 # Only support DefaultModelLoader for now loader = get_model_loader(load_config) if not isinstance(loader, DefaultModelLoader): -@@ -790,7 +798,9 @@ class ModelRunner: +@@ -790,7 +795,9 @@ class ModelRunner: return iter def model_load_weights(model, iter): @@ -61,18 +37,10 @@ index 8d5b7c715..c83b99191 100644 with set_default_torch_dtype(self.model_config.dtype): diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py -index 23d70be44..5c44a9278 100644 +index 23d70be44..7e8968743 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py -@@ -62,6 +62,7 @@ from sglang.srt.model_loader.weight_utils import ( - set_runai_streamer_env, - ) - from sglang.srt.utils import ( -+ debug_function, - get_bool_env_var, - get_device_capability, - is_npu, -@@ -263,7 +264,7 @@ class DefaultModelLoader(BaseModelLoader): +@@ -263,7 +263,7 @@ class DefaultModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) extra_config = load_config.model_loader_extra_config @@ -81,21 +49,7 @@ index 23d70be44..5c44a9278 100644 unexpected_keys = set(extra_config.keys()) - allowed_keys if unexpected_keys: -@@ -299,6 +300,7 @@ class DefaultModelLoader(BaseModelLoader): - return model_path - return None - -+ @debug_function - def _prepare_weights( - self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool - ) -> Tuple[str, List[str], bool]: -@@ -379,11 +381,15 @@ class DefaultModelLoader(BaseModelLoader): - - return hf_folder, hf_weights_files, use_safetensors - -+ @debug_function - def _get_weights_iterator( - self, source: "Source" +@@ -384,6 +384,9 @@ class DefaultModelLoader(BaseModelLoader): ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" extra_config = self.load_config.model_loader_extra_config @@ -105,7 +59,7 @@ index 23d70be44..5c44a9278 100644 hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( source.model_or_path, source.revision, source.fall_back_to_pt ) -@@ -426,16 +432,14 @@ class DefaultModelLoader(BaseModelLoader): +@@ -426,8 +429,6 @@ class DefaultModelLoader(BaseModelLoader): ) else: weights_iterator = pt_weights_iterator(hf_weights_files) @@ -113,9 +67,8 @@ index 23d70be44..5c44a9278 100644 - # Apply the prefix. return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) -+ @debug_function def _get_all_weights( - self, +@@ -435,7 +436,6 @@ class DefaultModelLoader(BaseModelLoader): model_config: ModelConfig, model: nn.Module, ) -> Generator[Tuple[str, torch.Tensor], None, None]: @@ -123,15 +76,7 @@ index 23d70be44..5c44a9278 100644 primary_weights = DefaultModelLoader.Source.init_new(model_config, model) yield from self._get_weights_iterator(primary_weights) -@@ -450,6 +454,7 @@ class DefaultModelLoader(BaseModelLoader): - model_config.model_path, model_config.revision, fall_back_to_pt=True - ) - -+ @debug_function - def load_model( - self, - *, -@@ -464,15 +469,23 @@ class DefaultModelLoader(BaseModelLoader): +@@ -464,15 +464,23 @@ class DefaultModelLoader(BaseModelLoader): self.load_config, ) @@ -160,19 +105,10 @@ index 23d70be44..5c44a9278 100644 for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py -index 042159a50..4f8b8a1d7 100644 +index 042159a50..399600b63 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py -@@ -27,7 +27,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe - from sglang.srt.model_loader.weight_utils import default_weight_loader - from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP - from sglang.srt.models.qwen2 import Qwen2Model --from sglang.srt.utils import add_prefix, is_cuda -+from sglang.srt.utils import add_prefix, debug_function, is_cuda - - Qwen3Config = None - -@@ -415,6 +415,99 @@ class Qwen3ForCausalLM(nn.Module): +@@ -415,6 +415,97 @@ class Qwen3ForCausalLM(nn.Module): def end_layer(self): return self.model.end_layer @@ -255,7 +191,6 @@ index 042159a50..4f8b8a1d7 100644 + f"Cannot find weight {local_name} in the loaded slices." + ) + -+ @debug_function + def load_weights_from_path(self, path: str): + # Customized weights loading from a given path of huggingface model + from sglang.srt.models.utils.load import load_weights_with_hf_path_fast @@ -268,11 +203,10 @@ index 042159a50..4f8b8a1d7 100644 + tie_word_embeddings=self.config.tie_word_embeddings, + ) + -+ @debug_function def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) -@@ -462,11 +555,11 @@ class Qwen3ForCausalLM(nn.Module): +@@ -462,11 +553,11 @@ class Qwen3ForCausalLM(nn.Module): for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -599,7 +533,7 @@ index 000000000..035835a0f + for _ in results: + pass diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py -index b5f6626a2..a5ac2a2bd 100644 +index b5f6626a2..8a8a69805 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -21,6 +21,7 @@ import ctypes @@ -610,56 +544,3 @@ index b5f6626a2..a5ac2a2bd 100644 import io import ipaddress import itertools -@@ -3011,3 +3012,52 @@ def check_cuda_result(raw_output): - raise Exception(f"CUDA error: {err}") - - return results -+ -+ -+def debug_function(func: Callable): -+ """Decorator that logs function entry/exit with contextual information.""" -+ -+ if func is None: -+ raise ValueError("debug_function decorator requires a function") -+ -+ @functools.wraps(func) -+ def wrapper(*args, **kwargs): -+ def _resolve_class_name(): -+ qualname = getattr(func, "__qualname__", "") -+ if qualname: -+ parts = qualname.split(".") -+ if len(parts) > 1 and parts[-2] != "": -+ return parts[-2] -+ if args: -+ instance = args[0] -+ if hasattr(instance, "__class__"): -+ return instance.__class__.__name__ -+ return None -+ -+ def _resolve_filename(): -+ try: -+ source = inspect.getsourcefile(func) or inspect.getfile(func) -+ except (TypeError, OSError): -+ source = None -+ if source is None: -+ return "" -+ return Path(source).name -+ -+ class_name = _resolve_class_name() -+ filename = _resolve_filename() -+ display_name = func.__name__ -+ if class_name: -+ display_name = f"{class_name}.{display_name}" -+ -+ print(f"[Debug] Entering {display_name} (file: {filename})", flush=True) -+ -+ start_time = time.perf_counter() -+ try: -+ return func(*args, **kwargs) -+ finally: -+ elapsed_ms = (time.perf_counter() - start_time) * 1000 -+ print( -+ f"[Debug] Exiting {display_name} (file: {filename}) - elapsed {elapsed_ms:.3f} ms" -+ ) -+ -+ return wrapper From 9bab81479ed63fab9534629d35e235071c078629 Mon Sep 17 00:00:00 2001 From: nuzant Date: Sat, 11 Oct 2025 15:07:35 +0800 Subject: [PATCH 10/24] patch to 0.5.2 --- areal/api/cli_args.py | 4 +-- .../{v0.5.1.post3.patch => v0.5.2.patch} | 35 +++++++++---------- 2 files changed, 19 insertions(+), 20 deletions(-) rename patch/sglang/{v0.5.1.post3.patch => v0.5.2.patch} (95%) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 027561ae5..851141285 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -656,8 +656,8 @@ def build_args( args: Dict = conf_as_dict(sglang_config) if sglang_config.enable_multithread_load or sglang_config.enable_fast_load: assert pkg_version.is_version_equal( - "sglang", "0.5.1.post3" - ), f"Customized model loading requires exact SGLang version 0.5.1.post3" + "sglang", "0.5.2" + ), f"Customized model loading requires exact SGLang version 0.5.2" model_loader_extra_config = dict( enable_multithread_load=sglang_config.enable_multithread_load, enable_fast_load=sglang_config.enable_fast_load, diff --git a/patch/sglang/v0.5.1.post3.patch b/patch/sglang/v0.5.2.patch similarity index 95% rename from patch/sglang/v0.5.1.post3.patch rename to patch/sglang/v0.5.2.patch index 80583c626..7420ce238 100644 --- a/patch/sglang/v0.5.1.post3.patch +++ b/patch/sglang/v0.5.2.patch @@ -1,17 +1,16 @@ diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py -index 8d5b7c715..af294364a 100644 +index aa0e2e0e6..2a79ad827 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py -@@ -253,6 +253,8 @@ class ModelRunner: +@@ -258,6 +258,7 @@ class ModelRunner: self._model_update_group = {} def initialize(self, min_per_gpu_memory: float): -+ logger.info("SGLang v0.5.1.post3 is patched with customized weight loading.") -+ print("[Debug] SGLang v0.5.1.post3 is patched with customized weight loading.") ++ logger.warning("SGLang v0.5.2 is patched with customized weight loading.") server_args = self.server_args self.memory_saver_adapter = TorchMemorySaverAdapter.create( -@@ -775,8 +777,11 @@ class ModelRunner: +@@ -823,8 +824,11 @@ class ModelRunner: target_device = torch.device(self.device) self.model_config.model_path = model_path @@ -25,7 +24,7 @@ index 8d5b7c715..af294364a 100644 # Only support DefaultModelLoader for now loader = get_model_loader(load_config) if not isinstance(loader, DefaultModelLoader): -@@ -790,7 +795,9 @@ class ModelRunner: +@@ -838,7 +842,9 @@ class ModelRunner: return iter def model_load_weights(model, iter): @@ -37,10 +36,10 @@ index 8d5b7c715..af294364a 100644 with set_default_torch_dtype(self.model_config.dtype): diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py -index 23d70be44..7e8968743 100644 +index d2b4c6bfc..9240ca9f5 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py -@@ -263,7 +263,7 @@ class DefaultModelLoader(BaseModelLoader): +@@ -278,7 +278,7 @@ class DefaultModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) extra_config = load_config.model_loader_extra_config @@ -49,7 +48,7 @@ index 23d70be44..7e8968743 100644 unexpected_keys = set(extra_config.keys()) - allowed_keys if unexpected_keys: -@@ -384,6 +384,9 @@ class DefaultModelLoader(BaseModelLoader): +@@ -399,6 +399,9 @@ class DefaultModelLoader(BaseModelLoader): ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" extra_config = self.load_config.model_loader_extra_config @@ -59,7 +58,7 @@ index 23d70be44..7e8968743 100644 hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( source.model_or_path, source.revision, source.fall_back_to_pt ) -@@ -426,8 +429,6 @@ class DefaultModelLoader(BaseModelLoader): +@@ -441,8 +444,6 @@ class DefaultModelLoader(BaseModelLoader): ) else: weights_iterator = pt_weights_iterator(hf_weights_files) @@ -68,7 +67,7 @@ index 23d70be44..7e8968743 100644 return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) def _get_all_weights( -@@ -435,7 +436,6 @@ class DefaultModelLoader(BaseModelLoader): +@@ -450,7 +451,6 @@ class DefaultModelLoader(BaseModelLoader): model_config: ModelConfig, model: nn.Module, ) -> Generator[Tuple[str, torch.Tensor], None, None]: @@ -76,7 +75,7 @@ index 23d70be44..7e8968743 100644 primary_weights = DefaultModelLoader.Source.init_new(model_config, model) yield from self._get_weights_iterator(primary_weights) -@@ -464,15 +464,23 @@ class DefaultModelLoader(BaseModelLoader): +@@ -479,15 +479,23 @@ class DefaultModelLoader(BaseModelLoader): self.load_config, ) @@ -105,10 +104,10 @@ index 23d70be44..7e8968743 100644 for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py -index 042159a50..399600b63 100644 +index bc5f054d7..aff429be0 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py -@@ -415,6 +415,97 @@ class Qwen3ForCausalLM(nn.Module): +@@ -418,6 +418,97 @@ class Qwen3ForCausalLM(nn.Module): def end_layer(self): return self.model.end_layer @@ -206,7 +205,7 @@ index 042159a50..399600b63 100644 def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) -@@ -462,11 +553,11 @@ class Qwen3ForCausalLM(nn.Module): +@@ -468,11 +559,11 @@ class Qwen3ForCausalLM(nn.Module): for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -222,10 +221,10 @@ index 042159a50..399600b63 100644 weight_loader(param, loaded_weight, shard_id) break diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py -index fcb45b947..e354e7a47 100644 +index c1c4c3638..c88de9d19 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py -@@ -739,6 +739,130 @@ class Qwen3MoeForCausalLM(nn.Module): +@@ -770,6 +770,130 @@ class Qwen3MoeForCausalLM(nn.Module): else: self.model.layers_to_capture = [val + 1 for val in layer_ids] @@ -533,7 +532,7 @@ index 000000000..035835a0f + for _ in results: + pass diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py -index b5f6626a2..8a8a69805 100644 +index 846baeb01..5235c0563 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -21,6 +21,7 @@ import ctypes From 368c76162dfa28e13762d01820f5e5e5695e7192 Mon Sep 17 00:00:00 2001 From: nuzant Date: Sat, 11 Oct 2025 16:42:32 +0800 Subject: [PATCH 11/24] fix gemini comments --- areal/launcher/local.py | 7 ------- areal/utils/launcher.py | 6 +++++- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/areal/launcher/local.py b/areal/launcher/local.py index 05fc83014..ff360ad49 100644 --- a/areal/launcher/local.py +++ b/areal/launcher/local.py @@ -26,7 +26,6 @@ JobException, JobInfo, JobState, - apply_sglang_patch, get_env_vars, wait_llm_server_addrs, ) @@ -142,8 +141,6 @@ def submit_array( + cmd[i] ) c = f"{c} 2>&1 | tee -a {self.log_path_of(job_name)}" - # SGLang will somehow remove quotes in the command, so we need to escape the quotes - c = c.replace('"', '\\"') logger.info("Starting local process with command: %s", c) process = subprocess.Popen( c, shell=isinstance(c, str), stdout=sys.stdout, stderr=sys.stdout @@ -296,10 +293,6 @@ def local_main(config, run_id: int = 0): config.vllm = to_structured_cfg(config.vllm, vLLMConfig) random_seed = config.vllm.seed - if alloc_mode.gen_backend == "sglang": - if config.sglang.enable_multithread_load or config.sglang.enable_fast_load: - apply_sglang_patch() - backend_spec = { "sglang": { "module": "areal.launcher.sglang_server", diff --git a/areal/utils/launcher.py b/areal/utils/launcher.py index a2f905177..17edff96f 100644 --- a/areal/utils/launcher.py +++ b/areal/utils/launcher.py @@ -170,12 +170,16 @@ def apply_sglang_patch(): target_path = "" sglang_meta = subprocess.check_output( - "python3 -m pip show sglang", shell=True + [sys.executable, "-m", "pip", "show", "sglang"] ).decode("ascii") for line in sglang_meta.split("\n"): line = line.strip() if line.startswith("Editable project location: "): target_path = str(Path(line.split(": ")[1]).parent) + break + elif line.startswith("Location: "): + target_path = str(Path(line.split(": ")[1]) / "sglang") + break if target_path: proc = subprocess.Popen( From 26c871f81f478fe87d93bd9a87b23f2f50470867 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=93=E9=9B=B7?= Date: Mon, 13 Oct 2025 10:58:33 +0800 Subject: [PATCH 12/24] use patch instead of git apply --- areal/utils/launcher.py | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/areal/utils/launcher.py b/areal/utils/launcher.py index 17edff96f..277b11811 100644 --- a/areal/utils/launcher.py +++ b/areal/utils/launcher.py @@ -3,6 +3,7 @@ import getpass import os import pathlib +import shutil import subprocess import sys import time @@ -182,11 +183,34 @@ def apply_sglang_patch(): break if target_path: - proc = subprocess.Popen( - ["git", "apply", patch_path], + patch_binary = shutil.which("patch") + if not patch_binary: + logger.warning( + "Could not locate the `patch` command; skipping SGLang patch application." + ) + return + + result = subprocess.run( + [patch_binary, "-p1", "-N", "-i", patch_path], cwd=target_path, - stderr=sys.stdout, - stdout=sys.stdout, + capture_output=True, + text=True, ) - proc.wait() - logger.info(f"Applied SGLang patch {patch_path} to {target_path}") + + output = (result.stdout or "") + (result.stderr or "") + if result.returncode == 0: + logger.info(f"Applied SGLang patch {patch_path} to {target_path}") + elif "Reversed (or previously applied) patch detected" in output or "Skipping patch." in output: + logger.warning( + f"SGLang patch {patch_path} appears to be already applied for {target_path}." + ) + else: + logger.error( + "Failed to apply SGLang patch %s to %s. Output:\n%s", + patch_path, + target_path, + output.strip(), + ) + raise RuntimeError( + f"SGLang patch {patch_path} failed with exit code {result.returncode}." + ) From ff21d52deb16437bc9263252172a865edcf62819 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=93=E9=9B=B7?= Date: Mon, 13 Oct 2025 11:01:57 +0800 Subject: [PATCH 13/24] . --- areal/launcher/local.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/areal/launcher/local.py b/areal/launcher/local.py index ff360ad49..23e29cf46 100644 --- a/areal/launcher/local.py +++ b/areal/launcher/local.py @@ -141,6 +141,9 @@ def submit_array( + cmd[i] ) c = f"{c} 2>&1 | tee -a {self.log_path_of(job_name)}" + # SGLang will somehow remove quotes in the command, so we need to escape the quotes + c = c.replace('"', '\\"') + logger.info("Starting local process with command: %s", c) process = subprocess.Popen( c, shell=isinstance(c, str), stdout=sys.stdout, stderr=sys.stdout From 028d1f3490ad35baa5d8f94e999132d635e5e9ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=93=E9=9B=B7?= Date: Mon, 13 Oct 2025 11:04:24 +0800 Subject: [PATCH 14/24] . --- areal/utils/launcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/areal/utils/launcher.py b/areal/utils/launcher.py index 277b11811..0c3d617b2 100644 --- a/areal/utils/launcher.py +++ b/areal/utils/launcher.py @@ -198,6 +198,7 @@ def apply_sglang_patch(): ) output = (result.stdout or "") + (result.stderr or "") + logger.info("Patch command output:\n%s", output.strip()) if result.returncode == 0: logger.info(f"Applied SGLang patch {patch_path} to {target_path}") elif "Reversed (or previously applied) patch detected" in output or "Skipping patch." in output: From c38a4261a44ae96ee41551867b56efda901f7e67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=93=E9=9B=B7?= Date: Mon, 13 Oct 2025 11:13:01 +0800 Subject: [PATCH 15/24] add debug info --- areal/utils/launcher.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/areal/utils/launcher.py b/areal/utils/launcher.py index 0c3d617b2..3f4939f3c 100644 --- a/areal/utils/launcher.py +++ b/areal/utils/launcher.py @@ -168,11 +168,13 @@ def apply_sglang_patch(): / "sglang" / f"v{pkg_version.get_version('sglang')}.patch" ) + logger.info(f"[Debug] p={p}, patch_path={patch_path}") target_path = "" sglang_meta = subprocess.check_output( [sys.executable, "-m", "pip", "show", "sglang"] ).decode("ascii") + logger.info(f"[Debug] sglang_meta=\n{sglang_meta}") for line in sglang_meta.split("\n"): line = line.strip() if line.startswith("Editable project location: "): @@ -182,13 +184,14 @@ def apply_sglang_patch(): target_path = str(Path(line.split(": ")[1]) / "sglang") break + logger.info(f"[Debug] target_path={target_path}") if target_path: patch_binary = shutil.which("patch") if not patch_binary: - logger.warning( - "Could not locate the `patch` command; skipping SGLang patch application." + raise RuntimeError( + "Could not locate the `patch` command; SGLang patch application failed." ) - return + result = subprocess.run( [patch_binary, "-p1", "-N", "-i", patch_path], @@ -215,3 +218,5 @@ def apply_sglang_patch(): raise RuntimeError( f"SGLang patch {patch_path} failed with exit code {result.returncode}." ) + else: + raise RuntimeError("Could not determine the installation path of SGLang.") From 50c328a1c17ba1f845868d039fc05682b8e8347e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=93=E9=9B=B7?= Date: Mon, 13 Oct 2025 11:20:56 +0800 Subject: [PATCH 16/24] . --- areal/utils/launcher.py | 75 +++++++++++++++++++++-------------------- 1 file changed, 39 insertions(+), 36 deletions(-) diff --git a/areal/utils/launcher.py b/areal/utils/launcher.py index 3f4939f3c..9b7089974 100644 --- a/areal/utils/launcher.py +++ b/areal/utils/launcher.py @@ -170,53 +170,56 @@ def apply_sglang_patch(): ) logger.info(f"[Debug] p={p}, patch_path={patch_path}") - target_path = "" + target_path = None sglang_meta = subprocess.check_output( [sys.executable, "-m", "pip", "show", "sglang"] ).decode("ascii") logger.info(f"[Debug] sglang_meta=\n{sglang_meta}") + # Prioritize editable install location, since pip show lists both locations + # if installed in editable mode. for line in sglang_meta.split("\n"): line = line.strip() if line.startswith("Editable project location: "): target_path = str(Path(line.split(": ")[1]).parent) break - elif line.startswith("Location: "): - target_path = str(Path(line.split(": ")[1]) / "sglang") - break + else: + for line in sglang_meta.split("\n"): + line = line.strip() + if line.startswith("Location: "): + target_path = str(Path(line.split(": ")[1]) / "sglang") + break + + if not target_path or not os.path.exists(target_path): + raise RuntimeError("Could not determine the installation path of SGLang.") logger.info(f"[Debug] target_path={target_path}") - if target_path: - patch_binary = shutil.which("patch") - if not patch_binary: - raise RuntimeError( - "Could not locate the `patch` command; SGLang patch application failed." - ) - - - result = subprocess.run( - [patch_binary, "-p1", "-N", "-i", patch_path], - cwd=target_path, - capture_output=True, - text=True, + patch_binary = shutil.which("patch") + if not patch_binary: + raise RuntimeError( + "Could not locate the `patch` command; SGLang patch application failed." ) + result = subprocess.run( + [patch_binary, "-p1", "-N", "-i", patch_path], + cwd=target_path, + capture_output=True, + text=True, + ) - output = (result.stdout or "") + (result.stderr or "") - logger.info("Patch command output:\n%s", output.strip()) - if result.returncode == 0: - logger.info(f"Applied SGLang patch {patch_path} to {target_path}") - elif "Reversed (or previously applied) patch detected" in output or "Skipping patch." in output: - logger.warning( - f"SGLang patch {patch_path} appears to be already applied for {target_path}." - ) - else: - logger.error( - "Failed to apply SGLang patch %s to %s. Output:\n%s", - patch_path, - target_path, - output.strip(), - ) - raise RuntimeError( - f"SGLang patch {patch_path} failed with exit code {result.returncode}." - ) + output = (result.stdout or "") + (result.stderr or "") + logger.info("Patch command output:\n%s", output.strip()) + if result.returncode == 0: + logger.info(f"Applied SGLang patch {patch_path} to {target_path}") + elif "Reversed (or previously applied) patch detected" in output or "Skipping patch." in output: + logger.warning( + f"SGLang patch {patch_path} appears to be already applied for {target_path}." + ) else: - raise RuntimeError("Could not determine the installation path of SGLang.") + logger.error( + "Failed to apply SGLang patch %s to %s. Output:\n%s", + patch_path, + target_path, + output.strip(), + ) + raise RuntimeError( + f"SGLang patch {patch_path} failed with exit code {result.returncode}." + ) From 83e548f1ddad1b878cadf1e2f23a83c17a4c984f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=93=E9=9B=B7?= Date: Mon, 13 Oct 2025 11:23:27 +0800 Subject: [PATCH 17/24] remove debug info --- areal/launcher/local.py | 3 --- areal/utils/launcher.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/areal/launcher/local.py b/areal/launcher/local.py index 23e29cf46..ff360ad49 100644 --- a/areal/launcher/local.py +++ b/areal/launcher/local.py @@ -141,9 +141,6 @@ def submit_array( + cmd[i] ) c = f"{c} 2>&1 | tee -a {self.log_path_of(job_name)}" - # SGLang will somehow remove quotes in the command, so we need to escape the quotes - c = c.replace('"', '\\"') - logger.info("Starting local process with command: %s", c) process = subprocess.Popen( c, shell=isinstance(c, str), stdout=sys.stdout, stderr=sys.stdout diff --git a/areal/utils/launcher.py b/areal/utils/launcher.py index 9b7089974..e469ae30a 100644 --- a/areal/utils/launcher.py +++ b/areal/utils/launcher.py @@ -174,7 +174,6 @@ def apply_sglang_patch(): sglang_meta = subprocess.check_output( [sys.executable, "-m", "pip", "show", "sglang"] ).decode("ascii") - logger.info(f"[Debug] sglang_meta=\n{sglang_meta}") # Prioritize editable install location, since pip show lists both locations # if installed in editable mode. for line in sglang_meta.split("\n"): @@ -192,7 +191,6 @@ def apply_sglang_patch(): if not target_path or not os.path.exists(target_path): raise RuntimeError("Could not determine the installation path of SGLang.") - logger.info(f"[Debug] target_path={target_path}") patch_binary = shutil.which("patch") if not patch_binary: raise RuntimeError( @@ -206,7 +204,6 @@ def apply_sglang_patch(): ) output = (result.stdout or "") + (result.stderr or "") - logger.info("Patch command output:\n%s", output.strip()) if result.returncode == 0: logger.info(f"Applied SGLang patch {patch_path} to {target_path}") elif "Reversed (or previously applied) patch detected" in output or "Skipping patch." in output: From 4148945ea84f19cbdc1acf724b5155eda1d5da95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=93=E9=9B=B7?= Date: Mon, 13 Oct 2025 11:36:21 +0800 Subject: [PATCH 18/24] . --- areal/experimental/megatron_engine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/areal/experimental/megatron_engine.py b/areal/experimental/megatron_engine.py index 4f7842258..f87bcd13e 100644 --- a/areal/experimental/megatron_engine.py +++ b/areal/experimental/megatron_engine.py @@ -642,6 +642,9 @@ def connect_engine(self, engine: InferenceEngine, meta: WeightUpdateMeta): ) self.rollout_engine = engine + if meta.type == "disk": + return + if not self.weight_update_group_initialized: self._init_weight_update_from_distributed(meta) self.weight_update_group_initialized = True From 4349c1c61bf05f240eff0c25cccdc3277bf3952d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=93=E9=9B=B7?= Date: Mon, 13 Oct 2025 11:42:13 +0800 Subject: [PATCH 19/24] fix update weight from disk --- areal/experimental/megatron_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/areal/experimental/megatron_engine.py b/areal/experimental/megatron_engine.py index f87bcd13e..25bb1bb35 100644 --- a/areal/experimental/megatron_engine.py +++ b/areal/experimental/megatron_engine.py @@ -612,8 +612,6 @@ def _update_weights_from_disk(self, meta: WeightUpdateMeta): # dist.barrier() are called when _save_model_to_hf finished if dist.get_rank() == 0: - fut.result() - update_name = names.update_weights_from_disk( self.config.experiment_name, self.config.trial_name, @@ -622,6 +620,8 @@ def _update_weights_from_disk(self, meta: WeightUpdateMeta): name_resolve.add( update_name, str(datetime.now().timestamp()), keepalive_ttl=120 ) + + fut.result() dist.barrier(device_ids=[self.device.index]) current_platform.synchronize() From 683f95f2eb3c767d5303051615387b34b66989fa Mon Sep 17 00:00:00 2001 From: nuzant Date: Mon, 13 Oct 2025 11:46:21 +0800 Subject: [PATCH 20/24] format --- areal/experimental/megatron_engine.py | 4 ++-- areal/utils/launcher.py | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/areal/experimental/megatron_engine.py b/areal/experimental/megatron_engine.py index 25bb1bb35..98e127ec7 100644 --- a/areal/experimental/megatron_engine.py +++ b/areal/experimental/megatron_engine.py @@ -620,7 +620,7 @@ def _update_weights_from_disk(self, meta: WeightUpdateMeta): name_resolve.add( update_name, str(datetime.now().timestamp()), keepalive_ttl=120 ) - + fut.result() dist.barrier(device_ids=[self.device.index]) @@ -644,7 +644,7 @@ def connect_engine(self, engine: InferenceEngine, meta: WeightUpdateMeta): if meta.type == "disk": return - + if not self.weight_update_group_initialized: self._init_weight_update_from_distributed(meta) self.weight_update_group_initialized = True diff --git a/areal/utils/launcher.py b/areal/utils/launcher.py index e469ae30a..db9887228 100644 --- a/areal/utils/launcher.py +++ b/areal/utils/launcher.py @@ -187,7 +187,7 @@ def apply_sglang_patch(): if line.startswith("Location: "): target_path = str(Path(line.split(": ")[1]) / "sglang") break - + if not target_path or not os.path.exists(target_path): raise RuntimeError("Could not determine the installation path of SGLang.") @@ -206,7 +206,10 @@ def apply_sglang_patch(): output = (result.stdout or "") + (result.stderr or "") if result.returncode == 0: logger.info(f"Applied SGLang patch {patch_path} to {target_path}") - elif "Reversed (or previously applied) patch detected" in output or "Skipping patch." in output: + elif ( + "Reversed (or previously applied) patch detected" in output + or "Skipping patch." in output + ): logger.warning( f"SGLang patch {patch_path} appears to be already applied for {target_path}." ) From 1848ea06f5eb759240efc2d81c807be9b2f7d89b Mon Sep 17 00:00:00 2001 From: nuzant Date: Mon, 13 Oct 2025 13:13:57 +0800 Subject: [PATCH 21/24] . --- areal/experimental/megatron_engine.py | 8 ++++---- areal/utils/launcher.py | 4 +--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/areal/experimental/megatron_engine.py b/areal/experimental/megatron_engine.py index 98e127ec7..5b14d4f7a 100644 --- a/areal/experimental/megatron_engine.py +++ b/areal/experimental/megatron_engine.py @@ -642,10 +642,10 @@ def connect_engine(self, engine: InferenceEngine, meta: WeightUpdateMeta): ) self.rollout_engine = engine - if meta.type == "disk": - return - - if not self.weight_update_group_initialized: + if ( + meta.type == current_platform.communication_backend + and not self.weight_update_group_initialized + ): self._init_weight_update_from_distributed(meta) self.weight_update_group_initialized = True diff --git a/areal/utils/launcher.py b/areal/utils/launcher.py index db9887228..e1ad36916 100644 --- a/areal/utils/launcher.py +++ b/areal/utils/launcher.py @@ -168,12 +168,10 @@ def apply_sglang_patch(): / "sglang" / f"v{pkg_version.get_version('sglang')}.patch" ) - logger.info(f"[Debug] p={p}, patch_path={patch_path}") - target_path = None sglang_meta = subprocess.check_output( [sys.executable, "-m", "pip", "show", "sglang"] - ).decode("ascii") + ).decode("utf-8") # Prioritize editable install location, since pip show lists both locations # if installed in editable mode. for line in sglang_meta.split("\n"): From e0783600bf99018c3a4c64dbc10ad88bcee628e7 Mon Sep 17 00:00:00 2001 From: nuzant Date: Mon, 13 Oct 2025 13:24:20 +0800 Subject: [PATCH 22/24] use diff instead of git diff --- patch/sglang/v0.5.2.patch | 937 +++++++++++++++++++------------------- 1 file changed, 465 insertions(+), 472 deletions(-) diff --git a/patch/sglang/v0.5.2.patch b/patch/sglang/v0.5.2.patch index 7420ce238..3a94b3253 100644 --- a/patch/sglang/v0.5.2.patch +++ b/patch/sglang/v0.5.2.patch @@ -1,545 +1,538 @@ -diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py -index aa0e2e0e6..2a79ad827 100644 ---- a/python/sglang/srt/model_executor/model_runner.py -+++ b/python/sglang/srt/model_executor/model_runner.py -@@ -258,6 +258,7 @@ class ModelRunner: +diff -ruN ./srt/model_executor/model_runner.py ../../../sglang/python/sglang/srt/model_executor/model_runner.py +--- ./srt/model_executor/model_runner.py 2025-10-11 15:06:03.052603197 +0800 ++++ ../../../sglang/python/sglang/srt/model_executor/model_runner.py 2025-10-13 13:20:52.071417615 +0800 +@@ -258,7 +258,6 @@ self._model_update_group = {} def initialize(self, min_per_gpu_memory: float): -+ logger.warning("SGLang v0.5.2 is patched with customized weight loading.") +- logger.warning("SGLang v0.5.2 is patched with customized weight loading.") server_args = self.server_args self.memory_saver_adapter = TorchMemorySaverAdapter.create( -@@ -823,8 +824,11 @@ class ModelRunner: +@@ -824,11 +823,8 @@ target_device = torch.device(self.device) self.model_config.model_path = model_path -- load_config = LoadConfig(load_format=load_format) -- -+ load_config = LoadConfig( -+ load_format=load_format, -+ # XXX: This should be in function args, passed in by requests -+ model_loader_extra_config=self.server_args.model_loader_extra_config, -+ ) +- load_config = LoadConfig( +- load_format=load_format, +- # XXX: This should be in function args, passed in by requests +- model_loader_extra_config=self.server_args.model_loader_extra_config, +- ) ++ load_config = LoadConfig(load_format=load_format) ++ # Only support DefaultModelLoader for now loader = get_model_loader(load_config) if not isinstance(loader, DefaultModelLoader): -@@ -838,7 +842,9 @@ class ModelRunner: +@@ -842,9 +838,7 @@ return iter def model_load_weights(model, iter): -- DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device) -+ DefaultModelLoader.load_weights_and_postprocess( -+ model, iter, target_device, load_config=load_config -+ ) +- DefaultModelLoader.load_weights_and_postprocess( +- model, iter, target_device, load_config=load_config +- ) ++ DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device) return model with set_default_torch_dtype(self.model_config.dtype): -diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py -index d2b4c6bfc..9240ca9f5 100644 ---- a/python/sglang/srt/model_loader/loader.py -+++ b/python/sglang/srt/model_loader/loader.py -@@ -278,7 +278,7 @@ class DefaultModelLoader(BaseModelLoader): +diff -ruN ./srt/model_loader/loader.py ../../../sglang/python/sglang/srt/model_loader/loader.py +--- ./srt/model_loader/loader.py 2025-10-11 15:03:31.201989298 +0800 ++++ ../../../sglang/python/sglang/srt/model_loader/loader.py 2025-10-13 13:20:52.071417615 +0800 +@@ -278,7 +278,7 @@ def __init__(self, load_config: LoadConfig): super().__init__(load_config) extra_config = load_config.model_loader_extra_config -- allowed_keys = {"enable_multithread_load", "num_threads"} -+ allowed_keys = {"enable_multithread_load", "num_threads", "enable_fast_load"} +- allowed_keys = {"enable_multithread_load", "num_threads", "enable_fast_load"} ++ allowed_keys = {"enable_multithread_load", "num_threads"} unexpected_keys = set(extra_config.keys()) - allowed_keys if unexpected_keys: -@@ -399,6 +399,9 @@ class DefaultModelLoader(BaseModelLoader): +@@ -399,9 +399,6 @@ ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" extra_config = self.load_config.model_loader_extra_config -+ if extra_config.get("enable_fast_load"): -+ return source.model_or_path -+ +- if extra_config.get("enable_fast_load"): +- return source.model_or_path +- hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( source.model_or_path, source.revision, source.fall_back_to_pt ) -@@ -441,8 +444,6 @@ class DefaultModelLoader(BaseModelLoader): +@@ -444,6 +441,8 @@ ) else: weights_iterator = pt_weights_iterator(hf_weights_files) -- -- # Apply the prefix. ++ ++ # Apply the prefix. return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) def _get_all_weights( -@@ -450,7 +451,6 @@ class DefaultModelLoader(BaseModelLoader): +@@ -451,6 +450,7 @@ model_config: ModelConfig, model: nn.Module, ) -> Generator[Tuple[str, torch.Tensor], None, None]: -- ++ primary_weights = DefaultModelLoader.Source.init_new(model_config, model) yield from self._get_weights_iterator(primary_weights) -@@ -479,15 +479,23 @@ class DefaultModelLoader(BaseModelLoader): +@@ -479,23 +479,15 @@ self.load_config, ) -+ extra_config = self.load_config.model_loader_extra_config -+ if extra_config.get("enable_fast_load"): -+ weights_iter_or_path = model_config.model_path -+ else: -+ weights_iter_or_path = self._get_all_weights(model_config, model) +- extra_config = self.load_config.model_loader_extra_config +- if extra_config.get("enable_fast_load"): +- weights_iter_or_path = model_config.model_path +- else: +- weights_iter_or_path = self._get_all_weights(model_config, model) self.load_weights_and_postprocess( -- model, self._get_all_weights(model_config, model), target_device -+ model, weights_iter_or_path, target_device, load_config=self.load_config +- model, weights_iter_or_path, target_device, load_config=self.load_config ++ model, self._get_all_weights(model_config, model), target_device ) -- ++ return model.eval() @staticmethod -- def load_weights_and_postprocess(model, weights, target_device): -- model.load_weights(weights) -+ def load_weights_and_postprocess(model, weights, target_device, load_config=None): -+ extra_config = load_config.model_loader_extra_config -+ if extra_config.get("enable_fast_load"): -+ model.load_weights_from_path(weights) -+ else: -+ model.load_weights(weights) +- def load_weights_and_postprocess(model, weights, target_device, load_config=None): +- extra_config = load_config.model_loader_extra_config +- if extra_config.get("enable_fast_load"): +- model.load_weights_from_path(weights) +- else: +- model.load_weights(weights) ++ def load_weights_and_postprocess(model, weights, target_device): ++ model.load_weights(weights) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) -diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py -index bc5f054d7..aff429be0 100644 ---- a/python/sglang/srt/models/qwen3.py -+++ b/python/sglang/srt/models/qwen3.py -@@ -418,6 +418,97 @@ class Qwen3ForCausalLM(nn.Module): +diff -ruN ./srt/models/qwen3_moe.py ../../../sglang/python/sglang/srt/models/qwen3_moe.py +--- ./srt/models/qwen3_moe.py 2025-10-11 15:03:31.207989560 +0800 ++++ ../../../sglang/python/sglang/srt/models/qwen3_moe.py 2025-10-13 13:20:52.077417877 +0800 +@@ -770,130 +770,6 @@ + else: + self.model.layers_to_capture = [val + 1 for val in layer_ids] + +- @property +- def stacked_params_mapping(self) -> List[Tuple[str, str, str]]: +- return [ +- # (param_name, shard_name, shard_id) +- ("qkv_proj", "q_proj", "q"), +- ("qkv_proj", "k_proj", "k"), +- ("qkv_proj", "v_proj", "v"), +- ("gate_up_proj", "gate_proj", 0), +- ("gate_up_proj", "up_proj", 1), +- ] +- +- @property +- def expert_params_mapping(self) -> List[Tuple[str, str, int, str]]: +- return get_moe_impl_class().make_expert_params_mapping( +- ckpt_gate_proj_name="gate_proj", +- ckpt_down_proj_name="down_proj", +- ckpt_up_proj_name="up_proj", +- num_experts=self.config.num_experts, +- ) +- +- def _load_weights_with_worker( +- self, +- params: Dict[str, torch.nn.Parameter], +- local_names: List[str], +- filenames: List[str], +- weight_path: str, +- ): +- import os +- +- from safetensors import safe_open +- +- from sglang.srt.model_loader.weight_utils import default_weight_loader +- +- all_slices = {} +- for filename in filenames: +- safetensor_file = os.path.join(weight_path, filename) +- with safe_open(safetensor_file, framework="pt", device="cpu") as f: +- for name in f.keys(): +- # all_slices[name] = f.get_slice(name) +- all_slices[name] = f.get_tensor(name) +- +- for local_name in local_names: +- # Skip loading extra bias for GPTQ models. +- if local_name.endswith(".bias") and local_name not in params: +- continue +- # Handle special cases +- if "rotary_emb.inv_freq" in local_name or "projector" in local_name: +- continue +- if ( +- "rotary_emb.cos_cached" in local_name +- or "rotary_emb.sin_cached" in local_name +- ): +- # Models trained using ColossalAI may include these tensors in +- # the checkpoint. Skip them. +- continue +- if local_name.startswith("model.vision_tower") and local_name not in params: +- continue +- +- param = params[local_name] +- # Handle weight tying +- if self.config.tie_word_embeddings and "lm_head.weight" in local_name: +- if self.pp_group.world_size > 1 and self.pp_group.is_last_rank: +- local_name = "model.embed_tokens.weight" +- +- loaded = False +- for param_name, shard_name, shard_id in self.stacked_params_mapping: +- if param_name not in local_name: +- continue +- if "mlp.experts" in local_name: +- # Skip experts here, handled below +- continue +- # If local_name weight is sharded into multiple keys +- weight_loader = param.weight_loader +- slice_name = local_name.replace(param_name, shard_name) +- loaded_weight = all_slices[slice_name] +- weight_loader(param, loaded_weight, shard_id) +- loaded = True +- +- for ( +- param_name, +- shard_name, +- expert_id, +- shard_id, +- ) in self.expert_params_mapping: +- if param_name not in local_name: +- continue +- # If local_name weight is sharded into multiple keys +- weight_loader = param.weight_loader +- slice_name = local_name.replace(param_name, shard_name) +- loaded_weight = all_slices[slice_name] +- weight_loader( +- param, +- loaded_weight, +- local_name, +- shard_id=shard_id, +- expert_id=expert_id, +- ) +- loaded = True +- +- if not loaded: +- # If local_name weight is not sharded +- if local_name in all_slices: +- loaded_weight = all_slices[local_name] +- weight_loader = getattr( +- param, "weight_loader", default_weight_loader +- ) +- weight_loader(param, loaded_weight) +- else: +- raise KeyError( +- f"Cannot find weight {local_name} in the loaded slices." +- ) +- +- def load_weights_from_path(self, path: str): +- # Customized weights loading from a given path of huggingface model +- from sglang.srt.models.utils.load import load_weights_with_hf_path_fast +- +- load_weights_with_hf_path_fast( +- model=self, +- weight_path=path, +- load_weights_with_worker_fn=self._load_weights_with_worker, +- stacked_params_mapping=self.stacked_params_mapping, +- expert_params_mapping=self.expert_params_mapping, +- ) +- + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) +diff -ruN ./srt/models/qwen3.py ../../../sglang/python/sglang/srt/models/qwen3.py +--- ./srt/models/qwen3.py 2025-10-11 15:03:31.207989560 +0800 ++++ ../../../sglang/python/sglang/srt/models/qwen3.py 2025-10-13 13:20:52.076417833 +0800 +@@ -418,97 +418,6 @@ def end_layer(self): return self.model.end_layer -+ @property -+ def stacked_params_mapping(self) -> List[Tuple[str, str, str]]: -+ return [ -+ # (param_name, shard_name, shard_id) -+ ("qkv_proj", "q_proj", "q"), -+ ("qkv_proj", "k_proj", "k"), -+ ("qkv_proj", "v_proj", "v"), -+ ("gate_up_proj", "gate_proj", 0), -+ ("gate_up_proj", "up_proj", 1), -+ ] -+ -+ def _load_weights_with_worker( -+ self, -+ params: Dict[str, torch.nn.Parameter], -+ local_names: List[str], -+ filenames: List[str], -+ weight_path: str, -+ ): -+ import os -+ -+ from safetensors import safe_open -+ -+ from sglang.srt.model_loader.weight_utils import default_weight_loader -+ -+ all_slices = {} -+ for filename in filenames: -+ safetensor_file = os.path.join(weight_path, filename) -+ with safe_open(safetensor_file, framework="pt", device="cpu") as f: -+ for name in f.keys(): -+ # all_slices[name] = f.get_slice(name) -+ all_slices[name] = f.get_tensor(name) -+ -+ for local_name in local_names: -+ # Skip loading extra bias for GPTQ models. -+ if local_name.endswith(".bias") and local_name not in params: -+ continue -+ # Handle special cases -+ if "rotary_emb.inv_freq" in local_name or "projector" in local_name: -+ continue -+ if ( -+ "rotary_emb.cos_cached" in local_name -+ or "rotary_emb.sin_cached" in local_name -+ ): -+ # Models trained using ColossalAI may include these tensors in -+ # the checkpoint. Skip them. -+ continue -+ if local_name.startswith("model.vision_tower") and local_name not in params: -+ continue -+ -+ param = params[local_name] -+ # Handle weight tying -+ if self.config.tie_word_embeddings and "lm_head.weight" in local_name: -+ if self.pp_group.world_size > 1 and self.pp_group.is_last_rank: -+ local_name = "model.embed_tokens.weight" -+ -+ loaded = False -+ for param_name, shard_name, shard_id in self.stacked_params_mapping: -+ if param_name not in local_name: -+ continue -+ # If local_name weight is sharded into multiple keys -+ weight_loader = param.weight_loader -+ slice_name = local_name.replace(param_name, shard_name) -+ loaded_weight = all_slices[slice_name] -+ weight_loader(param, loaded_weight, shard_id) -+ loaded = True -+ -+ if not loaded: -+ # If local_name weight is not sharded -+ if local_name in all_slices: -+ loaded_weight = all_slices[local_name] -+ weight_loader = getattr( -+ param, "weight_loader", default_weight_loader -+ ) -+ weight_loader(param, loaded_weight) -+ else: -+ raise KeyError( -+ f"Cannot find weight {local_name} in the loaded slices." -+ ) -+ -+ def load_weights_from_path(self, path: str): -+ # Customized weights loading from a given path of huggingface model -+ from sglang.srt.models.utils.load import load_weights_with_hf_path_fast -+ -+ load_weights_with_hf_path_fast( -+ model=self, -+ weight_path=path, -+ load_weights_with_worker_fn=self._load_weights_with_worker, -+ stacked_params_mapping=self.stacked_params_mapping, -+ tie_word_embeddings=self.config.tie_word_embeddings, -+ ) -+ +- @property +- def stacked_params_mapping(self) -> List[Tuple[str, str, str]]: +- return [ +- # (param_name, shard_name, shard_id) +- ("qkv_proj", "q_proj", "q"), +- ("qkv_proj", "k_proj", "k"), +- ("qkv_proj", "v_proj", "v"), +- ("gate_up_proj", "gate_proj", 0), +- ("gate_up_proj", "up_proj", 1), +- ] +- +- def _load_weights_with_worker( +- self, +- params: Dict[str, torch.nn.Parameter], +- local_names: List[str], +- filenames: List[str], +- weight_path: str, +- ): +- import os +- +- from safetensors import safe_open +- +- from sglang.srt.model_loader.weight_utils import default_weight_loader +- +- all_slices = {} +- for filename in filenames: +- safetensor_file = os.path.join(weight_path, filename) +- with safe_open(safetensor_file, framework="pt", device="cpu") as f: +- for name in f.keys(): +- # all_slices[name] = f.get_slice(name) +- all_slices[name] = f.get_tensor(name) +- +- for local_name in local_names: +- # Skip loading extra bias for GPTQ models. +- if local_name.endswith(".bias") and local_name not in params: +- continue +- # Handle special cases +- if "rotary_emb.inv_freq" in local_name or "projector" in local_name: +- continue +- if ( +- "rotary_emb.cos_cached" in local_name +- or "rotary_emb.sin_cached" in local_name +- ): +- # Models trained using ColossalAI may include these tensors in +- # the checkpoint. Skip them. +- continue +- if local_name.startswith("model.vision_tower") and local_name not in params: +- continue +- +- param = params[local_name] +- # Handle weight tying +- if self.config.tie_word_embeddings and "lm_head.weight" in local_name: +- if self.pp_group.world_size > 1 and self.pp_group.is_last_rank: +- local_name = "model.embed_tokens.weight" +- +- loaded = False +- for param_name, shard_name, shard_id in self.stacked_params_mapping: +- if param_name not in local_name: +- continue +- # If local_name weight is sharded into multiple keys +- weight_loader = param.weight_loader +- slice_name = local_name.replace(param_name, shard_name) +- loaded_weight = all_slices[slice_name] +- weight_loader(param, loaded_weight, shard_id) +- loaded = True +- +- if not loaded: +- # If local_name weight is not sharded +- if local_name in all_slices: +- loaded_weight = all_slices[local_name] +- weight_loader = getattr( +- param, "weight_loader", default_weight_loader +- ) +- weight_loader(param, loaded_weight) +- else: +- raise KeyError( +- f"Cannot find weight {local_name} in the loaded slices." +- ) +- +- def load_weights_from_path(self, path: str): +- # Customized weights loading from a given path of huggingface model +- from sglang.srt.models.utils.load import load_weights_with_hf_path_fast +- +- load_weights_with_hf_path_fast( +- model=self, +- weight_path=path, +- load_weights_with_worker_fn=self._load_weights_with_worker, +- stacked_params_mapping=self.stacked_params_mapping, +- tie_word_embeddings=self.config.tie_word_embeddings, +- ) +- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) -@@ -468,11 +559,11 @@ class Qwen3ForCausalLM(nn.Module): +@@ -559,11 +468,11 @@ for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue -- name = name.replace(weight_name, param_name) -+ _name = name.replace(weight_name, param_name) +- _name = name.replace(weight_name, param_name) ++ name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. -- if name.endswith(".bias") and name not in params_dict: -+ if _name.endswith(".bias") and _name not in params_dict: +- if _name.endswith(".bias") and _name not in params_dict: ++ if name.endswith(".bias") and name not in params_dict: continue -- param = params_dict[name] -+ param = params_dict[_name] +- param = params_dict[_name] ++ param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break -diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py -index c1c4c3638..c88de9d19 100644 ---- a/python/sglang/srt/models/qwen3_moe.py -+++ b/python/sglang/srt/models/qwen3_moe.py -@@ -770,6 +770,130 @@ class Qwen3MoeForCausalLM(nn.Module): - else: - self.model.layers_to_capture = [val + 1 for val in layer_ids] - -+ @property -+ def stacked_params_mapping(self) -> List[Tuple[str, str, str]]: -+ return [ -+ # (param_name, shard_name, shard_id) -+ ("qkv_proj", "q_proj", "q"), -+ ("qkv_proj", "k_proj", "k"), -+ ("qkv_proj", "v_proj", "v"), -+ ("gate_up_proj", "gate_proj", 0), -+ ("gate_up_proj", "up_proj", 1), -+ ] -+ -+ @property -+ def expert_params_mapping(self) -> List[Tuple[str, str, int, str]]: -+ return get_moe_impl_class().make_expert_params_mapping( -+ ckpt_gate_proj_name="gate_proj", -+ ckpt_down_proj_name="down_proj", -+ ckpt_up_proj_name="up_proj", -+ num_experts=self.config.num_experts, -+ ) -+ -+ def _load_weights_with_worker( -+ self, -+ params: Dict[str, torch.nn.Parameter], -+ local_names: List[str], -+ filenames: List[str], -+ weight_path: str, -+ ): -+ import os -+ -+ from safetensors import safe_open -+ -+ from sglang.srt.model_loader.weight_utils import default_weight_loader -+ -+ all_slices = {} -+ for filename in filenames: -+ safetensor_file = os.path.join(weight_path, filename) -+ with safe_open(safetensor_file, framework="pt", device="cpu") as f: -+ for name in f.keys(): -+ # all_slices[name] = f.get_slice(name) -+ all_slices[name] = f.get_tensor(name) -+ -+ for local_name in local_names: -+ # Skip loading extra bias for GPTQ models. -+ if local_name.endswith(".bias") and local_name not in params: -+ continue -+ # Handle special cases -+ if "rotary_emb.inv_freq" in local_name or "projector" in local_name: -+ continue -+ if ( -+ "rotary_emb.cos_cached" in local_name -+ or "rotary_emb.sin_cached" in local_name -+ ): -+ # Models trained using ColossalAI may include these tensors in -+ # the checkpoint. Skip them. -+ continue -+ if local_name.startswith("model.vision_tower") and local_name not in params: -+ continue -+ -+ param = params[local_name] -+ # Handle weight tying -+ if self.config.tie_word_embeddings and "lm_head.weight" in local_name: -+ if self.pp_group.world_size > 1 and self.pp_group.is_last_rank: -+ local_name = "model.embed_tokens.weight" -+ -+ loaded = False -+ for param_name, shard_name, shard_id in self.stacked_params_mapping: -+ if param_name not in local_name: -+ continue -+ if "mlp.experts" in local_name: -+ # Skip experts here, handled below -+ continue -+ # If local_name weight is sharded into multiple keys -+ weight_loader = param.weight_loader -+ slice_name = local_name.replace(param_name, shard_name) -+ loaded_weight = all_slices[slice_name] -+ weight_loader(param, loaded_weight, shard_id) -+ loaded = True -+ -+ for ( -+ param_name, -+ shard_name, -+ expert_id, -+ shard_id, -+ ) in self.expert_params_mapping: -+ if param_name not in local_name: -+ continue -+ # If local_name weight is sharded into multiple keys -+ weight_loader = param.weight_loader -+ slice_name = local_name.replace(param_name, shard_name) -+ loaded_weight = all_slices[slice_name] -+ weight_loader( -+ param, -+ loaded_weight, -+ local_name, -+ shard_id=shard_id, -+ expert_id=expert_id, -+ ) -+ loaded = True -+ -+ if not loaded: -+ # If local_name weight is not sharded -+ if local_name in all_slices: -+ loaded_weight = all_slices[local_name] -+ weight_loader = getattr( -+ param, "weight_loader", default_weight_loader -+ ) -+ weight_loader(param, loaded_weight) -+ else: -+ raise KeyError( -+ f"Cannot find weight {local_name} in the loaded slices." -+ ) -+ -+ def load_weights_from_path(self, path: str): -+ # Customized weights loading from a given path of huggingface model -+ from sglang.srt.models.utils.load import load_weights_with_hf_path_fast -+ -+ load_weights_with_hf_path_fast( -+ model=self, -+ weight_path=path, -+ load_weights_with_worker_fn=self._load_weights_with_worker, -+ stacked_params_mapping=self.stacked_params_mapping, -+ expert_params_mapping=self.expert_params_mapping, -+ ) -+ - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) -diff --git a/python/sglang/srt/models/utils/load.py b/python/sglang/srt/models/utils/load.py -new file mode 100644 -index 000000000..035835a0f ---- /dev/null -+++ b/python/sglang/srt/models/utils/load.py -@@ -0,0 +1,170 @@ -+import json -+import os -+from collections import defaultdict -+from concurrent.futures import ThreadPoolExecutor -+from glob import glob -+from typing import Callable, Dict, List, Tuple -+ -+import torch -+from safetensors import safe_open -+from transformers.utils.hub import cached_file -+ -+from sglang.srt.model_loader.weight_utils import default_weight_loader -+ -+ -+def get_actual_hf_path(weight_path: str): -+ return os.path.dirname(cached_file(weight_path, "config.json")) -+ -+ -+def make_filename_bins( -+ local_to_file_map: Dict[str, List[str]], -+) -> Tuple[List[List[str]], List[List[str]]]: -+ # Allocate local weight name into bins, where each bin access independent files -+ # Then we can use multiple threads to concurrently load each bin's parameters. -+ # This function has a complexity of O(F + L²) -+ # where F = total number of files, L = number of local names -+ if not local_to_file_map: -+ return [], [] -+ -+ local_names = list(local_to_file_map.keys()) -+ n = len(local_names) -+ -+ # Convert file lists to sets for O(1) lookups and create file-to-locals mapping -+ local_to_files = {name: set(local_to_file_map[name]) for name in local_names} -+ file_to_locals = defaultdict(set) -+ for local_name, files in local_to_files.items(): -+ for file in files: -+ file_to_locals[file].add(local_name) -+ -+ # Union-Find with path compression and union by rank -+ parent = list(range(n)) -+ rank = [0] * n -+ -+ def find(x): -+ if parent[x] != x: -+ parent[x] = find(parent[x]) # Path compression -+ return parent[x] -+ -+ def union(x, y): -+ root_x, root_y = find(x), find(y) -+ if root_x == root_y: -+ return -+ -+ # Union by rank -+ if rank[root_x] < rank[root_y]: -+ root_x, root_y = root_y, root_x -+ parent[root_y] = root_x -+ if rank[root_x] == rank[root_y]: -+ rank[root_x] += 1 -+ -+ # Create name-to-index mapping for O(1) lookups -+ name_to_idx = {name: i for i, name in enumerate(local_names)} -+ -+ # Union locals that share files - O(F) where F is total number of files -+ for locals_sharing_file in file_to_locals.values(): -+ if len(locals_sharing_file) > 1: -+ locals_list = list(locals_sharing_file) -+ first_idx = name_to_idx[locals_list[0]] -+ for local_name in locals_list[1:]: -+ union(first_idx, name_to_idx[local_name]) -+ -+ # Group by root - O(L) -+ root_to_group = defaultdict(list) -+ for i, name in enumerate(local_names): -+ root_to_group[find(i)].append(name) -+ -+ # Build result groups - O(L + F) -+ grouped_local_names = [] -+ grouped_filenames = [] -+ -+ for group in root_to_group.values(): -+ grouped_local_names.append(group) -+ # Use set union to merge files from all locals in group -+ all_files = set() -+ for local_name in group: -+ all_files.update(local_to_files[local_name]) -+ grouped_filenames.append(list(all_files)) -+ -+ return grouped_local_names, grouped_filenames -+ -+ -+def load_weights_with_hf_path_fast( -+ model: torch.nn.Module, -+ weight_path: str, -+ load_weights_with_worker_fn: Callable, -+ stacked_params_mapping: List[Tuple[str, str, str]] | None = None, -+ expert_params_mapping: List[Tuple[str, str, str]] | None = None, -+ tie_word_embeddings: bool = False, -+ max_workers: int = None, -+): -+ if not os.path.exists(weight_path): -+ weight_path = get_actual_hf_path(weight_path) -+ index_file = os.path.join(weight_path, "model.safetensors.index.json") -+ index = {} -+ if os.path.exists(index_file): -+ with open(index_file, "r") as f: -+ index = json.load(f)["weight_map"] -+ else: -+ # Search all safetensors files -+ safetensor_files = glob(os.path.join(weight_path, "*.safetensors")) -+ # If there are safetensors files -+ if safetensor_files: -+ # Iterate through each safetensors file -+ for safetensor_file in safetensor_files: -+ with safe_open(safetensor_file, framework="pt", device="cpu") as f: -+ for k in f.keys(): -+ index[k] = safetensor_file -+ else: -+ raise FileNotFoundError("No safetensors found in the model path to load.") -+ -+ params = dict(model.named_parameters()) -+ local_names = list(params.keys()) -+ -+ worker_args = [] -+ -+ # local name -> set of filenames that contains the weight -+ local_to_file_map = defaultdict(set) -+ # model.layers.31.mlp.experts -+ for local_name in local_names: -+ hf_names = [] -+ if "mlp.experts" not in local_name and stacked_params_mapping is not None: -+ for param_name, shard_name, _ in stacked_params_mapping: -+ if param_name in local_name: -+ hf_names.append(local_name.replace(param_name, shard_name)) -+ if expert_params_mapping is not None: -+ for param_name, shard_name, _, _ in expert_params_mapping: -+ if param_name in local_name: -+ hf_names.append(local_name.replace(param_name, shard_name)) -+ if tie_word_embeddings and "lm_head.weight" in local_name: -+ hf_names.append("model.embed_tokens.weight") -+ if len(hf_names) == 0: -+ hf_names.append(local_name) -+ for name in hf_names: -+ filename = index[name] -+ if filename not in local_to_file_map[local_name]: -+ local_to_file_map[local_name].add(filename) -+ -+ grouped_local_names, grouped_filenames = make_filename_bins(local_to_file_map) -+ -+ if max_workers is None: -+ # assume all GPUs are used by SGLang servers -+ max_workers = min(8, max(1, os.cpu_count() // torch.cuda.device_count())) -+ -+ for local_names, filenames in zip(grouped_local_names, grouped_filenames): -+ worker_args.append( -+ dict( -+ params=params, -+ local_names=local_names, -+ filenames=filenames, -+ weight_path=weight_path, -+ ) -+ ) -+ -+ max_workers = min(max_workers, len(worker_args)) -+ with ThreadPoolExecutor(max_workers=max_workers) as executor: -+ results = executor.map( -+ lambda kwargs: load_weights_with_worker_fn(**kwargs), worker_args -+ ) -+ # Consume all results to make result all tasks complete -+ for _ in results: -+ pass -diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py -index 846baeb01..5235c0563 100644 ---- a/python/sglang/srt/utils.py -+++ b/python/sglang/srt/utils.py -@@ -21,6 +21,7 @@ import ctypes +diff -ruN ./srt/models/utils/load.py ../../../sglang/python/sglang/srt/models/utils/load.py +--- ./srt/models/utils/load.py 2025-10-11 15:01:36.754004483 +0800 ++++ ../../../sglang/python/sglang/srt/models/utils/load.py 1970-01-01 08:00:00.000000000 +0800 +@@ -1,170 +0,0 @@ +-import json +-import os +-from collections import defaultdict +-from concurrent.futures import ThreadPoolExecutor +-from glob import glob +-from typing import Callable, Dict, List, Tuple +- +-import torch +-from safetensors import safe_open +-from transformers.utils.hub import cached_file +- +-from sglang.srt.model_loader.weight_utils import default_weight_loader +- +- +-def get_actual_hf_path(weight_path: str): +- return os.path.dirname(cached_file(weight_path, "config.json")) +- +- +-def make_filename_bins( +- local_to_file_map: Dict[str, List[str]], +-) -> Tuple[List[List[str]], List[List[str]]]: +- # Allocate local weight name into bins, where each bin access independent files +- # Then we can use multiple threads to concurrently load each bin's parameters. +- # This function has a complexity of O(F + L²) +- # where F = total number of files, L = number of local names +- if not local_to_file_map: +- return [], [] +- +- local_names = list(local_to_file_map.keys()) +- n = len(local_names) +- +- # Convert file lists to sets for O(1) lookups and create file-to-locals mapping +- local_to_files = {name: set(local_to_file_map[name]) for name in local_names} +- file_to_locals = defaultdict(set) +- for local_name, files in local_to_files.items(): +- for file in files: +- file_to_locals[file].add(local_name) +- +- # Union-Find with path compression and union by rank +- parent = list(range(n)) +- rank = [0] * n +- +- def find(x): +- if parent[x] != x: +- parent[x] = find(parent[x]) # Path compression +- return parent[x] +- +- def union(x, y): +- root_x, root_y = find(x), find(y) +- if root_x == root_y: +- return +- +- # Union by rank +- if rank[root_x] < rank[root_y]: +- root_x, root_y = root_y, root_x +- parent[root_y] = root_x +- if rank[root_x] == rank[root_y]: +- rank[root_x] += 1 +- +- # Create name-to-index mapping for O(1) lookups +- name_to_idx = {name: i for i, name in enumerate(local_names)} +- +- # Union locals that share files - O(F) where F is total number of files +- for locals_sharing_file in file_to_locals.values(): +- if len(locals_sharing_file) > 1: +- locals_list = list(locals_sharing_file) +- first_idx = name_to_idx[locals_list[0]] +- for local_name in locals_list[1:]: +- union(first_idx, name_to_idx[local_name]) +- +- # Group by root - O(L) +- root_to_group = defaultdict(list) +- for i, name in enumerate(local_names): +- root_to_group[find(i)].append(name) +- +- # Build result groups - O(L + F) +- grouped_local_names = [] +- grouped_filenames = [] +- +- for group in root_to_group.values(): +- grouped_local_names.append(group) +- # Use set union to merge files from all locals in group +- all_files = set() +- for local_name in group: +- all_files.update(local_to_files[local_name]) +- grouped_filenames.append(list(all_files)) +- +- return grouped_local_names, grouped_filenames +- +- +-def load_weights_with_hf_path_fast( +- model: torch.nn.Module, +- weight_path: str, +- load_weights_with_worker_fn: Callable, +- stacked_params_mapping: List[Tuple[str, str, str]] | None = None, +- expert_params_mapping: List[Tuple[str, str, str]] | None = None, +- tie_word_embeddings: bool = False, +- max_workers: int = None, +-): +- if not os.path.exists(weight_path): +- weight_path = get_actual_hf_path(weight_path) +- index_file = os.path.join(weight_path, "model.safetensors.index.json") +- index = {} +- if os.path.exists(index_file): +- with open(index_file, "r") as f: +- index = json.load(f)["weight_map"] +- else: +- # Search all safetensors files +- safetensor_files = glob(os.path.join(weight_path, "*.safetensors")) +- # If there are safetensors files +- if safetensor_files: +- # Iterate through each safetensors file +- for safetensor_file in safetensor_files: +- with safe_open(safetensor_file, framework="pt", device="cpu") as f: +- for k in f.keys(): +- index[k] = safetensor_file +- else: +- raise FileNotFoundError("No safetensors found in the model path to load.") +- +- params = dict(model.named_parameters()) +- local_names = list(params.keys()) +- +- worker_args = [] +- +- # local name -> set of filenames that contains the weight +- local_to_file_map = defaultdict(set) +- # model.layers.31.mlp.experts +- for local_name in local_names: +- hf_names = [] +- if "mlp.experts" not in local_name and stacked_params_mapping is not None: +- for param_name, shard_name, _ in stacked_params_mapping: +- if param_name in local_name: +- hf_names.append(local_name.replace(param_name, shard_name)) +- if expert_params_mapping is not None: +- for param_name, shard_name, _, _ in expert_params_mapping: +- if param_name in local_name: +- hf_names.append(local_name.replace(param_name, shard_name)) +- if tie_word_embeddings and "lm_head.weight" in local_name: +- hf_names.append("model.embed_tokens.weight") +- if len(hf_names) == 0: +- hf_names.append(local_name) +- for name in hf_names: +- filename = index[name] +- if filename not in local_to_file_map[local_name]: +- local_to_file_map[local_name].add(filename) +- +- grouped_local_names, grouped_filenames = make_filename_bins(local_to_file_map) +- +- if max_workers is None: +- # assume all GPUs are used by SGLang servers +- max_workers = min(8, max(1, os.cpu_count() // torch.cuda.device_count())) +- +- for local_names, filenames in zip(grouped_local_names, grouped_filenames): +- worker_args.append( +- dict( +- params=params, +- local_names=local_names, +- filenames=filenames, +- weight_path=weight_path, +- ) +- ) +- +- max_workers = min(max_workers, len(worker_args)) +- with ThreadPoolExecutor(max_workers=max_workers) as executor: +- results = executor.map( +- lambda kwargs: load_weights_with_worker_fn(**kwargs), worker_args +- ) +- # Consume all results to make result all tasks complete +- for _ in results: +- pass +diff -ruN ./srt/utils.py ../../../sglang/python/sglang/srt/utils.py +--- ./srt/utils.py 2025-10-11 15:03:52.182903128 +0800 ++++ ../../../sglang/python/sglang/srt/utils.py 2025-10-13 13:20:52.081418051 +0800 +@@ -21,7 +21,6 @@ import dataclasses import functools import importlib -+import inspect +-import inspect import io import ipaddress import itertools From 20ca31410c028713df43d31cb96aee2b89150e92 Mon Sep 17 00:00:00 2001 From: nuzant Date: Mon, 13 Oct 2025 13:33:55 +0800 Subject: [PATCH 23/24] fix patch --- areal/utils/launcher.py | 2 +- patch/sglang/v0.5.2.patch | 36 ++++++++++++++++++------------------ 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/areal/utils/launcher.py b/areal/utils/launcher.py index e1ad36916..6125030fe 100644 --- a/areal/utils/launcher.py +++ b/areal/utils/launcher.py @@ -177,7 +177,7 @@ def apply_sglang_patch(): for line in sglang_meta.split("\n"): line = line.strip() if line.startswith("Editable project location: "): - target_path = str(Path(line.split(": ")[1]).parent) + target_path = str(Path(line.split(": ")[1]) / "sglang") break else: for line in sglang_meta.split("\n"): diff --git a/patch/sglang/v0.5.2.patch b/patch/sglang/v0.5.2.patch index 3a94b3253..451d9def0 100644 --- a/patch/sglang/v0.5.2.patch +++ b/patch/sglang/v0.5.2.patch @@ -1,6 +1,6 @@ -diff -ruN ./srt/model_executor/model_runner.py ../../../sglang/python/sglang/srt/model_executor/model_runner.py ---- ./srt/model_executor/model_runner.py 2025-10-11 15:06:03.052603197 +0800 -+++ ../../../sglang/python/sglang/srt/model_executor/model_runner.py 2025-10-13 13:20:52.071417615 +0800 +diff -ruN ../../../my-sglang/python/sglang/srt/model_executor/model_runner.py ./srt/model_executor/model_runner.py +--- ../../../my-sglang/python/sglang/srt/model_executor/model_runner.py 2025-10-11 15:06:03.052603197 +0800 ++++ ./srt/model_executor/model_runner.py 2025-10-13 13:20:52.071417615 +0800 @@ -258,7 +258,6 @@ self._model_update_group = {} @@ -34,9 +34,9 @@ diff -ruN ./srt/model_executor/model_runner.py ../../../sglang/python/sglang/srt return model with set_default_torch_dtype(self.model_config.dtype): -diff -ruN ./srt/model_loader/loader.py ../../../sglang/python/sglang/srt/model_loader/loader.py ---- ./srt/model_loader/loader.py 2025-10-11 15:03:31.201989298 +0800 -+++ ../../../sglang/python/sglang/srt/model_loader/loader.py 2025-10-13 13:20:52.071417615 +0800 +diff -ruN ../../../my-sglang/python/sglang/srt/model_loader/loader.py ./srt/model_loader/loader.py +--- ../../../my-sglang/python/sglang/srt/model_loader/loader.py 2025-10-11 15:03:31.201989298 +0800 ++++ ./srt/model_loader/loader.py 2025-10-13 13:20:52.071417615 +0800 @@ -278,7 +278,7 @@ def __init__(self, load_config: LoadConfig): super().__init__(load_config) @@ -101,9 +101,9 @@ diff -ruN ./srt/model_loader/loader.py ../../../sglang/python/sglang/srt/model_l for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) -diff -ruN ./srt/models/qwen3_moe.py ../../../sglang/python/sglang/srt/models/qwen3_moe.py ---- ./srt/models/qwen3_moe.py 2025-10-11 15:03:31.207989560 +0800 -+++ ../../../sglang/python/sglang/srt/models/qwen3_moe.py 2025-10-13 13:20:52.077417877 +0800 +diff -ruN ../../../my-sglang/python/sglang/srt/models/qwen3_moe.py ./srt/models/qwen3_moe.py +--- ../../../my-sglang/python/sglang/srt/models/qwen3_moe.py 2025-10-11 15:03:31.207989560 +0800 ++++ ./srt/models/qwen3_moe.py 2025-10-13 13:20:52.077417877 +0800 @@ -770,130 +770,6 @@ else: self.model.layers_to_capture = [val + 1 for val in layer_ids] @@ -235,9 +235,9 @@ diff -ruN ./srt/models/qwen3_moe.py ../../../sglang/python/sglang/srt/models/qwe def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) -diff -ruN ./srt/models/qwen3.py ../../../sglang/python/sglang/srt/models/qwen3.py ---- ./srt/models/qwen3.py 2025-10-11 15:03:31.207989560 +0800 -+++ ../../../sglang/python/sglang/srt/models/qwen3.py 2025-10-13 13:20:52.076417833 +0800 +diff -ruN ../../../my-sglang/python/sglang/srt/models/qwen3.py ./srt/models/qwen3.py +--- ../../../my-sglang/python/sglang/srt/models/qwen3.py 2025-10-11 15:03:31.207989560 +0800 ++++ ./srt/models/qwen3.py 2025-10-13 13:20:52.076417833 +0800 @@ -418,97 +418,6 @@ def end_layer(self): return self.model.end_layer @@ -351,9 +351,9 @@ diff -ruN ./srt/models/qwen3.py ../../../sglang/python/sglang/srt/models/qwen3.p weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break -diff -ruN ./srt/models/utils/load.py ../../../sglang/python/sglang/srt/models/utils/load.py ---- ./srt/models/utils/load.py 2025-10-11 15:01:36.754004483 +0800 -+++ ../../../sglang/python/sglang/srt/models/utils/load.py 1970-01-01 08:00:00.000000000 +0800 +diff -ruN ../../../my-sglang/python/sglang/srt/models/utils/load.py ./srt/models/utils/load.py +--- ../../../my-sglang/python/sglang/srt/models/utils/load.py 2025-10-11 15:01:36.754004483 +0800 ++++ ./srt/models/utils/load.py 1970-01-01 08:00:00.000000000 +0800 @@ -1,170 +0,0 @@ -import json -import os @@ -525,9 +525,9 @@ diff -ruN ./srt/models/utils/load.py ../../../sglang/python/sglang/srt/models/ut - # Consume all results to make result all tasks complete - for _ in results: - pass -diff -ruN ./srt/utils.py ../../../sglang/python/sglang/srt/utils.py ---- ./srt/utils.py 2025-10-11 15:03:52.182903128 +0800 -+++ ../../../sglang/python/sglang/srt/utils.py 2025-10-13 13:20:52.081418051 +0800 +diff -ruN ../../../my-sglang/python/sglang/srt/utils.py ./srt/utils.py +--- ../../../my-sglang/python/sglang/srt/utils.py 2025-10-11 15:03:52.182903128 +0800 ++++ ./srt/utils.py 2025-10-13 13:20:52.081418051 +0800 @@ -21,7 +21,6 @@ import dataclasses import functools From c2f13b402b606f3bb74b533d3a9c82fbf854f9fa Mon Sep 17 00:00:00 2001 From: nuzant Date: Mon, 13 Oct 2025 13:37:40 +0800 Subject: [PATCH 24/24] . --- patch/sglang/v0.5.2.patch | 916 +++++++++++++++++++------------------- 1 file changed, 458 insertions(+), 458 deletions(-) diff --git a/patch/sglang/v0.5.2.patch b/patch/sglang/v0.5.2.patch index 451d9def0..65986c635 100644 --- a/patch/sglang/v0.5.2.patch +++ b/patch/sglang/v0.5.2.patch @@ -1,538 +1,538 @@ -diff -ruN ../../../my-sglang/python/sglang/srt/model_executor/model_runner.py ./srt/model_executor/model_runner.py ---- ../../../my-sglang/python/sglang/srt/model_executor/model_runner.py 2025-10-11 15:06:03.052603197 +0800 -+++ ./srt/model_executor/model_runner.py 2025-10-13 13:20:52.071417615 +0800 -@@ -258,7 +258,6 @@ +diff -ruN ./srt/model_executor/model_runner.py ../../../my-sglang/python/sglang/srt/model_executor/model_runner.py +--- ./srt/model_executor/model_runner.py 2025-10-13 13:20:52.071417615 +0800 ++++ ../../../my-sglang/python/sglang/srt/model_executor/model_runner.py 2025-10-11 15:06:03.052603197 +0800 +@@ -258,6 +258,7 @@ self._model_update_group = {} def initialize(self, min_per_gpu_memory: float): -- logger.warning("SGLang v0.5.2 is patched with customized weight loading.") ++ logger.warning("SGLang v0.5.2 is patched with customized weight loading.") server_args = self.server_args self.memory_saver_adapter = TorchMemorySaverAdapter.create( -@@ -824,11 +823,8 @@ +@@ -823,8 +824,11 @@ target_device = torch.device(self.device) self.model_config.model_path = model_path -- load_config = LoadConfig( -- load_format=load_format, -- # XXX: This should be in function args, passed in by requests -- model_loader_extra_config=self.server_args.model_loader_extra_config, -- ) -+ load_config = LoadConfig(load_format=load_format) -+ +- load_config = LoadConfig(load_format=load_format) +- ++ load_config = LoadConfig( ++ load_format=load_format, ++ # XXX: This should be in function args, passed in by requests ++ model_loader_extra_config=self.server_args.model_loader_extra_config, ++ ) # Only support DefaultModelLoader for now loader = get_model_loader(load_config) if not isinstance(loader, DefaultModelLoader): -@@ -842,9 +838,7 @@ +@@ -838,7 +842,9 @@ return iter def model_load_weights(model, iter): -- DefaultModelLoader.load_weights_and_postprocess( -- model, iter, target_device, load_config=load_config -- ) -+ DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device) +- DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device) ++ DefaultModelLoader.load_weights_and_postprocess( ++ model, iter, target_device, load_config=load_config ++ ) return model with set_default_torch_dtype(self.model_config.dtype): -diff -ruN ../../../my-sglang/python/sglang/srt/model_loader/loader.py ./srt/model_loader/loader.py ---- ../../../my-sglang/python/sglang/srt/model_loader/loader.py 2025-10-11 15:03:31.201989298 +0800 -+++ ./srt/model_loader/loader.py 2025-10-13 13:20:52.071417615 +0800 +diff -ruN ./srt/model_loader/loader.py ../../../my-sglang/python/sglang/srt/model_loader/loader.py +--- ./srt/model_loader/loader.py 2025-10-13 13:20:52.071417615 +0800 ++++ ../../../my-sglang/python/sglang/srt/model_loader/loader.py 2025-10-11 15:03:31.201989298 +0800 @@ -278,7 +278,7 @@ def __init__(self, load_config: LoadConfig): super().__init__(load_config) extra_config = load_config.model_loader_extra_config -- allowed_keys = {"enable_multithread_load", "num_threads", "enable_fast_load"} -+ allowed_keys = {"enable_multithread_load", "num_threads"} +- allowed_keys = {"enable_multithread_load", "num_threads"} ++ allowed_keys = {"enable_multithread_load", "num_threads", "enable_fast_load"} unexpected_keys = set(extra_config.keys()) - allowed_keys if unexpected_keys: -@@ -399,9 +399,6 @@ +@@ -399,6 +399,9 @@ ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" extra_config = self.load_config.model_loader_extra_config -- if extra_config.get("enable_fast_load"): -- return source.model_or_path -- ++ if extra_config.get("enable_fast_load"): ++ return source.model_or_path ++ hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( source.model_or_path, source.revision, source.fall_back_to_pt ) -@@ -444,6 +441,8 @@ +@@ -441,8 +444,6 @@ ) else: weights_iterator = pt_weights_iterator(hf_weights_files) -+ -+ # Apply the prefix. +- +- # Apply the prefix. return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) def _get_all_weights( -@@ -451,6 +450,7 @@ +@@ -450,7 +451,6 @@ model_config: ModelConfig, model: nn.Module, ) -> Generator[Tuple[str, torch.Tensor], None, None]: -+ +- primary_weights = DefaultModelLoader.Source.init_new(model_config, model) yield from self._get_weights_iterator(primary_weights) -@@ -479,23 +479,15 @@ +@@ -479,15 +479,23 @@ self.load_config, ) -- extra_config = self.load_config.model_loader_extra_config -- if extra_config.get("enable_fast_load"): -- weights_iter_or_path = model_config.model_path -- else: -- weights_iter_or_path = self._get_all_weights(model_config, model) ++ extra_config = self.load_config.model_loader_extra_config ++ if extra_config.get("enable_fast_load"): ++ weights_iter_or_path = model_config.model_path ++ else: ++ weights_iter_or_path = self._get_all_weights(model_config, model) self.load_weights_and_postprocess( -- model, weights_iter_or_path, target_device, load_config=self.load_config -+ model, self._get_all_weights(model_config, model), target_device +- model, self._get_all_weights(model_config, model), target_device ++ model, weights_iter_or_path, target_device, load_config=self.load_config ) -+ +- return model.eval() @staticmethod -- def load_weights_and_postprocess(model, weights, target_device, load_config=None): -- extra_config = load_config.model_loader_extra_config -- if extra_config.get("enable_fast_load"): -- model.load_weights_from_path(weights) -- else: -- model.load_weights(weights) -+ def load_weights_and_postprocess(model, weights, target_device): -+ model.load_weights(weights) +- def load_weights_and_postprocess(model, weights, target_device): +- model.load_weights(weights) ++ def load_weights_and_postprocess(model, weights, target_device, load_config=None): ++ extra_config = load_config.model_loader_extra_config ++ if extra_config.get("enable_fast_load"): ++ model.load_weights_from_path(weights) ++ else: ++ model.load_weights(weights) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) -diff -ruN ../../../my-sglang/python/sglang/srt/models/qwen3_moe.py ./srt/models/qwen3_moe.py ---- ../../../my-sglang/python/sglang/srt/models/qwen3_moe.py 2025-10-11 15:03:31.207989560 +0800 -+++ ./srt/models/qwen3_moe.py 2025-10-13 13:20:52.077417877 +0800 -@@ -770,130 +770,6 @@ +diff -ruN ./srt/models/qwen3_moe.py ../../../my-sglang/python/sglang/srt/models/qwen3_moe.py +--- ./srt/models/qwen3_moe.py 2025-10-13 13:20:52.077417877 +0800 ++++ ../../../my-sglang/python/sglang/srt/models/qwen3_moe.py 2025-10-11 15:03:31.207989560 +0800 +@@ -770,6 +770,130 @@ else: self.model.layers_to_capture = [val + 1 for val in layer_ids] -- @property -- def stacked_params_mapping(self) -> List[Tuple[str, str, str]]: -- return [ -- # (param_name, shard_name, shard_id) -- ("qkv_proj", "q_proj", "q"), -- ("qkv_proj", "k_proj", "k"), -- ("qkv_proj", "v_proj", "v"), -- ("gate_up_proj", "gate_proj", 0), -- ("gate_up_proj", "up_proj", 1), -- ] -- -- @property -- def expert_params_mapping(self) -> List[Tuple[str, str, int, str]]: -- return get_moe_impl_class().make_expert_params_mapping( -- ckpt_gate_proj_name="gate_proj", -- ckpt_down_proj_name="down_proj", -- ckpt_up_proj_name="up_proj", -- num_experts=self.config.num_experts, -- ) -- -- def _load_weights_with_worker( -- self, -- params: Dict[str, torch.nn.Parameter], -- local_names: List[str], -- filenames: List[str], -- weight_path: str, -- ): -- import os -- -- from safetensors import safe_open -- -- from sglang.srt.model_loader.weight_utils import default_weight_loader -- -- all_slices = {} -- for filename in filenames: -- safetensor_file = os.path.join(weight_path, filename) -- with safe_open(safetensor_file, framework="pt", device="cpu") as f: -- for name in f.keys(): -- # all_slices[name] = f.get_slice(name) -- all_slices[name] = f.get_tensor(name) -- -- for local_name in local_names: -- # Skip loading extra bias for GPTQ models. -- if local_name.endswith(".bias") and local_name not in params: -- continue -- # Handle special cases -- if "rotary_emb.inv_freq" in local_name or "projector" in local_name: -- continue -- if ( -- "rotary_emb.cos_cached" in local_name -- or "rotary_emb.sin_cached" in local_name -- ): -- # Models trained using ColossalAI may include these tensors in -- # the checkpoint. Skip them. -- continue -- if local_name.startswith("model.vision_tower") and local_name not in params: -- continue -- -- param = params[local_name] -- # Handle weight tying -- if self.config.tie_word_embeddings and "lm_head.weight" in local_name: -- if self.pp_group.world_size > 1 and self.pp_group.is_last_rank: -- local_name = "model.embed_tokens.weight" -- -- loaded = False -- for param_name, shard_name, shard_id in self.stacked_params_mapping: -- if param_name not in local_name: -- continue -- if "mlp.experts" in local_name: -- # Skip experts here, handled below -- continue -- # If local_name weight is sharded into multiple keys -- weight_loader = param.weight_loader -- slice_name = local_name.replace(param_name, shard_name) -- loaded_weight = all_slices[slice_name] -- weight_loader(param, loaded_weight, shard_id) -- loaded = True -- -- for ( -- param_name, -- shard_name, -- expert_id, -- shard_id, -- ) in self.expert_params_mapping: -- if param_name not in local_name: -- continue -- # If local_name weight is sharded into multiple keys -- weight_loader = param.weight_loader -- slice_name = local_name.replace(param_name, shard_name) -- loaded_weight = all_slices[slice_name] -- weight_loader( -- param, -- loaded_weight, -- local_name, -- shard_id=shard_id, -- expert_id=expert_id, -- ) -- loaded = True -- -- if not loaded: -- # If local_name weight is not sharded -- if local_name in all_slices: -- loaded_weight = all_slices[local_name] -- weight_loader = getattr( -- param, "weight_loader", default_weight_loader -- ) -- weight_loader(param, loaded_weight) -- else: -- raise KeyError( -- f"Cannot find weight {local_name} in the loaded slices." -- ) -- -- def load_weights_from_path(self, path: str): -- # Customized weights loading from a given path of huggingface model -- from sglang.srt.models.utils.load import load_weights_with_hf_path_fast -- -- load_weights_with_hf_path_fast( -- model=self, -- weight_path=path, -- load_weights_with_worker_fn=self._load_weights_with_worker, -- stacked_params_mapping=self.stacked_params_mapping, -- expert_params_mapping=self.expert_params_mapping, -- ) -- ++ @property ++ def stacked_params_mapping(self) -> List[Tuple[str, str, str]]: ++ return [ ++ # (param_name, shard_name, shard_id) ++ ("qkv_proj", "q_proj", "q"), ++ ("qkv_proj", "k_proj", "k"), ++ ("qkv_proj", "v_proj", "v"), ++ ("gate_up_proj", "gate_proj", 0), ++ ("gate_up_proj", "up_proj", 1), ++ ] ++ ++ @property ++ def expert_params_mapping(self) -> List[Tuple[str, str, int, str]]: ++ return get_moe_impl_class().make_expert_params_mapping( ++ ckpt_gate_proj_name="gate_proj", ++ ckpt_down_proj_name="down_proj", ++ ckpt_up_proj_name="up_proj", ++ num_experts=self.config.num_experts, ++ ) ++ ++ def _load_weights_with_worker( ++ self, ++ params: Dict[str, torch.nn.Parameter], ++ local_names: List[str], ++ filenames: List[str], ++ weight_path: str, ++ ): ++ import os ++ ++ from safetensors import safe_open ++ ++ from sglang.srt.model_loader.weight_utils import default_weight_loader ++ ++ all_slices = {} ++ for filename in filenames: ++ safetensor_file = os.path.join(weight_path, filename) ++ with safe_open(safetensor_file, framework="pt", device="cpu") as f: ++ for name in f.keys(): ++ # all_slices[name] = f.get_slice(name) ++ all_slices[name] = f.get_tensor(name) ++ ++ for local_name in local_names: ++ # Skip loading extra bias for GPTQ models. ++ if local_name.endswith(".bias") and local_name not in params: ++ continue ++ # Handle special cases ++ if "rotary_emb.inv_freq" in local_name or "projector" in local_name: ++ continue ++ if ( ++ "rotary_emb.cos_cached" in local_name ++ or "rotary_emb.sin_cached" in local_name ++ ): ++ # Models trained using ColossalAI may include these tensors in ++ # the checkpoint. Skip them. ++ continue ++ if local_name.startswith("model.vision_tower") and local_name not in params: ++ continue ++ ++ param = params[local_name] ++ # Handle weight tying ++ if self.config.tie_word_embeddings and "lm_head.weight" in local_name: ++ if self.pp_group.world_size > 1 and self.pp_group.is_last_rank: ++ local_name = "model.embed_tokens.weight" ++ ++ loaded = False ++ for param_name, shard_name, shard_id in self.stacked_params_mapping: ++ if param_name not in local_name: ++ continue ++ if "mlp.experts" in local_name: ++ # Skip experts here, handled below ++ continue ++ # If local_name weight is sharded into multiple keys ++ weight_loader = param.weight_loader ++ slice_name = local_name.replace(param_name, shard_name) ++ loaded_weight = all_slices[slice_name] ++ weight_loader(param, loaded_weight, shard_id) ++ loaded = True ++ ++ for ( ++ param_name, ++ shard_name, ++ expert_id, ++ shard_id, ++ ) in self.expert_params_mapping: ++ if param_name not in local_name: ++ continue ++ # If local_name weight is sharded into multiple keys ++ weight_loader = param.weight_loader ++ slice_name = local_name.replace(param_name, shard_name) ++ loaded_weight = all_slices[slice_name] ++ weight_loader( ++ param, ++ loaded_weight, ++ local_name, ++ shard_id=shard_id, ++ expert_id=expert_id, ++ ) ++ loaded = True ++ ++ if not loaded: ++ # If local_name weight is not sharded ++ if local_name in all_slices: ++ loaded_weight = all_slices[local_name] ++ weight_loader = getattr( ++ param, "weight_loader", default_weight_loader ++ ) ++ weight_loader(param, loaded_weight) ++ else: ++ raise KeyError( ++ f"Cannot find weight {local_name} in the loaded slices." ++ ) ++ ++ def load_weights_from_path(self, path: str): ++ # Customized weights loading from a given path of huggingface model ++ from sglang.srt.models.utils.load import load_weights_with_hf_path_fast ++ ++ load_weights_with_hf_path_fast( ++ model=self, ++ weight_path=path, ++ load_weights_with_worker_fn=self._load_weights_with_worker, ++ stacked_params_mapping=self.stacked_params_mapping, ++ expert_params_mapping=self.expert_params_mapping, ++ ) ++ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) -diff -ruN ../../../my-sglang/python/sglang/srt/models/qwen3.py ./srt/models/qwen3.py ---- ../../../my-sglang/python/sglang/srt/models/qwen3.py 2025-10-11 15:03:31.207989560 +0800 -+++ ./srt/models/qwen3.py 2025-10-13 13:20:52.076417833 +0800 -@@ -418,97 +418,6 @@ +diff -ruN ./srt/models/qwen3.py ../../../my-sglang/python/sglang/srt/models/qwen3.py +--- ./srt/models/qwen3.py 2025-10-13 13:20:52.076417833 +0800 ++++ ../../../my-sglang/python/sglang/srt/models/qwen3.py 2025-10-11 15:03:31.207989560 +0800 +@@ -418,6 +418,97 @@ def end_layer(self): return self.model.end_layer -- @property -- def stacked_params_mapping(self) -> List[Tuple[str, str, str]]: -- return [ -- # (param_name, shard_name, shard_id) -- ("qkv_proj", "q_proj", "q"), -- ("qkv_proj", "k_proj", "k"), -- ("qkv_proj", "v_proj", "v"), -- ("gate_up_proj", "gate_proj", 0), -- ("gate_up_proj", "up_proj", 1), -- ] -- -- def _load_weights_with_worker( -- self, -- params: Dict[str, torch.nn.Parameter], -- local_names: List[str], -- filenames: List[str], -- weight_path: str, -- ): -- import os -- -- from safetensors import safe_open -- -- from sglang.srt.model_loader.weight_utils import default_weight_loader -- -- all_slices = {} -- for filename in filenames: -- safetensor_file = os.path.join(weight_path, filename) -- with safe_open(safetensor_file, framework="pt", device="cpu") as f: -- for name in f.keys(): -- # all_slices[name] = f.get_slice(name) -- all_slices[name] = f.get_tensor(name) -- -- for local_name in local_names: -- # Skip loading extra bias for GPTQ models. -- if local_name.endswith(".bias") and local_name not in params: -- continue -- # Handle special cases -- if "rotary_emb.inv_freq" in local_name or "projector" in local_name: -- continue -- if ( -- "rotary_emb.cos_cached" in local_name -- or "rotary_emb.sin_cached" in local_name -- ): -- # Models trained using ColossalAI may include these tensors in -- # the checkpoint. Skip them. -- continue -- if local_name.startswith("model.vision_tower") and local_name not in params: -- continue -- -- param = params[local_name] -- # Handle weight tying -- if self.config.tie_word_embeddings and "lm_head.weight" in local_name: -- if self.pp_group.world_size > 1 and self.pp_group.is_last_rank: -- local_name = "model.embed_tokens.weight" -- -- loaded = False -- for param_name, shard_name, shard_id in self.stacked_params_mapping: -- if param_name not in local_name: -- continue -- # If local_name weight is sharded into multiple keys -- weight_loader = param.weight_loader -- slice_name = local_name.replace(param_name, shard_name) -- loaded_weight = all_slices[slice_name] -- weight_loader(param, loaded_weight, shard_id) -- loaded = True -- -- if not loaded: -- # If local_name weight is not sharded -- if local_name in all_slices: -- loaded_weight = all_slices[local_name] -- weight_loader = getattr( -- param, "weight_loader", default_weight_loader -- ) -- weight_loader(param, loaded_weight) -- else: -- raise KeyError( -- f"Cannot find weight {local_name} in the loaded slices." -- ) -- -- def load_weights_from_path(self, path: str): -- # Customized weights loading from a given path of huggingface model -- from sglang.srt.models.utils.load import load_weights_with_hf_path_fast -- -- load_weights_with_hf_path_fast( -- model=self, -- weight_path=path, -- load_weights_with_worker_fn=self._load_weights_with_worker, -- stacked_params_mapping=self.stacked_params_mapping, -- tie_word_embeddings=self.config.tie_word_embeddings, -- ) -- ++ @property ++ def stacked_params_mapping(self) -> List[Tuple[str, str, str]]: ++ return [ ++ # (param_name, shard_name, shard_id) ++ ("qkv_proj", "q_proj", "q"), ++ ("qkv_proj", "k_proj", "k"), ++ ("qkv_proj", "v_proj", "v"), ++ ("gate_up_proj", "gate_proj", 0), ++ ("gate_up_proj", "up_proj", 1), ++ ] ++ ++ def _load_weights_with_worker( ++ self, ++ params: Dict[str, torch.nn.Parameter], ++ local_names: List[str], ++ filenames: List[str], ++ weight_path: str, ++ ): ++ import os ++ ++ from safetensors import safe_open ++ ++ from sglang.srt.model_loader.weight_utils import default_weight_loader ++ ++ all_slices = {} ++ for filename in filenames: ++ safetensor_file = os.path.join(weight_path, filename) ++ with safe_open(safetensor_file, framework="pt", device="cpu") as f: ++ for name in f.keys(): ++ # all_slices[name] = f.get_slice(name) ++ all_slices[name] = f.get_tensor(name) ++ ++ for local_name in local_names: ++ # Skip loading extra bias for GPTQ models. ++ if local_name.endswith(".bias") and local_name not in params: ++ continue ++ # Handle special cases ++ if "rotary_emb.inv_freq" in local_name or "projector" in local_name: ++ continue ++ if ( ++ "rotary_emb.cos_cached" in local_name ++ or "rotary_emb.sin_cached" in local_name ++ ): ++ # Models trained using ColossalAI may include these tensors in ++ # the checkpoint. Skip them. ++ continue ++ if local_name.startswith("model.vision_tower") and local_name not in params: ++ continue ++ ++ param = params[local_name] ++ # Handle weight tying ++ if self.config.tie_word_embeddings and "lm_head.weight" in local_name: ++ if self.pp_group.world_size > 1 and self.pp_group.is_last_rank: ++ local_name = "model.embed_tokens.weight" ++ ++ loaded = False ++ for param_name, shard_name, shard_id in self.stacked_params_mapping: ++ if param_name not in local_name: ++ continue ++ # If local_name weight is sharded into multiple keys ++ weight_loader = param.weight_loader ++ slice_name = local_name.replace(param_name, shard_name) ++ loaded_weight = all_slices[slice_name] ++ weight_loader(param, loaded_weight, shard_id) ++ loaded = True ++ ++ if not loaded: ++ # If local_name weight is not sharded ++ if local_name in all_slices: ++ loaded_weight = all_slices[local_name] ++ weight_loader = getattr( ++ param, "weight_loader", default_weight_loader ++ ) ++ weight_loader(param, loaded_weight) ++ else: ++ raise KeyError( ++ f"Cannot find weight {local_name} in the loaded slices." ++ ) ++ ++ def load_weights_from_path(self, path: str): ++ # Customized weights loading from a given path of huggingface model ++ from sglang.srt.models.utils.load import load_weights_with_hf_path_fast ++ ++ load_weights_with_hf_path_fast( ++ model=self, ++ weight_path=path, ++ load_weights_with_worker_fn=self._load_weights_with_worker, ++ stacked_params_mapping=self.stacked_params_mapping, ++ tie_word_embeddings=self.config.tie_word_embeddings, ++ ) ++ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) -@@ -559,11 +468,11 @@ +@@ -468,11 +559,11 @@ for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue -- _name = name.replace(weight_name, param_name) -+ name = name.replace(weight_name, param_name) +- name = name.replace(weight_name, param_name) ++ _name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. -- if _name.endswith(".bias") and _name not in params_dict: -+ if name.endswith(".bias") and name not in params_dict: +- if name.endswith(".bias") and name not in params_dict: ++ if _name.endswith(".bias") and _name not in params_dict: continue -- param = params_dict[_name] -+ param = params_dict[name] +- param = params_dict[name] ++ param = params_dict[_name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break -diff -ruN ../../../my-sglang/python/sglang/srt/models/utils/load.py ./srt/models/utils/load.py ---- ../../../my-sglang/python/sglang/srt/models/utils/load.py 2025-10-11 15:01:36.754004483 +0800 -+++ ./srt/models/utils/load.py 1970-01-01 08:00:00.000000000 +0800 -@@ -1,170 +0,0 @@ --import json --import os --from collections import defaultdict --from concurrent.futures import ThreadPoolExecutor --from glob import glob --from typing import Callable, Dict, List, Tuple -- --import torch --from safetensors import safe_open --from transformers.utils.hub import cached_file -- --from sglang.srt.model_loader.weight_utils import default_weight_loader -- -- --def get_actual_hf_path(weight_path: str): -- return os.path.dirname(cached_file(weight_path, "config.json")) -- -- --def make_filename_bins( -- local_to_file_map: Dict[str, List[str]], --) -> Tuple[List[List[str]], List[List[str]]]: -- # Allocate local weight name into bins, where each bin access independent files -- # Then we can use multiple threads to concurrently load each bin's parameters. -- # This function has a complexity of O(F + L²) -- # where F = total number of files, L = number of local names -- if not local_to_file_map: -- return [], [] -- -- local_names = list(local_to_file_map.keys()) -- n = len(local_names) -- -- # Convert file lists to sets for O(1) lookups and create file-to-locals mapping -- local_to_files = {name: set(local_to_file_map[name]) for name in local_names} -- file_to_locals = defaultdict(set) -- for local_name, files in local_to_files.items(): -- for file in files: -- file_to_locals[file].add(local_name) -- -- # Union-Find with path compression and union by rank -- parent = list(range(n)) -- rank = [0] * n -- -- def find(x): -- if parent[x] != x: -- parent[x] = find(parent[x]) # Path compression -- return parent[x] -- -- def union(x, y): -- root_x, root_y = find(x), find(y) -- if root_x == root_y: -- return -- -- # Union by rank -- if rank[root_x] < rank[root_y]: -- root_x, root_y = root_y, root_x -- parent[root_y] = root_x -- if rank[root_x] == rank[root_y]: -- rank[root_x] += 1 -- -- # Create name-to-index mapping for O(1) lookups -- name_to_idx = {name: i for i, name in enumerate(local_names)} -- -- # Union locals that share files - O(F) where F is total number of files -- for locals_sharing_file in file_to_locals.values(): -- if len(locals_sharing_file) > 1: -- locals_list = list(locals_sharing_file) -- first_idx = name_to_idx[locals_list[0]] -- for local_name in locals_list[1:]: -- union(first_idx, name_to_idx[local_name]) -- -- # Group by root - O(L) -- root_to_group = defaultdict(list) -- for i, name in enumerate(local_names): -- root_to_group[find(i)].append(name) -- -- # Build result groups - O(L + F) -- grouped_local_names = [] -- grouped_filenames = [] -- -- for group in root_to_group.values(): -- grouped_local_names.append(group) -- # Use set union to merge files from all locals in group -- all_files = set() -- for local_name in group: -- all_files.update(local_to_files[local_name]) -- grouped_filenames.append(list(all_files)) -- -- return grouped_local_names, grouped_filenames -- -- --def load_weights_with_hf_path_fast( -- model: torch.nn.Module, -- weight_path: str, -- load_weights_with_worker_fn: Callable, -- stacked_params_mapping: List[Tuple[str, str, str]] | None = None, -- expert_params_mapping: List[Tuple[str, str, str]] | None = None, -- tie_word_embeddings: bool = False, -- max_workers: int = None, --): -- if not os.path.exists(weight_path): -- weight_path = get_actual_hf_path(weight_path) -- index_file = os.path.join(weight_path, "model.safetensors.index.json") -- index = {} -- if os.path.exists(index_file): -- with open(index_file, "r") as f: -- index = json.load(f)["weight_map"] -- else: -- # Search all safetensors files -- safetensor_files = glob(os.path.join(weight_path, "*.safetensors")) -- # If there are safetensors files -- if safetensor_files: -- # Iterate through each safetensors file -- for safetensor_file in safetensor_files: -- with safe_open(safetensor_file, framework="pt", device="cpu") as f: -- for k in f.keys(): -- index[k] = safetensor_file -- else: -- raise FileNotFoundError("No safetensors found in the model path to load.") -- -- params = dict(model.named_parameters()) -- local_names = list(params.keys()) -- -- worker_args = [] -- -- # local name -> set of filenames that contains the weight -- local_to_file_map = defaultdict(set) -- # model.layers.31.mlp.experts -- for local_name in local_names: -- hf_names = [] -- if "mlp.experts" not in local_name and stacked_params_mapping is not None: -- for param_name, shard_name, _ in stacked_params_mapping: -- if param_name in local_name: -- hf_names.append(local_name.replace(param_name, shard_name)) -- if expert_params_mapping is not None: -- for param_name, shard_name, _, _ in expert_params_mapping: -- if param_name in local_name: -- hf_names.append(local_name.replace(param_name, shard_name)) -- if tie_word_embeddings and "lm_head.weight" in local_name: -- hf_names.append("model.embed_tokens.weight") -- if len(hf_names) == 0: -- hf_names.append(local_name) -- for name in hf_names: -- filename = index[name] -- if filename not in local_to_file_map[local_name]: -- local_to_file_map[local_name].add(filename) -- -- grouped_local_names, grouped_filenames = make_filename_bins(local_to_file_map) -- -- if max_workers is None: -- # assume all GPUs are used by SGLang servers -- max_workers = min(8, max(1, os.cpu_count() // torch.cuda.device_count())) -- -- for local_names, filenames in zip(grouped_local_names, grouped_filenames): -- worker_args.append( -- dict( -- params=params, -- local_names=local_names, -- filenames=filenames, -- weight_path=weight_path, -- ) -- ) -- -- max_workers = min(max_workers, len(worker_args)) -- with ThreadPoolExecutor(max_workers=max_workers) as executor: -- results = executor.map( -- lambda kwargs: load_weights_with_worker_fn(**kwargs), worker_args -- ) -- # Consume all results to make result all tasks complete -- for _ in results: -- pass -diff -ruN ../../../my-sglang/python/sglang/srt/utils.py ./srt/utils.py ---- ../../../my-sglang/python/sglang/srt/utils.py 2025-10-11 15:03:52.182903128 +0800 -+++ ./srt/utils.py 2025-10-13 13:20:52.081418051 +0800 -@@ -21,7 +21,6 @@ +diff -ruN ./srt/models/utils/load.py ../../../my-sglang/python/sglang/srt/models/utils/load.py +--- ./srt/models/utils/load.py 1970-01-01 08:00:00.000000000 +0800 ++++ ../../../my-sglang/python/sglang/srt/models/utils/load.py 2025-10-11 15:01:36.754004483 +0800 +@@ -0,0 +1,170 @@ ++import json ++import os ++from collections import defaultdict ++from concurrent.futures import ThreadPoolExecutor ++from glob import glob ++from typing import Callable, Dict, List, Tuple ++ ++import torch ++from safetensors import safe_open ++from transformers.utils.hub import cached_file ++ ++from sglang.srt.model_loader.weight_utils import default_weight_loader ++ ++ ++def get_actual_hf_path(weight_path: str): ++ return os.path.dirname(cached_file(weight_path, "config.json")) ++ ++ ++def make_filename_bins( ++ local_to_file_map: Dict[str, List[str]], ++) -> Tuple[List[List[str]], List[List[str]]]: ++ # Allocate local weight name into bins, where each bin access independent files ++ # Then we can use multiple threads to concurrently load each bin's parameters. ++ # This function has a complexity of O(F + L²) ++ # where F = total number of files, L = number of local names ++ if not local_to_file_map: ++ return [], [] ++ ++ local_names = list(local_to_file_map.keys()) ++ n = len(local_names) ++ ++ # Convert file lists to sets for O(1) lookups and create file-to-locals mapping ++ local_to_files = {name: set(local_to_file_map[name]) for name in local_names} ++ file_to_locals = defaultdict(set) ++ for local_name, files in local_to_files.items(): ++ for file in files: ++ file_to_locals[file].add(local_name) ++ ++ # Union-Find with path compression and union by rank ++ parent = list(range(n)) ++ rank = [0] * n ++ ++ def find(x): ++ if parent[x] != x: ++ parent[x] = find(parent[x]) # Path compression ++ return parent[x] ++ ++ def union(x, y): ++ root_x, root_y = find(x), find(y) ++ if root_x == root_y: ++ return ++ ++ # Union by rank ++ if rank[root_x] < rank[root_y]: ++ root_x, root_y = root_y, root_x ++ parent[root_y] = root_x ++ if rank[root_x] == rank[root_y]: ++ rank[root_x] += 1 ++ ++ # Create name-to-index mapping for O(1) lookups ++ name_to_idx = {name: i for i, name in enumerate(local_names)} ++ ++ # Union locals that share files - O(F) where F is total number of files ++ for locals_sharing_file in file_to_locals.values(): ++ if len(locals_sharing_file) > 1: ++ locals_list = list(locals_sharing_file) ++ first_idx = name_to_idx[locals_list[0]] ++ for local_name in locals_list[1:]: ++ union(first_idx, name_to_idx[local_name]) ++ ++ # Group by root - O(L) ++ root_to_group = defaultdict(list) ++ for i, name in enumerate(local_names): ++ root_to_group[find(i)].append(name) ++ ++ # Build result groups - O(L + F) ++ grouped_local_names = [] ++ grouped_filenames = [] ++ ++ for group in root_to_group.values(): ++ grouped_local_names.append(group) ++ # Use set union to merge files from all locals in group ++ all_files = set() ++ for local_name in group: ++ all_files.update(local_to_files[local_name]) ++ grouped_filenames.append(list(all_files)) ++ ++ return grouped_local_names, grouped_filenames ++ ++ ++def load_weights_with_hf_path_fast( ++ model: torch.nn.Module, ++ weight_path: str, ++ load_weights_with_worker_fn: Callable, ++ stacked_params_mapping: List[Tuple[str, str, str]] | None = None, ++ expert_params_mapping: List[Tuple[str, str, str]] | None = None, ++ tie_word_embeddings: bool = False, ++ max_workers: int = None, ++): ++ if not os.path.exists(weight_path): ++ weight_path = get_actual_hf_path(weight_path) ++ index_file = os.path.join(weight_path, "model.safetensors.index.json") ++ index = {} ++ if os.path.exists(index_file): ++ with open(index_file, "r") as f: ++ index = json.load(f)["weight_map"] ++ else: ++ # Search all safetensors files ++ safetensor_files = glob(os.path.join(weight_path, "*.safetensors")) ++ # If there are safetensors files ++ if safetensor_files: ++ # Iterate through each safetensors file ++ for safetensor_file in safetensor_files: ++ with safe_open(safetensor_file, framework="pt", device="cpu") as f: ++ for k in f.keys(): ++ index[k] = safetensor_file ++ else: ++ raise FileNotFoundError("No safetensors found in the model path to load.") ++ ++ params = dict(model.named_parameters()) ++ local_names = list(params.keys()) ++ ++ worker_args = [] ++ ++ # local name -> set of filenames that contains the weight ++ local_to_file_map = defaultdict(set) ++ # model.layers.31.mlp.experts ++ for local_name in local_names: ++ hf_names = [] ++ if "mlp.experts" not in local_name and stacked_params_mapping is not None: ++ for param_name, shard_name, _ in stacked_params_mapping: ++ if param_name in local_name: ++ hf_names.append(local_name.replace(param_name, shard_name)) ++ if expert_params_mapping is not None: ++ for param_name, shard_name, _, _ in expert_params_mapping: ++ if param_name in local_name: ++ hf_names.append(local_name.replace(param_name, shard_name)) ++ if tie_word_embeddings and "lm_head.weight" in local_name: ++ hf_names.append("model.embed_tokens.weight") ++ if len(hf_names) == 0: ++ hf_names.append(local_name) ++ for name in hf_names: ++ filename = index[name] ++ if filename not in local_to_file_map[local_name]: ++ local_to_file_map[local_name].add(filename) ++ ++ grouped_local_names, grouped_filenames = make_filename_bins(local_to_file_map) ++ ++ if max_workers is None: ++ # assume all GPUs are used by SGLang servers ++ max_workers = min(8, max(1, os.cpu_count() // torch.cuda.device_count())) ++ ++ for local_names, filenames in zip(grouped_local_names, grouped_filenames): ++ worker_args.append( ++ dict( ++ params=params, ++ local_names=local_names, ++ filenames=filenames, ++ weight_path=weight_path, ++ ) ++ ) ++ ++ max_workers = min(max_workers, len(worker_args)) ++ with ThreadPoolExecutor(max_workers=max_workers) as executor: ++ results = executor.map( ++ lambda kwargs: load_weights_with_worker_fn(**kwargs), worker_args ++ ) ++ # Consume all results to make result all tasks complete ++ for _ in results: ++ pass +diff -ruN ./srt/utils.py ../../../my-sglang/python/sglang/srt/utils.py +--- ./srt/utils.py 2025-10-13 13:20:52.081418051 +0800 ++++ ../../../my-sglang/python/sglang/srt/utils.py 2025-10-11 15:03:52.182903128 +0800 +@@ -21,6 +21,7 @@ import dataclasses import functools import importlib --import inspect ++import inspect import io import ipaddress import itertools