Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion metadata-ingestion-modules/airflow-plugin/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_long_description():
# We remain restrictive on the versions allowed here to prevent
# us from being broken by backwards-incompatible changes in the
# underlying package.
"openlineage-airflow>=1.2.0,<=1.30.1",
"apache-airflow-providers-openlineage>=1.1.0,<2.5.0",
},
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
from typing import TYPE_CHECKING, Optional

from airflow.models.operator import Operator
from airflow.providers.openlineage.extractors.manager import (
ExtractorManager,
OperatorLineage,
)
from openlineage.airflow.extractors import (
BaseExtractor,
ExtractorManager as OLExtractorManager,
TaskMetadata,
BaseExtractor as OLBaseExtractor,
# ExtractorManager as OLExtractorManager,
TaskMetadata as OLTaskMetadata,
)
from openlineage.airflow.extractors.snowflake_extractor import SnowflakeExtractor
from openlineage.airflow.extractors.sql_extractor import SqlExtractor
from openlineage.airflow.extractors.snowflake_extractor import OLSnowflakeExtractor
from openlineage.airflow.extractors.sql_extractor import OLSqlExtractor
from openlineage.airflow.utils import get_operator_class, try_import_from_string
from openlineage.client.facet import (
ExtractionError,
Expand Down Expand Up @@ -38,7 +42,7 @@
SQL_PARSING_RESULT_KEY = "datahub_sql"


class ExtractorManager(OLExtractorManager):
class ExtractorManager(ExtractorManager):
# TODO: On Airflow 2.7, the OLExtractorManager is part of the built-in Airflow API.
# When available, we should use that instead. The same goe for most of the OL
# extractors.
Expand Down Expand Up @@ -75,7 +79,7 @@ def _patch_extractors(self):
# Patch the SqlExtractor.extract() method.
stack.enter_context(
unittest.mock.patch.object(
SqlExtractor,
OLSqlExtractor,
"extract",
_sql_extractor_extract,
)
Expand All @@ -84,7 +88,7 @@ def _patch_extractors(self):
# Patch the SnowflakeExtractor.default_schema property.
stack.enter_context(
unittest.mock.patch.object(
SnowflakeExtractor,
OLSnowflakeExtractor,
"default_schema",
property(_snowflake_default_schema),
)
Expand All @@ -105,14 +109,14 @@ def extract_metadata(
task_instance: Optional["TaskInstance"] = None,
task_uuid: Optional[str] = None,
graph: Optional["DataHubGraph"] = None,
) -> TaskMetadata:
) -> OperatorLineage:
self._graph = graph
with self._patch_extractors():
return super().extract_metadata(
dagrun, task, complete, task_instance, task_uuid
)

def _get_extractor(self, task: "Operator") -> Optional[BaseExtractor]:
def _get_extractor(self, task: "Operator") -> Optional[OLBaseExtractor]:
# By adding this, we can use the generic extractor as a fallback for
# any operator that inherits from SQLExecuteQueryOperator.
clazz = get_operator_class(task)
Expand All @@ -130,7 +134,7 @@ def _get_extractor(self, task: "Operator") -> Optional[BaseExtractor]:
return extractor


class GenericSqlExtractor(SqlExtractor):
class GenericSqlExtractor(OLSqlExtractor):
# Note that the extract() method is patched elsewhere.

@property
Expand Down Expand Up @@ -158,7 +162,7 @@ def _get_database(self) -> Optional[str]:
return None


def _sql_extractor_extract(self: "SqlExtractor") -> TaskMetadata:
def _sql_extractor_extract(self: "OLSqlExtractor") -> OLTaskMetadata:
# Why not override the OL sql_parse method directly, instead of overriding
# extract()? A few reasons:
#
Expand Down Expand Up @@ -198,16 +202,16 @@ def _sql_extractor_extract(self: "SqlExtractor") -> TaskMetadata:


def _parse_sql_into_task_metadata(
self: "BaseExtractor",
self: "OLBaseExtractor",
sql: str,
platform: str,
default_database: Optional[str],
default_schema: Optional[str],
) -> TaskMetadata:
) -> OLTaskMetadata:
task_name = f"{self.operator.dag_id}.{self.operator.task_id}"

run_facets = {}
job_facets = {"sql": SqlJobFacet(query=SqlExtractor._normalize_sql(sql))}
job_facets = {"sql": SqlJobFacet(query=OLSqlExtractor._normalize_sql(sql))}

# Prepare to run the SQL parser.
graph = self.context.get(_DATAHUB_GRAPH_CONTEXT_KEY, None)
Expand Down Expand Up @@ -250,7 +254,7 @@ def _parse_sql_into_task_metadata(
# facet dict in the extractor's processing logic.
run_facets[SQL_PARSING_RESULT_KEY] = sql_parsing_result # type: ignore

return TaskMetadata(
return OLTaskMetadata(
name=task_name,
inputs=[],
outputs=[],
Expand All @@ -259,8 +263,8 @@ def _parse_sql_into_task_metadata(
)


class BigQueryInsertJobOperatorExtractor(BaseExtractor):
def extract(self) -> Optional[TaskMetadata]:
class BigQueryInsertJobOperatorExtractor(OLBaseExtractor):
def extract(self) -> Optional[OLTaskMetadata]:
from airflow.providers.google.cloud.operators.bigquery import (
BigQueryInsertJobOperator, # type: ignore
)
Expand Down Expand Up @@ -303,8 +307,8 @@ def extract(self) -> Optional[TaskMetadata]:
return task_metadata


class AthenaOperatorExtractor(BaseExtractor):
def extract(self) -> Optional[TaskMetadata]:
class AthenaOperatorExtractor(OLBaseExtractor):
def extract(self) -> Optional[OLTaskMetadata]:
from airflow.providers.amazon.aws.operators.athena import (
AthenaOperator, # type: ignore
)
Expand All @@ -324,7 +328,7 @@ def extract(self) -> Optional[TaskMetadata]:
)


def _snowflake_default_schema(self: "SnowflakeExtractor") -> Optional[str]:
def _snowflake_default_schema(self: "OLSnowflakeExtractor") -> Optional[str]:
if hasattr(self.operator, "schema") and self.operator.schema is not None:
return self.operator.schema
return (
Expand Down
Loading
Loading