From 0e4aefb3c1849d120ac9aca27f45f721ee7e6b1a Mon Sep 17 00:00:00 2001 From: hkr <104939283+harishkesavarao@users.noreply.github.com> Date: Tue, 22 Jul 2025 07:46:17 +0530 Subject: [PATCH 01/11] feat(airflow): apache airflow providers openlineage for datahub --- metadata-ingestion-modules/airflow-plugin/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metadata-ingestion-modules/airflow-plugin/setup.py b/metadata-ingestion-modules/airflow-plugin/setup.py index 8b09b9ecd94cb..218768c79c55a 100644 --- a/metadata-ingestion-modules/airflow-plugin/setup.py +++ b/metadata-ingestion-modules/airflow-plugin/setup.py @@ -44,7 +44,7 @@ def get_long_description(): # We remain restrictive on the versions allowed here to prevent # us from being broken by backwards-incompatible changes in the # underlying package. - "openlineage-airflow>=1.2.0,<=1.30.1", + "apache-airflow-providers-openlineage>=1.1.0,<2.5.0", }, } From f27cbfcde14d01eb5da1a04438401f72ad5a424c Mon Sep 17 00:00:00 2001 From: hkr <104939283+harishkesavarao@users.noreply.github.com> Date: Sun, 3 Aug 2025 00:31:34 +0530 Subject: [PATCH 02/11] feature(airflow openlineage plugin): check and change dependencies for imports in the listener --- .../datahub_listener.py | 48 +++++++++++++++---- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py index dfb92573dbf46..d45bbbe82daed 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py @@ -11,9 +11,32 @@ from airflow.models import Variable from airflow.models.operator import Operator from airflow.models.serialized_dag import SerializedDagModel -from openlineage.airflow.listener import TaskHolder -from openlineage.airflow.utils import redact_with_exclusions -from openlineage.client.serde import Serde +# TODO: to change to Airflow plugin +# from openlineage.airflow.listener import TaskHolder +# Ref: https://github.com/apache/airflow/blob/main/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py +from airflow.providers.openlineage.plugins.listener import get_openlineage_listener +# TODO: to change to Airflow plugin +# from openlineage.airflow.utils import redact_with_exclusions +from airflow.providers.openlineage.utils.utils import ( + AIRFLOW_V_3_0_PLUS, + get_airflow_dag_run_facet, + get_airflow_debug_facet, + get_airflow_job_facet, + get_airflow_mapped_task_facet, + get_airflow_run_facet, + get_job_name, + get_task_parent_run_facet, + get_user_provided_run_facets, + is_operator_disabled, + is_selective_lineage_enabled, + print_warning, +) +from airflow.providers.openlineage import conf +# from airflow.providers.openlineage.extractors import ExtractorManager, OperatorLineage +from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter, RunState +# TODO: to change to Airflow plugin +# from openlineage.client.serde import Serde +from airflow.providers.openlineage.client.serde import Serde import datahub.emitter.mce_builder as builder from datahub.api.entities.datajob import DataJob @@ -87,7 +110,7 @@ def hookimpl(f: _F) -> _F: # type: ignore[misc] KILL_SWITCH_VARIABLE_NAME = "datahub_airflow_plugin_disable_listener" - +# verified - airflow openlineage dependencies def get_airflow_plugin_listener() -> Optional["DataHubListener"]: # Using globals instead of functools.lru_cache to make testing easier. global _airflow_listener_initialized @@ -121,13 +144,13 @@ def get_airflow_plugin_listener() -> Optional["DataHubListener"]: if plugin_config.disable_openlineage_plugin: # Deactivate the OpenLineagePlugin listener to avoid conflicts/errors. - from openlineage.airflow.plugin import OpenLineagePlugin + from airflow.providers.openlineage.plugins.openlineage import OpenLineageProviderPlugin - OpenLineagePlugin.listeners = [] + OpenLineageProviderPlugin.listeners = [] return _airflow_listener - +# verified - airflow openlineage dependencies def run_in_thread(f: _F) -> _F: # This is also responsible for catching exceptions and logging them. @@ -147,7 +170,7 @@ def wrapper(*args, **kwargs): if _RUN_IN_THREAD_TIMEOUT > 0: # If _RUN_IN_THREAD_TIMEOUT is 0, we just kick off the thread and move on. # Because it's a daemon thread, it'll be automatically killed when the main - # thread exists. + # thread exits. start_time = time.time() thread.join(timeout=_RUN_IN_THREAD_TIMEOUT) @@ -167,7 +190,7 @@ def wrapper(*args, **kwargs): return cast(_F, wrapper) - +# verified - airflow openlineage dependencies def _render_templates(task_instance: "TaskInstance") -> "TaskInstance": # Render templates in a copy of the task instance. # This is necessary to get the correct operator args in the extractors. @@ -185,6 +208,7 @@ def _render_templates(task_instance: "TaskInstance") -> "TaskInstance": class DataHubListener: __name__ = "DataHubListener" + # verified - airflow openlineage dependencies def __init__(self, config: DatahubLineageConfig): self.config = config self._set_log_level() @@ -195,7 +219,7 @@ def __init__(self, config: DatahubLineageConfig): # See discussion here https://github.com/OpenLineage/OpenLineage/pull/508 for # why we need to keep track of tasks ourselves. - self._task_holder = TaskHolder() + self._open_lineage_listener = get_openlineage_listener() # In our case, we also want to cache the initial datajob object # so that we can add to it when the task completes. @@ -208,10 +232,12 @@ def __init__(self, config: DatahubLineageConfig): # https://github.com/apache/airflow/blob/e99a518970b2d349a75b1647f6b738c8510fa40e/airflow/listeners/listener.py#L56 # self.__class__ = types.ModuleType + # verified - airflow openlineage dependencies @property def emitter(self): return self._emitter + # verified - airflow openlineage dependencies @property def graph(self) -> Optional[DataHubGraph]: if self._graph: @@ -226,6 +252,7 @@ def graph(self) -> Optional[DataHubGraph]: return self._graph + # verified - airflow openlineage dependencies def _set_log_level(self) -> None: """Set the log level for the plugin and its dependencies. @@ -240,6 +267,7 @@ def _set_log_level(self) -> None: if self.config.debug_emitter: logging.getLogger("datahub.emitter").setLevel(logging.DEBUG) + # verified - airflow openlineage dependencies def _make_emit_callback(self) -> Callable[[Optional[Exception], str], None]: def emit_callback(err: Optional[Exception], msg: str) -> None: if err: From ebf2ebac3091b47ded602a65f640d13eafed39b8 Mon Sep 17 00:00:00 2001 From: hkr <104939283+harishkesavarao@users.noreply.github.com> Date: Sun, 10 Aug 2025 18:53:28 +0530 Subject: [PATCH 03/11] feat(on_task_instance_running): WIP with new plugin version --- .../datahub_listener.py | 1007 ++++------------- 1 file changed, 232 insertions(+), 775 deletions(-) diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py index d45bbbe82daed..960ebb65dbc84 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py @@ -1,824 +1,281 @@ -import asyncio +# Copyright 2018-2025 contributors to the OpenLineage project +# SPDX-License-Identifier: Apache-2.0 + import copy -import functools import logging -import os import threading -import time -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast - -import airflow -from airflow.models import Variable -from airflow.models.operator import Operator -from airflow.models.serialized_dag import SerializedDagModel -# TODO: to change to Airflow plugin -# from openlineage.airflow.listener import TaskHolder -# Ref: https://github.com/apache/airflow/blob/main/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py -from airflow.providers.openlineage.plugins.listener import get_openlineage_listener -# TODO: to change to Airflow plugin -# from openlineage.airflow.utils import redact_with_exclusions -from airflow.providers.openlineage.utils.utils import ( - AIRFLOW_V_3_0_PLUS, - get_airflow_dag_run_facet, - get_airflow_debug_facet, - get_airflow_job_facet, - get_airflow_mapped_task_facet, +from concurrent.futures import Executor, ThreadPoolExecutor +from typing import TYPE_CHECKING, Callable, Optional + +from openlineage.airflow.adapter import OpenLineageAdapter +from openlineage.airflow.extractors import ExtractorManager +from openlineage.airflow.utils import ( + DagUtils, get_airflow_run_facet, + get_custom_facets, + get_dagrun_start_end, get_job_name, - get_task_parent_run_facet, - get_user_provided_run_facets, - is_operator_disabled, - is_selective_lineage_enabled, - print_warning, -) -from airflow.providers.openlineage import conf -# from airflow.providers.openlineage.extractors import ExtractorManager, OperatorLineage -from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter, RunState -# TODO: to change to Airflow plugin -# from openlineage.client.serde import Serde -from airflow.providers.openlineage.client.serde import Serde - -import datahub.emitter.mce_builder as builder -from datahub.api.entities.datajob import DataJob -from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult -from datahub.emitter.mce_builder import ( - make_data_platform_urn, - make_dataplatform_instance_urn, -) -from datahub.emitter.mcp import MetadataChangeProposalWrapper -from datahub.emitter.rest_emitter import DatahubRestEmitter -from datahub.ingestion.graph.client import DataHubGraph -from datahub.metadata.schema_classes import ( - BrowsePathEntryClass, - BrowsePathsV2Class, - DataFlowKeyClass, - DataJobKeyClass, - DataPlatformInstanceClass, - FineGrainedLineageClass, - FineGrainedLineageDownstreamTypeClass, - FineGrainedLineageUpstreamTypeClass, - OperationClass, - OperationTypeClass, - StatusClass, -) -from datahub.sql_parsing.sqlglot_lineage import SqlParsingResult -from datahub.telemetry import telemetry -from datahub_airflow_plugin._airflow_shims import ( - HAS_AIRFLOW_DAG_LISTENER_API, - HAS_AIRFLOW_DATASET_LISTENER_API, - get_task_inlets, - get_task_outlets, -) -from datahub_airflow_plugin._config import DatahubLineageConfig, get_lineage_config -from datahub_airflow_plugin._datahub_ol_adapter import translate_ol_to_datahub_urn -from datahub_airflow_plugin._extractors import SQL_PARSING_RESULT_KEY, ExtractorManager -from datahub_airflow_plugin._version import __package_name__, __version__ -from datahub_airflow_plugin.client.airflow_generator import AirflowGenerator -from datahub_airflow_plugin.entities import ( - _Entity, - entities_to_datajob_urn_list, - entities_to_dataset_urn_list, + get_task_location, + getboolean, + is_airflow_version_enough, ) -_F = TypeVar("_F", bound=Callable[..., None]) +from airflow.listeners import hookimpl +from airflow.utils import timezone + if TYPE_CHECKING: - from airflow.datasets import Dataset - from airflow.models import DAG, DagRun, TaskInstance from sqlalchemy.orm import Session - # To placate mypy on Airflow versions that don't have the listener API, - # we define a dummy hookimpl that's an identity function. + from airflow.models import BaseOperator, DagRun, TaskInstance - def hookimpl(f: _F) -> _F: # type: ignore[misc] - return f -else: - from airflow.listeners import hookimpl +class TaskHolder: + """Class that stores run data - run_id and task in-memory. This is needed because Airflow + does not always pass all runtime info to on_task_instance_success and + on_task_instance_failed that is needed to emit events. This is not a big problem since + we're only running on worker - in separate process that is always spawned (or forked) on + execution, just like old PHP runtime model. + """ -logger = logging.getLogger(__name__) + def __init__(self): + self.run_data = {} -_airflow_listener_initialized = False -_airflow_listener: Optional["DataHubListener"] = None -_RUN_IN_THREAD = os.getenv("DATAHUB_AIRFLOW_PLUGIN_RUN_IN_THREAD", "true").lower() in ( - "true", - "1", -) -_RUN_IN_THREAD_TIMEOUT = float( - os.getenv("DATAHUB_AIRFLOW_PLUGIN_RUN_IN_THREAD_TIMEOUT", 10) -) -_DATAHUB_CLEANUP_DAG = "Datahub_Cleanup" + def set_task(self, task_instance: "TaskInstance"): + self.run_data[self._pk(task_instance)] = task_instance.task -KILL_SWITCH_VARIABLE_NAME = "datahub_airflow_plugin_disable_listener" + def get_task(self, task_instance: "TaskInstance") -> Optional["BaseOperator"]: + return self.run_data.get(self._pk(task_instance)) -# verified - airflow openlineage dependencies -def get_airflow_plugin_listener() -> Optional["DataHubListener"]: - # Using globals instead of functools.lru_cache to make testing easier. - global _airflow_listener_initialized - global _airflow_listener + @staticmethod + def _pk(ti: "TaskInstance"): + return ti.dag_id + ti.task_id + ti.run_id - if not _airflow_listener_initialized: - _airflow_listener_initialized = True - plugin_config = get_lineage_config() +log = logging.getLogger(__name__) +# TODO: move task instance runs to executor +executor: Optional[Executor] = None - if plugin_config.enabled: - _airflow_listener = DataHubListener(config=plugin_config) - logger.info( - f"DataHub plugin v2 (package: {__package_name__} and version: {__version__}) listener initialized with config: {plugin_config}" - ) - telemetry.telemetry_instance.ping( - "airflow-plugin-init", - { - "airflow-version": airflow.__version__, - "datahub-airflow-plugin": "v2", - "datahub-airflow-plugin-dag-events": HAS_AIRFLOW_DAG_LISTENER_API, - "datahub-airflow-plugin-dataset-events": HAS_AIRFLOW_DATASET_LISTENER_API, - "capture_executions": plugin_config.capture_executions, - "capture_tags": plugin_config.capture_tags_info, - "capture_ownership": plugin_config.capture_ownership_info, - "enable_extractors": plugin_config.enable_extractors, - "render_templates": plugin_config.render_templates, - "disable_openlineage_plugin": plugin_config.disable_openlineage_plugin, - }, - ) - if plugin_config.disable_openlineage_plugin: - # Deactivate the OpenLineagePlugin listener to avoid conflicts/errors. - from airflow.providers.openlineage.plugins.openlineage import OpenLineageProviderPlugin +def execute_in_thread(target: Callable, kwargs=None): + if kwargs is None: + kwargs = {} + thread = threading.Thread(target=target, kwargs=kwargs, daemon=True) + thread.start() - OpenLineageProviderPlugin.listeners = [] + # Join, but ignore checking if thread stopped. If it did, then we shouldn't do anything. + # This basically gives this thread 5 seconds to complete work, then it can be killed, + # as daemon=True. We don't want to deadlock Airflow if our code hangs. - return _airflow_listener + # This will hang if this timeouts, and extractor is running non-daemon thread inside, + # since it will never be cleaned up. Ex. SnowflakeOperator + thread.join(timeout=10) -# verified - airflow openlineage dependencies -def run_in_thread(f: _F) -> _F: - # This is also responsible for catching exceptions and logging them. - @functools.wraps(f) - def wrapper(*args, **kwargs): - try: - if _RUN_IN_THREAD: - # A poor-man's timeout mechanism. - # This ensures that we don't hang the task if the extractors - # are slow or the DataHub API is slow to respond. - - thread = threading.Thread( - target=f, args=args, kwargs=kwargs, daemon=True - ) - thread.start() - - if _RUN_IN_THREAD_TIMEOUT > 0: - # If _RUN_IN_THREAD_TIMEOUT is 0, we just kick off the thread and move on. - # Because it's a daemon thread, it'll be automatically killed when the main - # thread exits. - - start_time = time.time() - thread.join(timeout=_RUN_IN_THREAD_TIMEOUT) - if thread.is_alive(): - logger.warning( - f"Thread for {f.__name__} is still running after {_RUN_IN_THREAD_TIMEOUT} seconds. " - "Continuing without waiting for it to finish." - ) - else: - logger.debug( - f"Thread for {f.__name__} finished after {time.time() - start_time} seconds" - ) - else: - f(*args, **kwargs) - except Exception as e: - logger.warning(e, exc_info=True) - - return cast(_F, wrapper) - -# verified - airflow openlineage dependencies -def _render_templates(task_instance: "TaskInstance") -> "TaskInstance": - # Render templates in a copy of the task instance. - # This is necessary to get the correct operator args in the extractors. - try: - task_instance_copy = copy.deepcopy(task_instance) - task_instance_copy.render_templates() - return task_instance_copy - except Exception as e: - logger.info( - f"Error rendering templates in DataHub listener. Jinja-templated variables will not be extracted correctly: {e}. Template rendering improves SQL parsing accuracy. If this causes issues, you can disable it by setting `render_templates` to `false` in the DataHub plugin configuration." - ) - return task_instance - - -class DataHubListener: - __name__ = "DataHubListener" - - # verified - airflow openlineage dependencies - def __init__(self, config: DatahubLineageConfig): - self.config = config - self._set_log_level() - - self._emitter = config.make_emitter_hook().make_emitter() - self._graph: Optional[DataHubGraph] = None - logger.info(f"DataHub plugin v2 using {repr(self._emitter)}") - - # See discussion here https://github.com/OpenLineage/OpenLineage/pull/508 for - # why we need to keep track of tasks ourselves. - self._open_lineage_listener = get_openlineage_listener() - - # In our case, we also want to cache the initial datajob object - # so that we can add to it when the task completes. - self._datajob_holder: Dict[str, DataJob] = {} - - self.extractor_manager = ExtractorManager() - - # This "inherits" from types.ModuleType to avoid issues with Airflow's listener plugin loader. - # It previously (v2.4.x and likely other versions too) would throw errors if it was not a module. - # https://github.com/apache/airflow/blob/e99a518970b2d349a75b1647f6b738c8510fa40e/airflow/listeners/listener.py#L56 - # self.__class__ = types.ModuleType - - # verified - airflow openlineage dependencies - @property - def emitter(self): - return self._emitter - - # verified - airflow openlineage dependencies - @property - def graph(self) -> Optional[DataHubGraph]: - if self._graph: - return self._graph - - if isinstance(self._emitter, DatahubRestEmitter) and not isinstance( - self._emitter, DataHubGraph - ): - # This is lazy initialized to avoid throwing errors on plugin load. - self._graph = self._emitter.to_graph() - self._emitter = self._graph - - return self._graph - - # verified - airflow openlineage dependencies - def _set_log_level(self) -> None: - """Set the log level for the plugin and its dependencies. - - This may need to be called multiple times, since Airflow sometimes - messes with the logging configuration after the plugin is loaded. - In particular, the loggers may get changed when the worker starts - executing a task. - """ - - if self.config.log_level: - logging.getLogger(__name__.split(".")[0]).setLevel(self.config.log_level) - if self.config.debug_emitter: - logging.getLogger("datahub.emitter").setLevel(logging.DEBUG) - - # verified - airflow openlineage dependencies - def _make_emit_callback(self) -> Callable[[Optional[Exception], str], None]: - def emit_callback(err: Optional[Exception], msg: str) -> None: - if err: - logger.error(f"Error sending metadata to datahub: {msg}", exc_info=err) - - return emit_callback - - def _extract_lineage( - self, - datajob: DataJob, - dagrun: "DagRun", - task: "Operator", - task_instance: "TaskInstance", - complete: bool = False, - ) -> None: - """ - Combine lineage (including column lineage) from task inlets/outlets and - extractor-generated task_metadata and write it to the datajob. This - routine is also responsible for converting the lineage to DataHub URNs. - """ - - if not self.config.enable_datajob_lineage: - return - - input_urns: List[str] = [] - output_urns: List[str] = [] - fine_grained_lineages: List[FineGrainedLineageClass] = [] - - task_metadata = None - if self.config.enable_extractors: - task_metadata = self.extractor_manager.extract_metadata( - dagrun, - task, - complete=complete, - task_instance=task_instance, - task_uuid=str(datajob.urn), - graph=self.graph, - ) - logger.debug(f"Got task metadata: {task_metadata}") +task_holder = TaskHolder() +extractor_manager = ExtractorManager() +adapter = OpenLineageAdapter() - # Translate task_metadata.inputs/outputs to DataHub URNs. - input_urns.extend( - translate_ol_to_datahub_urn(dataset) for dataset in task_metadata.inputs - ) - output_urns.extend( - translate_ol_to_datahub_urn(dataset) - for dataset in task_metadata.outputs - ) - # Add DataHub-native SQL parser results. - sql_parsing_result: Optional[SqlParsingResult] = None - if task_metadata: - sql_parsing_result = task_metadata.run_facets.pop( - SQL_PARSING_RESULT_KEY, None - ) - if sql_parsing_result: - if error := sql_parsing_result.debug_info.error: - logger.info(f"SQL parsing error: {error}", exc_info=error) - datajob.properties["datahub_sql_parser_error"] = ( - f"{type(error).__name__}: {error}" - ) - if not sql_parsing_result.debug_info.table_error: - input_urns.extend(sql_parsing_result.in_tables) - output_urns.extend(sql_parsing_result.out_tables) - - if sql_parsing_result.column_lineage: - fine_grained_lineages.extend( - FineGrainedLineageClass( - upstreamType=FineGrainedLineageUpstreamTypeClass.FIELD_SET, - downstreamType=FineGrainedLineageDownstreamTypeClass.FIELD, - upstreams=[ - builder.make_schema_field_urn( - upstream.table, upstream.column - ) - for upstream in column_lineage.upstreams - ], - downstreams=[ - builder.make_schema_field_urn( - downstream.table, downstream.column - ) - for downstream in [column_lineage.downstream] - if downstream.table - ], - ) - for column_lineage in sql_parsing_result.column_lineage - ) - - # Add DataHub-native inlets/outlets. - # These are filtered out by the extractor, so we need to add them manually. - input_urns.extend( - iolet.urn for iolet in get_task_inlets(task) if isinstance(iolet, _Entity) - ) - output_urns.extend( - iolet.urn for iolet in get_task_outlets(task) if isinstance(iolet, _Entity) - ) +def direct_execution(): + return is_airflow_version_enough("2.6.0") or getboolean( + "OPENLINEAGE_AIRFLOW_ENABLE_DIRECT_EXECUTION", False + ) - # Write the lineage to the datajob object. - datajob.inlets.extend(entities_to_dataset_urn_list(input_urns)) - datajob.outlets.extend(entities_to_dataset_urn_list(output_urns)) - datajob.upstream_urns.extend(entities_to_datajob_urn_list(input_urns)) - datajob.fine_grained_lineages.extend(fine_grained_lineages) - # Merge in extra stuff that was present in the DataJob we constructed - # at the start of the task. - if complete: - original_datajob = self._datajob_holder.get(str(datajob.urn), None) +def execute(_callable): + try: + if direct_execution(): + _callable() else: - self._datajob_holder[str(datajob.urn)] = datajob - original_datajob = None - - if original_datajob: - logger.debug("Merging start datajob into finish datajob") - datajob.inlets.extend(original_datajob.inlets) - datajob.outlets.extend(original_datajob.outlets) - datajob.upstream_urns.extend(original_datajob.upstream_urns) - datajob.fine_grained_lineages.extend(original_datajob.fine_grained_lineages) - - for k, v in original_datajob.properties.items(): - datajob.properties.setdefault(k, v) - - # Deduplicate inlets/outlets. - datajob.inlets = list(sorted(set(datajob.inlets), key=lambda x: str(x))) - datajob.outlets = list(sorted(set(datajob.outlets), key=lambda x: str(x))) - datajob.upstream_urns = list( - sorted(set(datajob.upstream_urns), key=lambda x: str(x)) - ) - - # Write all other OL facets as DataHub properties. - if task_metadata: - for k, v in task_metadata.job_facets.items(): - datajob.properties[f"openlineage_job_facet_{k}"] = Serde.to_json( - redact_with_exclusions(v) - ) - - for k, v in task_metadata.run_facets.items(): - datajob.properties[f"openlineage_run_facet_{k}"] = Serde.to_json( - redact_with_exclusions(v) - ) - - def check_kill_switch(self): - if Variable.get(KILL_SWITCH_VARIABLE_NAME, "false").lower() == "true": - logger.debug("DataHub listener disabled by kill switch") - return True - return False - - @hookimpl - @run_in_thread - def on_task_instance_running( - self, - previous_state: None, - task_instance: "TaskInstance", - session: "Session", # This will always be QUEUED - ) -> None: - if self.check_kill_switch(): - return - self._set_log_level() - - # This if statement mirrors the logic in https://github.com/OpenLineage/OpenLineage/pull/508. - if not hasattr(task_instance, "task"): - # The type ignore is to placate mypy on Airflow 2.1.x. - logger.warning( - f"No task set for task_id: {task_instance.task_id} - " # type: ignore[attr-defined] - f"dag_id: {task_instance.dag_id} - run_id {task_instance.run_id}" # type: ignore[attr-defined] - ) - return - - logger.debug( - f"DataHub listener got notification about task instance start for {task_instance.task_id} of dag {task_instance.dag_id}" + execute_in_thread(_callable) + except Exception: + # Make sure we're not failing task, even for things we think can't happen + log.exception("Failed to emit OpenLineage event due to exception") + + +@hookimpl +def on_task_instance_running(previous_state, task_instance: "TaskInstance", session: "Session"): + if not hasattr(task_instance, "task"): + log.warning( + "No task set for TI object task_id: %s - dag_id: %s - run_id %s", + task_instance.task_id, + task_instance.dag_id, + task_instance.run_id, ) + return - if not self.config.dag_filter_pattern.allowed(task_instance.dag_id): - logger.debug(f"DAG {task_instance.dag_id} is not allowed by the pattern") - return - - if self.config.render_templates: - task_instance = _render_templates(task_instance) - - # The type ignore is to placate mypy on Airflow 2.1.x. - dagrun: "DagRun" = task_instance.dag_run # type: ignore[attr-defined] - task = task_instance.task - assert task is not None - dag: "DAG" = task.dag # type: ignore[assignment] + log.debug("OpenLineage listener got notification about task instance start") + dagrun = task_instance.dag_run - self._task_holder.set_task(task_instance) - - # Handle async operators in Airflow 2.3 by skipping deferred state. - # Inspired by https://github.com/OpenLineage/OpenLineage/pull/1601 - if task_instance.next_method is not None: # type: ignore[attr-defined] - return - - datajob = AirflowGenerator.generate_datajob( - cluster=self.config.cluster, - task=task, - dag=dag, - capture_tags=self.config.capture_tags_info, - capture_owner=self.config.capture_ownership_info, - config=self.config, - ) - - # TODO: Make use of get_task_location to extract github urls. - - # Add lineage info. - self._extract_lineage(datajob, dagrun, task, task_instance) - - # TODO: Add handling for Airflow mapped tasks using task_instance.map_index - - for mcp in datajob.generate_mcp( - generate_lineage=self.config.enable_datajob_lineage, - materialize_iolets=self.config.materialize_iolets, - ): - self.emitter.emit(mcp, self._make_emit_callback()) - logger.debug(f"Emitted DataHub Datajob start: {datajob}") - - if self.config.capture_executions: - dpi = AirflowGenerator.run_datajob( - emitter=self.emitter, - config=self.config, - ti=task_instance, - dag=dag, - dag_run=dagrun, - datajob=datajob, - emit_templates=False, + def on_running(): + nonlocal task_instance + try: + ti = copy.deepcopy(task_instance) + except Exception as err: + log.debug( + f"Creating a task instance copy failed; proceeding without rendering templates. Error: {err}" ) - logger.debug(f"Emitted DataHub DataProcess Instance start: {dpi}") - - self.emitter.flush() - - logger.debug( - f"DataHub listener finished processing notification about task instance start for {task_instance.task_id}" - ) - - self.materialize_iolets(datajob) - - def materialize_iolets(self, datajob: DataJob) -> None: - if self.config.materialize_iolets: - for outlet in datajob.outlets: - reported_time: int = int(time.time() * 1000) - operation = OperationClass( - timestampMillis=reported_time, - operationType=OperationTypeClass.CREATE, - lastUpdatedTimestamp=reported_time, - actor=builder.make_user_urn("airflow"), - ) - - operation_mcp = MetadataChangeProposalWrapper( - entityUrn=str(outlet), aspect=operation - ) - - self.emitter.emit(operation_mcp) - logger.debug(f"Emitted Dataset Operation: {outlet}") - else: - if self.graph: - for outlet in datajob.outlets: - if not self.graph.exists(str(outlet)): - logger.warning(f"Dataset {str(outlet)} not materialized") - for inlet in datajob.inlets: - if not self.graph.exists(str(inlet)): - logger.warning(f"Dataset {str(inlet)} not materialized") - - def on_task_instance_finish( - self, task_instance: "TaskInstance", status: InstanceRunResult - ) -> None: - dagrun: "DagRun" = task_instance.dag_run # type: ignore[attr-defined] - - if self.config.render_templates: - task_instance = _render_templates(task_instance) - - # We must prefer the task attribute, in case modifications to the task's inlets/outlets - # were made by the execute() method. - if getattr(task_instance, "task", None): - task = task_instance.task + ti = task_instance else: - task = self._task_holder.get_task(task_instance) - assert task is not None - - dag: "DAG" = task.dag # type: ignore[assignment] - - if not self.config.dag_filter_pattern.allowed(dag.dag_id): - logger.debug(f"DAG {dag.dag_id} is not allowed by the pattern") + ti.render_templates() + + task = ti.task + dag = task.dag + task_holder.set_task(ti) + # that's a workaround to detect task running from deferred state + # we return here because Airflow 2.3 needs task from deferred state + if ti.next_method is not None: return - datajob = AirflowGenerator.generate_datajob( - cluster=self.config.cluster, - task=task, - dag=dag, - capture_tags=self.config.capture_tags_info, - capture_owner=self.config.capture_ownership_info, - config=self.config, + parent_run_id = OpenLineageAdapter.build_dag_run_id( + dag_id=dag.dag_id, + execution_date=dagrun.execution_date, ) - - # Add lineage info. - self._extract_lineage(datajob, dagrun, task, task_instance, complete=True) - - for mcp in datajob.generate_mcp( - generate_lineage=self.config.enable_datajob_lineage, - materialize_iolets=self.config.materialize_iolets, - ): - self.emitter.emit(mcp, self._make_emit_callback()) - logger.debug(f"Emitted DataHub Datajob finish w/ status {status}: {datajob}") - - if self.config.capture_executions: - dpi = AirflowGenerator.complete_datajob( - emitter=self.emitter, - cluster=self.config.cluster, - ti=task_instance, - dag=dag, - dag_run=dagrun, - datajob=datajob, - result=status, - config=self.config, - ) - logger.debug( - f"Emitted DataHub DataProcess Instance with status {status}: {dpi}" - ) - - self.emitter.flush() - - @hookimpl - @run_in_thread - def on_task_instance_success( - self, previous_state: None, task_instance: "TaskInstance", session: "Session" - ) -> None: - if self.check_kill_switch(): - return - - self._set_log_level() - - logger.debug( - f"DataHub listener got notification about task instance success for {task_instance.task_id}" - ) - self.on_task_instance_finish(task_instance, status=InstanceRunResult.SUCCESS) - logger.debug( - f"DataHub listener finished processing task instance success for {task_instance.task_id}" + task_uuid = OpenLineageAdapter.build_task_instance_run_id( + dag_id=dag.dag_id, + task_id=task.task_id, + try_number=ti._try_number, + execution_date=ti.execution_date, ) - @hookimpl - @run_in_thread - def on_task_instance_failed( - self, previous_state: None, task_instance: "TaskInstance", session: "Session" - ) -> None: - if self.check_kill_switch(): - return - - self._set_log_level() - - logger.debug( - f"DataHub listener got notification about task instance failure for {task_instance.task_id}" + task_metadata = extractor_manager.extract_metadata(dagrun, task, task_uuid=task_uuid) + + ti_start_time = ti.start_date if ti.start_date else timezone.utcnow() + start, end = get_dagrun_start_end(dagrun=dagrun, dag=dag) + + adapter.start_task( + run_id=task_uuid, + job_name=get_job_name(task), + job_description=dag.description, + event_time=DagUtils.get_start_time(ti_start_time), + parent_job_name=dag.dag_id, + parent_run_id=parent_run_id, + code_location=get_task_location(task), + nominal_start_time=DagUtils.get_start_time(start), + nominal_end_time=DagUtils.to_iso_8601(end), + owners=dag.owner.split(", "), + task=task_metadata, + run_facets={ + **get_custom_facets(dagrun, task, dagrun.external_trigger, ti), + **get_airflow_run_facet(dagrun, dag, ti, task, task_uuid), + }, ) - # TODO: Handle UP_FOR_RETRY state. - self.on_task_instance_finish(task_instance, status=InstanceRunResult.FAILURE) - logger.debug( - f"DataHub listener finished processing task instance failure for {task_instance.task_id}" + execute(on_running) + + +@hookimpl +def on_task_instance_success(previous_state, task_instance: "TaskInstance", session): + log.debug("OpenLineage listener got notification about task instance success") + task = task_holder.get_task(task_instance) or task_instance.task + dag = task.dag + dagrun = task_instance.dag_run + + parent_run_id = OpenLineageAdapter.build_dag_run_id( + dag_id=dag.dag_id, + execution_date=dagrun.execution_date, + ) + task_uuid = OpenLineageAdapter.build_task_instance_run_id( + dag_id=dag.dag_id, + task_id=task.task_id, + try_number=task_instance._try_number, + execution_date=task_instance.execution_date, + ) + + def on_success(): + task_metadata = extractor_manager.extract_metadata( + dagrun, task, complete=True, task_instance=task_instance ) - - def on_dag_start(self, dag_run: "DagRun") -> None: - dag = dag_run.dag - if not dag: - logger.warning( - f"DataHub listener could not find DAG for {dag_run.dag_id} - {dag_run.run_id}. Dag won't be captured" - ) - return - - dataflow = AirflowGenerator.generate_dataflow( - config=self.config, - dag=dag, + adapter.complete_task( + run_id=task_uuid, + job_name=get_job_name(task), + parent_job_name=dag.dag_id, + parent_run_id=parent_run_id, + end_time=DagUtils.to_iso_8601(task_instance.end_date), + task=task_metadata, ) - dataflow.emit(self.emitter, callback=self._make_emit_callback()) - logger.debug(f"Emitted DataHub DataFlow: {dataflow}") - event: MetadataChangeProposalWrapper = MetadataChangeProposalWrapper( - entityUrn=str(dataflow.urn), aspect=StatusClass(removed=False) + execute(on_success) + + +@hookimpl +def on_task_instance_failed(previous_state, task_instance: "TaskInstance", session): + log.debug("OpenLineage listener got notification about task instance failure") + task = task_holder.get_task(task_instance) or task_instance.task + dag = task.dag + dagrun = task_instance.dag_run + + parent_run_id = OpenLineageAdapter.build_dag_run_id( + dag_id=dag.dag_id, + execution_date=dagrun.execution_date, + ) + task_uuid = OpenLineageAdapter.build_task_instance_run_id( + dag_id=dag.dag_id, + task_id=task.task_id, + try_number=task_instance._try_number, + execution_date=task_instance.execution_date, + ) + + def on_failure(): + task_metadata = extractor_manager.extract_metadata( + dagrun, task, complete=True, task_instance=task_instance ) - self.emitter.emit(event) - for task in dag.tasks: - task_urn = builder.make_data_job_urn_with_flow( - str(dataflow.urn), task.task_id - ) - event = MetadataChangeProposalWrapper( - entityUrn=task_urn, aspect=StatusClass(removed=False) - ) - self.emitter.emit(event) + end_date = task_instance.end_date if task_instance.end_date else timezone.utcnow() - if self.config.platform_instance: - instance = make_dataplatform_instance_urn( - platform="airflow", - instance=self.config.platform_instance, - ) - event = MetadataChangeProposalWrapper( - entityUrn=str(dataflow.urn), - aspect=DataPlatformInstanceClass( - platform=make_data_platform_urn("airflow"), - instance=instance, - ), - ) - self.emitter.emit(event) - - # emit tags - for tag in dataflow.tags: - tag_urn = builder.make_tag_urn(tag) - - event = MetadataChangeProposalWrapper( - entityUrn=tag_urn, aspect=StatusClass(removed=False) - ) - self.emitter.emit(event) - - browsePaths: List[BrowsePathEntryClass] = [] - if self.config.platform_instance: - urn = make_dataplatform_instance_urn( - "airflow", self.config.platform_instance - ) - browsePaths.append(BrowsePathEntryClass(self.config.platform_instance, urn)) - browsePaths.append(BrowsePathEntryClass(str(dag.dag_id))) - browse_path_v2_event: MetadataChangeProposalWrapper = ( - MetadataChangeProposalWrapper( - entityUrn=str(dataflow.urn), - aspect=BrowsePathsV2Class( - path=browsePaths, - ), - ) + adapter.fail_task( + run_id=task_uuid, + job_name=get_job_name(task), + parent_job_name=dag.dag_id, + parent_run_id=parent_run_id, + end_time=DagUtils.to_iso_8601(end_date), + task=task_metadata, ) - self.emitter.emit(browse_path_v2_event) - - if dag.dag_id == _DATAHUB_CLEANUP_DAG: - assert self.graph - - logger.debug("Initiating the cleanup of obsolete data from datahub") - - # get all ingested dataflow and datajob - ingested_dataflow_urns = list( - self.graph.get_urns_by_filter( - platform="airflow", - entity_types=["dataFlow"], - platform_instance=self.config.platform_instance, - ) - ) - ingested_datajob_urns = list( - self.graph.get_urns_by_filter( - platform="airflow", - entity_types=["dataJob"], - platform_instance=self.config.platform_instance, - ) - ) - - # filter the ingested dataflow and datajob based on the cluster - filtered_ingested_dataflow_urns: List = [] - filtered_ingested_datajob_urns: List = [] - - for ingested_dataflow_urn in ingested_dataflow_urns: - data_flow_aspect = self.graph.get_aspect( - entity_urn=ingested_dataflow_urn, aspect_type=DataFlowKeyClass - ) - if ( - data_flow_aspect is not None - and data_flow_aspect.flowId != _DATAHUB_CLEANUP_DAG - and data_flow_aspect is not None - and data_flow_aspect.cluster == self.config.cluster - ): - filtered_ingested_dataflow_urns.append(ingested_dataflow_urn) - - for ingested_datajob_urn in ingested_datajob_urns: - data_job_aspect = self.graph.get_aspect( - entity_urn=ingested_datajob_urn, aspect_type=DataJobKeyClass - ) - if ( - data_job_aspect is not None - and data_job_aspect.flow in filtered_ingested_dataflow_urns - ): - filtered_ingested_datajob_urns.append(ingested_datajob_urn) - - # get all airflow dags - all_airflow_dags = SerializedDagModel.read_all_dags().values() - - airflow_flow_urns: List = [] - airflow_job_urns: List = [] - - for dag in all_airflow_dags: - flow_urn = builder.make_data_flow_urn( - orchestrator="airflow", - flow_id=dag.dag_id, - cluster=self.config.cluster, - platform_instance=self.config.platform_instance, - ) - airflow_flow_urns.append(flow_urn) - - for task in dag.tasks: - airflow_job_urns.append( - builder.make_data_job_urn_with_flow(str(flow_urn), task.task_id) - ) - - obsolete_pipelines = set(filtered_ingested_dataflow_urns) - set( - airflow_flow_urns - ) - obsolete_tasks = set(filtered_ingested_datajob_urns) - set(airflow_job_urns) - - obsolete_urns = obsolete_pipelines.union(obsolete_tasks) - - asyncio.run(self._soft_delete_obsolete_urns(obsolete_urns=obsolete_urns)) - - logger.debug(f"total pipelines removed = {len(obsolete_pipelines)}") - logger.debug(f"total tasks removed = {len(obsolete_tasks)}") - - @hookimpl - @run_in_thread - def on_dag_run_running(self, dag_run: "DagRun", msg: str) -> None: - if self.check_kill_switch(): - return - - self._set_log_level() - - logger.debug( - f"DataHub listener got notification about dag run start for {dag_run.dag_id}" - ) - - assert dag_run.dag_id - if not self.config.dag_filter_pattern.allowed(dag_run.dag_id): - logger.debug(f"DAG {dag_run.dag_id} is not allowed by the pattern") - return - - self.on_dag_start(dag_run) - self.emitter.flush() - - # TODO: Add hooks for on_dag_run_success, on_dag_run_failed -> call AirflowGenerator.complete_dataflow - - if HAS_AIRFLOW_DATASET_LISTENER_API: - - @hookimpl - @run_in_thread - def on_dataset_created(self, dataset: "Dataset") -> None: - self._set_log_level() - - logger.debug( - f"DataHub listener got notification about dataset create for {dataset}" - ) - - @hookimpl - @run_in_thread - def on_dataset_changed(self, dataset: "Dataset") -> None: - self._set_log_level() - - logger.debug( - f"DataHub listener got notification about dataset change for {dataset}" - ) - - async def _soft_delete_obsolete_urns(self, obsolete_urns): - delete_tasks = [self._delete_obsolete_data(urn) for urn in obsolete_urns] - await asyncio.gather(*delete_tasks) - - async def _delete_obsolete_data(self, obsolete_urn): - assert self.graph - if self.graph.exists(str(obsolete_urn)): - self.graph.soft_delete_entity(str(obsolete_urn)) + execute(on_failure) + + +@hookimpl +def on_starting(component): + global executor + executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="openlineage_") + + +@hookimpl +def before_stopping(component): + if executor: + # stom accepting new events + executor.shutdown(wait=False) + # block until all pending events are processed + adapter.close() + + +@hookimpl +def on_dag_run_running(dag_run: "DagRun", msg: str): + if not executor: + log.error("Executor have not started before `on_dag_run_running`") + return + start, end = get_dagrun_start_end(dag_run, dag_run.dag) + executor.submit( + adapter.dag_started, + dag_run=dag_run, + msg=msg, + nominal_start_time=DagUtils.get_start_time(start), + nominal_end_time=DagUtils.to_iso_8601(end), + ) + + +@hookimpl +def on_dag_run_success(dag_run: "DagRun", msg: str): + if not executor: + log.error("Executor have not started before `on_dag_run_success`") + return + executor.submit(adapter.dag_success, dag_run=dag_run, msg=msg) + + +@hookimpl +def on_dag_run_failed(dag_run: "DagRun", msg: str): + if not executor: + log.error("Executor have not started before `on_dag_run_failed`") + return + executor.submit(adapter.dag_failed, dag_run=dag_run, msg=msg) From 8f90c235c1f5f3f4b4420830736102382d353b9c Mon Sep 17 00:00:00 2001 From: hkr <104939283+harishkesavarao@users.noreply.github.com> Date: Sun, 10 Aug 2025 20:12:31 +0530 Subject: [PATCH 04/11] feat(on_task_instance_running): WIP with new plugin version --- .../datahub_listener.py | 1159 +++++++++++++---- 1 file changed, 926 insertions(+), 233 deletions(-) diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py index 960ebb65dbc84..7c801a5bb9147 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py @@ -1,281 +1,974 @@ -# Copyright 2018-2025 contributors to the OpenLineage project -# SPDX-License-Identifier: Apache-2.0 - +import asyncio import copy +import functools import logging +import os +from datetime import datetime import threading -from concurrent.futures import Executor, ThreadPoolExecutor -from typing import TYPE_CHECKING, Callable, Optional - -from openlineage.airflow.adapter import OpenLineageAdapter -from openlineage.airflow.extractors import ExtractorManager -from openlineage.airflow.utils import ( - DagUtils, +import time +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast + +import airflow +from airflow.models import Variable +from airflow.models.operator import Operator +from airflow.models.serialized_dag import SerializedDagModel +from airflow.utils.state import TaskInstanceState +from airflow.utils.timeout import timeout +from airflow.utils import timezone +from airflow.settings improt configure_orm +from airflow.stats import Stats +# TODO: to change to Airflow plugin +# from openlineage.airflow.listener import TaskHolder +# Ref: https://github.com/apache/airflow/blob/main/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py +from airflow.providers.openlineage.plugins.listener import get_openlineage_listener +# TODO: to change to Airflow plugin +# from openlineage.airflow.utils import redact_with_exclusions +from airflow.providers.openlineage.utils.utils import ( + AIRFLOW_V_3_0_PLUS, + get_airflow_dag_run_facet, + get_airflow_debug_facet, + get_airflow_job_facet, + get_airflow_mapped_task_facet, get_airflow_run_facet, - get_custom_facets, - get_dagrun_start_end, get_job_name, - get_task_location, - getboolean, - is_airflow_version_enough, + get_task_parent_run_facet, + get_task_documentation, + get_user_provided_run_facets, + is_operator_disabled, + is_selective_lineage_enabled, + print_warning, +) +from airflow.providers.openlineage import conf +from airflow.providers.openlineage.extractors.manager import ExtractorManager +from airflow.providers.openlineage.extractors.base import OperatorLineage +from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter, RunState +# TODO: to change to Airflow plugin +# from openlineage.client.serde import Serde +from airflow.providers.openlineage.client.serde import Serde + +import datahub.emitter.mce_builder as builder +from datahub.api.entities.datajob import DataJob +from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult +from datahub.emitter.mce_builder import ( + make_data_platform_urn, + make_dataplatform_instance_urn, +) +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.emitter.rest_emitter import DatahubRestEmitter +from datahub.ingestion.graph.client import DataHubGraph +from datahub.metadata.schema_classes import ( + BrowsePathEntryClass, + BrowsePathsV2Class, + DataFlowKeyClass, + DataJobKeyClass, + DataPlatformInstanceClass, + FineGrainedLineageClass, + FineGrainedLineageDownstreamTypeClass, + FineGrainedLineageUpstreamTypeClass, + OperationClass, + OperationTypeClass, + StatusClass, +) +from datahub.sql_parsing.sqlglot_lineage import SqlParsingResult +from datahub.telemetry import telemetry +from datahub_airflow_plugin._airflow_shims import ( + HAS_AIRFLOW_DAG_LISTENER_API, + HAS_AIRFLOW_DATASET_LISTENER_API, + get_task_inlets, + get_task_outlets, +) +from datahub_airflow_plugin._config import DatahubLineageConfig, get_lineage_config +from datahub_airflow_plugin._datahub_ol_adapter import translate_ol_to_datahub_urn +from datahub_airflow_plugin._extractors import SQL_PARSING_RESULT_KEY, ExtractorManager +from datahub_airflow_plugin._version import __package_name__, __version__ +from datahub_airflow_plugin.client.airflow_generator import AirflowGenerator +from datahub_airflow_plugin.entities import ( + _Entity, + entities_to_datajob_urn_list, + entities_to_dataset_urn_list, ) -from airflow.listeners import hookimpl -from airflow.utils import timezone - +_F = TypeVar("_F", bound=Callable[..., None]) if TYPE_CHECKING: - from sqlalchemy.orm import Session + from airflow.datasets import Dataset + from airflow.models import DAG, DagRun, TaskInstance + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + from airflow.settings import Session - from airflow.models import BaseOperator, DagRun, TaskInstance + # To placate mypy on Airflow versions that don't have the listener API, + # we define a dummy hookimpl that's an identity function. + def hookimpl(f: _F) -> _F: # type: ignore[misc] + return f -class TaskHolder: - """Class that stores run data - run_id and task in-memory. This is needed because Airflow - does not always pass all runtime info to on_task_instance_success and - on_task_instance_failed that is needed to emit events. This is not a big problem since - we're only running on worker - in separate process that is always spawned (or forked) on - execution, just like old PHP runtime model. - """ +else: + from airflow.listeners import hookimpl - def __init__(self): - self.run_data = {} +logger = logging.getLogger(__name__) - def set_task(self, task_instance: "TaskInstance"): - self.run_data[self._pk(task_instance)] = task_instance.task +_airflow_listener_initialized = False +_airflow_listener: Optional["DataHubListener"] = None +_RUN_IN_THREAD = os.getenv("DATAHUB_AIRFLOW_PLUGIN_RUN_IN_THREAD", "true").lower() in ( + "true", + "1", +) +_RUN_IN_THREAD_TIMEOUT = float( + os.getenv("DATAHUB_AIRFLOW_PLUGIN_RUN_IN_THREAD_TIMEOUT", 10) +) +_DATAHUB_CLEANUP_DAG = "Datahub_Cleanup" - def get_task(self, task_instance: "TaskInstance") -> Optional["BaseOperator"]: - return self.run_data.get(self._pk(task_instance)) +KILL_SWITCH_VARIABLE_NAME = "datahub_airflow_plugin_disable_listener" - @staticmethod - def _pk(ti: "TaskInstance"): - return ti.dag_id + ti.task_id + ti.run_id +# verified - airflow openlineage dependencies +def get_airflow_plugin_listener() -> Optional["DataHubListener"]: + # Using globals instead of functools.lru_cache to make testing easier. + global _airflow_listener_initialized + global _airflow_listener + if not _airflow_listener_initialized: + _airflow_listener_initialized = True -log = logging.getLogger(__name__) -# TODO: move task instance runs to executor -executor: Optional[Executor] = None + plugin_config = get_lineage_config() + if plugin_config.enabled: + _airflow_listener = DataHubListener(config=plugin_config) + logger.info( + f"DataHub plugin v2 (package: {__package_name__} and version: {__version__}) listener initialized with config: {plugin_config}" + ) + telemetry.telemetry_instance.ping( + "airflow-plugin-init", + { + "airflow-version": airflow.__version__, + "datahub-airflow-plugin": "v2", + "datahub-airflow-plugin-dag-events": HAS_AIRFLOW_DAG_LISTENER_API, + "datahub-airflow-plugin-dataset-events": HAS_AIRFLOW_DATASET_LISTENER_API, + "capture_executions": plugin_config.capture_executions, + "capture_tags": plugin_config.capture_tags_info, + "capture_ownership": plugin_config.capture_ownership_info, + "enable_extractors": plugin_config.enable_extractors, + "render_templates": plugin_config.render_templates, + "disable_openlineage_plugin": plugin_config.disable_openlineage_plugin, + }, + ) + + if plugin_config.disable_openlineage_plugin: + # Deactivate the OpenLineagePlugin listener to avoid conflicts/errors. + from airflow.providers.openlineage.plugins.openlineage import OpenLineageProviderPlugin -def execute_in_thread(target: Callable, kwargs=None): - if kwargs is None: - kwargs = {} - thread = threading.Thread(target=target, kwargs=kwargs, daemon=True) - thread.start() + OpenLineageProviderPlugin.listeners = [] - # Join, but ignore checking if thread stopped. If it did, then we shouldn't do anything. - # This basically gives this thread 5 seconds to complete work, then it can be killed, - # as daemon=True. We don't want to deadlock Airflow if our code hangs. + return _airflow_listener - # This will hang if this timeouts, and extractor is running non-daemon thread inside, - # since it will never be cleaned up. Ex. SnowflakeOperator - thread.join(timeout=10) +# verified - airflow openlineage dependencies +def run_in_thread(f: _F) -> _F: + # This is also responsible for catching exceptions and logging them. + @functools.wraps(f) + def wrapper(*args, **kwargs): + try: + if _RUN_IN_THREAD: + # A poor-man's timeout mechanism. + # This ensures that we don't hang the task if the extractors + # are slow or the DataHub API is slow to respond. + + thread = threading.Thread( + target=f, args=args, kwargs=kwargs, daemon=True + ) + thread.start() + + if _RUN_IN_THREAD_TIMEOUT > 0: + # If _RUN_IN_THREAD_TIMEOUT is 0, we just kick off the thread and move on. + # Because it's a daemon thread, it'll be automatically killed when the main + # thread exits. + + start_time = time.time() + thread.join(timeout=_RUN_IN_THREAD_TIMEOUT) + if thread.is_alive(): + logger.warning( + f"Thread for {f.__name__} is still running after {_RUN_IN_THREAD_TIMEOUT} seconds. " + "Continuing without waiting for it to finish." + ) + else: + logger.debug( + f"Thread for {f.__name__} finished after {time.time() - start_time} seconds" + ) + else: + f(*args, **kwargs) + except Exception as e: + logger.warning(e, exc_info=True) + + return cast(_F, wrapper) + +# verified - airflow openlineage dependencies +def _render_templates(task_instance: "TaskInstance") -> "TaskInstance": + # Render templates in a copy of the task instance. + # This is necessary to get the correct operator args in the extractors. + try: + task_instance_copy = copy.deepcopy(task_instance) + task_instance_copy.render_templates() + return task_instance_copy + except Exception as e: + logger.info( + f"Error rendering templates in DataHub listener. Jinja-templated variables will not be extracted correctly: {e}. Template rendering improves SQL parsing accuracy. If this causes issues, you can disable it by setting `render_templates` to `false` in the DataHub plugin configuration." + ) + return task_instance + + +class DataHubListener: + __name__ = "DataHubListener" + + # verified - airflow openlineage dependencies + def __init__(self, config: DatahubLineageConfig): + self.config = config + self._set_log_level() + + self._emitter = config.make_emitter_hook().make_emitter() + self._graph: Optional[DataHubGraph] = None + logger.info(f"DataHub plugin v2 using {repr(self._emitter)}") + + # See discussion here https://github.com/OpenLineage/OpenLineage/pull/508 for + # why we need to keep track of tasks ourselves. + self._open_lineage_listener = get_openlineage_listener() + + # In our case, we also want to cache the initial datajob object + # so that we can add to it when the task completes. + self._datajob_holder: Dict[str, DataJob] = {} + + self.extractor_manager = ExtractorManager() + + # This "inherits" from types.ModuleType to avoid issues with Airflow's listener plugin loader. + # It previously (v2.4.x and likely other versions too) would throw errors if it was not a module. + # https://github.com/apache/airflow/blob/e99a518970b2d349a75b1647f6b738c8510fa40e/airflow/listeners/listener.py#L56 + # self.__class__ = types.ModuleType + + # verified - airflow openlineage dependencies + @property + def emitter(self): + return self._emitter + + # verified - airflow openlineage dependencies + @property + def graph(self) -> Optional[DataHubGraph]: + if self._graph: + return self._graph + + if isinstance(self._emitter, DatahubRestEmitter) and not isinstance( + self._emitter, DataHubGraph + ): + # This is lazy initialized to avoid throwing errors on plugin load. + self._graph = self._emitter.to_graph() + self._emitter = self._graph + + return self._graph + + # verified - airflow openlineage dependencies + def _set_log_level(self) -> None: + """Set the log level for the plugin and its dependencies. + + This may need to be called multiple times, since Airflow sometimes + messes with the logging configuration after the plugin is loaded. + In particular, the loggers may get changed when the worker starts + executing a task. + """ + + if self.config.log_level: + logging.getLogger(__name__.split(".")[0]).setLevel(self.config.log_level) + if self.config.debug_emitter: + logging.getLogger("datahub.emitter").setLevel(logging.DEBUG) + + # verified - airflow openlineage dependencies + def _make_emit_callback(self) -> Callable[[Optional[Exception], str], None]: + def emit_callback(err: Optional[Exception], msg: str) -> None: + if err: + logger.error(f"Error sending metadata to datahub: {msg}", exc_info=err) + + return emit_callback + + # verified - airflow openlineage dependencies + def _extract_lineage( + self, + datajob: DataJob, + dagrun: "DagRun", + task: "Operator", + task_instance: "TaskInstance", + complete: bool = False, + ) -> None: + """ + Combine lineage (including column lineage) from task inlets/outlets and + extractor-generated task_metadata and write it to the datajob. This + routine is also responsible for converting the lineage to DataHub URNs. + """ + + if not self.config.enable_datajob_lineage: + return -task_holder = TaskHolder() -extractor_manager = ExtractorManager() -adapter = OpenLineageAdapter() + input_urns: List[str] = [] + output_urns: List[str] = [] + fine_grained_lineages: List[FineGrainedLineageClass] = [] + + task_metadata = None + if self.config.enable_extractors: + task_metadata = self.extractor_manager.extract_metadata( + dagrun, + task, + complete=complete, + task_instance=task_instance, + task_uuid=str(datajob.urn), + graph=self.graph, + ) + logger.debug(f"Got task metadata: {task_metadata}") + # Translate task_metadata.inputs/outputs to DataHub URNs. + input_urns.extend( + translate_ol_to_datahub_urn(dataset) for dataset in task_metadata.inputs + ) + output_urns.extend( + translate_ol_to_datahub_urn(dataset) + for dataset in task_metadata.outputs + ) -def direct_execution(): - return is_airflow_version_enough("2.6.0") or getboolean( - "OPENLINEAGE_AIRFLOW_ENABLE_DIRECT_EXECUTION", False - ) + # Add DataHub-native SQL parser results. + sql_parsing_result: Optional[SqlParsingResult] = None + if task_metadata: + sql_parsing_result = task_metadata.run_facets.pop( + SQL_PARSING_RESULT_KEY, None + ) + if sql_parsing_result: + if error := sql_parsing_result.debug_info.error: + logger.info(f"SQL parsing error: {error}", exc_info=error) + datajob.properties["datahub_sql_parser_error"] = ( + f"{type(error).__name__}: {error}" + ) + if not sql_parsing_result.debug_info.table_error: + input_urns.extend(sql_parsing_result.in_tables) + output_urns.extend(sql_parsing_result.out_tables) + + if sql_parsing_result.column_lineage: + fine_grained_lineages.extend( + FineGrainedLineageClass( + upstreamType=FineGrainedLineageUpstreamTypeClass.FIELD_SET, + downstreamType=FineGrainedLineageDownstreamTypeClass.FIELD, + upstreams=[ + builder.make_schema_field_urn( + upstream.table, upstream.column + ) + for upstream in column_lineage.upstreams + ], + downstreams=[ + builder.make_schema_field_urn( + downstream.table, downstream.column + ) + for downstream in [column_lineage.downstream] + if downstream.table + ], + ) + for column_lineage in sql_parsing_result.column_lineage + ) + + # Add DataHub-native inlets/outlets. + # These are filtered out by the extractor, so we need to add them manually. + input_urns.extend( + iolet.urn for iolet in get_task_inlets(task) if isinstance(iolet, _Entity) + ) + output_urns.extend( + iolet.urn for iolet in get_task_outlets(task) if isinstance(iolet, _Entity) + ) + # Write the lineage to the datajob object. + datajob.inlets.extend(entities_to_dataset_urn_list(input_urns)) + datajob.outlets.extend(entities_to_dataset_urn_list(output_urns)) + datajob.upstream_urns.extend(entities_to_datajob_urn_list(input_urns)) + datajob.fine_grained_lineages.extend(fine_grained_lineages) -def execute(_callable): - try: - if direct_execution(): - _callable() + # Merge in extra stuff that was present in the DataJob we constructed + # at the start of the task. + if complete: + original_datajob = self._datajob_holder.get(str(datajob.urn), None) else: - execute_in_thread(_callable) - except Exception: - # Make sure we're not failing task, even for things we think can't happen - log.exception("Failed to emit OpenLineage event due to exception") - - -@hookimpl -def on_task_instance_running(previous_state, task_instance: "TaskInstance", session: "Session"): - if not hasattr(task_instance, "task"): - log.warning( - "No task set for TI object task_id: %s - dag_id: %s - run_id %s", - task_instance.task_id, - task_instance.dag_id, - task_instance.run_id, + self._datajob_holder[str(datajob.urn)] = datajob + original_datajob = None + + if original_datajob: + logger.debug("Merging start datajob into finish datajob") + datajob.inlets.extend(original_datajob.inlets) + datajob.outlets.extend(original_datajob.outlets) + datajob.upstream_urns.extend(original_datajob.upstream_urns) + datajob.fine_grained_lineages.extend(original_datajob.fine_grained_lineages) + + for k, v in original_datajob.properties.items(): + datajob.properties.setdefault(k, v) + + # Deduplicate inlets/outlets. + datajob.inlets = list(sorted(set(datajob.inlets), key=lambda x: str(x))) + datajob.outlets = list(sorted(set(datajob.outlets), key=lambda x: str(x))) + datajob.upstream_urns = list( + sorted(set(datajob.upstream_urns), key=lambda x: str(x)) ) - return - log.debug("OpenLineage listener got notification about task instance start") - dagrun = task_instance.dag_run + # Write all other OL facets as DataHub properties. + if task_metadata: + for k, v in task_metadata.job_facets.items(): + datajob.properties[f"openlineage_job_facet_{k}"] = Serde.to_json( + redact_with_exclusions(v) + ) + + for k, v in task_metadata.run_facets.items(): + datajob.properties[f"openlineage_run_facet_{k}"] = Serde.to_json( + redact_with_exclusions(v) + ) + + # verified - airflow openlineage dependencies + def check_kill_switch(self): + if Variable.get(KILL_SWITCH_VARIABLE_NAME, "false").lower() == "true": + logger.debug("DataHub listener disabled by kill switch") + return True + return False + + if AIRFLOW_V_3_0_PLUS: + + @hookimpl + @run_in_thread + def on_task_instance_running(self, previous_state: TaskInstanceState, task_instance: "TaskInstance", session: "Session") -> None: + self.log.debug("DataHub listener got notification about task instance start") + context = task_instance.get_template_context() + task = context["task"] + + if TYPE_CHECKING: + assert task + dagrun = context["dag_run"] + dag = context["dag"] + start_date = task_instance.start_date + self._on_task_instance_running(task_instance, dag, dagrun, task, start_date) + else: + + @hookimpl + @run_in_thread + def on_task_instance_running(self, previous_state: TaskInstance, task_instance: "TaskInstance", session: "Session") -> None: + from airflow.providers.openlineage.utils.utils import is_ti_rescheduled_already + + if not getattr(task_instance, "task", None) is not None: + logger.warning( + "No task set for TI object task_id: %s - dag_id: %s - run_id %s", + task_instance.task_id, + task_instance.dag_id, + task_instance.run_id, + ) + return + + logger.debug("OpenLineage listener got notification about task instance start") + task = task_instance.task + if TYPE_CHECKING: + assert task + start_date = task_instance.start_date if task_instance.start_date else timezone.utcnow() + + if is_ti_rescheduled_already(task_instance): + self.log.debug("Skipping this instance of rescheduled task - START event was emitted already") + return + self._on_task_instance_running(task_instance, task.dag, task_instance.dag_run, task, start_date) + + def _on_task_instance_running(self, task_instance: RuntimeTaskInstance | TaskInstance, dag, dagrun, task, start_date: datetime): + if is_operator_disabled(task): + self.log.debug("Skipping OpenLineage event emission for operator `%s` " + "due to its presence in [openlineage] disabled_for_operators.", + task.task_type, + ) + return - def on_running(): - nonlocal task_instance - try: - ti = copy.deepcopy(task_instance) - except Exception as err: - log.debug( - f"Creating a task instance copy failed; proceeding without rendering templates. Error: {err}" + if not is_selective_lineage_enabled(task): + self.log.debug( + "Skipping OpenLineage event emission for task `%s` " + "due to lack of explicit lineage enablement for task or DAG while " + "[openlineage] selective_enable is on.", + task_instance.task_id, + ) + return + + # Needs to be calculated outside of inner method so that it gets cached for usage in fork processes + debug_facet = get_airflow_debug_facet() + + @print_warning(self.log) + def on_running(): + context = task_instance.get_template_context() + if hasattr(context, "task_reschedule_count") and context["task_reschedule_count"] > 0: + self.log.debug("Skipping this instance of rescheduled task - START event was emitted already") + return + + date = dagrun.logical_date + if AIRFLOW_V_3_0_PLUS and date is None: + date = dagrun.run_after + + clear_number = 0 + if hasattr(dagrun, "clear_number"): + clear_number = dagrun.clear_number + + parent_run_id = self.adapter.build_dag_run_id( + dag_id=task_instance.dag_id, + logical_date=date, + clear_number=clear_number, + ) + + task_uuid = self.adapter.build_task_instance_run_id( + dag_id=task_instance.dag_id, + task_id=task_instance.task_id, + try_number=task_instance.try_number, + logical_date=date, + map_index=task_instance.map_index, + ) + event_type = RunState.RUNNING.value.lower() + operator_name = task.task_type.lower() + + data_interval_start = dagrun.data_interval_start + if isinstance(data_interval_start, datetime): + data_interval_start = data_interval_start.isoformat() + data_interval_end = dagrun.data_interval_end + if isinstance(data_interval_end, datetime): + data_interval_end = data_interval_end.isoformat() + + doc, doc_type = get_task_documentation(task) + if not doc: + doc, doc_type = get_dag_documentation(dag) + + with Stats.timer(f"ol.extract.{event_type}.{operator_name}"): + task_metadata = self.extractor_manager.extract_metadata( + dagrun=dagrun, task=task, task_instance_state=TaskInstanceState.RUNNING + ) + + redacted_event = self.adapter.start_task( + run_id=task_uuid, + job_name=get_job_name(task_instance), + job_description=doc, + job_description_type=doc_type, + event_time=start_date.isoformat(), + nominal_start_time=data_interval_start, + nominal_end_time=data_interval_end, + # If task owner is default ("airflow"), use DAG owner instead that may have more details + owners=[x.strip() for x in (task if task.owner != "airflow" else dag).owner.split(",")], + tags=dag.tags, + task=task_metadata, + run_facets={ + **get_task_parent_run_facet(parent_run_id=parent_run_id, parent_job_name=dag.dag_id), + **get_user_provided_run_facets(task_instance, TaskInstanceState.RUNNING), + **get_airflow_mapped_task_facet(task_instance), + **get_airflow_run_facet(dagrun, dag, task_instance, task, task_uuid), + **debug_facet, + }, + ) + Stats.gauge( + f"ol.event.size.{event_type}.{operator_name}", + len(Serde.to_json(redacted_event).encode("utf-8")), + ) + + self._execute(on_running, "on_running", use_fork=True) + + + @hookimpl + @run_in_thread + def on_task_instance_running( + self, + previous_state: None, + task_instance: "TaskInstance", + session: "Session", # This will always be QUEUED + ) -> None: + if self.check_kill_switch(): + return + self._set_log_level() + + # This if statement mirrors the logic in https://github.com/OpenLineage/OpenLineage/pull/508. + if not hasattr(task_instance, "task"): + # The type ignore is to placate mypy on Airflow 2.1.x. + logger.warning( + f"No task set for task_id: {task_instance.task_id} - " # type: ignore[attr-defined] + f"dag_id: {task_instance.dag_id} - run_id {task_instance.run_id}" # type: ignore[attr-defined] + ) + return + + logger.debug( + f"DataHub listener got notification about task instance start for {task_instance.task_id} of dag {task_instance.dag_id}" + ) + + if not self.config.dag_filter_pattern.allowed(task_instance.dag_id): + logger.debug(f"DAG {task_instance.dag_id} is not allowed by the pattern") + return + + if self.config.render_templates: + task_instance = _render_templates(task_instance) + + # The type ignore is to placate mypy on Airflow 2.1.x. + dagrun: "DagRun" = task_instance.dag_run # type: ignore[attr-defined] + task = task_instance.task + assert task is not None + dag: "DAG" = task.dag # type: ignore[assignment] + + # TODO: Replace with airflow openlineage plugin equivalent + # TODO: Implement each step for Airflow < 3.0 and Airflow >= 3.0 + self._task_holder.set_task(task_instance) + + # Handle async operators in Airflow 2.3 by skipping deferred state. + # Inspired by https://github.com/OpenLineage/OpenLineage/pull/1601 + if task_instance.next_method is not None: # type: ignore[attr-defined] + return + + datajob = AirflowGenerator.generate_datajob( + cluster=self.config.cluster, + task=task, + dag=dag, + capture_tags=self.config.capture_tags_info, + capture_owner=self.config.capture_ownership_info, + config=self.config, + ) + + # TODO: Make use of get_task_location to extract github urls. + + # Add lineage info. + self._extract_lineage(datajob, dagrun, task, task_instance) + + # TODO: Add handling for Airflow mapped tasks using task_instance.map_index + + for mcp in datajob.generate_mcp( + generate_lineage=self.config.enable_datajob_lineage, + materialize_iolets=self.config.materialize_iolets, + ): + self.emitter.emit(mcp, self._make_emit_callback()) + logger.debug(f"Emitted DataHub Datajob start: {datajob}") + + if self.config.capture_executions: + dpi = AirflowGenerator.run_datajob( + emitter=self.emitter, + config=self.config, + ti=task_instance, + dag=dag, + dag_run=dagrun, + datajob=datajob, + emit_templates=False, ) - ti = task_instance + logger.debug(f"Emitted DataHub DataProcess Instance start: {dpi}") + + self.emitter.flush() + + logger.debug( + f"DataHub listener finished processing notification about task instance start for {task_instance.task_id}" + ) + + self.materialize_iolets(datajob) + + def materialize_iolets(self, datajob: DataJob) -> None: + if self.config.materialize_iolets: + for outlet in datajob.outlets: + reported_time: int = int(time.time() * 1000) + operation = OperationClass( + timestampMillis=reported_time, + operationType=OperationTypeClass.CREATE, + lastUpdatedTimestamp=reported_time, + actor=builder.make_user_urn("airflow"), + ) + + operation_mcp = MetadataChangeProposalWrapper( + entityUrn=str(outlet), aspect=operation + ) + + self.emitter.emit(operation_mcp) + logger.debug(f"Emitted Dataset Operation: {outlet}") else: - ti.render_templates() - - task = ti.task - dag = task.dag - task_holder.set_task(ti) - # that's a workaround to detect task running from deferred state - # we return here because Airflow 2.3 needs task from deferred state - if ti.next_method is not None: + if self.graph: + for outlet in datajob.outlets: + if not self.graph.exists(str(outlet)): + logger.warning(f"Dataset {str(outlet)} not materialized") + for inlet in datajob.inlets: + if not self.graph.exists(str(inlet)): + logger.warning(f"Dataset {str(inlet)} not materialized") + + def on_task_instance_finish( + self, task_instance: "TaskInstance", status: InstanceRunResult + ) -> None: + dagrun: "DagRun" = task_instance.dag_run # type: ignore[attr-defined] + + if self.config.render_templates: + task_instance = _render_templates(task_instance) + + # We must prefer the task attribute, in case modifications to the task's inlets/outlets + # were made by the execute() method. + if getattr(task_instance, "task", None): + task = task_instance.task + else: + task = self._task_holder.get_task(task_instance) + assert task is not None + + dag: "DAG" = task.dag # type: ignore[assignment] + + if not self.config.dag_filter_pattern.allowed(dag.dag_id): + logger.debug(f"DAG {dag.dag_id} is not allowed by the pattern") return - parent_run_id = OpenLineageAdapter.build_dag_run_id( - dag_id=dag.dag_id, - execution_date=dagrun.execution_date, + datajob = AirflowGenerator.generate_datajob( + cluster=self.config.cluster, + task=task, + dag=dag, + capture_tags=self.config.capture_tags_info, + capture_owner=self.config.capture_ownership_info, + config=self.config, ) - task_uuid = OpenLineageAdapter.build_task_instance_run_id( - dag_id=dag.dag_id, - task_id=task.task_id, - try_number=ti._try_number, - execution_date=ti.execution_date, + + # Add lineage info. + self._extract_lineage(datajob, dagrun, task, task_instance, complete=True) + + for mcp in datajob.generate_mcp( + generate_lineage=self.config.enable_datajob_lineage, + materialize_iolets=self.config.materialize_iolets, + ): + self.emitter.emit(mcp, self._make_emit_callback()) + logger.debug(f"Emitted DataHub Datajob finish w/ status {status}: {datajob}") + + if self.config.capture_executions: + dpi = AirflowGenerator.complete_datajob( + emitter=self.emitter, + cluster=self.config.cluster, + ti=task_instance, + dag=dag, + dag_run=dagrun, + datajob=datajob, + result=status, + config=self.config, + ) + logger.debug( + f"Emitted DataHub DataProcess Instance with status {status}: {dpi}" + ) + + self.emitter.flush() + + @hookimpl + @run_in_thread + def on_task_instance_success( + self, previous_state: None, task_instance: "TaskInstance", session: "Session" + ) -> None: + if self.check_kill_switch(): + return + + self._set_log_level() + + logger.debug( + f"DataHub listener got notification about task instance success for {task_instance.task_id}" + ) + self.on_task_instance_finish(task_instance, status=InstanceRunResult.SUCCESS) + logger.debug( + f"DataHub listener finished processing task instance success for {task_instance.task_id}" ) - task_metadata = extractor_manager.extract_metadata(dagrun, task, task_uuid=task_uuid) - - ti_start_time = ti.start_date if ti.start_date else timezone.utcnow() - start, end = get_dagrun_start_end(dagrun=dagrun, dag=dag) - - adapter.start_task( - run_id=task_uuid, - job_name=get_job_name(task), - job_description=dag.description, - event_time=DagUtils.get_start_time(ti_start_time), - parent_job_name=dag.dag_id, - parent_run_id=parent_run_id, - code_location=get_task_location(task), - nominal_start_time=DagUtils.get_start_time(start), - nominal_end_time=DagUtils.to_iso_8601(end), - owners=dag.owner.split(", "), - task=task_metadata, - run_facets={ - **get_custom_facets(dagrun, task, dagrun.external_trigger, ti), - **get_airflow_run_facet(dagrun, dag, ti, task, task_uuid), - }, + @hookimpl + @run_in_thread + def on_task_instance_failed( + self, previous_state: None, task_instance: "TaskInstance", session: "Session" + ) -> None: + if self.check_kill_switch(): + return + + self._set_log_level() + + logger.debug( + f"DataHub listener got notification about task instance failure for {task_instance.task_id}" ) - execute(on_running) - - -@hookimpl -def on_task_instance_success(previous_state, task_instance: "TaskInstance", session): - log.debug("OpenLineage listener got notification about task instance success") - task = task_holder.get_task(task_instance) or task_instance.task - dag = task.dag - dagrun = task_instance.dag_run - - parent_run_id = OpenLineageAdapter.build_dag_run_id( - dag_id=dag.dag_id, - execution_date=dagrun.execution_date, - ) - task_uuid = OpenLineageAdapter.build_task_instance_run_id( - dag_id=dag.dag_id, - task_id=task.task_id, - try_number=task_instance._try_number, - execution_date=task_instance.execution_date, - ) - - def on_success(): - task_metadata = extractor_manager.extract_metadata( - dagrun, task, complete=True, task_instance=task_instance + # TODO: Handle UP_FOR_RETRY state. + self.on_task_instance_finish(task_instance, status=InstanceRunResult.FAILURE) + logger.debug( + f"DataHub listener finished processing task instance failure for {task_instance.task_id}" ) - adapter.complete_task( - run_id=task_uuid, - job_name=get_job_name(task), - parent_job_name=dag.dag_id, - parent_run_id=parent_run_id, - end_time=DagUtils.to_iso_8601(task_instance.end_date), - task=task_metadata, + + def on_dag_start(self, dag_run: "DagRun") -> None: + dag = dag_run.dag + if not dag: + logger.warning( + f"DataHub listener could not find DAG for {dag_run.dag_id} - {dag_run.run_id}. Dag won't be captured" + ) + return + + dataflow = AirflowGenerator.generate_dataflow( + config=self.config, + dag=dag, ) + dataflow.emit(self.emitter, callback=self._make_emit_callback()) + logger.debug(f"Emitted DataHub DataFlow: {dataflow}") - execute(on_success) - - -@hookimpl -def on_task_instance_failed(previous_state, task_instance: "TaskInstance", session): - log.debug("OpenLineage listener got notification about task instance failure") - task = task_holder.get_task(task_instance) or task_instance.task - dag = task.dag - dagrun = task_instance.dag_run - - parent_run_id = OpenLineageAdapter.build_dag_run_id( - dag_id=dag.dag_id, - execution_date=dagrun.execution_date, - ) - task_uuid = OpenLineageAdapter.build_task_instance_run_id( - dag_id=dag.dag_id, - task_id=task.task_id, - try_number=task_instance._try_number, - execution_date=task_instance.execution_date, - ) - - def on_failure(): - task_metadata = extractor_manager.extract_metadata( - dagrun, task, complete=True, task_instance=task_instance + event: MetadataChangeProposalWrapper = MetadataChangeProposalWrapper( + entityUrn=str(dataflow.urn), aspect=StatusClass(removed=False) ) + self.emitter.emit(event) - end_date = task_instance.end_date if task_instance.end_date else timezone.utcnow() + for task in dag.tasks: + task_urn = builder.make_data_job_urn_with_flow( + str(dataflow.urn), task.task_id + ) + event = MetadataChangeProposalWrapper( + entityUrn=task_urn, aspect=StatusClass(removed=False) + ) + self.emitter.emit(event) - adapter.fail_task( - run_id=task_uuid, - job_name=get_job_name(task), - parent_job_name=dag.dag_id, - parent_run_id=parent_run_id, - end_time=DagUtils.to_iso_8601(end_date), - task=task_metadata, + if self.config.platform_instance: + instance = make_dataplatform_instance_urn( + platform="airflow", + instance=self.config.platform_instance, + ) + event = MetadataChangeProposalWrapper( + entityUrn=str(dataflow.urn), + aspect=DataPlatformInstanceClass( + platform=make_data_platform_urn("airflow"), + instance=instance, + ), + ) + self.emitter.emit(event) + + # emit tags + for tag in dataflow.tags: + tag_urn = builder.make_tag_urn(tag) + + event = MetadataChangeProposalWrapper( + entityUrn=tag_urn, aspect=StatusClass(removed=False) + ) + self.emitter.emit(event) + + browsePaths: List[BrowsePathEntryClass] = [] + if self.config.platform_instance: + urn = make_dataplatform_instance_urn( + "airflow", self.config.platform_instance + ) + browsePaths.append(BrowsePathEntryClass(self.config.platform_instance, urn)) + browsePaths.append(BrowsePathEntryClass(str(dag.dag_id))) + browse_path_v2_event: MetadataChangeProposalWrapper = ( + MetadataChangeProposalWrapper( + entityUrn=str(dataflow.urn), + aspect=BrowsePathsV2Class( + path=browsePaths, + ), + ) ) + self.emitter.emit(browse_path_v2_event) + + if dag.dag_id == _DATAHUB_CLEANUP_DAG: + assert self.graph + + logger.debug("Initiating the cleanup of obsolete data from datahub") + + # get all ingested dataflow and datajob + ingested_dataflow_urns = list( + self.graph.get_urns_by_filter( + platform="airflow", + entity_types=["dataFlow"], + platform_instance=self.config.platform_instance, + ) + ) + ingested_datajob_urns = list( + self.graph.get_urns_by_filter( + platform="airflow", + entity_types=["dataJob"], + platform_instance=self.config.platform_instance, + ) + ) + + # filter the ingested dataflow and datajob based on the cluster + filtered_ingested_dataflow_urns: List = [] + filtered_ingested_datajob_urns: List = [] + + for ingested_dataflow_urn in ingested_dataflow_urns: + data_flow_aspect = self.graph.get_aspect( + entity_urn=ingested_dataflow_urn, aspect_type=DataFlowKeyClass + ) + if ( + data_flow_aspect is not None + and data_flow_aspect.flowId != _DATAHUB_CLEANUP_DAG + and data_flow_aspect is not None + and data_flow_aspect.cluster == self.config.cluster + ): + filtered_ingested_dataflow_urns.append(ingested_dataflow_urn) + + for ingested_datajob_urn in ingested_datajob_urns: + data_job_aspect = self.graph.get_aspect( + entity_urn=ingested_datajob_urn, aspect_type=DataJobKeyClass + ) + if ( + data_job_aspect is not None + and data_job_aspect.flow in filtered_ingested_dataflow_urns + ): + filtered_ingested_datajob_urns.append(ingested_datajob_urn) + + # get all airflow dags + all_airflow_dags = SerializedDagModel.read_all_dags().values() + + airflow_flow_urns: List = [] + airflow_job_urns: List = [] + + for dag in all_airflow_dags: + flow_urn = builder.make_data_flow_urn( + orchestrator="airflow", + flow_id=dag.dag_id, + cluster=self.config.cluster, + platform_instance=self.config.platform_instance, + ) + airflow_flow_urns.append(flow_urn) + + for task in dag.tasks: + airflow_job_urns.append( + builder.make_data_job_urn_with_flow(str(flow_urn), task.task_id) + ) + + obsolete_pipelines = set(filtered_ingested_dataflow_urns) - set( + airflow_flow_urns + ) + obsolete_tasks = set(filtered_ingested_datajob_urns) - set(airflow_job_urns) + + obsolete_urns = obsolete_pipelines.union(obsolete_tasks) + + asyncio.run(self._soft_delete_obsolete_urns(obsolete_urns=obsolete_urns)) + + logger.debug(f"total pipelines removed = {len(obsolete_pipelines)}") + logger.debug(f"total tasks removed = {len(obsolete_tasks)}") + + @hookimpl + @run_in_thread + def on_dag_run_running(self, dag_run: "DagRun", msg: str) -> None: + if self.check_kill_switch(): + return + + self._set_log_level() + + logger.debug( + f"DataHub listener got notification about dag run start for {dag_run.dag_id}" + ) + + assert dag_run.dag_id + if not self.config.dag_filter_pattern.allowed(dag_run.dag_id): + logger.debug(f"DAG {dag_run.dag_id} is not allowed by the pattern") + return + + self.on_dag_start(dag_run) + self.emitter.flush() + + # TODO: Add hooks for on_dag_run_success, on_dag_run_failed -> call AirflowGenerator.complete_dataflow + + if HAS_AIRFLOW_DATASET_LISTENER_API: + + @hookimpl + @run_in_thread + def on_dataset_created(self, dataset: "Dataset") -> None: + self._set_log_level() + + logger.debug( + f"DataHub listener got notification about dataset create for {dataset}" + ) + + @hookimpl + @run_in_thread + def on_dataset_changed(self, dataset: "Dataset") -> None: + self._set_log_level() + + logger.debug( + f"DataHub listener got notification about dataset change for {dataset}" + ) + + async def _soft_delete_obsolete_urns(self, obsolete_urns): + delete_tasks = [self._delete_obsolete_data(urn) for urn in obsolete_urns] + await asyncio.gather(*delete_tasks) + + async def _delete_obsolete_data(self, obsolete_urn): + assert self.graph - execute(on_failure) - - -@hookimpl -def on_starting(component): - global executor - executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="openlineage_") - - -@hookimpl -def before_stopping(component): - if executor: - # stom accepting new events - executor.shutdown(wait=False) - # block until all pending events are processed - adapter.close() - - -@hookimpl -def on_dag_run_running(dag_run: "DagRun", msg: str): - if not executor: - log.error("Executor have not started before `on_dag_run_running`") - return - start, end = get_dagrun_start_end(dag_run, dag_run.dag) - executor.submit( - adapter.dag_started, - dag_run=dag_run, - msg=msg, - nominal_start_time=DagUtils.get_start_time(start), - nominal_end_time=DagUtils.to_iso_8601(end), - ) - - -@hookimpl -def on_dag_run_success(dag_run: "DagRun", msg: str): - if not executor: - log.error("Executor have not started before `on_dag_run_success`") - return - executor.submit(adapter.dag_success, dag_run=dag_run, msg=msg) - - -@hookimpl -def on_dag_run_failed(dag_run: "DagRun", msg: str): - if not executor: - log.error("Executor have not started before `on_dag_run_failed`") - return - executor.submit(adapter.dag_failed, dag_run=dag_run, msg=msg) + if self.graph.exists(str(obsolete_urn)): + self.graph.soft_delete_entity(str(obsolete_urn)) From 63f1282e088ff8af4512ddd00f191091e3976c80 Mon Sep 17 00:00:00 2001 From: hkr <104939283+harishkesavarao@users.noreply.github.com> Date: Thu, 14 Aug 2025 00:32:06 +0530 Subject: [PATCH 05/11] feat(on_task_instance_running): from airflow plugin --- .../datahub_listener.py | 344 +++++++----------- 1 file changed, 140 insertions(+), 204 deletions(-) diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py index 7c801a5bb9147..503a2f66d6df3 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py @@ -42,9 +42,8 @@ from airflow.providers.openlineage.extractors.manager import ExtractorManager from airflow.providers.openlineage.extractors.base import OperatorLineage from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter, RunState -# TODO: to change to Airflow plugin -# from openlineage.client.serde import Serde -from airflow.providers.openlineage.client.serde import Serde +from openlineage.client.serde import Serde +# from airflow.providers.openlineage.client.serde import Serde import datahub.emitter.mce_builder as builder from datahub.api.entities.datajob import DataJob @@ -119,7 +118,6 @@ def hookimpl(f: _F) -> _F: # type: ignore[misc] KILL_SWITCH_VARIABLE_NAME = "datahub_airflow_plugin_disable_listener" -# verified - airflow openlineage dependencies def get_airflow_plugin_listener() -> Optional["DataHubListener"]: # Using globals instead of functools.lru_cache to make testing easier. global _airflow_listener_initialized @@ -159,7 +157,6 @@ def get_airflow_plugin_listener() -> Optional["DataHubListener"]: return _airflow_listener -# verified - airflow openlineage dependencies def run_in_thread(f: _F) -> _F: # This is also responsible for catching exceptions and logging them. @@ -199,8 +196,7 @@ def wrapper(*args, **kwargs): return cast(_F, wrapper) -# verified - airflow openlineage dependencies -def _render_templates(task_instance: "TaskInstance") -> "TaskInstance": +def _render_templates(task_instance: "TaskInstance" ) -> "TaskInstance": # Render templates in a copy of the task instance. # This is necessary to get the correct operator args in the extractors. try: @@ -217,7 +213,6 @@ def _render_templates(task_instance: "TaskInstance") -> "TaskInstance": class DataHubListener: __name__ = "DataHubListener" - # verified - airflow openlineage dependencies def __init__(self, config: DatahubLineageConfig): self.config = config self._set_log_level() @@ -241,12 +236,10 @@ def __init__(self, config: DatahubLineageConfig): # https://github.com/apache/airflow/blob/e99a518970b2d349a75b1647f6b738c8510fa40e/airflow/listeners/listener.py#L56 # self.__class__ = types.ModuleType - # verified - airflow openlineage dependencies @property def emitter(self): return self._emitter - # verified - airflow openlineage dependencies @property def graph(self) -> Optional[DataHubGraph]: if self._graph: @@ -261,7 +254,6 @@ def graph(self) -> Optional[DataHubGraph]: return self._graph - # verified - airflow openlineage dependencies def _set_log_level(self) -> None: """Set the log level for the plugin and its dependencies. @@ -276,7 +268,6 @@ def _set_log_level(self) -> None: if self.config.debug_emitter: logging.getLogger("datahub.emitter").setLevel(logging.DEBUG) - # verified - airflow openlineage dependencies def _make_emit_callback(self) -> Callable[[Optional[Exception], str], None]: def emit_callback(err: Optional[Exception], msg: str) -> None: if err: @@ -284,7 +275,6 @@ def emit_callback(err: Optional[Exception], msg: str) -> None: return emit_callback - # verified - airflow openlineage dependencies def _extract_lineage( self, datajob: DataJob, @@ -417,239 +407,185 @@ def _extract_lineage( redact_with_exclusions(v) ) - # verified - airflow openlineage dependencies def check_kill_switch(self): if Variable.get(KILL_SWITCH_VARIABLE_NAME, "false").lower() == "true": logger.debug("DataHub listener disabled by kill switch") return True return False - if AIRFLOW_V_3_0_PLUS: @hookimpl @run_in_thread - def on_task_instance_running(self, previous_state: TaskInstanceState, task_instance: "TaskInstance", session: "Session") -> None: - self.log.debug("DataHub listener got notification about task instance start") - context = task_instance.get_template_context() - task = context["task"] - - if TYPE_CHECKING: - assert task - dagrun = context["dag_run"] - dag = context["dag"] - start_date = task_instance.start_date - self._on_task_instance_running(task_instance, dag, dagrun, task, start_date) - else: - - @hookimpl - @run_in_thread - def on_task_instance_running(self, previous_state: TaskInstance, task_instance: "TaskInstance", session: "Session") -> None: - from airflow.providers.openlineage.utils.utils import is_ti_rescheduled_already - - if not getattr(task_instance, "task", None) is not None: - logger.warning( - "No task set for TI object task_id: %s - dag_id: %s - run_id %s", - task_instance.task_id, - task_instance.dag_id, - task_instance.run_id, - ) + def on_task_instance_running( + self, + previous_state: TaskInstanceState, + task_instance: RuntimeTaskInstance # This will always be QUEUED + ) -> None: + if self.check_kill_switch(): return + self._set_log_level() - logger.debug("OpenLineage listener got notification about task instance start") - task = task_instance.task - if TYPE_CHECKING: - assert task - start_date = task_instance.start_date if task_instance.start_date else timezone.utcnow() - - if is_ti_rescheduled_already(task_instance): - self.log.debug("Skipping this instance of rescheduled task - START event was emitted already") - return - self._on_task_instance_running(task_instance, task.dag, task_instance.dag_run, task, start_date) + # This if statement mirrors the logic in https://github.com/OpenLineage/OpenLineage/pull/508. - def _on_task_instance_running(self, task_instance: RuntimeTaskInstance | TaskInstance, dag, dagrun, task, start_date: datetime): - if is_operator_disabled(task): - self.log.debug("Skipping OpenLineage event emission for operator `%s` " - "due to its presence in [openlineage] disabled_for_operators.", - task.task_type, - ) - return - if not is_selective_lineage_enabled(task): - self.log.debug( - "Skipping OpenLineage event emission for task `%s` " - "due to lack of explicit lineage enablement for task or DAG while " - "[openlineage] selective_enable is on.", - task_instance.task_id, + logger.debug( + f"DataHub listener got notification about task instance start for {task_instance.task_id} of dag {task_instance.dag_id}" ) - return - # Needs to be calculated outside of inner method so that it gets cached for usage in fork processes - debug_facet = get_airflow_debug_facet() - - @print_warning(self.log) - def on_running(): - context = task_instance.get_template_context() - if hasattr(context, "task_reschedule_count") and context["task_reschedule_count"] > 0: - self.log.debug("Skipping this instance of rescheduled task - START event was emitted already") + if not self.config.dag_filter_pattern.allowed(task_instance.dag_id): + logger.debug(f"DAG {task_instance.dag_id} is not allowed by the pattern") return - date = dagrun.logical_date - if AIRFLOW_V_3_0_PLUS and date is None: - date = dagrun.run_after + if self.config.render_templates: + task_instance = _render_templates(task_instance) - clear_number = 0 - if hasattr(dagrun, "clear_number"): - clear_number = dagrun.clear_number + # The type ignore is to placate mypy on Airflow 2.1.x. + dagrun: "DagRun" = task_instance.dag_run # type: ignore[attr-defined] + task = task_instance.task + assert task is not None + dag: "DAG" = task.dag # type: ignore[assignment] + start_date = task_instance.start_date + # self._on_task_instance_running(task_instance, dag, dagrun, task, start_date) - parent_run_id = self.adapter.build_dag_run_id( - dag_id=task_instance.dag_id, - logical_date=date, - clear_number=clear_number, - ) - task_uuid = self.adapter.build_task_instance_run_id( - dag_id=task_instance.dag_id, - task_id=task_instance.task_id, - try_number=task_instance.try_number, - logical_date=date, - map_index=task_instance.map_index, - ) - event_type = RunState.RUNNING.value.lower() - operator_name = task.task_type.lower() - - data_interval_start = dagrun.data_interval_start - if isinstance(data_interval_start, datetime): - data_interval_start = data_interval_start.isoformat() - data_interval_end = dagrun.data_interval_end - if isinstance(data_interval_end, datetime): - data_interval_end = data_interval_end.isoformat() - - doc, doc_type = get_task_documentation(task) - if not doc: - doc, doc_type = get_dag_documentation(dag) - - with Stats.timer(f"ol.extract.{event_type}.{operator_name}"): - task_metadata = self.extractor_manager.extract_metadata( - dagrun=dagrun, task=task, task_instance_state=TaskInstanceState.RUNNING - ) + #TODO: Datahub specific methods below - redacted_event = self.adapter.start_task( - run_id=task_uuid, - job_name=get_job_name(task_instance), - job_description=doc, - job_description_type=doc_type, - event_time=start_date.isoformat(), - nominal_start_time=data_interval_start, - nominal_end_time=data_interval_end, - # If task owner is default ("airflow"), use DAG owner instead that may have more details - owners=[x.strip() for x in (task if task.owner != "airflow" else dag).owner.split(",")], - tags=dag.tags, - task=task_metadata, - run_facets={ - **get_task_parent_run_facet(parent_run_id=parent_run_id, parent_job_name=dag.dag_id), - **get_user_provided_run_facets(task_instance, TaskInstanceState.RUNNING), - **get_airflow_mapped_task_facet(task_instance), - **get_airflow_run_facet(dagrun, dag, task_instance, task, task_uuid), - **debug_facet, - }, - ) - Stats.gauge( - f"ol.event.size.{event_type}.{operator_name}", - len(Serde.to_json(redacted_event).encode("utf-8")), + datajob = AirflowGenerator.generate_datajob( + cluster=self.config.cluster, + task=task, + dag=dag, + capture_tags=self.config.capture_tags_info, + capture_owner=self.config.capture_ownership_info, + config=self.config, ) - self._execute(on_running, "on_running", use_fork=True) + # TODO: Make use of get_task_location to extract github urls. + + # Add lineage info. + self._extract_lineage(datajob, dagrun, task, task_instance) + + # TODO: Add handling for Airflow mapped tasks using task_instance.map_index + + for mcp in datajob.generate_mcp( + generate_lineage=self.config.enable_datajob_lineage, + materialize_iolets=self.config.materialize_iolets, + ): + self.emitter.emit(mcp, self._make_emit_callback()) + logger.debug(f"Emitted DataHub Datajob start: {datajob}") + + if self.config.capture_executions: + dpi = AirflowGenerator.run_datajob( + emitter=self.emitter, + config=self.config, + ti=task_instance, + dag=dag, + dag_run=dagrun, + datajob=datajob, + emit_templates=False, + ) + logger.debug(f"Emitted DataHub DataProcess Instance start: {dpi}") + self.emitter.flush() - @hookimpl - @run_in_thread - def on_task_instance_running( - self, - previous_state: None, - task_instance: "TaskInstance", - session: "Session", # This will always be QUEUED - ) -> None: - if self.check_kill_switch(): - return - self._set_log_level() - - # This if statement mirrors the logic in https://github.com/OpenLineage/OpenLineage/pull/508. - if not hasattr(task_instance, "task"): - # The type ignore is to placate mypy on Airflow 2.1.x. - logger.warning( - f"No task set for task_id: {task_instance.task_id} - " # type: ignore[attr-defined] - f"dag_id: {task_instance.dag_id} - run_id {task_instance.run_id}" # type: ignore[attr-defined] + logger.debug( + f"DataHub listener finished processing notification about task instance start for {task_instance.task_id}" ) - return - - logger.debug( - f"DataHub listener got notification about task instance start for {task_instance.task_id} of dag {task_instance.dag_id}" - ) - if not self.config.dag_filter_pattern.allowed(task_instance.dag_id): - logger.debug(f"DAG {task_instance.dag_id} is not allowed by the pattern") - return + self.materialize_iolets(datajob) + else: - if self.config.render_templates: - task_instance = _render_templates(task_instance) + @hookimpl + @run_in_thread + def on_task_instance_running( + self, + previous_state: None, + task_instance: "TaskInstance", + session: "Session", # This will always be QUEUED + ) -> None: + from airflow.providers.openlineage.utils.utils import is_ti_scheduled_already + + if self.check_kill_switch(): + return + self._set_log_level() - # The type ignore is to placate mypy on Airflow 2.1.x. - dagrun: "DagRun" = task_instance.dag_run # type: ignore[attr-defined] - task = task_instance.task - assert task is not None - dag: "DAG" = task.dag # type: ignore[assignment] + # This if statement mirrors the logic in https://github.com/OpenLineage/OpenLineage/pull/508. + if not hasattr(task_instance, "task"): + # The type ignore is to placate mypy on Airflow 2.1.x. + logger.warning( + f"No task set for task_id: {task_instance.task_id} - " # type: ignore[attr-defined] + f"dag_id: {task_instance.dag_id} - run_id {task_instance.run_id}" # type: ignore[attr-defined] + ) + return - # TODO: Replace with airflow openlineage plugin equivalent - # TODO: Implement each step for Airflow < 3.0 and Airflow >= 3.0 - self._task_holder.set_task(task_instance) + logger.debug( + f"DataHub listener got notification about task instance start for {task_instance.task_id} of dag {task_instance.dag_id}" + ) - # Handle async operators in Airflow 2.3 by skipping deferred state. - # Inspired by https://github.com/OpenLineage/OpenLineage/pull/1601 - if task_instance.next_method is not None: # type: ignore[attr-defined] - return + if not self.config.dag_filter_pattern.allowed(task_instance.dag_id): + logger.debug(f"DAG {task_instance.dag_id} is not allowed by the pattern") + return - datajob = AirflowGenerator.generate_datajob( - cluster=self.config.cluster, - task=task, - dag=dag, - capture_tags=self.config.capture_tags_info, - capture_owner=self.config.capture_ownership_info, - config=self.config, - ) + if self.config.render_templates: + task_instance = _render_templates(task_instance) - # TODO: Make use of get_task_location to extract github urls. + # The type ignore is to placate mypy on Airflow 2.1.x. + dagrun: "DagRun" = task_instance.dag_run # type: ignore[attr-defined] + task = task_instance.task + if TYPE_CHECKING: + assert task + start_date = task_instance.start_date if task_instance.start_date else timezone.utcnow() + dag: "DAG" = task.dag # type: ignore[assignment] - # Add lineage info. - self._extract_lineage(datajob, dagrun, task, task_instance) + if is_ti_scheduled_already(task_instance): + self.log.debug("Skipping this instance of rescheduled task - START event was emitted already") + return - # TODO: Add handling for Airflow mapped tasks using task_instance.map_index + # self._on_task_instance_running(task_instance, dag, dag_run, task, start_date) - for mcp in datajob.generate_mcp( - generate_lineage=self.config.enable_datajob_lineage, - materialize_iolets=self.config.materialize_iolets, - ): - self.emitter.emit(mcp, self._make_emit_callback()) - logger.debug(f"Emitted DataHub Datajob start: {datajob}") + # TODO: Datahub specific methods below - if self.config.capture_executions: - dpi = AirflowGenerator.run_datajob( - emitter=self.emitter, - config=self.config, - ti=task_instance, + datajob = AirflowGenerator.generate_datajob( + cluster=self.config.cluster, + task=task, dag=dag, - dag_run=dagrun, - datajob=datajob, - emit_templates=False, + capture_tags=self.config.capture_tags_info, + capture_owner=self.config.capture_ownership_info, + config=self.config, ) - logger.debug(f"Emitted DataHub DataProcess Instance start: {dpi}") - self.emitter.flush() + # TODO: Make use of get_task_location to extract github urls. + + # Add lineage info. + self._extract_lineage(datajob, dagrun, task, task_instance) + + # TODO: Add handling for Airflow mapped tasks using task_instance.map_index + + for mcp in datajob.generate_mcp( + generate_lineage=self.config.enable_datajob_lineage, + materialize_iolets=self.config.materialize_iolets, + ): + self.emitter.emit(mcp, self._make_emit_callback()) + logger.debug(f"Emitted DataHub Datajob start: {datajob}") + + if self.config.capture_executions: + dpi = AirflowGenerator.run_datajob( + emitter=self.emitter, + config=self.config, + ti=task_instance, + dag=dag, + dag_run=dagrun, + datajob=datajob, + emit_templates=False, + ) + logger.debug(f"Emitted DataHub DataProcess Instance start: {dpi}") - logger.debug( - f"DataHub listener finished processing notification about task instance start for {task_instance.task_id}" - ) + self.emitter.flush() + + logger.debug( + f"DataHub listener finished processing notification about task instance start for {task_instance.task_id}" + ) - self.materialize_iolets(datajob) + self.materialize_iolets(datajob) def materialize_iolets(self, datajob: DataJob) -> None: if self.config.materialize_iolets: From 93a5e224e62846f09badfbd8d5b522d40f214acc Mon Sep 17 00:00:00 2001 From: hkr <104939283+harishkesavarao@users.noreply.github.com> Date: Sun, 17 Aug 2025 16:45:50 +0530 Subject: [PATCH 06/11] Update datahub_listener.py --- .../src/datahub_airflow_plugin/datahub_listener.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py index 503a2f66d6df3..bc3dccfda58a7 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py @@ -39,11 +39,8 @@ print_warning, ) from airflow.providers.openlineage import conf -from airflow.providers.openlineage.extractors.manager import ExtractorManager -from airflow.providers.openlineage.extractors.base import OperatorLineage from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter, RunState from openlineage.client.serde import Serde -# from airflow.providers.openlineage.client.serde import Serde import datahub.emitter.mce_builder as builder from datahub.api.entities.datajob import DataJob @@ -445,10 +442,6 @@ def on_task_instance_running( assert task is not None dag: "DAG" = task.dag # type: ignore[assignment] start_date = task_instance.start_date - # self._on_task_instance_running(task_instance, dag, dagrun, task, start_date) - - - #TODO: Datahub specific methods below datajob = AirflowGenerator.generate_datajob( cluster=self.config.cluster, @@ -462,7 +455,7 @@ def on_task_instance_running( # TODO: Make use of get_task_location to extract github urls. # Add lineage info. - self._extract_lineage(datajob, dagrun, task, task_instance) + self._extract_lineage(datajob, dagrun, task, task_instance, start_date) # TODO: Add handling for Airflow mapped tasks using task_instance.map_index @@ -542,8 +535,6 @@ def on_task_instance_running( # self._on_task_instance_running(task_instance, dag, dag_run, task, start_date) - # TODO: Datahub specific methods below - datajob = AirflowGenerator.generate_datajob( cluster=self.config.cluster, task=task, From ac1e2d678beac5742a56a34a165924bf3429cac9 Mon Sep 17 00:00:00 2001 From: hkr <104939283+harishkesavarao@users.noreply.github.com> Date: Sun, 17 Aug 2025 16:46:14 +0530 Subject: [PATCH 07/11] Update _extractors.py --- .../src/datahub_airflow_plugin/_extractors.py | 42 +++++++++++-------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py index 2750773cfd732..2437acf2b6d21 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py @@ -4,13 +4,19 @@ from typing import TYPE_CHECKING, Optional from airflow.models.operator import Operator +from airflow.providers.openlineage.extractors.manager import ( + ExtractorManager as ExtractorManager, + BaseExtractor as BaseExtractor, + TaskMetadata as TaskMetadata +) +from airflow.providers.openlineage.extractors.base import OperatorLineage from openlineage.airflow.extractors import ( - BaseExtractor, + BaseExtractor as OLBaseExtractor, ExtractorManager as OLExtractorManager, - TaskMetadata, + TaskMetadata as OLTaskMetadata, ) -from openlineage.airflow.extractors.snowflake_extractor import SnowflakeExtractor -from openlineage.airflow.extractors.sql_extractor import SqlExtractor +from openlineage.airflow.extractors.snowflake_extractor import OLSnowflakeExtractor +from openlineage.airflow.extractors.sql_extractor import OLSqlExtractor from openlineage.airflow.utils import get_operator_class, try_import_from_string from openlineage.client.facet import ( ExtractionError, @@ -38,7 +44,7 @@ SQL_PARSING_RESULT_KEY = "datahub_sql" -class ExtractorManager(OLExtractorManager): +class ExtractorManager(ExtractorManager): # TODO: On Airflow 2.7, the OLExtractorManager is part of the built-in Airflow API. # When available, we should use that instead. The same goe for most of the OL # extractors. @@ -75,7 +81,7 @@ def _patch_extractors(self): # Patch the SqlExtractor.extract() method. stack.enter_context( unittest.mock.patch.object( - SqlExtractor, + OLSqlExtractor, "extract", _sql_extractor_extract, ) @@ -84,7 +90,7 @@ def _patch_extractors(self): # Patch the SnowflakeExtractor.default_schema property. stack.enter_context( unittest.mock.patch.object( - SnowflakeExtractor, + OLSnowflakeExtractor, "default_schema", property(_snowflake_default_schema), ) @@ -112,7 +118,7 @@ def extract_metadata( dagrun, task, complete, task_instance, task_uuid ) - def _get_extractor(self, task: "Operator") -> Optional[BaseExtractor]: + def _get_extractor(self, task: "Operator") -> Optional[OLBaseExtractor]: # By adding this, we can use the generic extractor as a fallback for # any operator that inherits from SQLExecuteQueryOperator. clazz = get_operator_class(task) @@ -130,7 +136,7 @@ def _get_extractor(self, task: "Operator") -> Optional[BaseExtractor]: return extractor -class GenericSqlExtractor(SqlExtractor): +class GenericSqlExtractor(OLSqlExtractor): # Note that the extract() method is patched elsewhere. @property @@ -158,7 +164,7 @@ def _get_database(self) -> Optional[str]: return None -def _sql_extractor_extract(self: "SqlExtractor") -> TaskMetadata: +def _sql_extractor_extract(self: "OLSqlExtractor") -> OLTaskMetadata: # Why not override the OL sql_parse method directly, instead of overriding # extract()? A few reasons: # @@ -198,7 +204,7 @@ def _sql_extractor_extract(self: "SqlExtractor") -> TaskMetadata: def _parse_sql_into_task_metadata( - self: "BaseExtractor", + self: "OLBaseExtractor", sql: str, platform: str, default_database: Optional[str], @@ -207,7 +213,7 @@ def _parse_sql_into_task_metadata( task_name = f"{self.operator.dag_id}.{self.operator.task_id}" run_facets = {} - job_facets = {"sql": SqlJobFacet(query=SqlExtractor._normalize_sql(sql))} + job_facets = {"sql": SqlJobFacet(query=OLSqlExtractor._normalize_sql(sql))} # Prepare to run the SQL parser. graph = self.context.get(_DATAHUB_GRAPH_CONTEXT_KEY, None) @@ -250,7 +256,7 @@ def _parse_sql_into_task_metadata( # facet dict in the extractor's processing logic. run_facets[SQL_PARSING_RESULT_KEY] = sql_parsing_result # type: ignore - return TaskMetadata( + return OLTaskMetadata( name=task_name, inputs=[], outputs=[], @@ -259,8 +265,8 @@ def _parse_sql_into_task_metadata( ) -class BigQueryInsertJobOperatorExtractor(BaseExtractor): - def extract(self) -> Optional[TaskMetadata]: +class BigQueryInsertJobOperatorExtractor(OLBaseExtractor): + def extract(self) -> Optional[OLTaskMetadata]: from airflow.providers.google.cloud.operators.bigquery import ( BigQueryInsertJobOperator, # type: ignore ) @@ -303,8 +309,8 @@ def extract(self) -> Optional[TaskMetadata]: return task_metadata -class AthenaOperatorExtractor(BaseExtractor): - def extract(self) -> Optional[TaskMetadata]: +class AthenaOperatorExtractor(OLBaseExtractor): + def extract(self) -> Optional[OLTaskMetadata]: from airflow.providers.amazon.aws.operators.athena import ( AthenaOperator, # type: ignore ) @@ -324,7 +330,7 @@ def extract(self) -> Optional[TaskMetadata]: ) -def _snowflake_default_schema(self: "SnowflakeExtractor") -> Optional[str]: +def _snowflake_default_schema(self: "OLSnowflakeExtractor") -> Optional[str]: if hasattr(self.operator, "schema") and self.operator.schema is not None: return self.operator.schema return ( From 8b2b6c99097aaeee19e665382369b2c716983e5b Mon Sep 17 00:00:00 2001 From: Harish Kesava Rao Date: Sun, 17 Aug 2025 19:23:41 +0530 Subject: [PATCH 08/11] feat(on_task_instance_running): changes to airflow v3.0 --- .../src/datahub_airflow_plugin/_extractors.py | 7 +- .../datahub_listener.py | 96 +++++++++++-------- 2 files changed, 57 insertions(+), 46 deletions(-) diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py index 2437acf2b6d21..d5c1cd11a2f50 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py @@ -5,14 +5,13 @@ from airflow.models.operator import Operator from airflow.providers.openlineage.extractors.manager import ( - ExtractorManager as ExtractorManager, BaseExtractor as BaseExtractor, - TaskMetadata as TaskMetadata + ExtractorManager as ExtractorManager, + TaskMetadata as TaskMetadata, ) -from airflow.providers.openlineage.extractors.base import OperatorLineage from openlineage.airflow.extractors import ( BaseExtractor as OLBaseExtractor, - ExtractorManager as OLExtractorManager, + # ExtractorManager as OLExtractorManager, TaskMetadata as OLTaskMetadata, ) from openlineage.airflow.extractors.snowflake_extractor import OLSnowflakeExtractor diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py index bc3dccfda58a7..e1b228390c777 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py @@ -3,7 +3,6 @@ import functools import logging import os -from datetime import datetime import threading import time from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast @@ -12,34 +11,17 @@ from airflow.models import Variable from airflow.models.operator import Operator from airflow.models.serialized_dag import SerializedDagModel -from airflow.utils.state import TaskInstanceState -from airflow.utils.timeout import timeout -from airflow.utils import timezone -from airflow.settings improt configure_orm -from airflow.stats import Stats -# TODO: to change to Airflow plugin -# from openlineage.airflow.listener import TaskHolder -# Ref: https://github.com/apache/airflow/blob/main/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py from airflow.providers.openlineage.plugins.listener import get_openlineage_listener + # TODO: to change to Airflow plugin # from openlineage.airflow.utils import redact_with_exclusions from airflow.providers.openlineage.utils.utils import ( AIRFLOW_V_3_0_PLUS, - get_airflow_dag_run_facet, - get_airflow_debug_facet, - get_airflow_job_facet, - get_airflow_mapped_task_facet, - get_airflow_run_facet, - get_job_name, - get_task_parent_run_facet, - get_task_documentation, - get_user_provided_run_facets, + OpenLineageRedactor, is_operator_disabled, is_selective_lineage_enabled, - print_warning, ) -from airflow.providers.openlineage import conf -from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter, RunState +from airflow.utils.state import TaskInstanceState from openlineage.client.serde import Serde import datahub.emitter.mce_builder as builder @@ -115,6 +97,7 @@ def hookimpl(f: _F) -> _F: # type: ignore[misc] KILL_SWITCH_VARIABLE_NAME = "datahub_airflow_plugin_disable_listener" + def get_airflow_plugin_listener() -> Optional["DataHubListener"]: # Using globals instead of functools.lru_cache to make testing easier. global _airflow_listener_initialized @@ -148,12 +131,15 @@ def get_airflow_plugin_listener() -> Optional["DataHubListener"]: if plugin_config.disable_openlineage_plugin: # Deactivate the OpenLineagePlugin listener to avoid conflicts/errors. - from airflow.providers.openlineage.plugins.openlineage import OpenLineageProviderPlugin + from airflow.providers.openlineage.plugins.openlineage import ( + OpenLineageProviderPlugin, + ) OpenLineageProviderPlugin.listeners = [] return _airflow_listener + def run_in_thread(f: _F) -> _F: # This is also responsible for catching exceptions and logging them. @@ -193,7 +179,8 @@ def wrapper(*args, **kwargs): return cast(_F, wrapper) -def _render_templates(task_instance: "TaskInstance" ) -> "TaskInstance": + +def _render_templates(task_instance: "TaskInstance") -> "TaskInstance": # Render templates in a copy of the task instance. # This is necessary to get the correct operator args in the extractors. try: @@ -221,6 +208,7 @@ def __init__(self, config: DatahubLineageConfig): # See discussion here https://github.com/OpenLineage/OpenLineage/pull/508 for # why we need to keep track of tasks ourselves. self._open_lineage_listener = get_openlineage_listener() + self.redact_with_exclusions = OpenLineageRedactor() # In our case, we also want to cache the initial datajob object # so that we can add to it when the task completes. @@ -396,12 +384,12 @@ def _extract_lineage( if task_metadata: for k, v in task_metadata.job_facets.items(): datajob.properties[f"openlineage_job_facet_{k}"] = Serde.to_json( - redact_with_exclusions(v) + self.redact_with_exclusions._redact(v) ) for k, v in task_metadata.run_facets.items(): datajob.properties[f"openlineage_run_facet_{k}"] = Serde.to_json( - redact_with_exclusions(v) + self.redact_with_exclusions._redact(v) ) def check_kill_switch(self): @@ -409,6 +397,21 @@ def check_kill_switch(self): logger.debug("DataHub listener disabled by kill switch") return True return False + + def _should_skip_task(self, task, task_instance): + # Mimic OpenLineageListener's operator and selective lineage checks + if is_operator_disabled(task): + logger.debug( + f"Skipping DataHub event emission for operator `{task.task_type}` due to its presence in disabled_for_operators." + ) + return True + if not is_selective_lineage_enabled(task): + logger.debug( + f"Skipping DataHub event emission for task `{task_instance.task_id}` due to lack of explicit lineage enablement for task or DAG while selective_enable is on." + ) + return True + return False + if AIRFLOW_V_3_0_PLUS: @hookimpl @@ -416,7 +419,7 @@ def check_kill_switch(self): def on_task_instance_running( self, previous_state: TaskInstanceState, - task_instance: RuntimeTaskInstance # This will always be QUEUED + task_instance: RuntimeTaskInstance, # This will always be QUEUED ) -> None: if self.check_kill_switch(): return @@ -424,13 +427,17 @@ def on_task_instance_running( # This if statement mirrors the logic in https://github.com/OpenLineage/OpenLineage/pull/508. - logger.debug( f"DataHub listener got notification about task instance start for {task_instance.task_id} of dag {task_instance.dag_id}" ) if not self.config.dag_filter_pattern.allowed(task_instance.dag_id): - logger.debug(f"DAG {task_instance.dag_id} is not allowed by the pattern") + logger.debug( + f"DAG {task_instance.dag_id} is not allowed by the pattern" + ) + return + + if self._should_skip_task(task_instance.task, task_instance): return if self.config.render_templates: @@ -441,7 +448,6 @@ def on_task_instance_running( task = task_instance.task assert task is not None dag: "DAG" = task.dag # type: ignore[assignment] - start_date = task_instance.start_date datajob = AirflowGenerator.generate_datajob( cluster=self.config.cluster, @@ -455,7 +461,7 @@ def on_task_instance_running( # TODO: Make use of get_task_location to extract github urls. # Add lineage info. - self._extract_lineage(datajob, dagrun, task, task_instance, start_date) + self._extract_lineage(datajob, dagrun, task, task_instance) # TODO: Add handling for Airflow mapped tasks using task_instance.map_index @@ -490,12 +496,14 @@ def on_task_instance_running( @hookimpl @run_in_thread def on_task_instance_running( - self, - previous_state: None, - task_instance: "TaskInstance", - session: "Session", # This will always be QUEUED + self, + previous_state: None, + task_instance: "TaskInstance", + session: "Session", # This will always be QUEUED ) -> None: - from airflow.providers.openlineage.utils.utils import is_ti_scheduled_already + from airflow.providers.openlineage.utils.utils import ( + is_ti_scheduled_already, + ) if self.check_kill_switch(): return @@ -515,7 +523,12 @@ def on_task_instance_running( ) if not self.config.dag_filter_pattern.allowed(task_instance.dag_id): - logger.debug(f"DAG {task_instance.dag_id} is not allowed by the pattern") + logger.debug( + f"DAG {task_instance.dag_id} is not allowed by the pattern" + ) + return + + if self._should_skip_task(task_instance.task, task_instance): return if self.config.render_templates: @@ -526,15 +539,14 @@ def on_task_instance_running( task = task_instance.task if TYPE_CHECKING: assert task - start_date = task_instance.start_date if task_instance.start_date else timezone.utcnow() dag: "DAG" = task.dag # type: ignore[assignment] if is_ti_scheduled_already(task_instance): - self.log.debug("Skipping this instance of rescheduled task - START event was emitted already") + self.log.debug( + "Skipping this instance of rescheduled task - START event was emitted already" + ) return - # self._on_task_instance_running(task_instance, dag, dag_run, task, start_date) - datajob = AirflowGenerator.generate_datajob( cluster=self.config.cluster, task=task, @@ -552,8 +564,8 @@ def on_task_instance_running( # TODO: Add handling for Airflow mapped tasks using task_instance.map_index for mcp in datajob.generate_mcp( - generate_lineage=self.config.enable_datajob_lineage, - materialize_iolets=self.config.materialize_iolets, + generate_lineage=self.config.enable_datajob_lineage, + materialize_iolets=self.config.materialize_iolets, ): self.emitter.emit(mcp, self._make_emit_callback()) logger.debug(f"Emitted DataHub Datajob start: {datajob}") From eca44718ef63d4d9e2fdbf11dfa19efa43e4820b Mon Sep 17 00:00:00 2001 From: Harish Kesava Rao Date: Sun, 17 Aug 2025 20:43:53 +0530 Subject: [PATCH 09/11] fix(taskmetadata): airflow plugin extract_metadata method returns operatorlineage objects, linted --- .../src/datahub_airflow_plugin/_extractors.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py index d5c1cd11a2f50..13fd11bb048ca 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py @@ -5,9 +5,8 @@ from airflow.models.operator import Operator from airflow.providers.openlineage.extractors.manager import ( - BaseExtractor as BaseExtractor, - ExtractorManager as ExtractorManager, - TaskMetadata as TaskMetadata, + ExtractorManager, + OperatorLineage, ) from openlineage.airflow.extractors import ( BaseExtractor as OLBaseExtractor, @@ -110,7 +109,7 @@ def extract_metadata( task_instance: Optional["TaskInstance"] = None, task_uuid: Optional[str] = None, graph: Optional["DataHubGraph"] = None, - ) -> TaskMetadata: + ) -> OperatorLineage: self._graph = graph with self._patch_extractors(): return super().extract_metadata( @@ -208,7 +207,7 @@ def _parse_sql_into_task_metadata( platform: str, default_database: Optional[str], default_schema: Optional[str], -) -> TaskMetadata: +) -> OLTaskMetadata: task_name = f"{self.operator.dag_id}.{self.operator.task_id}" run_facets = {} From d4ce493f844e2486ae8a9700fbd94a5a5eed0fdd Mon Sep 17 00:00:00 2001 From: hkr <104939283+harishkesavarao@users.noreply.github.com> Date: Tue, 26 Aug 2025 00:13:46 +0530 Subject: [PATCH 10/11] feat(extractor): changed imports, WIP sqlparser for databases --- .../src/datahub_airflow_plugin/_extractors.py | 55 ++++++++++--------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py index 13fd11bb048ca..f9cb8d8285a7d 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py @@ -3,25 +3,28 @@ import unittest.mock from typing import TYPE_CHECKING, Optional -from airflow.models.operator import Operator -from airflow.providers.openlineage.extractors.manager import ( - ExtractorManager, - OperatorLineage, +from airflow.providers.openlineage.extractors import ( + BaseExtractor as BaseExtractor, + OperatorLineage as OperatorLineage, ) -from openlineage.airflow.extractors import ( - BaseExtractor as OLBaseExtractor, - # ExtractorManager as OLExtractorManager, - TaskMetadata as OLTaskMetadata, -) -from openlineage.airflow.extractors.snowflake_extractor import OLSnowflakeExtractor -from openlineage.airflow.extractors.sql_extractor import OLSqlExtractor -from openlineage.airflow.utils import get_operator_class, try_import_from_string -from openlineage.client.facet import ( - ExtractionError, - ExtractionErrorRunFacet, - SqlJobFacet, -) - +from airflow.providers.openlineage.extractors.bash import BaseExtractor +from airflow.providers.openlineage.extractors.manager import ExtractorManager +from airflow.providers.openlineage.extractors.python import PythonExtractor +from airflow.providers.openlineage.utils.utils import try_import_from_string + +# from openlineage.airflow.extractors import ( +# BaseExtractor as OLBaseExtractor, +# # ExtractorManager as OLExtractorManager, +# TaskMetadata as OLTaskMetadata, +# ) +# from openlineage.airflow.extractors.snowflake_extractor import OLSnowflakeExtractor +# from openlineage.airflow.extractors.sql_extractor import OLSqlExtractor +# from openlineage.airflow.utils import get_operator_class, try_import_from_string +# from openlineage.client.facet import ( +# ExtractionError, +# ExtractionErrorRunFacet, +# SqlJobFacet, +# ) import datahub.emitter.mce_builder as builder from datahub.ingestion.source.sql.sqlalchemy_uri_mapper import ( get_platform_from_sqlalchemy_uri, @@ -34,6 +37,7 @@ if TYPE_CHECKING: from airflow.models import DagRun, TaskInstance + from airflow.models.operator import Operator from datahub.ingestion.graph.client import DataHubGraph @@ -41,12 +45,13 @@ _DATAHUB_GRAPH_CONTEXT_KEY = "datahub_graph" SQL_PARSING_RESULT_KEY = "datahub_sql" +def _iter_extractor_types() -> Iterator[type[BaseExtractor]]: + if PythonExtractor is not None: + yield PythonExtractor + if BaseExtractor is not None: + yield BaseExtractor class ExtractorManager(ExtractorManager): - # TODO: On Airflow 2.7, the OLExtractorManager is part of the built-in Airflow API. - # When available, we should use that instead. The same goe for most of the OL - # extractors. - def __init__(self): super().__init__() @@ -63,11 +68,11 @@ def __init__(self): "SqliteOperator", ] for operator in _sql_operator_overrides: - self.task_to_extractor.extractors[operator] = GenericSqlExtractor + self.extractors[operator_class] = GenericSqlExtractor - self.task_to_extractor.extractors["AthenaOperator"] = AthenaOperatorExtractor + self.extractors["AthenaOperator"] = AthenaOperatorExtractor - self.task_to_extractor.extractors["BigQueryInsertJobOperator"] = ( + self.extractors["BigQueryInsertJobOperator"] = ( BigQueryInsertJobOperatorExtractor ) From 98219b9cf6cb414dfa48b18b0369ca257c3db9cd Mon Sep 17 00:00:00 2001 From: hkr <104939283+harishkesavarao@users.noreply.github.com> Date: Wed, 27 Aug 2025 01:00:17 +0530 Subject: [PATCH 11/11] Changes to use SQLParser instead of SQLExtractor --- .../src/datahub_airflow_plugin/_extractors.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py index f9cb8d8285a7d..71d256f1cbc10 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_extractors.py @@ -11,6 +11,8 @@ from airflow.providers.openlineage.extractors.manager import ExtractorManager from airflow.providers.openlineage.extractors.python import PythonExtractor from airflow.providers.openlineage.utils.utils import try_import_from_string +from airflow.utils.state import TaskInstanceState +from airflow.providers.openlineage.sqlparser import SQLParser # from openlineage.airflow.extractors import ( # BaseExtractor as OLBaseExtractor, @@ -110,15 +112,14 @@ def extract_metadata( self, dagrun: "DagRun", task: "Operator", - complete: bool = False, - task_instance: Optional["TaskInstance"] = None, + task_instance_state: Optional["TaskInstanceState"] = None, task_uuid: Optional[str] = None, graph: Optional["DataHubGraph"] = None, ) -> OperatorLineage: self._graph = graph with self._patch_extractors(): return super().extract_metadata( - dagrun, task, complete, task_instance, task_uuid + dagrun, task, complete, task_instance_state, task_uuid ) def _get_extractor(self, task: "Operator") -> Optional[OLBaseExtractor]: @@ -139,7 +140,7 @@ def _get_extractor(self, task: "Operator") -> Optional[OLBaseExtractor]: return extractor -class GenericSqlExtractor(OLSqlExtractor): +class GenericSqlExtractor(SQLParser): # Note that the extract() method is patched elsewhere. @property