Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 153 additions & 86 deletions bigframes/core/expression_factoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@
Hashable,
Iterable,
Iterator,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)

from bigframes.core import (
Expand All @@ -38,12 +40,158 @@
graphs,
identifiers,
nodes,
subquery_expression,
window_spec,
)
import bigframes.core.ordered_sets as sets

_MAX_INLINE_COMPLEXITY = 10

T = TypeVar("T")
ExprDomain = Union[window_spec.WindowSpec, Literal["Scalar", "Other"]]


class ExpressionGraph(graphs.DiGraph[nodes.ColumnDef]):
def __init__(self, column_defs: Sequence[nodes.ColumnDef]):
# Assumption: All column defs have unique ids
expr_ids = set(cdef.id for cdef in column_defs)
self._graph = graphs.DiGraph(
(expr.id for expr in column_defs),
(
(expr.id, child_id)
for expr in column_defs
for child_id in expr.expression.column_references
if child_id in expr_ids
),
)
self._id_to_cdef = {cdef.id: cdef for cdef in column_defs}

# TODO: Also prevent inlining expensive or non-deterministic
# We avoid inlining multi-parent ids, as they would be inlined multiple places, potentially increasing work and/or compiled text size
self._multi_parent_ids = set(
id
for id in self._graph.graph_nodes
if len(list(self._graph.parents(id))) > 2
)
self._free_ids_by_domain: dict[
ExprDomain, sets.InsertionOrderedSet[identifiers.ColumnId]
] = collections.defaultdict(sets.InsertionOrderedSet)

for id in self._graph.graph_nodes:
if len(list(self._graph.children(id))) == 0:
self._mark_free(id)

@property
def graph_nodes(self) -> Iterable[nodes.ColumnDef]:
# should be the same set of ids as self._parents
return map(self._id_to_cdef.__getitem__, self._graph.graph_nodes)

@property
def empty(self):
return self._graph.empty

def __len__(self):
return len(self._graph)

def parents(self, node: nodes.ColumnDef) -> Iterator[nodes.ColumnDef]:
yield from map(self._id_to_cdef.__getitem__, self._graph.parents(node.id))

def children(self, node: nodes.ColumnDef) -> Iterator[nodes.ColumnDef]:
yield from map(self._id_to_cdef.__getitem__, self._graph.children(node.id))

def _expr_domain(self, expr: expression.Expression) -> ExprDomain:
if expr.is_scalar_expr:
return "Scalar"
elif isinstance(expr, agg_expressions.WindowExpression):
return expr.window
elif isinstance(expr, subquery_expression.SubqueryExpression):
return "Other"
else:
raise ValueError(f"unrecognized expression {expr}")

def _mark_free(self, id: identifiers.ColumnId):
cdef = self._id_to_cdef[id]
expr = cdef.expression
# If this expands further, probably generalize a compatibility key
self._free_ids_by_domain[self._expr_domain(expr)].add(id)

def _remove_free_mark(self, id: identifiers.ColumnId):
cdef = self._id_to_cdef[id]
expr = cdef.expression
# If this expands further, probably generalize a compatibility key
if id in self._free_ids_by_domain[self._expr_domain(expr)]:
self._free_ids_by_domain[self._expr_domain(expr)].remove(id)

def remove_node(self, node: nodes.ColumnDef) -> None:
for child in self._children[node]:
self._parents[child].remove(node)
for parent in self._parents[node]:
self._children[parent].remove(node)
if len(self._children[parent]) == 0:
self._mark_free(parent.id)
del self._children[node]
del self._parents[node]
self._remove_free_mark(node.id)

def extract_scalar_exprs(self) -> Sequence[nodes.ColumnDef]:
results: dict[identifiers.ColumnId, expression.Expression] = dict()
while (
True
): # Will converge as each loop either reduces graph size, or fails to find any candidate and breaks
candidate_ids = list(
id
for id in self._free_ids_by_domain["Scalar"]
if not any(
(
child in self._multi_parent_ids
and id in results.keys()
and not is_simple(results[id])
)
for child in self._graph.children(id)
)
)
if len(candidate_ids) == 0:
break
for id in candidate_ids:
self._graph.remove_node(id)
new_exprs = {
id: self._id_to_cdef[id].expression.bind_refs(
results, allow_partial_bindings=True
)
}
results.update(new_exprs)
# TODO: We can prune expressions that won't be reused here,
return tuple(nodes.ColumnDef(expr, id) for id, expr in results.items())

def extract_window_expr(
self,
) -> Optional[Tuple[Sequence[nodes.ColumnDef], window_spec.WindowSpec]]:
window = next(
(
domain
for domain in self._free_ids_by_domain
if domain not in ["Scalar", "Other"]
),
None,
)
assert not isinstance(window, str)
if window:
window_expr_ids = self._free_ids_by_domain[window]
window_exprs = (self._id_to_cdef[id] for id in window_expr_ids)
agg_exprs = tuple(
nodes.ColumnDef(
cast(
agg_expressions.WindowExpression, cdef.expression
).analytic_expr,
cdef.id,
)
for cdef in window_exprs
)
for cdef in window_exprs:
self.remove_node(cdef)
return (agg_exprs, window)

return None


def unique_nodes(
Expand Down Expand Up @@ -324,106 +472,25 @@ def push_into_tree(
target_ids: Sequence[identifiers.ColumnId],
) -> nodes.BigFrameNode:
curr_root = root
by_id = {expr.id: expr for expr in exprs}
# id -> id
graph = graphs.DiGraph(
(expr.id for expr in exprs),
(
(expr.id, child_id)
for expr in exprs
for child_id in expr.expression.column_references
if child_id in by_id.keys()
),
)
# TODO: Also prevent inlining expensive or non-deterministic
# We avoid inlining multi-parent ids, as they would be inlined multiple places, potentially increasing work and/or compiled text size
multi_parent_ids = set(id for id in graph.nodes if len(list(graph.parents(id))) > 2)
scalar_ids = set(expr.id for expr in exprs if expr.expression.is_scalar_expr)

analytic_defs = filter(
lambda x: isinstance(x.expression, agg_expressions.WindowExpression), exprs
)
analytic_by_window = grouped(
map(
lambda x: (cast(agg_expressions.WindowExpression, x.expression).window, x),
analytic_defs,
)
)

def graph_extract_scalar_exprs() -> Sequence[nodes.ColumnDef]:
results: dict[identifiers.ColumnId, expression.Expression] = dict()
while (
True
): # Will converge as each loop either reduces graph size, or fails to find any candidate and breaks
candidate_ids = list(
id
for id in graph.sinks
if (id in scalar_ids)
and not any(
(
child in multi_parent_ids
and id in results.keys()
and not is_simple(results[id])
)
for child in graph.children(id)
)
)
if len(candidate_ids) == 0:
break
for id in candidate_ids:
graph.remove_node(id)
new_exprs = {
id: by_id[id].expression.bind_refs(
results, allow_partial_bindings=True
)
}
results.update(new_exprs)
# TODO: We can prune expressions that won't be reused here,
return tuple(nodes.ColumnDef(expr, id) for id, expr in results.items())

def graph_extract_window_expr() -> Optional[
Tuple[Sequence[nodes.ColumnDef], window_spec.WindowSpec]
]:
for id in graph.sinks:
next_def = by_id[id]
if isinstance(next_def.expression, agg_expressions.WindowExpression):
window = next_def.expression.window
window_exprs = [
cdef
for cdef in analytic_by_window[window]
if cdef.id in graph.sinks
]
agg_exprs = tuple(
nodes.ColumnDef(
cast(
agg_expressions.WindowExpression, cdef.expression
).analytic_expr,
cdef.id,
)
for cdef in window_exprs
)
for cdef in window_exprs:
graph.remove_node(cdef.id)
return (agg_exprs, window)

return None
graph = ExpressionGraph(exprs)

while not graph.empty:
pre_size = len(graph.nodes)
scalar_exprs = graph_extract_scalar_exprs()
pre_size = len(graph)
scalar_exprs = graph.extract_scalar_exprs()
if scalar_exprs:
curr_root = nodes.ProjectionNode(
curr_root, tuple((x.expression, x.id) for x in scalar_exprs)
)
while result := graph_extract_window_expr():
while result := graph.extract_window_expr():
defs, window = result
assert len(defs) > 0
curr_root = nodes.WindowOpNode(
curr_root,
tuple(defs),
window,
)
if len(graph.nodes) >= pre_size:
if len(graph) >= pre_size:
raise ValueError("graph didn't shrink")
# TODO: Try to get the ordering right earlier, so can avoid this extra node.
post_ids = (*root.ids, *target_ids)
Expand Down
25 changes: 8 additions & 17 deletions bigframes/core/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import collections
import collections.abc
from typing import Dict, Generic, Hashable, Iterable, Iterator, Tuple, TypeVar

import bigframes.core.ordered_sets as sets
Expand All @@ -28,32 +29,26 @@ def __init__(self, nodes: Iterable[T], edges: Iterable[Tuple[T, T]]):
self._children: Dict[T, sets.InsertionOrderedSet[T]] = collections.defaultdict(
sets.InsertionOrderedSet
)
self._sinks: sets.InsertionOrderedSet[T] = sets.InsertionOrderedSet()
for node in nodes:
self._children[node]
self._parents[node]
self._sinks.add(node)
for src, dst in edges:
assert src in self.nodes
assert dst in self.nodes
assert src in self.graph_nodes
assert dst in self.graph_nodes
self._children[src].add(dst)
self._parents[dst].add(src)
# sinks have no children
if src in self._sinks:
self._sinks.remove(src)

def __len__(self):
return len(self._children.keys())

@property
def nodes(self):
def graph_nodes(self) -> Iterable[T]:
# should be the same set of ids as self._parents
return self._children.keys()

@property
def sinks(self) -> Iterable[T]:
return self._sinks

@property
def empty(self):
return len(self.nodes) == 0
return len(self) == 0

def parents(self, node: T) -> Iterator[T]:
assert node in self._parents
Expand All @@ -68,9 +63,5 @@ def remove_node(self, node: T) -> None:
self._parents[child].remove(node)
for parent in self._parents[node]:
self._children[parent].remove(node)
if len(self._children[parent]) == 0:
self._sinks.add(parent)
del self._children[node]
del self._parents[node]
if node in self._sinks:
self._sinks.remove(node)
Loading
Loading