Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
78fb614
add patch to optimize sglang loading
nuzant Sep 11, 2025
f75fdcd
merge
nuzant Oct 10, 2025
e05bf74
.
nuzant Oct 10, 2025
d408ba8
.
nuzant Oct 10, 2025
5f8e990
.
nuzant Oct 10, 2025
8102814
.
nuzant Oct 10, 2025
5664c12
Merge remote-tracking branch 'origin/mzy/optimize-sglang-load' into m…
nuzant Oct 10, 2025
635fad3
.
nuzant Oct 10, 2025
f62ecc3
Merge remote-tracking branch 'origin/mzy/optimize-sglang-load' into m…
nuzant Oct 10, 2025
677bd21
filename bins
nuzant Oct 11, 2025
478af67
Merge remote-tracking branch 'origin/mzy/optimize-sglang-load' into m…
nuzant Oct 11, 2025
69043aa
.
nuzant Oct 11, 2025
b097673
.
nuzant Oct 11, 2025
74d2580
remove debug info
nuzant Oct 11, 2025
9bab814
patch to 0.5.2
nuzant Oct 11, 2025
8e47d93
Merge remote-tracking branch 'origin/mzy/optimize-sglang-load' into m…
nuzant Oct 11, 2025
0c2f8b3
Merge remote-tracking branch 'origin/main' into mzy/optimize-sglang-load
nuzant Oct 11, 2025
368c761
fix gemini comments
nuzant Oct 11, 2025
a0e6c10
Merge branch 'mzy/optimize-sglang-load' into mzy/antcode/optimize-sgl…
nuzant Oct 11, 2025
6db171b
Merge remote-tracking branch 'origin/main' into mzy/optimize-sglang-load
nuzant Oct 13, 2025
26c871f
use patch instead of git apply
nuzant Oct 13, 2025
13d01dd
Merge branch 'mzy/optimize-sglang-load' into mzy/antcode/optimize-sgl…
nuzant Oct 13, 2025
ff21d52
.
nuzant Oct 13, 2025
028d1f3
.
nuzant Oct 13, 2025
1dd84bb
Merge remote-tracking branch 'origin/main' into mzy/antcode/optimize-…
nuzant Oct 13, 2025
c38a426
add debug info
nuzant Oct 13, 2025
50c328a
.
nuzant Oct 13, 2025
83e548f
remove debug info
nuzant Oct 13, 2025
4148945
.
nuzant Oct 13, 2025
4349c1c
fix update weight from disk
nuzant Oct 13, 2025
683f95f
format
nuzant Oct 13, 2025
1848ea0
.
nuzant Oct 13, 2025
e078360
use diff instead of git diff
nuzant Oct 13, 2025
20ca314
fix patch
nuzant Oct 13, 2025
c2f13b4
.
nuzant Oct 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import json
import os
from dataclasses import asdict, dataclass, field
from pathlib import Path
Expand Down Expand Up @@ -599,6 +600,11 @@ class SGLangConfig:
# The interval (in decoding iterations) to log throughput
# and update prometheus metrics
decode_log_interval: int = 1
# Extra loader arguments
# NOTE: These arguments will be parsed into a dict json-string
# and passed as `model_loader_extra_config` to SGLang.
enable_multithread_load: bool = False
enable_fast_load: bool = False

# Use staticmethod to make OmegaConf happy.
@staticmethod
Expand Down Expand Up @@ -649,6 +655,19 @@ def build_args(
):
# Map "all-linear" to "all"
args: Dict = conf_as_dict(sglang_config)
if sglang_config.enable_multithread_load or sglang_config.enable_fast_load:
assert pkg_version.is_version_equal(
"sglang", "0.5.2"
), f"Customized model loading requires exact SGLang version 0.5.2"
model_loader_extra_config = dict(
enable_multithread_load=sglang_config.enable_multithread_load,
enable_fast_load=sglang_config.enable_fast_load,
)
args.pop("enable_multithread_load", None)
args.pop("enable_fast_load", None)
args["model_loader_extra_config"] = json.dumps(
model_loader_extra_config, separators=(",", ":")
)
# Map "all-linear" to "all"
if "lora_target_modules" in args and args["lora_target_modules"]:
args["lora_target_modules"] = [
Expand Down
9 changes: 6 additions & 3 deletions areal/experimental/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,8 +612,6 @@ def _update_weights_from_disk(self, meta: WeightUpdateMeta):
# dist.barrier() are called when _save_model_to_hf finished

if dist.get_rank() == 0:
fut.result()

update_name = names.update_weights_from_disk(
self.config.experiment_name,
self.config.trial_name,
Expand All @@ -623,6 +621,8 @@ def _update_weights_from_disk(self, meta: WeightUpdateMeta):
update_name, str(datetime.now().timestamp()), keepalive_ttl=120
)

fut.result()

dist.barrier(device_ids=[self.device.index])
current_platform.synchronize()

Expand All @@ -642,7 +642,10 @@ def connect_engine(self, engine: InferenceEngine, meta: WeightUpdateMeta):
)
self.rollout_engine = engine

