Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ runtime
:nosignatures:
kernel
set_default_settings
Config
Settings
```
Expand Down
23 changes: 2 additions & 21 deletions docs/api/settings.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ Settings can be configured via:

1. **Environment variables**
2. **Keyword arguments to `@helion.kernel`**
3. **Global defaults via `helion.set_default_settings()`**

If both are provided, decorator arguments take precedence.

## Configuration Examples

Expand Down Expand Up @@ -62,24 +63,6 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor:
return result
```

### Global Configuration

```python
import logging
import helion

# Set global defaults
with helion.set_default_settings(
ignore_warnings=[helion.exc.TensorOperationInWrapper],
autotune_log_level=logging.WARNING
):
# All kernels in this block use these settings
@helion.kernel
def kernel1(x): ...

@helion.kernel
def kernel2(x): ...
```

## Settings Reference

Expand Down Expand Up @@ -231,9 +214,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe
## Functions

```{eval-rst}
.. autofunction:: set_default_settings

.. automethod:: Settings.default
```

## Environment Variable Reference
Expand Down
2 changes: 0 additions & 2 deletions helion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from .runtime import kernel as jit # alias
from .runtime.settings import RefMode
from .runtime.settings import Settings
from .runtime.settings import set_default_settings

__all__ = [
"Config",
Expand All @@ -28,7 +27,6 @@
"language",
"next_power_of_2",
"runtime",
"set_default_settings",
]

_logging.init_logs()
4 changes: 2 additions & 2 deletions helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
Args:
fn: The function to be compiled as a Helion kernel.
configs: A list of configurations to use for the kernel.
settings: The settings to be used by the Kernel. If None, default settings are used.
settings: The settings to be used by the Kernel. If None, a new `Settings()` instance is created.
key: Optional callable that returns an extra hashable component for specialization.
"""
super().__init__()
Expand All @@ -88,7 +88,7 @@ def __init__(
self.name: str = fn.__name__
self.fn: types.FunctionType = fn
self.signature: inspect.Signature = inspect.signature(fn)
self.settings: Settings = settings or Settings.default()
self.settings: Settings = settings or Settings()
self._key_fn: Callable[..., Hashable] | None = key
self.configs: list[Config] = [
Config(**c) if isinstance(c, dict) else c # pyright: ignore[reportArgumentType]
Expand Down
53 changes: 0 additions & 53 deletions helion/runtime/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import dataclasses
import logging
import os
import threading
import time
from typing import TYPE_CHECKING
from typing import Literal
Expand All @@ -20,47 +19,15 @@
from .ref_mode import RefMode

if TYPE_CHECKING:
from contextlib import AbstractContextManager

from ..autotuner.base_search import BaseAutotuner
from .kernel import BoundKernel

class _TLS(Protocol):
default_settings: Settings | None

class AutotunerFunction(Protocol):
def __call__(
self, bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object
) -> BaseAutotuner: ...


_tls: _TLS = cast("_TLS", threading.local())


def set_default_settings(settings: Settings) -> AbstractContextManager[None, None]:
"""
Set the default settings for the current thread and return a context manager
that restores the previous settings upon exit.
Args:
settings: The Settings object to set as the default.
Returns:
AbstractContextManager[None, None]: A context manager that restores the previous settings upon exit.
"""
prior = getattr(_tls, "default_settings", None)
_tls.default_settings = settings

class _RestoreContext:
def __enter__(self) -> None:
pass

def __exit__(self, *args: object) -> None:
_tls.default_settings = prior

return _RestoreContext()


def default_autotuner_fn(
bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object
) -> BaseAutotuner:
Expand Down Expand Up @@ -254,15 +221,8 @@ class Settings(_Settings):
def __init__(self, **settings: object) -> None:
"""
Initialize the Settings object with the provided dictionary of settings.
If no settings are provided, the default settings are used (see `set_default_settings`).
Args:
settings: Keyword arguments representing various settings.
"""

if defaults := getattr(_tls, "default_settings", None):
settings = {**defaults.to_dict(), **settings}

super().__init__(**settings) # pyright: ignore[reportArgumentType]

self._check_ref_eager_mode_before_print_output_code()
Expand Down Expand Up @@ -323,16 +283,3 @@ def _check_ref_eager_mode_before_print_output_code(self) -> None:
"""
if self.ref_mode == RefMode.EAGER and self.print_output_code:
raise exc.RefEagerModeCodePrintError

@staticmethod
def default() -> Settings:
"""
Get the default Settings object. If no default settings are set, create a new one.
Returns:
Settings: The default Settings object.
"""
result = getattr(_tls, "default_settings", None)
if result is None:
_tls.default_settings = result = Settings()
return result
Loading