From d4c8e7b217a3a78df2e4f84608f092e75678fd5f Mon Sep 17 00:00:00 2001 From: Jan Szczepaniec Date: Tue, 25 Nov 2025 08:31:14 -0800 Subject: [PATCH] Add setup for Helion to compile on MTIA with basic test Summary: Enable Helion with MTIA for the first time. This is very basic setup to just make showcase it working. Following diffs will try to enable more modes (e.g. eager) and more tests. I used default mtia machine arch setup for now + set triton_mtia backend + compile it through JIT flow. Highlights/doubts: - Remove RefEager mode from the test, seems to not be used? - For some reason i have to call mtia_triton_launcher.init() twice (?), otherwise it result with inaccuracy.. i did not debug it fully - There is a check `if TYPE_CHECKING:` in __init__.py, not sure how does this work for functions, can i import triton? will CI be able to check this for me? - I disabled one test when run on mtia. Test complains about different type interfered from Triton-MTIA compiler (error: IntList but got GenericList) - Hardcoded grid size for mtia architectures, I will have to cover this part separately, sadly this is not fully covered in torch by Triton-MTIA, when doing Triton-MTIA update i created hardcoded stub, since it is not being used. v2 We cannot expose arch names to opensource, hide them then in triton_mtia. CPU added support for their backend, use their _get_triton_backend instead of calling triton.runtime.driver.active.get_current_target directly v3 Aligned further to changes made when diff was not landed and i removed custom implementation for handling torch device. Removed calling is_available from torch/mtia/__init__.py Differential Revision: D76375133 --- helion/_testing.py | 11 +++++++++-- helion/runtime/__init__.py | 31 ++++++++++++++++++++++++++++--- test/test_constexpr.py | 20 +++++++++++++++++--- 3 files changed, 54 insertions(+), 8 deletions(-) diff --git a/helion/_testing.py b/helion/_testing.py index 8f81b8068..931e8a5b6 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -60,6 +60,12 @@ def is_cpu() -> bool: or _get_triton_backend() == "cpu" ) +def is_mtia() -> bool: + """Return True if running on MTIA.""" + return _get_triton_backend() == "mtia" + +def skipIfMTIA(reason: str) -> Callable[[Callable], Callable]: + return unittest.skipIf(is_mtia(), reason) class _LogCapture(logging.Handler): """Simple logging handler to capture log records.""" @@ -98,14 +104,16 @@ def is_cuda() -> bool: """Return True if running on CUDA (NVIDIA GPU).""" return _get_triton_backend() == "cuda" and torch.cuda.is_available() - PROJECT_ROOT: Path = Path(__file__).parent.parent EXAMPLES_DIR: Path = PROJECT_ROOT / "examples" +DEVICE = None if is_cpu(): DEVICE = torch.device("cpu") elif torch.xpu.is_available(): DEVICE = torch.device("xpu") +elif is_mtia(): + DEVICE = torch.device("mtia") else: DEVICE = torch.device("cuda") @@ -212,7 +220,6 @@ def skipIfPyTorchBaseVerLessThan(min_version: str) -> Callable[[Callable], Calla f"PyTorch version {min_version} or higher required", ) - @contextlib.contextmanager def track_run_ref_calls() -> Generator[list[int], None, None]: """Context manager that tracks BoundKernel.run_ref calls. diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index 875074d9b..56b39db72 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -13,13 +13,16 @@ from .triton_helpers import triton_send_signal as triton_send_signal from .triton_helpers import triton_wait_multiple_signal as triton_wait_multiple_signal from .triton_helpers import triton_wait_signal as triton_wait_signal +import os if TYPE_CHECKING: import triton def _alloc_fn(size: int, alignment: int, stream: int | None) -> torch.Tensor: - return torch.empty(size, device="cuda", dtype=torch.int8) + # Dynamically get device from Triton backend + backend = triton.runtime.driver.active.get_current_target().backend # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] + return torch.empty(size, device=backend, dtype=torch.int8) def set_triton_allocator() -> None: @@ -66,14 +69,19 @@ def get_num_sm(device: torch.device, *, reserved_sms: int = 0) -> int: # TODO(EikanWang): gpu_subslice_count is an out-of-date term. we change update it to XeCore number. elif device.type == "xpu": available_sms = torch.xpu.get_device_properties(device.index).gpu_subslice_count + elif device.type == "mtia": + try: + from triton_mtia.backend.compiler import get_num_sm_for_arch + return get_num_sm_for_arch(device.backend.arch) + except ImportError: + raise RuntimeError("MTIA backend selected, but not available.") else: - raise AssertionError("TODO: implement for other devices") + raise NotImplementedError(f"get_num_sm not implemented for device type: {device.type}") if reserved_sms <= 0: return available_sms return max(available_sms - reserved_sms, 1) - def default_launcher( triton_kernel: triton.JITFunction, grid: tuple[int, ...], @@ -83,6 +91,23 @@ def default_launcher( **kwargs: dict, ) -> object: """Default launcher function that executes the kernel immediately.""" + # Get current backend from Triton + import triton + backend = triton.runtime.driver.active.get_current_target().backend # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] + if backend == "mtia": + # MTIA-specific initialization + try: + from mtia.re.re_unittest_lib import init_mtia_device + from triton_mtia.python.mtia.eager import mtia_triton_launcher + + init_mtia_device() + # Ignore disk cache. Kernels will still keep an in-memory cache. + os.environ.setdefault("TRITON_ALWAYS_COMPILE", "1") + mtia_triton_launcher.init() + except ImportError as e: + raise RuntimeError(f"MTIA backend selected but required modules not available: {e}") + + # For both CUDA and MTIA, use the same kernel execution return triton_kernel.run( *args, grid=grid, diff --git a/test/test_constexpr.py b/test/test_constexpr.py index ee4d8f619..b3e5792f5 100644 --- a/test/test_constexpr.py +++ b/test/test_constexpr.py @@ -7,14 +7,27 @@ import helion from helion._testing import DEVICE -from helion._testing import RefEagerTestBase from helion._testing import TestCase from helion._testing import code_and_output -from helion._testing import skipIfRefEager +from helion._testing import skipIfRefEager, skipIfMTIA, is_mtia import helion.language as hl +if is_mtia(): + from mtia.re.re_unittest_lib import MTIAUnittest + import mtia.host_runtime.torch_mtia.dynamic_library # noqa: F401 + + +class TestConstExpr(TestCase): + @classmethod + def setUpClass(cls): + # Explicitly call setUpClass for TestCase. + super().setUpClass() + if is_mtia(): + # Call MTIAUnittest.setUpClass for MTIA initialization + MTIAUnittest.setUpClass.__func__(cls) + # Initialize MTIA properly + torch.mtia.init() -class TestConstExpr(RefEagerTestBase, TestCase): def test_constexpr_float(self): @helion.kernel() def fn(x: torch.Tensor, v: hl.constexpr) -> torch.Tensor: @@ -95,6 +108,7 @@ def fn(x: torch.Tensor, mode: str) -> torch.Tensor: self.assertExpectedJournal(code) @skipIfRefEager("Triton codegen does not work in ref eager mode") + @skipIfMTIA("Not supported on MTIA. Error: \"Expected IntList but got GenericList\"") def test_block_size_constexpr_assignment_in_host_code(self) -> None: @helion.kernel( config=helion.Config(