if not self.weight_update_group_initialized:
if (
meta.type == current_platform.communication_backend
and not self.weight_update_group_initialized
):
self._init_weight_update_from_distributed(meta)
self.weight_update_group_initialized = True

Expand Down
5 changes: 4 additions & 1 deletion areal/launcher/sglang_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from areal.platforms import current_platform
from areal.utils import logging, name_resolve, names
from areal.utils.launcher import TRITON_CACHE_PATH
from areal.utils.launcher import TRITON_CACHE_PATH, apply_sglang_patch
from areal.utils.network import find_free_ports, gethostip

logger = logging.getLogger("SGLangServer Wrapper")
Expand Down Expand Up @@ -130,6 +130,9 @@ def __init__(
self.server_process = None
self.n_gpus_per_node = n_gpus_per_node

if self.config.enable_fast_load or self.config.enable_multithread_load:
apply_sglang_patch()
Comment on lines +133 to +134
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd better apply the patch in launchers instead of in sglang_server to avoid data race issues.

Copy link
Collaborator Author

@nuzant nuzant Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I am wrong, isn't one node expected to run only one areal.launcher.sglang_server instance now? So there will not be any data race issue.


def run(self):
gpus_per_server = self.allocation_mode.gen_instance_size
cross_nodes = False
Expand Down
69 changes: 68 additions & 1 deletion areal/utils/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
import getpass
import os
import pathlib
import shutil
import subprocess
import sys
import time
from pathlib import Path
from typing import Dict, Optional

from areal.api.alloc_mode import AllocationMode, AllocationType
from areal.utils import logging, name_resolve, names
from areal.utils import logging, name_resolve, names, pkg_version

logger = logging.getLogger("Launcher Utils")

Expand Down Expand Up @@ -154,3 +158,66 @@ def validate_config_for_distributed_launcher(config):
assert (
allocation_mode.gen.tp_size <= config.cluster.n_gpus_per_node
), "Currently only support vLLM TP size less <= #GPUs per node."


def apply_sglang_patch():
p = Path(os.path.dirname(__file__))
patch_path = str(
p.parent.parent
/ "patch"
/ "sglang"
/ f"v{pkg_version.get_version('sglang')}.patch"
)
target_path = None
sglang_meta = subprocess.check_output(
[sys.executable, "-m", "pip", "show", "sglang"]
).decode("utf-8")
# Prioritize editable install location, since pip show lists both locations
# if installed in editable mode.
for line in sglang_meta.split("\n"):
line = line.strip()
if line.startswith("Editable project location: "):
target_path = str(Path(line.split(": ")[1]) / "sglang")
break
else:
for line in sglang_meta.split("\n"):
line = line.strip()
if line.startswith("Location: "):
target_path = str(Path(line.split(": ")[1]) / "sglang")
break

if not target_path or not os.path.exists(target_path):
raise RuntimeError("Could not determine the installation path of SGLang.")

patch_binary = shutil.which("patch")
if not patch_binary:
raise RuntimeError(
"Could not locate the `patch` command; SGLang patch application failed."
)
result = subprocess.run(
[patch_binary, "-p1", "-N", "-i", patch_path],
cwd=target_path,
capture_output=True,
text=True,
)

output = (result.stdout or "") + (result.stderr or "")
if result.returncode == 0:
logger.info(f"Applied SGLang patch {patch_path} to {target_path}")
elif (
"Reversed (or previously applied) patch detected" in output
or "Skipping patch." in output
):
logger.warning(
f"SGLang patch {patch_path} appears to be already applied for {target_path}."
)
else:
logger.error(
"Failed to apply SGLang patch %s to %s. Output:\n%s",
patch_path,
target_path,
output.strip(),
)
raise RuntimeError(
f"SGLang patch {patch_path} failed with exit code {result.returncode}."
)
12 changes: 12 additions & 0 deletions areal/utils/pkg_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,15 @@ def is_version_less(package_name: str, target_version: str) -> bool:
"""
installed_version = get_version(package_name)
return compare_versions(installed_version, target_version) < 0


def is_version_equal(package_name: str, target_version: str) -> bool:
"""
Check if the installed version of a package is equal to the target version.

:param package_name: Name of the package.
:param target_version: Target version to compare against.
:return: True if the installed version is equal to the target version, False otherwise.
"""
installed_version = get_version(package_name)
return compare_versions(installed_version, target_version) == 0
2 changes: 2 additions & 0 deletions docs/cli_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,8 @@ https://github.com/sgl-project/sglang for detailed documentation.
| `show_time_cost` | boolean | `False` | - |
| `enable_metrics` | boolean | `True` | - |
| `decode_log_interval` | integer | `1` | - |
| `enable_multithread_load` | boolean | `False` | - |
| `enable_fast_load` | boolean | `False` | - |

(section-v-llm)=

Expand Down
Loading
Loading