introduce device_context to simplify code.#3875
introduce device_context to simplify code.#3875ykaitao wants to merge 1 commit intounslothai:mainfrom
Conversation
Summary of ChangesHello @ykaitao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors device-specific operations by introducing a Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: c8a4746caa
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
There was a problem hiding this comment.
Code Review
This pull request introduces a DeviceContext class to encapsulate and simplify device-specific operations, which is a great refactoring for code clarity and maintainability. The changes in unsloth/models/llama.py correctly adopt this new abstraction.
My review includes a critical fix for a typo that would cause a NameError and a suggestion to improve exception handling for better robustness. Overall, this is a positive change that improves the codebase structure.
c8a4746 to
fce9504
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a DeviceContext class to encapsulate device-specific logic, which significantly simplifies the code in unsloth/models/llama.py. The refactoring is well-done and improves maintainability by centralizing device-related information and operations. I have a couple of suggestions to further improve the code's clarity and robustness.
fce9504 to
eb3d918
Compare
|
Hi @ykaitao, there is similar statistics logic in vision.py. Would you be able to add your simplifications there as well? Thank you! |
Hi @mmathew23 , yes I can do it later today. |
eb3d918 to
4a2c058
Compare
|
/gemini review |
Hi @mmathew23 , I have handled the similar statistics logic in vision.py. Please review again. |
There was a problem hiding this comment.
Code Review
This pull request introduces a DeviceContext class to encapsulate device-specific logic, which is an excellent refactoring. It significantly simplifies the code in unsloth/models/llama.py and unsloth/models/vision.py by removing duplicated device-checking logic and centralizing it. This greatly improves code clarity and maintainability. My review includes a few suggestions to further enhance the new DeviceContext class for better code quality and to improve exception handling in the model files.
8baf20c to
156d32c
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a DeviceContext class to encapsulate device-specific logic, which is a great improvement. This refactoring simplifies the code in unsloth/models/llama.py and unsloth/models/vision.py by removing duplicated if/elif/else blocks for handling different device types (CUDA, HIP, XPU), enhancing code readability and maintainability. I've found one critical issue regarding module loading that could break the library for users without an XPU-enabled PyTorch build, and I've provided a suggestion to fix it.
f586eb4 to
0e62bc7
Compare
|
Hi Team, I have addressed all the comments. Can this PR be merged? Or do you have additional comments? @mmathew23 @Datta0 @danielhanchen |
danielhanchen
left a comment
There was a problem hiding this comment.
Tested on 1x NVIDIA B200 (compute capability 10.0, CUDA 12.8, torch 2.9.1, transformers 4.57.6, vLLM 0.15.1) with 10 notebooks:
| Notebook | Result |
|---|---|
| Llama3.1_8B_Alpaca | PASS |
| Qwen3_14B_Reasoning_Conversational | PASS |
| Phi_4_Conversational | PASS |
| Mistral_v0.3_7B_Conversational | PASS |
| Qwen3_VL_8B_Vision | PASS |
| Gemma3_4B_Vision | PASS |
| bert_classification | PASS |
| gpt_oss_20B_GRPO | FAIL (pre-existing: GptOssExperts.down_projs -- model architecture issue) |
| Qwen3_4B_GRPO | FAIL (pre-existing: device_synchronize NameError in compiled cache) |
| Whisper | FAIL (pre-existing: torchcodec/ffmpeg env) |
No regressions on NVIDIA. All 7 testable notebooks pass. The 3 failures are pre-existing issues unrelated to this PR.
Code review note for HIP/AMD: The refactoring drops resolve_hip_gpu_stats_name() which was used for AMD GPU name resolution. This is a regression for ROCm users -- the old logic in llama.py had a 30-line function handling HIP-specific GPU name fallbacks. If ROCm support is intended, this needs to be preserved in device_type.py.
The core refactoring (moving clean_gpu_cache/get_current_device to device_type.py) is clean and sound. Import paths are backwards-compatible via re-export in llama.py.
0e62bc7 to
bf557e6
Compare
|
Hi Team, I have resolved the conflicts. @mmathew23 @Datta0 @danielhanchen |
|
/gemini review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: bf557e64b9
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| """Encapsulates device-specific operations for XPU/HIP/CUDA.""" | ||
|
|
||
| def __init__(self, device_type: str = DEVICE_TYPE) -> None: | ||
| DEVICE_MODULE_MAP = {"xpu": torch.xpu, "cuda": torch.cuda, "hip": torch.cuda} |
There was a problem hiding this comment.
Avoid eager torch.xpu lookup in device module map
Constructing DEVICE_MODULE_MAP with "xpu": torch.xpu eagerly dereferences torch.xpu even when running on CUDA/HIP, so CUDA-only PyTorch builds that do not expose an xpu attribute will fail during import with AttributeError before device_type is checked. This is a startup regression (the file already treats xpu as optional via hasattr(torch, "xpu") in get_device_type()), and it can block all model initialization on those environments.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Code Review
This pull request introduces a DeviceContext class to encapsulate and simplify device-specific operations, which is a great improvement for code clarity and maintainability. The changes in llama.py and vision.py effectively leverage this new class, removing duplicated code. My review includes a few suggestions to further improve the DeviceContext class by moving method-level dictionaries to class-level constants for better organization, and to make exception handling more specific.
| class DeviceContext: | ||
| """Encapsulates device-specific operations for XPU/HIP/CUDA.""" | ||
|
|
||
| def __init__(self, device_type: str = DEVICE_TYPE) -> None: | ||
| DEVICE_MODULE_MAP = {"xpu": torch.xpu, "cuda": torch.cuda, "hip": torch.cuda} | ||
| if device_type not in DEVICE_MODULE_MAP: | ||
| raise ValueError(f"Unsloth: Unsupported device type: {device_type}") | ||
| self.device_type = device_type | ||
| # Cache the torch module for this device | ||
| self.torch_module = DEVICE_MODULE_MAP[device_type] | ||
|
|
||
| def get_stats(self) -> tuple[str, str, float]: | ||
| """Return (name, stats_snippet, max_memory_gb).""" | ||
| gpu_stats = self.torch_module.get_device_properties(0) | ||
| max_mem = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) | ||
|
|
||
| # Device name | ||
| name = gpu_stats.name + ". " if gpu_stats.name else self._get_default_name() | ||
|
|
||
| # Toolkit snippet | ||
| snippet = self._get_toolkit_snippet(gpu_stats) | ||
|
|
||
| return name, snippet, max_mem | ||
|
|
||
| def _get_default_name(self) -> str: | ||
| """Get default device name when props.name is empty.""" | ||
| names = {"xpu": "Intel XPU", "cuda": "NVIDIA GPU", "hip": "AMD GPU"} | ||
| return names[self.device_type] + " Device. " |
There was a problem hiding this comment.
For better code organization and to avoid re-creating dictionaries on each method call, it's a good practice to define DEVICE_MODULE_MAP and names as class-level constants.
class DeviceContext:
"""Encapsulates device-specific operations for XPU/HIP/CUDA."""
DEVICE_MODULE_MAP = {"xpu": torch.xpu, "cuda": torch.cuda, "hip": torch.cuda}
_DEFAULT_NAMES = {"xpu": "Intel XPU", "cuda": "NVIDIA GPU", "hip": "AMD GPU"}
def __init__(self, device_type: str = DEVICE_TYPE) -> None:
if device_type not in self.DEVICE_MODULE_MAP:
raise ValueError(f"Unsloth: Unsupported device type: {device_type}")
self.device_type = device_type
# Cache the torch module for this device
self.torch_module = self.DEVICE_MODULE_MAP[device_type]
def get_stats(self) -> tuple[str, str, float]:
"""Return (name, stats_snippet, max_memory_gb)."""
gpu_stats = self.torch_module.get_device_properties(0)
max_mem = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
# Device name
name = gpu_stats.name + ". " if gpu_stats.name else self._get_default_name()
# Toolkit snippet
snippet = self._get_toolkit_snippet(gpu_stats)
return name, snippet, max_mem
def _get_default_name(self) -> str:
"""Get default device name when props.name is empty."""
return self._DEFAULT_NAMES[self.device_type] + " Device. "| max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) | ||
| try: | ||
| vllm_version = f" vLLM: {importlib_version('vllm')}." | ||
| except: |
There was a problem hiding this comment.
| gpu_stats_name, gpu_stats_snippet, max_memory = device_context.get_stats() | ||
| try: | ||
| vllm_version = f" vLLM: {importlib_version('vllm')}." | ||
| except: |
There was a problem hiding this comment.
No description provided.