Skip to content

Commit ff8b213

Browse files
authored
Auto detect low vram (#956)
1 parent 3ccf510 commit ff8b213

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

.github/workflows/test.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ jobs:
131131
if [[ "${{ matrix.dtype-asserts }}" == "true" ]]; then export HELION_DEBUG_DTYPE_ASSERTS=1; fi
132132
if [[ "${{ matrix.expecttest-accept }}" == "true" ]]; then export EXPECTTEST_ACCEPT=1; fi
133133
if [[ "${{ matrix.ref-eager }}" == "true" ]]; then export HELION_INTERPRET=1; fi
134-
if [[ "${{ matrix.alias }}" == *"a10g"* ]]; then export HELION_DEV_LOW_VRAM=1; fi
135134
# -rf: print failed tests
136135
# --timeout: max allowed time for each test
137136
pytest -rf --timeout=60

benchmarks/run.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,7 @@ class RunResult:
183183
"tritonbench.operators.jagged_mean.operator",
184184
"examples.jagged_mean",
185185
"jagged_mean_tritonbench",
186-
{"B": 32, "M": 8, "seqlen": 64}
187-
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1"
188-
else {},
186+
{},
189187
),
190188
"fp8_gemm": (
191189
"tritonbench.operators.fp8_gemm.fp8_gemm",
@@ -208,9 +206,7 @@ class RunResult:
208206
"tritonbench.operators.cross_entropy.operator",
209207
"examples.cross_entropy",
210208
"cross_entropy",
211-
{"B": 4, "T": 512, "v_range": "10,15"}
212-
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1"
213-
else {},
209+
{},
214210
),
215211
"fp8_attention": (
216212
"tritonbench.operators.fp8_attention.operator",

helion/_testing.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,19 @@ def skipIfNotCUDA() -> Callable[[Callable], Callable]:
7676
def skipIfLowVRAM(
7777
reason: str = "Test requires high VRAM",
7878
) -> Callable[[Callable], Callable]:
79-
"""Skip test if HELION_DEV_LOW_VRAM=1 is set"""
80-
return unittest.skipIf(os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1", reason)
79+
"""Skip test on systems with low GPU VRAM."""
80+
81+
threshold_bytes = int(30.0 * (1024**3))
82+
total_memory: int | None = None
83+
try:
84+
if torch.cuda.is_available():
85+
props = torch.cuda.get_device_properties(torch.cuda.current_device())
86+
total_memory = int(getattr(props, "total_memory", 0))
87+
except Exception:
88+
total_memory = None
89+
90+
low_vram = total_memory is not None and total_memory < threshold_bytes
91+
return unittest.skipIf(low_vram, reason)
8192

8293

8394
def skipIfPy314(reason: str) -> Callable[[Callable], Callable]:

0 commit comments

Comments
 (0)