diff --git a/bigframes/core/expression_factoring.py b/bigframes/core/expression_factoring.py index b58330f5a4..a3467c5b82 100644 --- a/bigframes/core/expression_factoring.py +++ b/bigframes/core/expression_factoring.py @@ -25,11 +25,13 @@ Hashable, Iterable, Iterator, + Literal, Mapping, Optional, Sequence, Tuple, TypeVar, + Union, ) from bigframes.core import ( @@ -38,12 +40,158 @@ graphs, identifiers, nodes, + subquery_expression, window_spec, ) +import bigframes.core.ordered_sets as sets _MAX_INLINE_COMPLEXITY = 10 T = TypeVar("T") +ExprDomain = Union[window_spec.WindowSpec, Literal["Scalar", "Other"]] + + +class ExpressionGraph(graphs.DiGraph[nodes.ColumnDef]): + def __init__(self, column_defs: Sequence[nodes.ColumnDef]): + # Assumption: All column defs have unique ids + expr_ids = set(cdef.id for cdef in column_defs) + self._graph = graphs.DiGraph( + (expr.id for expr in column_defs), + ( + (expr.id, child_id) + for expr in column_defs + for child_id in expr.expression.column_references + if child_id in expr_ids + ), + ) + self._id_to_cdef = {cdef.id: cdef for cdef in column_defs} + + # TODO: Also prevent inlining expensive or non-deterministic + # We avoid inlining multi-parent ids, as they would be inlined multiple places, potentially increasing work and/or compiled text size + self._multi_parent_ids = set( + id + for id in self._graph.graph_nodes + if len(list(self._graph.parents(id))) > 2 + ) + self._free_ids_by_domain: dict[ + ExprDomain, sets.InsertionOrderedSet[identifiers.ColumnId] + ] = collections.defaultdict(sets.InsertionOrderedSet) + + for id in self._graph.graph_nodes: + if len(list(self._graph.children(id))) == 0: + self._mark_free(id) + + @property + def graph_nodes(self) -> Iterable[nodes.ColumnDef]: + # should be the same set of ids as self._parents + return map(self._id_to_cdef.__getitem__, self._graph.graph_nodes) + + @property + def empty(self): + return self._graph.empty + + def __len__(self): + return len(self._graph) + + def parents(self, node: nodes.ColumnDef) -> Iterator[nodes.ColumnDef]: + yield from map(self._id_to_cdef.__getitem__, self._graph.parents(node.id)) + + def children(self, node: nodes.ColumnDef) -> Iterator[nodes.ColumnDef]: + yield from map(self._id_to_cdef.__getitem__, self._graph.children(node.id)) + + def _expr_domain(self, expr: expression.Expression) -> ExprDomain: + if expr.is_scalar_expr: + return "Scalar" + elif isinstance(expr, agg_expressions.WindowExpression): + return expr.window + elif isinstance(expr, subquery_expression.SubqueryExpression): + return "Other" + else: + raise ValueError(f"unrecognized expression {expr}") + + def _mark_free(self, id: identifiers.ColumnId): + cdef = self._id_to_cdef[id] + expr = cdef.expression + # If this expands further, probably generalize a compatibility key + self._free_ids_by_domain[self._expr_domain(expr)].add(id) + + def _remove_free_mark(self, id: identifiers.ColumnId): + cdef = self._id_to_cdef[id] + expr = cdef.expression + # If this expands further, probably generalize a compatibility key + if id in self._free_ids_by_domain[self._expr_domain(expr)]: + self._free_ids_by_domain[self._expr_domain(expr)].remove(id) + + def remove_node(self, node: nodes.ColumnDef) -> None: + for child in self._children[node]: + self._parents[child].remove(node) + for parent in self._parents[node]: + self._children[parent].remove(node) + if len(self._children[parent]) == 0: + self._mark_free(parent.id) + del self._children[node] + del self._parents[node] + self._remove_free_mark(node.id) + + def extract_scalar_exprs(self) -> Sequence[nodes.ColumnDef]: + results: dict[identifiers.ColumnId, expression.Expression] = dict() + while ( + True + ): # Will converge as each loop either reduces graph size, or fails to find any candidate and breaks + candidate_ids = list( + id + for id in self._free_ids_by_domain["Scalar"] + if not any( + ( + child in self._multi_parent_ids + and id in results.keys() + and not is_simple(results[id]) + ) + for child in self._graph.children(id) + ) + ) + if len(candidate_ids) == 0: + break + for id in candidate_ids: + self._graph.remove_node(id) + new_exprs = { + id: self._id_to_cdef[id].expression.bind_refs( + results, allow_partial_bindings=True + ) + } + results.update(new_exprs) + # TODO: We can prune expressions that won't be reused here, + return tuple(nodes.ColumnDef(expr, id) for id, expr in results.items()) + + def extract_window_expr( + self, + ) -> Optional[Tuple[Sequence[nodes.ColumnDef], window_spec.WindowSpec]]: + window = next( + ( + domain + for domain in self._free_ids_by_domain + if domain not in ["Scalar", "Other"] + ), + None, + ) + assert not isinstance(window, str) + if window: + window_expr_ids = self._free_ids_by_domain[window] + window_exprs = (self._id_to_cdef[id] for id in window_expr_ids) + agg_exprs = tuple( + nodes.ColumnDef( + cast( + agg_expressions.WindowExpression, cdef.expression + ).analytic_expr, + cdef.id, + ) + for cdef in window_exprs + ) + for cdef in window_exprs: + self.remove_node(cdef) + return (agg_exprs, window) + + return None def unique_nodes( @@ -324,98 +472,17 @@ def push_into_tree( target_ids: Sequence[identifiers.ColumnId], ) -> nodes.BigFrameNode: curr_root = root - by_id = {expr.id: expr for expr in exprs} # id -> id - graph = graphs.DiGraph( - (expr.id for expr in exprs), - ( - (expr.id, child_id) - for expr in exprs - for child_id in expr.expression.column_references - if child_id in by_id.keys() - ), - ) - # TODO: Also prevent inlining expensive or non-deterministic - # We avoid inlining multi-parent ids, as they would be inlined multiple places, potentially increasing work and/or compiled text size - multi_parent_ids = set(id for id in graph.nodes if len(list(graph.parents(id))) > 2) - scalar_ids = set(expr.id for expr in exprs if expr.expression.is_scalar_expr) - - analytic_defs = filter( - lambda x: isinstance(x.expression, agg_expressions.WindowExpression), exprs - ) - analytic_by_window = grouped( - map( - lambda x: (cast(agg_expressions.WindowExpression, x.expression).window, x), - analytic_defs, - ) - ) - - def graph_extract_scalar_exprs() -> Sequence[nodes.ColumnDef]: - results: dict[identifiers.ColumnId, expression.Expression] = dict() - while ( - True - ): # Will converge as each loop either reduces graph size, or fails to find any candidate and breaks - candidate_ids = list( - id - for id in graph.sinks - if (id in scalar_ids) - and not any( - ( - child in multi_parent_ids - and id in results.keys() - and not is_simple(results[id]) - ) - for child in graph.children(id) - ) - ) - if len(candidate_ids) == 0: - break - for id in candidate_ids: - graph.remove_node(id) - new_exprs = { - id: by_id[id].expression.bind_refs( - results, allow_partial_bindings=True - ) - } - results.update(new_exprs) - # TODO: We can prune expressions that won't be reused here, - return tuple(nodes.ColumnDef(expr, id) for id, expr in results.items()) - - def graph_extract_window_expr() -> Optional[ - Tuple[Sequence[nodes.ColumnDef], window_spec.WindowSpec] - ]: - for id in graph.sinks: - next_def = by_id[id] - if isinstance(next_def.expression, agg_expressions.WindowExpression): - window = next_def.expression.window - window_exprs = [ - cdef - for cdef in analytic_by_window[window] - if cdef.id in graph.sinks - ] - agg_exprs = tuple( - nodes.ColumnDef( - cast( - agg_expressions.WindowExpression, cdef.expression - ).analytic_expr, - cdef.id, - ) - for cdef in window_exprs - ) - for cdef in window_exprs: - graph.remove_node(cdef.id) - return (agg_exprs, window) - - return None + graph = ExpressionGraph(exprs) while not graph.empty: - pre_size = len(graph.nodes) - scalar_exprs = graph_extract_scalar_exprs() + pre_size = len(graph) + scalar_exprs = graph.extract_scalar_exprs() if scalar_exprs: curr_root = nodes.ProjectionNode( curr_root, tuple((x.expression, x.id) for x in scalar_exprs) ) - while result := graph_extract_window_expr(): + while result := graph.extract_window_expr(): defs, window = result assert len(defs) > 0 curr_root = nodes.WindowOpNode( @@ -423,7 +490,7 @@ def graph_extract_window_expr() -> Optional[ tuple(defs), window, ) - if len(graph.nodes) >= pre_size: + if len(graph) >= pre_size: raise ValueError("graph didn't shrink") # TODO: Try to get the ordering right earlier, so can avoid this extra node. post_ids = (*root.ids, *target_ids) diff --git a/bigframes/core/graphs.py b/bigframes/core/graphs.py index b7ce80e3cf..a05a744d76 100644 --- a/bigframes/core/graphs.py +++ b/bigframes/core/graphs.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +import collections.abc from typing import Dict, Generic, Hashable, Iterable, Iterator, Tuple, TypeVar import bigframes.core.ordered_sets as sets @@ -28,32 +29,26 @@ def __init__(self, nodes: Iterable[T], edges: Iterable[Tuple[T, T]]): self._children: Dict[T, sets.InsertionOrderedSet[T]] = collections.defaultdict( sets.InsertionOrderedSet ) - self._sinks: sets.InsertionOrderedSet[T] = sets.InsertionOrderedSet() for node in nodes: self._children[node] self._parents[node] - self._sinks.add(node) for src, dst in edges: - assert src in self.nodes - assert dst in self.nodes + assert src in self.graph_nodes + assert dst in self.graph_nodes self._children[src].add(dst) self._parents[dst].add(src) - # sinks have no children - if src in self._sinks: - self._sinks.remove(src) + + def __len__(self): + return len(self._children.keys()) @property - def nodes(self): + def graph_nodes(self) -> Iterable[T]: # should be the same set of ids as self._parents return self._children.keys() - @property - def sinks(self) -> Iterable[T]: - return self._sinks - @property def empty(self): - return len(self.nodes) == 0 + return len(self) == 0 def parents(self, node: T) -> Iterator[T]: assert node in self._parents @@ -68,9 +63,5 @@ def remove_node(self, node: T) -> None: self._parents[child].remove(node) for parent in self._parents[node]: self._children[parent].remove(node) - if len(self._children[parent]) == 0: - self._sinks.add(parent) del self._children[node] del self._parents[node] - if node in self._sinks: - self._sinks.remove(node) diff --git a/bigframes/core/subquery_expression.py b/bigframes/core/subquery_expression.py new file mode 100644 index 0000000000..facb4f98b5 --- /dev/null +++ b/bigframes/core/subquery_expression.py @@ -0,0 +1,101 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import abc +import dataclasses +import functools +import itertools +import typing +from typing import Callable, Mapping, Tuple + +from bigframes import dtypes +from bigframes.core import bigframe_node, expression +import bigframes.core.identifiers as ids + + +@dataclasses.dataclass(frozen=True) +class SubqueryExpression(expression.Expression): + """Represents windowing or aggregation over a column.""" + + subquery: bigframe_node.BigFrameNode + + @property + def column_references(self) -> typing.Tuple[ids.ColumnId, ...]: + return tuple( + itertools.chain.from_iterable( + map(lambda x: x.column_references, self.inputs) + ) + ) + + @functools.cached_property + def is_resolved(self) -> bool: + return False + + @functools.cached_property + def output_type(self) -> dtypes.ExpressionType: + raise ValueError("Subquery has no output type.") + + @property + @abc.abstractmethod + def inputs( + self, + ) -> typing.Tuple[expression.Expression, ...]: + ... + + @property + def children(self) -> Tuple[expression.Expression, ...]: + return self.inputs + + @property + def free_variables(self) -> typing.Tuple[str, ...]: + return tuple( + itertools.chain.from_iterable(map(lambda x: x.free_variables, self.inputs)) + ) + + @property + def is_const(self) -> bool: + return all(child.is_const for child in self.inputs) + + @functools.cached_property + def is_scalar_expr(self) -> bool: + return False + + @abc.abstractmethod + def replace_args(self, *arg) -> SubqueryExpression: + ... + + def transform_children( + self, t: Callable[[expression.Expression], expression.Expression] + ) -> SubqueryExpression: + return self.replace_args(*(t(arg) for arg in self.inputs)) + + def bind_variables( + self, + bindings: Mapping[str, expression.Expression], + allow_partial_bindings: bool = False, + ) -> SubqueryExpression: + return self.transform_children( + lambda x: x.bind_variables(bindings, allow_partial_bindings) + ) + + def bind_refs( + self, + bindings: Mapping[ids.ColumnId, expression.Expression], + allow_partial_bindings: bool = False, + ) -> SubqueryExpression: + return self.transform_children( + lambda x: x.bind_refs(bindings, allow_partial_bindings) + )