Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions start.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,32 @@ def get_install_features(lib_name: str = None):
possible_features = ["cu12", "amd"]

if not lib_name:
# Ask the user for the GPU lib
gpu_lib_choices = {
"A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu12"},
"B": {"pretty": "AMD", "internal": "amd"},
}
user_input = get_user_choice(
"Select your GPU. If you don't know, select Cuda 12.x (A)",
gpu_lib_choices,
)

lib_name = gpu_lib_choices.get(user_input, {}).get("internal")
has_nvidia = which("nvidia-smi") is not None
has_rocm = which("rocm-smi") is not None
has_amd = which("amd-smi") is not None
has_amd_gpu = has_rocm or has_amd

if has_nvidia and not has_amd_gpu:
lib_name = "cu12"
print("Auto-detected NVIDIA GPU. Using CUDA 12.x backend.")
elif has_amd_gpu and not has_nvidia:
lib_name = "amd"
print("Auto-detected AMD GPU. Using AMD backend.")
else:
gpu_lib_choices = {
"A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu12"},
"B": {"pretty": "AMD", "internal": "amd"},
}
print(
"WARNING: Auto-detection failed. "
"Please ensure you have either an NVIDIA GPU (with nvidia-smi) "
"or an AMD GPU (with rocm-smi or amd-smi) installed."
)
user_input = get_user_choice(
"Select your GPU. If you don't know, select Cuda 12.x (A)",
gpu_lib_choices,
)
lib_name = gpu_lib_choices.get(user_input, {}).get("internal")

# Write to start options
start_options["gpu_lib"] = lib_name
Expand Down