Skip to content

introduce device_context to simplify code.#3875

Open
ykaitao wants to merge 1 commit intounslothai:mainfrom
ykaitao:ktyang_device_context
Open

introduce device_context to simplify code.#3875
ykaitao wants to merge 1 commit intounslothai:mainfrom
ykaitao:ktyang_device_context

Conversation

@ykaitao
Copy link
Contributor

@ykaitao ykaitao commented Jan 11, 2026

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @ykaitao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors device-specific operations by introducing a DeviceContext class. This new abstraction centralizes logic for handling different GPU types (CUDA, HIP, XPU), including retrieving device properties and managing GPU cache. The change significantly simplifies the codebase by eliminating redundant conditional checks and promoting a more modular design for device interaction.

Highlights

  • Centralized Device Management: Introduced a new DeviceContext class in unsloth/device_type.py to encapsulate and centralize device-specific operations for CUDA, HIP, and XPU. This class handles retrieving device properties, toolkit versions, and managing GPU cache.
  • Code Simplification and Readability: Replaced repetitive conditional logic for device statistics gathering and GPU cache management with calls to a singleton instance of the DeviceContext. This significantly reduces boilerplate code in unsloth/models/llama.py and improves overall readability.
  • Improved Maintainability: By consolidating device-related functions into a single class, future additions or modifications to device support (e.g., new GPU types or toolkit versions) will be more straightforward and require changes in fewer places.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c8a4746caa

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a DeviceContext class to encapsulate and simplify device-specific operations, which is a great refactoring for code clarity and maintainability. The changes in unsloth/models/llama.py correctly adopt this new abstraction.

My review includes a critical fix for a typo that would cause a NameError and a suggestion to improve exception handling for better robustness. Overall, this is a positive change that improves the codebase structure.

@ykaitao ykaitao force-pushed the ktyang_device_context branch from c8a4746 to fce9504 Compare January 11, 2026 05:31
@ykaitao
Copy link
Contributor Author

ykaitao commented Jan 11, 2026

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a DeviceContext class to encapsulate device-specific logic, which significantly simplifies the code in unsloth/models/llama.py. The refactoring is well-done and improves maintainability by centralizing device-related information and operations. I have a couple of suggestions to further improve the code's clarity and robustness.

@ykaitao ykaitao force-pushed the ktyang_device_context branch from fce9504 to eb3d918 Compare January 11, 2026 05:39
@mmathew23
Copy link
Collaborator

Hi @ykaitao, there is similar statistics logic in vision.py. Would you be able to add your simplifications there as well? Thank you!

@ykaitao
Copy link
Contributor Author

ykaitao commented Jan 14, 2026

Hi @ykaitao, there is similar statistics logic in vision.py. Would you be able to add your simplifications there as well? Thank you!

Hi @mmathew23 , yes I can do it later today.

@ykaitao ykaitao force-pushed the ktyang_device_context branch from eb3d918 to 4a2c058 Compare January 15, 2026 23:26
@ykaitao
Copy link
Contributor Author

ykaitao commented Jan 15, 2026

/gemini review

@ykaitao
Copy link
Contributor Author

ykaitao commented Jan 15, 2026

Hi @ykaitao, there is similar statistics logic in vision.py. Would you be able to add your simplifications there as well? Thank you!

Hi @mmathew23 , I have handled the similar statistics logic in vision.py. Please review again.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a DeviceContext class to encapsulate device-specific logic, which is an excellent refactoring. It significantly simplifies the code in unsloth/models/llama.py and unsloth/models/vision.py by removing duplicated device-checking logic and centralizing it. This greatly improves code clarity and maintainability. My review includes a few suggestions to further enhance the new DeviceContext class for better code quality and to improve exception handling in the model files.

@ykaitao ykaitao force-pushed the ktyang_device_context branch from 8baf20c to 156d32c Compare January 16, 2026 00:11
@ykaitao
Copy link
Contributor Author

ykaitao commented Jan 16, 2026

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a DeviceContext class to encapsulate device-specific logic, which is a great improvement. This refactoring simplifies the code in unsloth/models/llama.py and unsloth/models/vision.py by removing duplicated if/elif/else blocks for handling different device types (CUDA, HIP, XPU), enhancing code readability and maintainability. I've found one critical issue regarding module loading that could break the library for users without an XPU-enabled PyTorch build, and I've provided a suggestion to fix it.

@ykaitao ykaitao force-pushed the ktyang_device_context branch from f586eb4 to 0e62bc7 Compare January 19, 2026 05:39
@ykaitao
Copy link
Contributor Author

ykaitao commented Jan 19, 2026

Hi Team, I have addressed all the comments. Can this PR be merged? Or do you have additional comments? @mmathew23 @Datta0 @danielhanchen

Copy link
Contributor

@danielhanchen danielhanchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested on 1x NVIDIA B200 (compute capability 10.0, CUDA 12.8, torch 2.9.1, transformers 4.57.6, vLLM 0.15.1) with 10 notebooks:

Notebook Result
Llama3.1_8B_Alpaca PASS
Qwen3_14B_Reasoning_Conversational PASS
Phi_4_Conversational PASS
Mistral_v0.3_7B_Conversational PASS
Qwen3_VL_8B_Vision PASS
Gemma3_4B_Vision PASS
bert_classification PASS
gpt_oss_20B_GRPO FAIL (pre-existing: GptOssExperts.down_projs -- model architecture issue)
Qwen3_4B_GRPO FAIL (pre-existing: device_synchronize NameError in compiled cache)
Whisper FAIL (pre-existing: torchcodec/ffmpeg env)

