Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def is_cpu() -> bool:
or _get_triton_backend() == "cpu"
)

def is_mtia() -> bool:
"""Return True if running on MTIA."""
return _get_triton_backend() == "mtia"

def skipIfMTIA(reason: str) -> Callable[[Callable], Callable]:
return unittest.skipIf(is_mtia(), reason)

class _LogCapture(logging.Handler):
"""Simple logging handler to capture log records."""
Expand Down Expand Up @@ -98,14 +104,16 @@ def is_cuda() -> bool:
"""Return True if running on CUDA (NVIDIA GPU)."""
return _get_triton_backend() == "cuda" and torch.cuda.is_available()


PROJECT_ROOT: Path = Path(__file__).parent.parent
EXAMPLES_DIR: Path = PROJECT_ROOT / "examples"
DEVICE = None

if is_cpu():
DEVICE = torch.device("cpu")
elif torch.xpu.is_available():
DEVICE = torch.device("xpu")
elif is_mtia():
DEVICE = torch.device("mtia")
else:
DEVICE = torch.device("cuda")

Expand Down Expand Up @@ -212,7 +220,6 @@ def skipIfPyTorchBaseVerLessThan(min_version: str) -> Callable[[Callable], Calla
f"PyTorch version {min_version} or higher required",
)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd expect this to cause some lints, though for some reason CI isn't running. Can you do ./lint.sh install && ./lint.sh?

@contextlib.contextmanager
def track_run_ref_calls() -> Generator[list[int], None, None]:
"""Context manager that tracks BoundKernel.run_ref calls.
Expand Down
31 changes: 28 additions & 3 deletions helion/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
from .triton_helpers import triton_send_signal as triton_send_signal
from .triton_helpers import triton_wait_multiple_signal as triton_wait_multiple_signal
from .triton_helpers import triton_wait_signal as triton_wait_signal
import os

if TYPE_CHECKING:
import triton


def _alloc_fn(size: int, alignment: int, stream: int | None) -> torch.Tensor:
return torch.empty(size, device="cuda", dtype=torch.int8)
# Dynamically get device from Triton backend
backend = triton.runtime.driver.active.get_current_target().backend # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
return torch.empty(size, device=backend, dtype=torch.int8)


def set_triton_allocator() -> None:
Expand Down Expand Up @@ -66,14 +69,19 @@ def get_num_sm(device: torch.device, *, reserved_sms: int = 0) -> int:
# TODO(EikanWang): gpu_subslice_count is an out-of-date term. we change update it to XeCore number.
elif device.type == "xpu":
available_sms = torch.xpu.get_device_properties(device.index).gpu_subslice_count
elif device.type == "mtia":
try:
from triton_mtia.backend.compiler import get_num_sm_for_arch
return get_num_sm_for_arch(device.backend.arch)
except ImportError:
raise RuntimeError("MTIA backend selected, but not available.")
else:
raise AssertionError("TODO: implement for other devices")
raise NotImplementedError(f"get_num_sm not implemented for device type: {device.type}")

if reserved_sms <= 0:
return available_sms
return max(available_sms - reserved_sms, 1)


def default_launcher(
triton_kernel: triton.JITFunction,
grid: tuple[int, ...],
Expand All @@ -83,6 +91,23 @@ def default_launcher(
**kwargs: dict,
) -> object:
"""Default launcher function that executes the kernel immediately."""
# Get current backend from Triton
import triton
backend = triton.runtime.driver.active.get_current_target().backend # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
if backend == "mtia":
# MTIA-specific initialization
try:
from mtia.re.re_unittest_lib import init_mtia_device
from triton_mtia.python.mtia.eager import mtia_triton_launcher

init_mtia_device()
# Ignore disk cache. Kernels will still keep an in-memory cache.
os.environ.setdefault("TRITON_ALWAYS_COMPILE", "1")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this required? Seems a bit odd to be running this code on every kernel launch and not cleaning it up.

mtia_triton_launcher.init()
except ImportError as e:
raise RuntimeError(f"MTIA backend selected but required modules not available: {e}")

# For both CUDA and MTIA, use the same kernel execution
return triton_kernel.run(
*args,
grid=grid,
Expand Down
20 changes: 17 additions & 3 deletions test/test_constexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,27 @@

import helion
from helion._testing import DEVICE
from helion._testing import RefEagerTestBase
from helion._testing import TestCase
from helion._testing import code_and_output
from helion._testing import skipIfRefEager
from helion._testing import skipIfRefEager, skipIfMTIA, is_mtia
import helion.language as hl

if is_mtia():
from mtia.re.re_unittest_lib import MTIAUnittest
import mtia.host_runtime.torch_mtia.dynamic_library # noqa: F401


class TestConstExpr(TestCase):
@classmethod
def setUpClass(cls):
# Explicitly call setUpClass for TestCase.
super().setUpClass()
if is_mtia():
# Call MTIAUnittest.setUpClass for MTIA initialization
MTIAUnittest.setUpClass.__func__(cls)
# Initialize MTIA properly
torch.mtia.init()

class TestConstExpr(RefEagerTestBase, TestCase):
def test_constexpr_float(self):
@helion.kernel()
def fn(x: torch.Tensor, v: hl.constexpr) -> torch.Tensor:
Expand Down Expand Up @@ -95,6 +108,7 @@ def fn(x: torch.Tensor, mode: str) -> torch.Tensor:
self.assertExpectedJournal(code)

@skipIfRefEager("Triton codegen does not work in ref eager mode")
@skipIfMTIA("Not supported on MTIA. Error: \"Expected IntList but got GenericList\"")
def test_block_size_constexpr_assignment_in_host_code(self) -> None:
@helion.kernel(
config=helion.Config(
Expand Down