Skip to content

[rollout] fix: guard sglang profiling when self.tokenizer_manager is …#6217

Open
LeiDing191 wants to merge 1 commit intoverl-project:mainfrom
LeiDing191:fix-sglang-profile-tokenizer_manager-none
Open

[rollout] fix: guard sglang profiling when self.tokenizer_manager is …#6217
LeiDing191 wants to merge 1 commit intoverl-project:mainfrom
LeiDing191:fix-sglang-profile-tokenizer_manager-none

Conversation

@LeiDing191
Copy link
Copy Markdown

What does this PR do?

Fixes a multi-node SGLang rollout profiling crash when a non-zero-node SGLang server actor does not have a tokenizer manager.

start_profile() / stop_profile() now skip tokenizer-manager profiling on server actors whose tokenizer_manager is unavailable, while preserving the existing behavior on the node that owns the tokenizer manager.

Why?

RolloutReplica.start_profile() broadcasts the profiling request to all rollout server actors:

await asyncio.gather(*[server.start_profile.remote(**kwargs) for server in self.servers])

For each multi-node SGLang rollout replica, SGLangReplica.launch_servers() creates one SGLangHttpServer actor for each node_rank in range(self.nnodes), and appends all of them to self.servers:

# create server actor in each node with node affinity and cuda visible devices
for node_rank in range(self.nnodes):
    ...
    server = self.server_class.options(
        scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
            node_id=node_id,
            soft=False,
        ),
        ...
    ).remote(
        ...
        node_rank=node_rank,
        nnodes=self.nnodes,
        ...
    )
    self.servers.append(server)

However, in SGLang's server launch flow, non-zero node_rank processes do not initialize a tokenizer manager.

The issue was reproduced with SGLang v0.5.6. In SGLang v0.5.6, _launch_subprocesses(...) returns early for non-zero node ranks after the scheduler subprocesses are ready:

if server_args.node_rank >= 1:
    # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer,
    # so they can just wait here.
    ...
    if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
        # When using `Engine` as a Python API, we don't want to block here.
        return None, None, None, port_args
    ...
    return None, None, None, port_args

# tokenizer manager is initialized only after the non-zero-node early return.
if server_args.tokenizer_worker_num == 1:
    tokenizer_manager, template_manager = _init_tokenizer_manager(
        server_args, port_args
    )
else:
    tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
    template_manager = None

So, on non-zero node ranks, SGLang intentionally skips tokenizer/detokenizer initialization and can return tokenizer_manager=None.

I also checked later SGLang versions after v0.5.6: v0.5.7, v0.5.8, v0.5.9, and upstream main on 2026-04-30. The same launch design is still present. In v0.5.9, which SGLang docs list as the last release branch at the time of checking, non-zero node ranks still return before detokenizer/tokenizer-manager initialization:

if server_args.node_rank >= 1:
    # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer,
    # so they can just wait here.
    ...
    if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
        # When using `Engine` as a Python API, we don't want to block here.
        return None, None, scheduler_infos, port_args
    ...
    return None, None, scheduler_infos, port_args

# tokenizer manager is initialized only after the non-zero-node early return.
if server_args.tokenizer_worker_num == 1:
    tokenizer_manager, template_manager = init_tokenizer_manager_func(
        server_args, port_args
    )
else:
    tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
    template_manager = None

The current SGLang main branch has changed the return tuple shape again, but the relevant behavior is still the same: non-zero node ranks return tokenizer_manager=None before detokenizer/tokenizer-manager initialization:

if server_args.node_rank >= 1:
    # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer,
    # so they can just wait here.
    ...
    if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
        # When using `Engine` as a Python API, we don't want to block here.
        return (
            None,
            None,
            port_args,
            scheduler_init_result,
            None,
        )
    ...
    return (
        None,
        None,
        port_args,
        scheduler_init_result,
        None,
    )

# tokenizer manager is initialized only after the non-zero-node early return.
if server_args.tokenizer_worker_num == 1:
    tokenizer_manager, template_manager = init_tokenizer_manager_func(
        server_args, port_args
    )
else:
    tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
    template_manager = None

Therefore, when verl broadcasts start_profile() to every SGLangHttpServer, the non-zero-node actor can hit:

await self.tokenizer_manager.start_profile(...)

with:

self.tokenizer_manager is None

which raises:

AttributeError: 'NoneType' object has no attribute 'start_profile'

This PR adds a guard so that only server actors with an initialized tokenizer manager call SGLang's tokenizer-manager profiling API.

Test

  • python -m py_compile verl/workers/rollout/sglang_rollout/async_sglang_server.py

Reproduced before the fix with a 2-node SGLang rollout profiling run:

actor_rollout_ref.rollout.name=sglang
actor_rollout_ref.rollout.tensor_model_parallel_size=16
actor_rollout_ref.rollout.expert_parallel_size=1
actor_rollout_ref.rollout.profiler.enable=True
trainer.n_gpus_per_node=8
trainer.nnodes=2
global_profiler.tool=torch
global_profiler.steps=[1]

Before this fix, profiling failed when start_profile() was dispatched to the non-zero-node SGLang server actor.

Sanitized traceback from the failed run:

ray.exceptions.RayTaskError(AttributeError): ray::TaskRunner.run()
  File ".../verl/trainer/ppo/ray_trainer.py", line 1483, in fit
    self.async_rollout_manager.start_profile()
  File ".../verl/experimental/agent_loop/agent_loop.py", line 1030, in start_profile
    self._run_all([replica.start_profile(**kwargs) for replica in self.rollout_replicas])
  File ".../verl/experimental/agent_loop/agent_loop.py", line 1038, in run_all
    await asyncio.gather(*tasks)
  File ".../verl/workers/rollout/replica.py", line 240, in start_profile
    await asyncio.gather(*[server.start_profile.remote() for server in self.servers])

ray.exceptions.RayTaskError(AttributeError): ray::SGLangHttpServer.start_profile()
  File ".../verl/workers/rollout/sglang_rollout/async_sglang_server.py", line 499, in start_profile
    await self.tokenizer_manager.start_profile(
AttributeError: 'NoneType' object has no attribute 'start_profile'

After this fix, the actor without a tokenizer manager skips tokenizer-manager profiling instead of failing the training step.

API and Usage Example

No API change.

Design & Code Changes

  • Add a tokenizer_manager is None guard in SGLangHttpServer.start_profile().
  • Add the same guard in SGLangHttpServer.stop_profile().
  • Log replica_rank and node_rank when profiling is skipped for easier debugging.

@LeiDing191 LeiDing191 requested a review from chenhaiq as a code owner April 30, 2026 03:10
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds safety checks for the tokenizer_manager attribute in the start_profile and stop_profile methods to prevent errors when it is uninitialized. Feedback suggests replacing the getattr and None checks with a node_rank == 0 guard to maintain consistency with the class's existing patterns and reduce log noise on non-primary nodes.

Comment on lines +571 to +578
tokenizer_manager = getattr(self, "tokenizer_manager", None)
if tokenizer_manager is None:
logger.warning(
"Skip SGLang start_profile because tokenizer_manager is None, "
f"replica_rank={self.replica_rank}, node_rank={self.node_rank}"
)
return
await tokenizer_manager.start_profile(**profile_args)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of getattr and checking for None with a warning is inconsistent with the established pattern in this class (e.g., in wake_up, sleep, and clear_kv_cache), which uses self.node_rank == 0 to guard operations that only apply to the primary node.

Furthermore, logging a warning on every non-zero rank node during profiling is noisy in multi-node setups where this is the expected state. It is better to use the node_rank == 0 check and direct attribute access, as tokenizer_manager is guaranteed to be initialized on the primary node after launch_server completes.

Suggested change
tokenizer_manager = getattr(self, "tokenizer_manager", None)
if tokenizer_manager is None:
logger.warning(
"Skip SGLang start_profile because tokenizer_manager is None, "
f"replica_rank={self.replica_rank}, node_rank={self.node_rank}"
)
return
await tokenizer_manager.start_profile(**profile_args)
if self.node_rank == 0:
await self.tokenizer_manager.start_profile(**profile_args)

Comment on lines +586 to +593
tokenizer_manager = getattr(self, "tokenizer_manager", None)
if tokenizer_manager is None:
logger.warning(
"Skip SGLang stop_profile because tokenizer_manager is None, "
f"replica_rank={self.replica_rank}, node_rank={self.node_rank}"
)
return
await tokenizer_manager.stop_profile()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to start_profile, using self.node_rank == 0 is more consistent with the rest of the file and avoids unnecessary log noise on non-zero rank nodes.

Suggested change
tokenizer_manager = getattr(self, "tokenizer_manager", None)
if tokenizer_manager is None:
logger.warning(
"Skip SGLang stop_profile because tokenizer_manager is None, "
f"replica_rank={self.replica_rank}, node_rank={self.node_rank}"
)
return
await tokenizer_manager.stop_profile()
if self.node_rank == 0:
await self.tokenizer_manager.stop_profile()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant