2323
2424import argparse
2525import collections
26+ from contextlib import suppress
2627import dataclasses
2728import functools
2829import gc
3233import os
3334from pathlib import Path
3435from pprint import pformat
36+ import shutil
3537import subprocess
3638import sys
3739import tempfile
38- from typing import TYPE_CHECKING
3940from typing import Any
4041from typing import Callable
42+ from typing import cast
4143
4244import torch
4345from torch .utils ._pytree import tree_leaves
4446from torch .utils ._pytree import tree_map
4547
4648from helion ._utils import counters
4749
48- if TYPE_CHECKING :
49- from tritonbench .utils .triton_op import BenchmarkOperator
50- from tritonbench .utils .triton_op import BenchmarkOperatorMetrics
50+ logger : logging .Logger = logging .getLogger (__name__ )
5151
52- try :
53- from tritonbench .utils .env_utils import get_nvidia_gpu_model
54- from tritonbench .utils .env_utils import is_cuda
52+ StrPath = str | os .PathLike [str ]
5553
56- IS_B200 = is_cuda () and get_nvidia_gpu_model () == "NVIDIA B200"
57- except ImportError :
58- print ("Failed B200 detection since tritonbench is not installed (yet)" )
59- IS_B200 = False
54+ if os .getenv ("HELION_BENCHMARK_DISABLE_LOGGING" , "0" ) == "1" :
55+ logging .disable (logging .CRITICAL )
56+
57+
58+ def is_cuda () -> bool :
59+ return torch .version .cuda is not None
60+
61+
62+ def get_nvidia_gpu_model () -> str :
63+ """
64+ Retrieves the model of the NVIDIA GPU being used.
65+ Will return the name of the first GPU listed.
66+ Returns:
67+ str: The model of the NVIDIA GPU or empty str if not found.
68+ """
69+ try :
70+ model = subprocess .check_output (
71+ ["nvidia-smi" , "--query-gpu=name" , "--format=csv,noheader,nounits" ]
72+ )
73+ return model .decode ().strip ().split ("\n " )[0 ]
74+ except OSError :
75+ logger .warning ("nvidia-smi not found. Returning empty str." )
76+ return ""
77+
78+
79+ IS_B200 = is_cuda () and get_nvidia_gpu_model () == "NVIDIA B200"
6080
6181
6282def log_tensor_metadata (args : tuple [object , ...], kwargs : dict [str , object ]) -> None :
@@ -82,11 +102,6 @@ def describe_tensor(obj: object) -> object:
82102 )
83103
84104
85- logger : logging .Logger = logging .getLogger (__name__ )
86-
87- if os .getenv ("HELION_BENCHMARK_DISABLE_LOGGING" , "0" ) == "1" :
88- logging .disable (logging .CRITICAL )
89-
90105# Maximum number of inputs to use
91106MAX_NUM_INPUTS = 20
92107
@@ -600,109 +615,141 @@ class RunResult:
600615}
601616
602617
603- def get_system_memory_gb () -> float :
604- """Get system memory in GB."""
605- try :
606- # Try to read from /proc/meminfo on Linux
607- meminfo_path = Path ("/proc/meminfo" )
608- if meminfo_path .exists ():
609- with open (meminfo_path ) as f :
610- for line in f :
611- if line .startswith ("MemTotal:" ):
612- # Extract memory in kB and convert to GB
613- mem_kb = int (line .split ()[1 ])
614- return mem_kb / (1024 * 1024 )
615-
616- # Fallback: use psutil if available
617- try :
618- import psutil
618+ def check_and_setup_tritonbench () -> None :
619+ """Ensure a usable tritonbench installation is available."""
619620
620- return psutil . virtual_memory (). total / ( 1024 ** 3 )
621- except ImportError :
622- pass
621+ benchmarks_dir = Path ( __file__ ). parent
622+ tritonbench_path = benchmarks_dir / "tritonbench"
623+ installing_marker = ( benchmarks_dir / ".tritonbench_installing" ). resolve ()
623624
624- except Exception :
625- pass
625+ try :
626+ import tritonbench # pyright: ignore[reportMissingImports]
627+
628+ module_file = getattr (tritonbench , "__file__" , None )
629+ tb_repo_path = tritonbench_path .resolve ()
630+
631+ candidate_paths : list [Path ] = []
632+
633+ def add_candidate_path (entry : object ) -> None :
634+ if not isinstance (entry , (str , os .PathLike )):
635+ return
636+ path_entry = cast ("StrPath" , entry )
637+ with suppress (TypeError , OSError , RuntimeError ):
638+ candidate_paths .append (Path (path_entry ))
639+
640+ if module_file is not None :
641+ add_candidate_path (module_file )
642+
643+ module_paths = getattr (tritonbench , "__path__" , None )
644+ if module_paths is not None :
645+ for entry in module_paths :
646+ add_candidate_path (entry )
647+
648+ def is_local (path : Path ) -> bool :
649+ try :
650+ resolved_path = path .resolve ()
651+ except (OSError , RuntimeError ):
652+ return False
653+ return (
654+ resolved_path == tb_repo_path or tb_repo_path in resolved_path .parents
655+ )
626656
627- # Default to assuming high memory if we can't detect
628- return 32.0
657+ has_local_checkout = any (is_local (path ) for path in candidate_paths )
629658
659+ if candidate_paths and not has_local_checkout :
660+ # If tritonbench is not from local checkout, assume it's a proper installation
661+ return
630662
631- def check_and_setup_tritonbench () -> None :
632- """Check if tritonbench is installed and install it from GitHub if not."""
633- # Check if tritonbench is already installed
634- if importlib .util .find_spec ("tritonbench" ) is not None :
635- return # Already installed
663+ if has_local_checkout :
664+ if installing_marker .exists ():
665+ print (
666+ "Detected partially installed tritonbench; reinstalling local checkout." ,
667+ file = sys .stderr ,
668+ )
669+ else :
670+ return
671+ else :
672+ print (
673+ "Unable to determine tritonbench import path; reinstalling local checkout." ,
674+ file = sys .stderr ,
675+ )
636676
637- print ("Tritonbench not found. Installing..." , file = sys .stderr )
677+ except ImportError :
678+ pass
638679
639- # Clone to benchmarks/tritonbench
640- benchmarks_dir = Path (__file__ ).parent
641- tritonbench_path = benchmarks_dir / "tritonbench"
680+ print (
681+ "Installing tritonbench from source..." ,
682+ file = sys .stderr ,
683+ )
642684 print (f"Using tritonbench path: { tritonbench_path } " )
643685
644- try :
645- # Clone the repository if it doesn't exist
646- if not tritonbench_path .exists ():
647- print ("Cloning tritonbench repository..." , file = sys .stderr )
648- subprocess .run (
649- [
650- "git" ,
651- "clone" ,
652- "https://github.com/meta-pytorch/tritonbench.git" ,
653- str (tritonbench_path ),
654- ],
655- check = True ,
656- )
686+ if tritonbench_path .exists ():
687+ print ("Removing existing tritonbench checkout..." , file = sys .stderr )
688+ if tritonbench_path .is_dir ():
689+ shutil .rmtree (tritonbench_path )
690+ else :
691+ tritonbench_path .unlink ()
657692
658- # Initialize submodules
659- print ("Initializing tritonbench's submodules..." , file = sys .stderr )
660- subprocess .run (
661- ["git" , "submodule" , "update" , "--init" , "--recursive" ],
662- cwd = tritonbench_path ,
663- check = True ,
664- )
693+ sys .modules .pop ("tritonbench" , None )
665694
666- # Detect system memory and choose install flags.
667- # Low-memory systems can freeze when building dependencies like flash-attn,
668- # so we only install the Liger library in that case.
669- memory_gb = get_system_memory_gb ()
670- install_flag = "--liger" if memory_gb < 16 else "--all"
695+ installing_marker .touch ()
671696
672- # Install optional dependencies for tritonbench
673- print (
674- f"Running install.py { install_flag } (detected { memory_gb :.1f} GB system RAM)..." ,
675- file = sys .stderr ,
697+ try :
698+ print ("Cloning tritonbench repository..." , file = sys .stderr )
699+ subprocess .run (
700+ [
701+ "git" ,
702+ "clone" ,
703+ "https://github.com/meta-pytorch/tritonbench.git" ,
704+ str (tritonbench_path ),
705+ ],
706+ cwd = benchmarks_dir ,
707+ check = True ,
676708 )
677- env = os .environ .copy ()
678- if install_flag == "--all" :
679- # Set max jobs to 4 to avoid OOM
680- env ["MAX_JOBS" ] = "4"
709+
710+ print ("Initializing tritonbench submodules..." , file = sys .stderr )
681711 subprocess .run (
682- [sys .executable , "install.py" , install_flag ],
712+ ["git" , "submodule" , "update" , "--init" , "--recursive" ],
713+ cwd = tritonbench_path ,
714+ check = True ,
715+ )
716+
717+ print ("Installing tritonbench requirements..." , file = sys .stderr )
718+ subprocess .run (
719+ [
720+ sys .executable ,
721+ "-m" ,
722+ "pip" ,
723+ "install" ,
724+ "-r" ,
725+ "requirements.txt" ,
726+ ],
727+ cwd = tritonbench_path ,
728+ check = True ,
729+ )
730+
731+ print ("Running install.py --liger..." , file = sys .stderr )
732+ subprocess .run (
733+ [sys .executable , "install.py" , "--liger" ],
683734 cwd = tritonbench_path ,
684735 check = True ,
685- env = env ,
686736 )
687737
688- # Install tritonbench package
689738 print ("Installing tritonbench package..." , file = sys .stderr )
690739 subprocess .run (
691- [sys .executable , "-m" , "pip" , "install" , "-e" , str (tritonbench_path )],
740+ [sys .executable , "-m" , "pip" , "install" , "-e" , "." ],
741+ cwd = tritonbench_path ,
692742 check = True ,
693743 )
694744
695- # Invalidate import caches to recognize newly installed package
696745 importlib .invalidate_caches ()
697746
698- # Verify installation worked
699747 try :
700- import tritonbench # noqa: F401 # pyright: ignore[reportMissingImports]
748+ import tritonbench # pyright: ignore[reportMissingImports]
701749
702- print (
703- f"Tritonbench installed successfully with { install_flag } ." ,
704- file = sys .stderr ,
705- )
750+ print ("Tritonbench installed successfully." , file = sys .stderr )
751+ if installing_marker .exists ():
752+ installing_marker .unlink ()
706753 except ImportError :
707754 print (
708755 "Error: Tritonbench package installation failed. The package cannot be imported." ,
@@ -789,6 +836,8 @@ def run_kernel_variants(
789836 from tritonbench .utils .parser import ( # pyright: ignore[reportMissingImports]
790837 get_parser ,
791838 )
839+ from tritonbench .utils .triton_op import BenchmarkOperator
840+ from tritonbench .utils .triton_op import BenchmarkOperatorMetrics
792841
793842 # Get the tritonbench operator name, stripping -bwd suffix for backward operators
794843 operator_name = kernel_name .removesuffix ("-bwd" )
0 commit comments