[rollout] fix: guard sglang profiling when self.tokenizer_manager is …#6217
[rollout] fix: guard sglang profiling when self.tokenizer_manager is …#6217LeiDing191 wants to merge 1 commit intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds safety checks for the tokenizer_manager attribute in the start_profile and stop_profile methods to prevent errors when it is uninitialized. Feedback suggests replacing the getattr and None checks with a node_rank == 0 guard to maintain consistency with the class's existing patterns and reduce log noise on non-primary nodes.
| tokenizer_manager = getattr(self, "tokenizer_manager", None) | ||
| if tokenizer_manager is None: | ||
| logger.warning( | ||
| "Skip SGLang start_profile because tokenizer_manager is None, " | ||
| f"replica_rank={self.replica_rank}, node_rank={self.node_rank}" | ||
| ) | ||
| return | ||
| await tokenizer_manager.start_profile(**profile_args) |
There was a problem hiding this comment.
The use of getattr and checking for None with a warning is inconsistent with the established pattern in this class (e.g., in wake_up, sleep, and clear_kv_cache), which uses self.node_rank == 0 to guard operations that only apply to the primary node.
Furthermore, logging a warning on every non-zero rank node during profiling is noisy in multi-node setups where this is the expected state. It is better to use the node_rank == 0 check and direct attribute access, as tokenizer_manager is guaranteed to be initialized on the primary node after launch_server completes.
| tokenizer_manager = getattr(self, "tokenizer_manager", None) | |
| if tokenizer_manager is None: | |
| logger.warning( | |
| "Skip SGLang start_profile because tokenizer_manager is None, " | |
| f"replica_rank={self.replica_rank}, node_rank={self.node_rank}" | |
| ) | |
| return | |
| await tokenizer_manager.start_profile(**profile_args) | |
| if self.node_rank == 0: | |
| await self.tokenizer_manager.start_profile(**profile_args) |
| tokenizer_manager = getattr(self, "tokenizer_manager", None) | ||
| if tokenizer_manager is None: | ||
| logger.warning( | ||
| "Skip SGLang stop_profile because tokenizer_manager is None, " | ||
| f"replica_rank={self.replica_rank}, node_rank={self.node_rank}" | ||
| ) | ||
| return | ||
| await tokenizer_manager.stop_profile() |
There was a problem hiding this comment.
Similar to start_profile, using self.node_rank == 0 is more consistent with the rest of the file and avoids unnecessary log noise on non-zero rank nodes.
| tokenizer_manager = getattr(self, "tokenizer_manager", None) | |
| if tokenizer_manager is None: | |
| logger.warning( | |
| "Skip SGLang stop_profile because tokenizer_manager is None, " | |
| f"replica_rank={self.replica_rank}, node_rank={self.node_rank}" | |
| ) | |
| return | |
| await tokenizer_manager.stop_profile() | |
| if self.node_rank == 0: | |
| await self.tokenizer_manager.stop_profile() |
What does this PR do?
Fixes a multi-node SGLang rollout profiling crash when a non-zero-node SGLang server actor does not have a tokenizer manager.
start_profile()/stop_profile()now skip tokenizer-manager profiling on server actors whosetokenizer_manageris unavailable, while preserving the existing behavior on the node that owns the tokenizer manager.Why?
RolloutReplica.start_profile()broadcasts the profiling request to all rollout server actors:For each multi-node SGLang rollout replica,
SGLangReplica.launch_servers()creates oneSGLangHttpServeractor for eachnode_rankinrange(self.nnodes), and appends all of them toself.servers:However, in SGLang's server launch flow, non-zero
node_rankprocesses do not initialize a tokenizer manager.The issue was reproduced with SGLang
v0.5.6. In SGLangv0.5.6,_launch_subprocesses(...)returns early for non-zero node ranks after the scheduler subprocesses are ready:So, on non-zero node ranks, SGLang intentionally skips tokenizer/detokenizer initialization and can return
tokenizer_manager=None.I also checked later SGLang versions after
v0.5.6:v0.5.7,v0.5.8,v0.5.9, and upstreammainon 2026-04-30. The same launch design is still present. Inv0.5.9, which SGLang docs list as the last release branch at the time of checking, non-zero node ranks still return before detokenizer/tokenizer-manager initialization:The current SGLang
mainbranch has changed the return tuple shape again, but the relevant behavior is still the same: non-zero node ranks returntokenizer_manager=Nonebefore detokenizer/tokenizer-manager initialization:Therefore, when verl broadcasts
start_profile()to everySGLangHttpServer, the non-zero-node actor can hit:with:
which raises:
This PR adds a guard so that only server actors with an initialized tokenizer manager call SGLang's tokenizer-manager profiling API.
Test
python -m py_compile verl/workers/rollout/sglang_rollout/async_sglang_server.pyReproduced before the fix with a 2-node SGLang rollout profiling run:
Before this fix, profiling failed when
start_profile()was dispatched to the non-zero-node SGLang server actor.Sanitized traceback from the failed run:
After this fix, the actor without a tokenizer manager skips tokenizer-manager profiling instead of failing the training step.
API and Usage Example
No API change.
Design & Code Changes
tokenizer_manager is Noneguard inSGLangHttpServer.start_profile().SGLangHttpServer.stop_profile().replica_rankandnode_rankwhen profiling is skipped for easier debugging.