Skip to content

Commit a7b8c7f

Browse files
Make weight_rule independent of airflow-core priority_strategy (#62210)
* Make weight_rule independent of airflow-core priority_strategy * better typing for protocol --------- Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
1 parent f5d7a3a commit a7b8c7f

File tree

4 files changed

+32
-12
lines changed

4 files changed

+32
-12
lines changed

task-sdk/.pre-commit-config.yaml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ repos:
3333
exclude: |
3434
(?x)
3535
# TODO: These files need to be refactored to remove core coupling
36+
^src/airflow/sdk/bases/operator\.py$|
3637
^src/airflow/sdk/definitions/decorators/__init__\.pyi$|
3738
^src/airflow/sdk/definitions/decorators/setup_teardown\.py$|
3839
^src/airflow/sdk/definitions/asset/__init__\.py$|
@@ -41,13 +42,15 @@ repos:
4142
^src/airflow/sdk/definitions/mappedoperator\.py$|
4243
^src/airflow/sdk/definitions/deadline\.py$|
4344
^src/airflow/sdk/definitions/dag\.py$|
44-
^src/airflow/sdk/execution_time/execute_workload\.py$|
4545
^src/airflow/sdk/definitions/_internal/types\.py$|
46-
^src/airflow/sdk/serde/serializers/kubernetes\.py$|
47-
^src/airflow/sdk/execution_time/task_runner\.py$|
48-
^src/airflow/sdk/execution_time/supervisor\.py$|
46+
^src/airflow/sdk/execution_time/execute_workload\.py$|
4947
^src/airflow/sdk/execution_time/secrets_masker\.py$|
50-
^src/airflow/sdk/bases/operator\.py$
48+
^src/airflow/sdk/execution_time/supervisor\.py$|
49+
^src/airflow/sdk/execution_time/task_runner\.py$|
50+
^src/airflow/sdk/io/path.py$|
51+
^src/airflow/sdk/log.py$|
52+
^src/airflow/sdk/serde/serializers/kubernetes\.py$|
53+
^src/airflow/sdk/types.py$
5154
- id: check-init-decorator-arguments
5255
name: Sync model __init__ and decorator arguments
5356
language: python

task-sdk/src/airflow/sdk/bases/operator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def db_safe_priority(priority_weight: int) -> int:
9696
from airflow.sdk.definitions.operator_resources import Resources
9797
from airflow.sdk.definitions.taskgroup import TaskGroup
9898
from airflow.sdk.definitions.xcom_arg import XComArg
99-
from airflow.task.priority_strategy import PriorityWeightStrategy
99+
from airflow.sdk.types import WeightRuleParam
100100
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
101101

102102
TaskPreExecuteHook = Callable[[Context], None]
@@ -298,7 +298,7 @@ def partial(
298298
retry_delay: timedelta | float = ...,
299299
retry_exponential_backoff: float = ...,
300300
priority_weight: int = ...,
301-
weight_rule: str | PriorityWeightStrategy = ...,
301+
weight_rule: WeightRuleParam = ...,
302302
sla: timedelta | None = ...,
303303
map_index_template: str | None = ...,
304304
max_active_tis_per_dag: int | None = ...,
@@ -868,7 +868,7 @@ def say_hello_world(**context):
868868
params: ParamsDict | dict = field(default_factory=ParamsDict)
869869
default_args: dict | None = None
870870
priority_weight: int = DEFAULT_PRIORITY_WEIGHT
871-
weight_rule: PriorityWeightStrategy | str = field(default=DEFAULT_WEIGHT_RULE)
871+
weight_rule: WeightRuleParam = field(default=DEFAULT_WEIGHT_RULE)
872872
queue: str = DEFAULT_QUEUE
873873
pool: str = DEFAULT_POOL_NAME
874874
pool_slots: int = DEFAULT_POOL_SLOTS
@@ -1024,7 +1024,7 @@ def __init__(
10241024
params: collections.abc.MutableMapping[str, Any] | None = None,
10251025
default_args: dict | None = None,
10261026
priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
1027-
weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE,
1027+
weight_rule: WeightRuleParam = DEFAULT_WEIGHT_RULE,
10281028
queue: str = DEFAULT_QUEUE,
10291029
pool: str | None = None,
10301030
pool_slots: int = DEFAULT_POOL_SLOTS,

task-sdk/src/airflow/sdk/definitions/mappedoperator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
)
6767
from airflow.sdk.definitions.operator_resources import Resources
6868
from airflow.sdk.definitions.param import ParamsDict
69-
from airflow.task.priority_strategy import PriorityWeightStrategy
69+
from airflow.sdk.types import WeightRuleParam
7070
from airflow.triggers.base import StartTriggerArgs
7171

7272
ValidationSource = Literal["expand"] | Literal["partial"]
@@ -555,11 +555,11 @@ def priority_weight(self, value: int) -> None:
555555
self.partial_kwargs["priority_weight"] = value
556556

557557
@property
558-
def weight_rule(self) -> PriorityWeightStrategy:
558+
def weight_rule(self) -> WeightRuleParam:
559559
return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
560560

561561
@weight_rule.setter
562-
def weight_rule(self, value: str | PriorityWeightStrategy) -> None:
562+
def weight_rule(self, value: WeightRuleParam) -> None:
563563
self.partial_kwargs["weight_rule"] = value
564564

565565
@property

task-sdk/src/airflow/sdk/types.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from collections.abc import Iterable
2222
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypeAlias
2323

24+
from airflow.sdk.api.datamodels._generated import WeightRule
2425
from airflow.sdk.bases.xcom import BaseXCom
2526
from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
2627

@@ -29,6 +30,7 @@
2930

3031
from pydantic import AwareDatetime, JsonValue
3132

33+
from airflow.models.taskinstance import TaskInstance as SchedulerTaskInstance
3234
from airflow.sdk._shared.logging.types import Logger as Logger
3335
from airflow.sdk.api.datamodels._generated import PreviousTIResponse, TaskInstanceState
3436
from airflow.sdk.bases.operator import BaseOperator
@@ -39,6 +41,21 @@
3941
Operator: TypeAlias = BaseOperator | MappedOperator
4042

4143

44+
class WeightRuleProtocol(Protocol):
45+
"""
46+
Protocol for custom weight strategy instances.
47+
48+
Matches objects that implement get_weight(ti).
49+
"""
50+
51+
def get_weight(self, ti: SchedulerTaskInstance) -> int:
52+
"""Return the priority weight for the task instance."""
53+
...
54+
55+
56+
WeightRuleParam: TypeAlias = str | WeightRule | WeightRuleProtocol
57+
58+
4259
class TaskInstanceKey(NamedTuple):
4360
"""Key used to identify task instance."""
4461

0 commit comments

Comments
 (0)