diff --git a/.gitignore b/.gitignore index bd40e59..65f7054 100644 --- a/.gitignore +++ b/.gitignore @@ -153,3 +153,5 @@ dev/cleanup.py .python-version .databricks-login.json + +.test_*.ipynb diff --git a/src/databricks/labs/blueprint/_logging_context.py b/src/databricks/labs/blueprint/_logging_context.py new file mode 100644 index 0000000..535f5ef --- /dev/null +++ b/src/databricks/labs/blueprint/_logging_context.py @@ -0,0 +1,179 @@ +"""internall plumbing for passing logging context (dict) to logger instances""" + +import dataclasses +import inspect +import logging +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +from contextvars import ContextVar +from functools import partial, wraps +from types import MappingProxyType +from typing import TYPE_CHECKING, Annotated, Any, TypeVar, get_origin + +if TYPE_CHECKING: + T = TypeVar("T") + # SkipLogging[list[str]] will be treated by type checkers as list[str], because that's what Annotated is on runtime + # if this workaround is not in place, caller will complain about having wrong typing + # https://github.com/python/typing/discussions/1229 + SkipLogging = Annotated[T, ...] +else: + + @dataclasses.dataclass(slots=True) + class SkipLogging: + """`@logging_context_params` will ignore parameters annotated with this class.""" + + def __class_getitem__(cls, item: Any) -> Any: + return Annotated[item, SkipLogging()] + + +_CTX: ContextVar = ContextVar("ctx", default={}) + + +def _params_str(params: dict[str, Any]): + return ", ".join(f"{k}={v!r}" for k, v in params.items()) + + +def _get_skip_logging_param_names(sig: inspect.Signature): + """Generates list of parameters names having SkipLogging annotation""" + for name, param in sig.parameters.items(): + ann = param.annotation + + # only consider annotation + if not ann or get_origin(ann) is not Annotated: + continue + + # there can be many annotations for each param + for meta in ann.__metadata__: + # type checker thinks SkipLogging is a generic, despite it being Annotated + if meta and isinstance(meta, SkipLogging): # type: ignore + yield name + + +def _skip_dict_key(params: dict, keys_to_skip: set): + return {k: v for k, v in params.items() if k not in keys_to_skip} + + +def current_context(): + """Returns dictionary of current context set via `with loggin_context(...)` context manager or `@logging_context_params` decorator + + Example: + current_context() + >>> {'foo': 'bar', 'a': 2} + + """ + return _CTX.get() + + +def current_context_repr(): + """Returns repr like "key1=val1, key2=val2" string representation of current_context(), or "" in case context is empty""" + return _params_str(current_context()) + + +@contextmanager +def logging_context(**kwds): + """Context manager adding keywords to current loging context. Thread and async safe. + + Example: + with logging_context(foo="bar", a=2): + logger.info("hello") + >>> 2025-06-06 07:15:09,329 - __main__ - INFO - hello (foo='bar', a=2) + """ + # Get the current context and update it with new keywords + current_ctx = _CTX.get() + new_ctx = {**current_ctx, **kwds} + token = _CTX.set(MappingProxyType(new_ctx)) + try: + yield _CTX.get() + except Exception as e: + # python 3.11+: https://docs.python.org/3.11/tutorial/errors.html#enriching-exceptions-with-notes + # https://peps.python.org/pep-0678/ + if hasattr(e, "add_note"): + # __notes__ list[str] is only defined if notes were added, otherwise is not there + # we only want to add note if there was no note before, otherwise chaining context cause sproblems + if not getattr(e, "__notes__", None): + e.add_note(f"Context: {_params_str(current_context())}") + + raise + finally: + _CTX.reset(token) + + +def logging_context_params(func=None, **extra_context): + """Decorator that automatically adds all the function parameters to current logging context. + + Any passed keyward arguments in will be added to the context. Function parameters take precendnce over the extra keywords in case the names would overlap. + + Parameters annotated with `SkipLogging` will be ignored from being added to the context. + + Example: + + @logging_context_params(foo="bar") + def do_math(a: int, b: SkipLogging[int]): + r = pow(a, b) + logger.info(f"result of {a}**{b} is {r}") + return r + + >>> 2025-06-06 07:15:09,329 - __main__ - INFO - result of 2**8 is 256 (foo='bar', a=2) + + Note: + - `a` parameter will be logged, type annotation is optional + - `b` parameter wont be logged because is it is annotated with `SkipLogging` + - `foo` parameter will be logged because it is passed as extra context to the decorator + + """ + + if func is None: + return partial(logging_context_params, **extra_context) + + # will use function's singature to bind positional params to name of the param + sig = inspect.signature(func) + skip_params = set(_get_skip_logging_param_names(sig)) + + @wraps(func) + def wrapper(*args, **kwds): + # only bind if there are positional args + # extra context has lower priority than any of the args + # skip_params is used to filter out parameters that are annotated with SkipLogging + + if args: + bound = sig.bind(*args, **kwds) + ctx_data = {**extra_context, **_skip_dict_key(bound.arguments, skip_params)} + else: + ctx_data = {**extra_context, **_skip_dict_key(kwds, skip_params)} + + with logging_context(**ctx_data): + return func(*args, **kwds) + + return wrapper + + +class LoggingContextInjectingFilter(logging.Filter): + """Adds current_context() to the log record.""" + + def filter(self, record): + # https://docs.python.org/3/howto/logging-cookbook.html#using-filters-to-impart-contextual-information + # https://docs.python.org/3/howto/logging-cookbook.html#use-of-contextvars + ctx = current_context() + record.context = f"{_params_str(ctx)}" if ctx else "" + record.context_msg = f" ({record.context})" if record.context else "" + return True + + +class LoggingThreadPoolExecutor(ThreadPoolExecutor): + """ThreadPoolExecutor drop in replacement that will apply current loging context to all new started threads.""" + + def __init__(self, max_workers=None, thread_name_prefix="", initializer=None, initargs=()): + self.__current_context = current_context() + self.__wrapped_initializer = initializer + + super().__init__( + max_workers=max_workers, + thread_name_prefix=thread_name_prefix, + initializer=self._logging_context_init, + initargs=initargs, + ) + + def _logging_context_init(self, *args): + _CTX.set(self.__current_context) + if self.__wrapped_initializer: + self.__wrapped_initializer(*args) diff --git a/src/databricks/labs/blueprint/logger.py b/src/databricks/labs/blueprint/logger.py index b686dfe..02063e8 100644 --- a/src/databricks/labs/blueprint/logger.py +++ b/src/databricks/labs/blueprint/logger.py @@ -4,6 +4,23 @@ import sys from typing import TextIO +from ._logging_context import ( + LoggingContextInjectingFilter, + SkipLogging, + current_context, + logging_context, + logging_context_params, +) + +__all__ = [ + "NiceFormatter", + "install_logger", + "current_context", + "SkipLogging", + "logging_context_params", + "logging_context", +] + class NiceFormatter(logging.Formatter): """A nice formatter for logging. It uses colors and bold text if the console supports it.""" @@ -36,7 +53,7 @@ def __init__(self, *, probe_tty: bool = False, stream: TextIO = sys.stdout) -> N stream: the output stream to which the formatter will write, used to check if it is a console. probe_tty: If true, the formatter will enable color support if the output stream appears to be a console. """ - super().__init__(fmt="%(asctime)s %(levelname)s [%(name)s] %(message)s", datefmt="%H:%M:%S") + super().__init__(fmt="%(asctime)s %(levelname)s [%(name)s] %(message)s%(context_msg)s", datefmt="%H:%M:%S") # Used to colorize the level names. self._levels = { logging.DEBUG: self._bold(f"{self.CYAN} DEBUG"), @@ -88,7 +105,12 @@ def format(self, record: logging.LogRecord) -> str: color_marker = self._msg_colors[record.levelno] thread_name = f"[{record.threadName}]" if record.threadName != "MainThread" else "" - return f"{self.GRAY}{timestamp}{self.RESET} {level} {color_marker}[{name}]{thread_name} {msg}{self.RESET}" + + # safe check, just in case injection filter is removed + context_repr = record.context if hasattr(record, "context") else "" + context_msg = f" {self.GRAY}({context_repr}){self.RESET}" if context_repr else "" + + return f"{self.GRAY}{timestamp}{self.RESET} {level} {color_marker}[{name}]{thread_name} {msg}{self.RESET}{context_msg}" def install_logger( @@ -102,6 +124,7 @@ def install_logger( - All existing handlers will be removed. - A new handler will be installed with our custom formatter. It will be configured to emit logs at the given level (default: DEBUG) or higher, to the specified stream (default: sys.stderr). + - A new (injection) filter for adding logger_context will be added, that will add `context` with current context, to all logger messages. Args: level: The logging level to set for the console handler. @@ -115,6 +138,8 @@ def install_logger( root.removeHandler(handler) console_handler = logging.StreamHandler(stream) console_handler.setFormatter(NiceFormatter(stream=stream)) + console_handler.addFilter(LoggingContextInjectingFilter()) console_handler.setLevel(level) + root.addHandler(console_handler) return console_handler diff --git a/src/databricks/labs/blueprint/parallel.py b/src/databricks/labs/blueprint/parallel.py index 8be6359..157cd84 100644 --- a/src/databricks/labs/blueprint/parallel.py +++ b/src/databricks/labs/blueprint/parallel.py @@ -8,9 +8,10 @@ import re import threading from collections.abc import Callable, Collection, Sequence -from concurrent.futures import ThreadPoolExecutor from typing import Generic, TypeVar +from ._logging_context import LoggingThreadPoolExecutor + MIN_THREADS = 8 Result = TypeVar("Result") @@ -61,12 +62,12 @@ def gather( return cls(name, tasks, num_threads=num_threads)._run() @classmethod - def strict(cls, name: str, tasks: Sequence[Task[Result]]) -> Collection[Result]: + def strict(cls, name: str, tasks: Sequence[Task[Result]], num_threads: int | None = None) -> Collection[Result]: """Run tasks in parallel and raise ManyError if any task fails""" # this dunder variable is hiding this method from tracebacks, making it cleaner # for the user to see the actual error without too much noise. __tracebackhide__ = True # pylint: disable=unused-variable - collected, errs = cls.gather(name, tasks) + collected, errs = cls.gather(name, tasks, num_threads) if errs: if len(errs) == 1: raise errs[0] @@ -114,7 +115,7 @@ def _on_finish(self, given_cnt: int, collected_cnt: int, failed_cnt: int): def _execute(self): """Run tasks in parallel and return futures""" thread_name_prefix = re.sub(r"\W+", "_", self._name) - with ThreadPoolExecutor(self._num_threads, thread_name_prefix) as pool: + with LoggingThreadPoolExecutor(self._num_threads, thread_name_prefix) as pool: futures = [] for task in self._tasks: if task is None: diff --git a/tests/unit/test_logger.py b/tests/unit/test_logger.py index ecfcf84..928220e 100644 --- a/tests/unit/test_logger.py +++ b/tests/unit/test_logger.py @@ -11,7 +11,12 @@ import pytest -from databricks.labs.blueprint.logger import NiceFormatter, install_logger +from databricks.labs.blueprint._logging_context import LoggingContextInjectingFilter +from databricks.labs.blueprint.logger import ( + NiceFormatter, + install_logger, + logging_context, +) class LogCaptureHandler(logging.Handler): @@ -32,6 +37,7 @@ def emit(self, record: logging.LogRecord) -> None: def record_capturing(cls, logger: logging.Logger) -> Generator[LogCaptureHandler, None, None]: """Temporarily capture all log records, in addition to existing handling.""" handler = LogCaptureHandler() + handler.addFilter(LoggingContextInjectingFilter()) logger.addHandler(handler) try: yield handler @@ -49,7 +55,9 @@ class LoggingSystemFixture: def __init__(self) -> None: self.output_buffer = io.StringIO() self.root = logging.RootLogger(logging.WARNING) - self.root.addHandler(logging.StreamHandler(self.output_buffer)) + handler = logging.StreamHandler(self.output_buffer) + handler.addFilter(LoggingContextInjectingFilter()) + self.root.addHandler(handler) self.manager = logging.Manager(self.root) def getLogger(self, name: str) -> logging.Logger: @@ -84,6 +92,8 @@ def test_install_logger(logging_system) -> None: # Verify that the root logger was configured as expected. assert root.level == logging.FATAL # remains unchanged assert root.handlers == [handler] + assert len(handler.filters) == 1 + assert isinstance(handler.filters[0], LoggingContextInjectingFilter) assert handler.level == logging.INFO assert isinstance(handler.formatter, NiceFormatter) @@ -98,7 +108,8 @@ def test_installed_logger_logging(logging_system) -> None: logger = logging_system.getLogger(__file__) logger.debug("This is a debug message") logger.info("This is an info message") - logger.warning("This is a warning message") + with logging_context(foo="bar-warning"): + logger.warning("This is a warning message") logger.error("This is an error message", exc_info=KeyError(123)) logger.critical("This is a critical message") @@ -107,6 +118,7 @@ def test_installed_logger_logging(logging_system) -> None: assert "This is a debug message" in output assert "This is an info message" in output assert "This is a warning message" in output + assert "(foo='bar-warning')" in output assert "This is an error message\nKeyError: 123" in output assert "This is a critical message" in output @@ -348,3 +360,22 @@ def test_formatter_format_exception(use_colors: bool) -> None: " raise RuntimeError(exception_message)", ] assert exception == "RuntimeError: Test exception." + + +@pytest.mark.parametrize("use_colors", (True, False), ids=("with_colors", "without_colors")) +def test_formatter_with_logging_context(use_colors: bool) -> None: + """Ensure the formatter correctly formats message when logging_context is used""" + formatter = NiceFormatter() + formatter.colors = use_colors + + with logging_context(foo="bar", baz="zak"): + record = create_record(logging.DEBUG, " This is a test message for logging context") + assert hasattr(record, "context") + assert record.context == "foo='bar', baz='zak'" + formatted = formatter.format(record) + assert record.context in formatted, "record context not in formatted" + stripped = _strip_sgr_sequences(formatted) if use_colors else formatted + assert record.context in stripped, "record context not in stripped" + + # H:M:S LEVEL [logger_name] message (logging_context) + assert stripped.endswith(" This is a test message for logging context (foo='bar', baz='zak')") diff --git a/tests/unit/test_logger_context.py b/tests/unit/test_logger_context.py new file mode 100644 index 0000000..a3c628a --- /dev/null +++ b/tests/unit/test_logger_context.py @@ -0,0 +1,203 @@ +import logging +import sys + +from databricks.labs.blueprint._logging_context import LoggingThreadPoolExecutor +from databricks.labs.blueprint.logger import ( + SkipLogging, + current_context, + logging_context, + logging_context_params, +) + +# only python 3.11 supports notes in exceptions, hence assert only on these version... +SUPPORTS_NOTES = (sys.version_info[0], sys.version_info[1]) >= (3, 11) + + +def test_nested_logger_context(): + logger = logging.getLogger(__name__) + + ctx0 = current_context() + logger.info("before entering context") + + with logging_context(user="Alice", action="read") as ctx1: + logger.info("inside of first context") + assert ctx1 == {"user": "Alice", "action": "read"} + + with logging_context(action="write") as ctx2: + logger.info("inner context") + assert ctx2 == {"user": "Alice", "action": "write"} + + logger.info("still inside first") + assert current_context() == ctx1 + + logger.info("after exiting context") + assert current_context() == ctx0 + + with logging_context(user="Bob", action="write") as ctx2: + logger.info("inside of second context") + assert ctx2 == {"user": "Bob", "action": "write"} + + +def test_exception_with_notest_flat(): + logger = logging.getLogger(__name__) + + try: + with logging_context( + user="Alice", + action="read", + ): + logger.info("inside of first context") + with logging_context(top_secret="47"): + 1 / 0 + except Exception as e: + logger.exception(f"Exception! {e}") + if SUPPORTS_NOTES: + assert e.__notes__ == ["Context: user='Alice', action='read', top_secret='47'"] + assert str(e) == "division by zero" + + +def test_exception_with_notest_nested(): + logger = logging.getLogger(__name__) + logger.info("before entering context") + + try: + with logging_context(file="foo.txt"): + with logging_context( + user="Alice", + action="read", + ): + logger.info("inside of first context") + 1 / 0 + except Exception as e: + logger.exception(f"Exception! {e}") + logger.error(f"Error! {e}") + if SUPPORTS_NOTES: + assert e.__notes__ == ["Context: file='foo.txt', user='Alice', action='read'"] + assert str(e) == "division by zero" + + +def test_logging_function_params_empty_deco_call(): + logger = logging.getLogger(__name__) + + @logging_context_params() + def do_math_verbose_test(a: int, b): + r = pow(a, b) + logger.info(f"result of {a}**{b} is {r}") + assert current_context() == {"a": a, "b": b} + return r + + assert current_context() == {} + do_math_verbose_test(2, b=8) + assert current_context() == {} + do_math_verbose_test(2, b=7) + assert current_context() == {} + + +def test_logging_function_params_no_call(): + logger = logging.getLogger(__name__) + + @logging_context_params + def do_math_verbose_test(a: int, b): + r = pow(a, b) + logger.info(f"result of {a}**{b} is {r}") + assert current_context() == {"a": a, "b": b} + return r + + assert current_context() == {} + do_math_verbose_test(2, b=8) + assert current_context() == {} + do_math_verbose_test(2, b=7) + assert current_context() == {} + + +def test_logging_function_skip_loggingl(): + logger = logging.getLogger(__name__) + + @logging_context_params + def do_math_verbose_test(a: SkipLogging[int], b): + r = pow(a, b) + logger.info(f"result of {a}**{b} is {r}") + assert current_context() == {"b": b} + return r + + assert current_context() == {} + do_math_verbose_test(2, b=8) + assert current_context() == {} + do_math_verbose_test(2, b=7) + assert current_context() == {} + + +def test_logging_function_params_shadow_deco_call(): + logger = logging.getLogger(__name__) + + @logging_context_params(a="bar") + def do_math_verbose_test(a: int, b): + r = pow(a, b) + logger.info(f"result of {a}**{b} is {r}") + assert current_context() == {"a": a, "b": b} + return r + + assert current_context() == {} + do_math_verbose_test(2, b=8) + assert current_context() == {} + do_math_verbose_test(2, b=7) + assert current_context() == {} + + +def test_logging_function_params_non_shadow_deco_call(): + logger = logging.getLogger(__name__) + + @logging_context_params(foo="bar") + def do_math_verbose_test(a: int, b): + r = pow(a, b) + logger.info(f"result of {a}**{b} is {r}") + assert current_context() == {"foo": "bar", "a": a, "b": b} + return r + + assert current_context() == {} + do_math_verbose_test(2, b=8) + assert current_context() == {} + do_math_verbose_test(2, b=7) + assert current_context() == {} + + +def test_logging_function_params_multiple_contexts(): + logger = logging.getLogger(__name__) + + @logging_context_params(foo="bar") + def do_math_verbose_test(a: int, b): + r = pow(a, b) + logger.info(f"result of {a}**{b} is {r}") + assert current_context() == {"foo": "bar", "a": a, "b": b, "x": "6"} + return r + + with logging_context(x="6"): + do_math_verbose_test(2, b=8) + + +def test_logging_thread_pool(): + logger = logging.getLogger(__name__) + + @logging_context_params(foo="bar") + def do_math_verbose(a, b: int): + r = pow(a, b) + logger.info(f"result of {a}**{b} is {r}") + assert current_context() == {"foo": "bar", "a": a, "b": b, "user": "Alice"} + return r + + def do_math_verbose_without_context(a, b: int): + r = pow(a, b) + logger.info(f"result of {a}**{b} is {r}") + assert current_context() == {"foo": "bar" if a == 2 else "zar", "a": a, "b": b, "user": "Alice"} + return r + + with logging_context(user="Alice"): + futures = [] + with LoggingThreadPoolExecutor(max_workers=1) as executor: + futures.append(executor.submit(do_math_verbose, 2, 2)) + futures.append(executor.submit(do_math_verbose, 2, 6)) + futures.append(executor.submit(logging_context_params(foo="zar")(do_math_verbose_without_context), 3, 8)) + futures.append(executor.submit(logging_context_params(foo="zar")(do_math_verbose_without_context), 3, 12)) + + for f in futures: + f.result() diff --git a/tests/unit/test_parallel.py b/tests/unit/test_parallel.py index b1cf174..223bd59 100644 --- a/tests/unit/test_parallel.py +++ b/tests/unit/test_parallel.py @@ -1,12 +1,17 @@ import logging import os +import sys from functools import partial from unittest.mock import MagicMock, patch from databricks.sdk.core import DatabricksError +from databricks.labs.blueprint.logger import logging_context_params from databricks.labs.blueprint.parallel import Threads +# only python 3.11 supports notes in exceptions, hence test only on these version... +SUPPORTS_NOTES = (sys.version_info[0], sys.version_info[1]) >= (3, 11) + def _predictable_messages(caplog): res = [] @@ -152,6 +157,47 @@ def fails_on_odd(n=1, dummy=None): "testing(n=1) task failed: failed", ] == _predictable_messages(caplog) + # not context, no notes + for e in errors: + assert getattr(e, "__notes__", None) is None + + +def test_odd_partial_failed_with_context(caplog): + caplog.set_level(logging.INFO) + + # it will push context information into notes into Execeptions + @logging_context_params + def fails_on_odd(n=1, dummy=None): + if isinstance(n, str): + raise RuntimeError("strings are not supported!") + + if n % 2: + msg = "failed" + raise DatabricksError(msg) + + tasks = [ + partial(fails_on_odd, n=1), + partial(fails_on_odd, 1, dummy="6"), + partial(fails_on_odd), + partial(fails_on_odd, n="aaa"), + ] + + results, errors = Threads.gather("testing", tasks) + + assert [] == results + assert 4 == len(errors) + assert [ + "All 'testing' tasks failed!!!", + "testing task failed: failed", + "testing(1, dummy='6') task failed: failed", + "testing(n='aaa') task failed: strings are not supported!", + "testing(n=1) task failed: failed", + ] == _predictable_messages(caplog) + + if SUPPORTS_NOTES: + for e in errors: + assert e.__notes__ is not None + def test_cpu_count() -> None: """Verify a CPU count is available."""