Skip to content

Conversation

chaojun-zhang
Copy link
Contributor

@chaojun-zhang chaojun-zhang commented Aug 7, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

When using Intel GPUs with Ray, the device_control_env for Ray is ONEAPI_BASE_SELECTOR, which requires identifiers starting with level_zero:. However, vLLM currently constructs device IDs without including level_zero. Modifying this would require extensive changes across multiple areas, including numerous unit tests (UT).

Alternatively, we have a more convenient device_control_env called ZE_AFFINITY_MASK. If we use this, we need to manually add it during ray.remote launches and remove the existing ONEAPI_BASE_SELECTOR from the environment variables. Otherwise, different (tp) workers could run on the same device and device oom.

Test Plan

VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py

Test Result

without this PR:
(EngineCore_0 pid=2085439) (RayWorkerWrapper pid=2071473) INFO 08-18 23:44:38 [xpu_worker.py:139] After memory profiling run, peak memory usage is 5582.46 MB,torch mem is 134.00 MB, non-torch mem is 5064.37 MB, free gpu memory is 60337.63 MB.
(MyLLM pid=2071464) (EngineCore_0 pid=2085439) INFO 08-18 23:44:39 [kv_cache_utils.py:829] GPU KV cache size: 1,173,696 tokens
(MyLLM pid=2071464) (EngineCore_0 pid=2085439) INFO 08-18 23:44:39 [kv_cache_utils.py:833] Maximum concurrency for 2,048 tokens per request: 573.09x
(MyLLM pid=2071464) (EngineCore_0 pid=2085439) INFO 08-18 23:44:39 [kv_cache_utils.py:829] GPU KV cache size: 1,185,344 tokens
(MyLLM pid=2071464) (EngineCore_0 pid=2085439) INFO 08-18 23:44:39 [kv_cache_utils.py:833] Maximum concurrency for 2,048 tokens per request: 578.78x
(MyLLM pid=2071464) (EngineCore_0 pid=2085439) INFO 08-18 23:44:39 [core.py:199] init engine (profile, create kv cache, warmup model) took 1.84 seconds
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) Process EngineCore_0:
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) Traceback (most recent call last):
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) self.run()
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) self._target(*self._args, **self._kwargs)
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) File "/home/chaojun/vllm/vllm/v1/engine/core.py", line 687, in run_engine_core
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) raise e
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) File "/home/chaojun/vllm/vllm/v1/engine/core.py", line 674, in run_engine_core
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) engine_core = EngineCoreProc(*args, **kwargs)
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) File "/home/chaojun/vllm/vllm/v1/engine/core.py", line 475, in init
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) super().init(vllm_config, executor_class, log_stats,
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) File "/home/chaojun/vllm/vllm/v1/engine/core.py", line 86, in init
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) self._initialize_kv_caches(vllm_config)
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) File "/home/chaojun/vllm/vllm/v1/engine/core.py", line 173, in _initialize_kv_caches
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) kv_cache_configs = [
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) File "/home/chaojun/vllm/vllm/v1/engine/core.py", line 174, in
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) File "/home/chaojun/vllm/vllm/v1/core/kv_cache_utils.py", line 1075, in get_kv_cache_config
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) File "/home/chaojun/vllm/vllm/v1/core/kv_cache_utils.py", line 662, in check_enough_kv_cache_memory
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) raise ValueError("No available memory for the cache blocks. "
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) ValueError: No available memory for the cache blocks. Try increasing gpu_memory_utilization when initializing the engine.
(MyLLM pid=2071463) (EngineCore_0 pid=2085443) (RayWorkerWrapper pid=2071468) INFO 08-18 23:44:22 [worker_base.py:591] Injected <class 'rlhf_utils.ColocateWorkerExtension'>

With this PR:
fix oom issue above.

(Optional) Documentation Update

@mergify mergify bot added the documentation Improvements or additions to documentation label Aug 7, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Intel XPUs to the RLHF co-location example. The changes involve generalizing device-specific environment variables and logic in both the example scripts and core vLLM components like the Ray executor and worker base.

My main feedback is on a potential bug in vllm/executor/ray_distributed_executor.py where ray_remote_kwargs is overwritten, which could break other functionalities like nsight profiling when using XPUs. I've provided a suggestion to fix this.

Overall, the changes are a good step towards broader hardware support in vLLM.

Copy link

github-actions bot commented Aug 7, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@chaojun-zhang chaojun-zhang marked this pull request as ready for review August 7, 2025 02:16
@chaojun-zhang chaojun-zhang changed the title [XPU] support rlhf colocate on XPU [XPU] support rlhf colocate with ray on XPU Aug 7, 2025
@chaojun-zhang chaojun-zhang changed the title [XPU] support rlhf colocate with ray on XPU [XPU] support ray distribute executor on XPU Aug 7, 2025
@chaojun-zhang chaojun-zhang marked this pull request as draft August 7, 2025 08:13
@chaojun-zhang chaojun-zhang marked this pull request as ready for review August 7, 2025 12:55
@chaojun-zhang chaojun-zhang changed the title [XPU] support ray distribute executor on XPU [XPU] Fix OOM when manually specifying ZE_AFFINITY_MASK with Ray distributed executor on XPU Aug 19, 2025
@chaojun-zhang chaojun-zhang force-pushed the upstream/rlhf_colate branch 2 times, most recently from f945416 to 27a6c4a Compare August 26, 2025 01:15
@chaojun-zhang
Copy link
Contributor Author

@simon-mo can you help review this pr?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants