diff --git a/helion/_testing.py b/helion/_testing.py index 8f81b8068..931e8a5b6 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -60,6 +60,12 @@ def is_cpu() -> bool: or _get_triton_backend() == "cpu" ) +def is_mtia() -> bool: + """Return True if running on MTIA.""" + return _get_triton_backend() == "mtia" + +def skipIfMTIA(reason: str) -> Callable[[Callable], Callable]: + return unittest.skipIf(is_mtia(), reason) class _LogCapture(logging.Handler): """Simple logging handler to capture log records.""" @@ -98,14 +104,16 @@ def is_cuda() -> bool: """Return True if running on CUDA (NVIDIA GPU).""" return _get_triton_backend() == "cuda" and torch.cuda.is_available() - PROJECT_ROOT: Path = Path(__file__).parent.parent EXAMPLES_DIR: Path = PROJECT_ROOT / "examples" +DEVICE = None if is_cpu(): DEVICE = torch.device("cpu") elif torch.xpu.is_available(): DEVICE = torch.device("xpu") +elif is_mtia(): + DEVICE = torch.device("mtia") else: DEVICE = torch.device("cuda") @@ -212,7 +220,6 @@ def skipIfPyTorchBaseVerLessThan(min_version: str) -> Callable[[Callable], Calla f"PyTorch version {min_version} or higher required", ) - @contextlib.contextmanager def track_run_ref_calls() -> Generator[list[int], None, None]: """Context manager that tracks BoundKernel.run_ref calls. diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index 875074d9b..56b39db72 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -13,13 +13,16 @@ from .triton_helpers import triton_send_signal as triton_send_signal from .triton_helpers import triton_wait_multiple_signal as triton_wait_multiple_signal from .triton_helpers import triton_wait_signal as triton_wait_signal +import os if TYPE_CHECKING: import triton def _alloc_fn(size: int, alignment: int, stream: int | None) -> torch.Tensor: - return torch.empty(size, device="cuda", dtype=torch.int8) + # Dynamically get device from Triton backend + backend = triton.runtime.driver.active.get_current_target().backend # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] + return torch.empty(size, device=backend, dtype=torch.int8) def set_triton_allocator() -> None: @@ -66,14 +69,19 @@ def get_num_sm(device: torch.device, *, reserved_sms: int = 0) -> int: # TODO(EikanWang): gpu_subslice_count is an out-of-date term. we change update it to XeCore number. elif device.type == "xpu": available_sms = torch.xpu.get_device_properties(device.index).gpu_subslice_count + elif device.type == "mtia": + try: + from triton_mtia.backend.compiler import get_num_sm_for_arch + return get_num_sm_for_arch(device.backend.arch) + except ImportError: + raise RuntimeError("MTIA backend selected, but not available.") else: - raise AssertionError("TODO: implement for other devices") + raise NotImplementedError(f"get_num_sm not implemented for device type: {device.type}") if reserved_sms <= 0: return available_sms return max(available_sms - reserved_sms, 1) - def default_launcher( triton_kernel: triton.JITFunction, grid: tuple[int, ...], @@ -83,6 +91,23 @@ def default_launcher( **kwargs: dict, ) -> object: """Default launcher function that executes the kernel immediately.""" + # Get current backend from Triton + import triton + backend = triton.runtime.driver.active.get_current_target().backend # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] + if backend == "mtia": + # MTIA-specific initialization + try: + from mtia.re.re_unittest_lib import init_mtia_device + from triton_mtia.python.mtia.eager import mtia_triton_launcher + + init_mtia_device() + # Ignore disk cache. Kernels will still keep an in-memory cache. + os.environ.setdefault("TRITON_ALWAYS_COMPILE", "1") + mtia_triton_launcher.init() + except ImportError as e: + raise RuntimeError(f"MTIA backend selected but required modules not available: {e}") + + # For both CUDA and MTIA, use the same kernel execution return triton_kernel.run( *args, grid=grid, diff --git a/test/test_constexpr.py b/test/test_constexpr.py index ee4d8f619..b3e5792f5 100644 --- a/test/test_constexpr.py +++ b/test/test_constexpr.py @@ -7,14 +7,27 @@ import helion from helion._testing import DEVICE -from helion._testing import RefEagerTestBase from helion._testing import TestCase from helion._testing import code_and_output -from helion._testing import skipIfRefEager +from helion._testing import skipIfRefEager, skipIfMTIA, is_mtia import helion.language as hl +if is_mtia(): + from mtia.re.re_unittest_lib import MTIAUnittest + import mtia.host_runtime.torch_mtia.dynamic_library # noqa: F401 + + +class TestConstExpr(TestCase): + @classmethod + def setUpClass(cls): + # Explicitly call setUpClass for TestCase. + super().setUpClass() + if is_mtia(): + # Call MTIAUnittest.setUpClass for MTIA initialization + MTIAUnittest.setUpClass.__func__(cls) + # Initialize MTIA properly + torch.mtia.init() -class TestConstExpr(RefEagerTestBase, TestCase): def test_constexpr_float(self): @helion.kernel() def fn(x: torch.Tensor, v: hl.constexpr) -> torch.Tensor: @@ -95,6 +108,7 @@ def fn(x: torch.Tensor, mode: str) -> torch.Tensor: self.assertExpectedJournal(code) @skipIfRefEager("Triton codegen does not work in ref eager mode") + @skipIfMTIA("Not supported on MTIA. Error: \"Expected IntList but got GenericList\"") def test_block_size_constexpr_assignment_in_host_code(self) -> None: @helion.kernel( config=helion.Config(