No regressions on NVIDIA. All 7 testable notebooks pass. The 3 failures are pre-existing issues unrelated to this PR.

Code review note for HIP/AMD: The refactoring drops resolve_hip_gpu_stats_name() which was used for AMD GPU name resolution. This is a regression for ROCm users -- the old logic in llama.py had a 30-line function handling HIP-specific GPU name fallbacks. If ROCm support is intended, this needs to be preserved in device_type.py.

The core refactoring (moving clean_gpu_cache/get_current_device to device_type.py) is clean and sound. Import paths are backwards-compatible via re-export in llama.py.

@ykaitao ykaitao force-pushed the ktyang_device_context branch from 0e62bc7 to bf557e6 Compare March 6, 2026 07:08
@ykaitao ykaitao requested review from Datta0 and mmathew23 as code owners March 6, 2026 07:08
@ykaitao
Copy link
Contributor Author

ykaitao commented Mar 6, 2026

Hi Team, I have resolved the conflicts. @mmathew23 @Datta0 @danielhanchen

@ykaitao
Copy link
Contributor Author

ykaitao commented Mar 6, 2026

/gemini review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: bf557e64b9

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

"""Encapsulates device-specific operations for XPU/HIP/CUDA."""

def __init__(self, device_type: str = DEVICE_TYPE) -> None:
DEVICE_MODULE_MAP = {"xpu": torch.xpu, "cuda": torch.cuda, "hip": torch.cuda}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid eager torch.xpu lookup in device module map

Constructing DEVICE_MODULE_MAP with "xpu": torch.xpu eagerly dereferences torch.xpu even when running on CUDA/HIP, so CUDA-only PyTorch builds that do not expose an xpu attribute will fail during import with AttributeError before device_type is checked. This is a startup regression (the file already treats xpu as optional via hasattr(torch, "xpu") in get_device_type()), and it can block all model initialization on those environments.

Useful? React with 👍 / 👎.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a DeviceContext class to encapsulate and simplify device-specific operations, which is a great improvement for code clarity and maintainability. The changes in llama.py and vision.py effectively leverage this new class, removing duplicated code. My review includes a few suggestions to further improve the DeviceContext class by moving method-level dictionaries to class-level constants for better organization, and to make exception handling more specific.

Comment on lines +134 to +161
class DeviceContext:
"""Encapsulates device-specific operations for XPU/HIP/CUDA."""

def __init__(self, device_type: str = DEVICE_TYPE) -> None:
DEVICE_MODULE_MAP = {"xpu": torch.xpu, "cuda": torch.cuda, "hip": torch.cuda}
if device_type not in DEVICE_MODULE_MAP:
raise ValueError(f"Unsloth: Unsupported device type: {device_type}")
self.device_type = device_type
# Cache the torch module for this device
self.torch_module = DEVICE_MODULE_MAP[device_type]

def get_stats(self) -> tuple[str, str, float]:
"""Return (name, stats_snippet, max_memory_gb)."""
gpu_stats = self.torch_module.get_device_properties(0)
max_mem = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)

# Device name
name = gpu_stats.name + ". " if gpu_stats.name else self._get_default_name()

# Toolkit snippet
snippet = self._get_toolkit_snippet(gpu_stats)

return name, snippet, max_mem

def _get_default_name(self) -> str:
"""Get default device name when props.name is empty."""
names = {"xpu": "Intel XPU", "cuda": "NVIDIA GPU", "hip": "AMD GPU"}
return names[self.device_type] + " Device. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better code organization and to avoid re-creating dictionaries on each method call, it's a good practice to define DEVICE_MODULE_MAP and names as class-level constants.

class DeviceContext:
    """Encapsulates device-specific operations for XPU/HIP/CUDA."""
    DEVICE_MODULE_MAP = {"xpu": torch.xpu, "cuda": torch.cuda, "hip": torch.cuda}
    _DEFAULT_NAMES = {"xpu": "Intel XPU", "cuda": "NVIDIA GPU", "hip": "AMD GPU"}

    def __init__(self, device_type: str = DEVICE_TYPE) -> None:
        if device_type not in self.DEVICE_MODULE_MAP:
            raise ValueError(f"Unsloth: Unsupported device type: {device_type}")
        self.device_type = device_type
        # Cache the torch module for this device
        self.torch_module = self.DEVICE_MODULE_MAP[device_type]

    def get_stats(self) -> tuple[str, str, float]:
        """Return (name, stats_snippet, max_memory_gb)."""
        gpu_stats = self.torch_module.get_device_properties(0)
        max_mem = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)

        # Device name
        name = gpu_stats.name + ". " if gpu_stats.name else self._get_default_name()

        # Toolkit snippet
        snippet = self._get_toolkit_snippet(gpu_stats)

        return name, snippet, max_mem

    def _get_default_name(self) -> str:
        """Get default device name when props.name is empty."""
        return self._DEFAULT_NAMES[self.device_type] + " Device. "

max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
try:
vllm_version = f" vLLM: {importlib_version('vllm')}."
except:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's generally better to catch specific exceptions rather than using a bare except:. This avoids accidentally catching other unexpected errors like KeyboardInterrupt or SystemExit. Using except Exception: is a safer alternative.

Suggested change
except:
except Exception:

gpu_stats_name, gpu_stats_snippet, max_memory = device_context.get_stats()
try:
vllm_version = f" vLLM: {importlib_version('vllm')}."
except:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's generally better to catch specific exceptions rather than using a bare except:. This avoids accidentally catching other unexpected errors like KeyboardInterrupt or SystemExit. Using except Exception: is a safer alternative.

Suggested change
except:
except Exception:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants