diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index 96280bbcf944..d15de07755c6 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -1942,7 +1942,12 @@ impl LazyFrame { } #[cfg(feature = "merge_sorted")] - pub fn merge_sorted(self, other: LazyFrame, key: S) -> PolarsResult + pub fn merge_sorted( + self, + other: LazyFrame, + key: S, + maintain_order: bool, + ) -> PolarsResult where S: Into, { @@ -1952,6 +1957,7 @@ impl LazyFrame { input_left: Arc::new(self.logical_plan), input_right: Arc::new(other.logical_plan), key, + maintain_order, }; Ok(LazyFrame::from_logical_plan(lp, self.opt_state)) } diff --git a/crates/polars-mem-engine/src/planner/lp.rs b/crates/polars-mem-engine/src/planner/lp.rs index b6d4734c80f4..6ff1fe42f9c2 100644 --- a/crates/polars-mem-engine/src/planner/lp.rs +++ b/crates/polars-mem-engine/src/planner/lp.rs @@ -804,6 +804,7 @@ fn create_physical_plan_impl( input_left, input_right, key, + maintain_order: _, } => { let (input_left, input_right) = state.with_new_branch(|new_state| { ( diff --git a/crates/polars-plan/dsl-schema-hashes.json b/crates/polars-plan/dsl-schema-hashes.json index 72fc9ff54c2b..2ace571f34a3 100644 --- a/crates/polars-plan/dsl-schema-hashes.json +++ b/crates/polars-plan/dsl-schema-hashes.json @@ -44,7 +44,7 @@ "Dimension": "68880cdb10230df6c8c1632b073c80bd8ceb5c56a368c0cb438431ca9f3d3b31", "DistinctOptionsDSL": "41be5ec69ef9a614f2b36ac5deadfecdea5cca847ae1ada9d4bc626ff52a5b38", "DslFunction": "221f1a46a043c8ed54f57be981bf24509f04f5f91f0f08e0acc180d96f842ebf", - "DslPlan": "14caf5b73e69c4975ff3a57331891521ff5b78c96bbaf8d6cc9be57c82f3ea98", + "DslPlan": "037aeb1be892efd716c6934961e6df74dcd38815064b6d7efa72efe41e6e913d", "Duration": "44999d59023085cbb592ce94b30d34f9b983081fc72bd6435a49bdf0869c0074", "Duration2": "f251cb1bee2955a17c6defe1573bce21ddbe6cdf6eb9324a19cd37932ab29347", "DynListLiteralValue": "2266a553cb4a943f7097f24539eaa802453cf8742675996215235bd682dec0e8", diff --git a/crates/polars-plan/src/dsl/plan.rs b/crates/polars-plan/src/dsl/plan.rs index 18f44955ccc6..d9d6f560139e 100644 --- a/crates/polars-plan/src/dsl/plan.rs +++ b/crates/polars-plan/src/dsl/plan.rs @@ -169,6 +169,7 @@ pub enum DslPlan { input_left: Arc, input_right: Arc, key: PlSmallStr, + maintain_order: bool, }, IR { // Keep the original Dsl around as we need that for serialization. @@ -211,7 +212,7 @@ impl Clone for DslPlan { #[cfg(feature = "pivot")] Self::Pivot { input, on, on_columns, index, values, agg, separator, maintain_order, column_naming } => Self::Pivot { input: input.clone(), on: on.clone(), on_columns: on_columns.clone(), index: index.clone(), values: values.clone(), agg: agg.clone(), separator: separator.clone(), maintain_order: *maintain_order, column_naming: *column_naming }, #[cfg(feature = "merge_sorted")] - Self::MergeSorted { input_left, input_right, key } => Self::MergeSorted { input_left: input_left.clone(), input_right: input_right.clone(), key: key.clone() }, + Self::MergeSorted { input_left, input_right, key, maintain_order } => Self::MergeSorted { input_left: input_left.clone(), input_right: input_right.clone(), key: key.clone(), maintain_order: *maintain_order }, Self::IR {node, dsl, version} => Self::IR {node: *node, dsl: dsl.clone(), version: *version}, } } diff --git a/crates/polars-plan/src/dsl/serializable_plan.rs b/crates/polars-plan/src/dsl/serializable_plan.rs index 21460852f5b3..135ff8f6796a 100644 --- a/crates/polars-plan/src/dsl/serializable_plan.rs +++ b/crates/polars-plan/src/dsl/serializable_plan.rs @@ -146,6 +146,7 @@ pub(crate) enum SerializableDslPlanNode { input_left: DslPlanKey, input_right: DslPlanKey, key: PlSmallStr, + maintain_order: bool, }, IR { dsl: DslPlanKey, @@ -360,10 +361,12 @@ fn convert_dsl_plan_to_serializable_plan( input_left, input_right, key, + maintain_order, } => SP::MergeSorted { input_left: dsl_plan_key(input_left, arenas), input_right: dsl_plan_key(input_right, arenas), key: key.clone(), + maintain_order: *maintain_order, }, DP::IR { dsl, @@ -608,10 +611,12 @@ fn try_convert_serializable_plan_to_dsl_plan( input_left, input_right, key, + maintain_order, } => Ok(DP::MergeSorted { input_left: get_dsl_plan(*input_left, ser_dsl_plan, arenas)?, input_right: get_dsl_plan(*input_right, ser_dsl_plan, arenas)?, key: key.clone(), + maintain_order: *maintain_order, }), SP::IR { dsl: dsl_key, diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs index b8091633f64e..b2c0fd5c37fc 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs @@ -1502,6 +1502,7 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult input_left, input_right, key, + maintain_order, } => { let input_left = to_alp_impl(owned(input_left), ctxt) .map_err(|e| e.context(failed_here!(merge_sorted)))?; @@ -1523,6 +1524,7 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult input_left, input_right, key, + maintain_order, } }, DslPlan::IR { node, dsl, version } => { diff --git a/crates/polars-plan/src/plans/ir/dot.rs b/crates/polars-plan/src/plans/ir/dot.rs index 603dcfcaf69e..e981a6f8d46f 100644 --- a/crates/polars-plan/src/plans/ir/dot.rs +++ b/crates/polars-plan/src/plans/ir/dot.rs @@ -320,6 +320,7 @@ impl<'a> IRDotDisplay<'a> { input_left, input_right, key, + .. } => { recurse!(*input_left); recurse!(*input_right); diff --git a/crates/polars-plan/src/plans/ir/format.rs b/crates/polars-plan/src/plans/ir/format.rs index 8c4a927d70d8..7e9e21e18bd7 100644 --- a/crates/polars-plan/src/plans/ir/format.rs +++ b/crates/polars-plan/src/plans/ir/format.rs @@ -246,6 +246,7 @@ impl<'a> IRDisplay<'a> { input_left, input_right, key: _, + .. } => { write_ir_non_recursive(f, ir_node, self.lp.expr_arena, output_schema, indent)?; write!(f, ":")?; @@ -1002,6 +1003,7 @@ pub fn write_ir_non_recursive( input_left: _, input_right: _, key, + .. } => write!(f, "{:indent$}MERGE SORTED ON '{key}'", ""), IR::Invalid => write!(f, "{:indent$}INVALID", ""), } diff --git a/crates/polars-plan/src/plans/ir/mod.rs b/crates/polars-plan/src/plans/ir/mod.rs index a57ae66e6844..70b4cb8ea754 100644 --- a/crates/polars-plan/src/plans/ir/mod.rs +++ b/crates/polars-plan/src/plans/ir/mod.rs @@ -158,6 +158,7 @@ pub enum IR { input_left: Node, input_right: Node, key: PlSmallStr, + maintain_order: bool, }, #[default] Invalid, diff --git a/crates/polars-plan/src/plans/ir/tree_format.rs b/crates/polars-plan/src/plans/ir/tree_format.rs index aaef5e8b36f0..d7ffcf0cb50b 100644 --- a/crates/polars-plan/src/plans/ir/tree_format.rs +++ b/crates/polars-plan/src/plans/ir/tree_format.rs @@ -386,6 +386,7 @@ impl<'a> TreeFmtNode<'a> { input_left, input_right, key, + .. } => ND( wh(h, &format!("MERGE SORTED ON '{key}")), [self.lp_node(Some("LEFT PLAN:".to_string()), *input_left)] diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs index 110af0165831..780df520864a 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs @@ -779,6 +779,7 @@ impl ProjectionPushDown { input_left, input_right, key, + maintain_order, } => { if ctx.has_pushed_down() { // make sure that the filter column is projected @@ -792,6 +793,7 @@ impl ProjectionPushDown { input_left, input_right, key, + maintain_order, }) }, Invalid => unreachable!(), diff --git a/crates/polars-plan/src/plans/optimizer/simplify_ordering/mod.rs b/crates/polars-plan/src/plans/optimizer/simplify_ordering/mod.rs index 3a741401b98b..5d66333d4bc3 100644 --- a/crates/polars-plan/src/plans/optimizer/simplify_ordering/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/simplify_ordering/mod.rs @@ -415,6 +415,7 @@ impl SimplifyIRNodeOrder<'_> { input_left, input_right, key: _, + .. } => { let ([in_edge_lhs, in_edge_rhs], [out_edge]) = unpack_edges!(3); diff --git a/crates/polars-plan/src/plans/visitor/hash.rs b/crates/polars-plan/src/plans/visitor/hash.rs index 8362ad96f22b..ba2942edfb0d 100644 --- a/crates/polars-plan/src/plans/visitor/hash.rs +++ b/crates/polars-plan/src/plans/visitor/hash.rs @@ -261,8 +261,10 @@ impl Hash for IRHashWrap<'_> { input_left: _, input_right: _, key, + maintain_order, } => { key.hash(state); + maintain_order.hash(state); }, IR::Invalid => unreachable!(), } diff --git a/crates/polars-python/src/lazyframe/general.rs b/crates/polars-python/src/lazyframe/general.rs index 3b6fb9bd8337..8224fbd3c2a1 100644 --- a/crates/polars-python/src/lazyframe/general.rs +++ b/crates/polars-python/src/lazyframe/general.rs @@ -1503,12 +1503,12 @@ impl PyLazyFrame { } #[cfg(feature = "merge_sorted")] - fn merge_sorted(&self, other: Self, key: &str) -> PyResult { + fn merge_sorted(&self, other: Self, key: &str, maintain_order: bool) -> PyResult { let out = self .ldf .read() .clone() - .merge_sorted(other.ldf.into_inner(), key) + .merge_sorted(other.ldf.into_inner(), key, maintain_order) .map_err(PyPolarsErr::from)?; Ok(out.into()) } diff --git a/crates/polars-python/src/lazyframe/visitor/nodes.rs b/crates/polars-python/src/lazyframe/visitor/nodes.rs index 416c7c81a797..48e092d79cae 100644 --- a/crates/polars-python/src/lazyframe/visitor/nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/nodes.rs @@ -281,6 +281,8 @@ pub struct MergeSorted { input_right: usize, #[pyo3(get)] key: String, + #[pyo3(get)] + maintain_order: bool, } #[pyclass(frozen)] @@ -744,10 +746,12 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult> { input_left, input_right, key, + maintain_order, } => MergeSorted { input_left: input_left.0, input_right: input_right.0, key: key.to_string(), + maintain_order: *maintain_order, } .into_py_any(py), IR::Invalid => Err(PyNotImplementedError::new_err("Invalid")), diff --git a/crates/polars-stream/src/nodes/merge_sorted.rs b/crates/polars-stream/src/nodes/merge_sorted.rs index cb34d9daae92..910f3bb1c454 100644 --- a/crates/polars-stream/src/nodes/merge_sorted.rs +++ b/crates/polars-stream/src/nodes/merge_sorted.rs @@ -15,18 +15,22 @@ pub struct MergeSortedNode { starting_nulls: bool, + maintain_order: bool, + // Not yet merged buffers. left_unmerged: VecDeque, right_unmerged: VecDeque, } impl MergeSortedNode { - pub fn new() -> Self { + pub fn new(maintain_order: bool) -> Self { Self { seq: MorselSeq::default(), starting_nulls: false, + maintain_order, + left_unmerged: VecDeque::new(), right_unmerged: VecDeque::new(), } @@ -42,6 +46,7 @@ fn find_mergeable( is_first: bool, starting_nulls: &mut bool, + maintain_order: bool, ) -> PolarsResult> { fn first_non_empty(vd: &mut VecDeque) -> Option { let mut df = vd.pop_front()?; @@ -133,13 +138,26 @@ fn find_mergeable( } else if left_key_last.lt(&right_key_last)?.all() { // @TODO: This is essentially search sorted, but that does not // support categoricals at moment. - let gt_mask = right_key.gt(&left_key_last)?; - right_cutoff = gt_mask.first_true_idx().unwrap_or(gt_mask.len()); + if maintain_order { + // When maintaining order, hold back right-side rows with keys + // equal to left's max, since more left rows with that key may + // arrive in later morsels. + let gte_mask = right_key.gt_eq(&left_key_last)?; + right_cutoff = gte_mask.first_true_idx().unwrap_or(gte_mask.len()); + } else { + let gt_mask = right_key.gt(&left_key_last)?; + right_cutoff = gt_mask.first_true_idx().unwrap_or(gt_mask.len()); + } } else if left_key_last.gt(&right_key_last)?.all() { // @TODO: This is essentially search sorted, but that does not // support categoricals at moment. let gt_mask = left_key.gt(&right_key_last)?; left_cutoff = gt_mask.first_true_idx().unwrap_or(gt_mask.len()); + } else if maintain_order { + // Keys are equal at both maxima. Hold back right-side rows with + // keys equal to the shared maximum to ensure left-biased ordering. + let gte_mask = right_key.gt_eq(&left_key_last)?; + right_cutoff = gte_mask.first_true_idx().unwrap_or(gte_mask.len()); } let left_mergeable: DataFrame; @@ -235,6 +253,7 @@ impl ComputeNode for MergeSortedNode { let seq = &mut self.seq; let starting_nulls = &mut self.starting_nulls; + let maintain_order = self.maintain_order; let left_unmerged = &mut self.left_unmerged; let right_unmerged = &mut self.right_unmerged; @@ -319,6 +338,7 @@ impl ComputeNode for MergeSortedNode { right_unmerged, seq.to_u64() == 0, starting_nulls, + maintain_order, )? { let left_mergeable = Morsel::new(left_mergeable, *seq, source_token.clone()); @@ -379,6 +399,7 @@ impl ComputeNode for MergeSortedNode { right_unmerged, seq.to_u64() == 0, starting_nulls, + maintain_order, )? { let left_mergeable = Morsel::new(left_mergeable, *seq, source_token.clone()); diff --git a/crates/polars-stream/src/physical_plan/fmt.rs b/crates/polars-stream/src/physical_plan/fmt.rs index 823ff9a70e7e..a4b286f3f14c 100644 --- a/crates/polars-stream/src/physical_plan/fmt.rs +++ b/crates/polars-stream/src/physical_plan/fmt.rs @@ -804,6 +804,7 @@ fn visualize_plan_rec( PhysNodeKind::MergeSorted { input_left, input_right, + .. } => ("merge-sorted".to_string(), &[*input_left, *input_right][..]), #[cfg(feature = "ewma")] PhysNodeKind::EwmMean { input, options: _ } => ("ewm-mean".to_string(), &[*input][..]), diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index 7470d687daa5..ab851ab58a75 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -331,10 +331,12 @@ pub fn lower_ir( input_left, input_right, key, + maintain_order, } => { let input_left = *input_left; let input_right = *input_right; let key = key.clone(); + let maintain_order = *maintain_order; let mut phys_left = lower_ir!(input_left)?; let mut phys_right = lower_ir!(input_right)?; @@ -379,6 +381,7 @@ pub fn lower_ir( PhysNodeKind::MergeSorted { input_left: phys_left, input_right: phys_right, + maintain_order, } }, diff --git a/crates/polars-stream/src/physical_plan/mod.rs b/crates/polars-stream/src/physical_plan/mod.rs index f5f28df68cdc..53fac8de3a5c 100644 --- a/crates/polars-stream/src/physical_plan/mod.rs +++ b/crates/polars-stream/src/physical_plan/mod.rs @@ -473,6 +473,7 @@ pub enum PhysNodeKind { MergeSorted { input_left: PhysStream, input_right: PhysStream, + maintain_order: bool, }, #[cfg(feature = "ewma")] diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index a9aaeef54ff8..510ca868d206 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -1318,11 +1318,12 @@ fn to_graph_rec<'a>( MergeSorted { input_left, input_right, + maintain_order, } => { let left_input_key = to_graph_rec(input_left.node, ctx)?; let right_input_key = to_graph_rec(input_right.node, ctx)?; ctx.graph.add_node( - nodes::merge_sorted::MergeSortedNode::new(), + nodes::merge_sorted::MergeSortedNode::new(*maintain_order), [ (left_input_key, input_left.port), (right_input_key, input_right.port), diff --git a/py-polars/src/polars/_plr.pyi b/py-polars/src/polars/_plr.pyi index dd1a470a9207..07f272abd04a 100644 --- a/py-polars/src/polars/_plr.pyi +++ b/py-polars/src/polars/_plr.pyi @@ -1154,7 +1154,9 @@ class PyLazyFrame: def collect_schema(self) -> dict[str, Any]: ... def unnest(self, columns: PySelector, separator: str | None) -> PyLazyFrame: ... def count(self) -> PyLazyFrame: ... - def merge_sorted(self, other: PyLazyFrame, key: str) -> PyLazyFrame: ... + def merge_sorted( + self, other: PyLazyFrame, key: str, maintain_order: bool + ) -> PyLazyFrame: ... def hint_sorted( self, columns: list[str], descending: list[bool], nulls_last: list[bool] ) -> PyLazyFrame: ... diff --git a/py-polars/src/polars/dataframe/frame.py b/py-polars/src/polars/dataframe/frame.py index 280d07412e47..91473fe460bf 100644 --- a/py-polars/src/polars/dataframe/frame.py +++ b/py-polars/src/polars/dataframe/frame.py @@ -12453,7 +12453,13 @@ def corr(self, *, label: str | None = None, **kwargs: Any) -> DataFrame: df.insert_column(0, cols) return df - def merge_sorted(self, other: DataFrame, key: str) -> DataFrame: + def merge_sorted( + self, + other: DataFrame, + key: str, + *, + maintain_order: bool = False, + ) -> DataFrame: """ Take two sorted DataFrames and merge them by the sorted key. @@ -12470,6 +12476,10 @@ def merge_sorted(self, other: DataFrame, key: str) -> DataFrame: Other DataFrame that must be merged key Key that is sorted. + maintain_order + If ``True``, the output is guaranteed to have left-biased ordering + for equal keys: rows from the left frame appear before rows from + the right frame when their keys are equal. Examples -------- @@ -12520,8 +12530,8 @@ def merge_sorted(self, other: DataFrame, key: str) -> DataFrame: Notes ----- - No guarantee is given over the output row order when the key is equal - between the both dataframes. + Unless ``maintain_order=True``, no guarantee is given over the output + row order when the key is equal between the both dataframes. The key must be sorted in ascending order. """ @@ -12531,7 +12541,7 @@ def merge_sorted(self, other: DataFrame, key: str) -> DataFrame: return ( self.lazy() - .merge_sorted(other.lazy(), key) + .merge_sorted(other.lazy(), key, maintain_order=maintain_order) .collect(optimizations=QueryOptFlags._eager()) ) diff --git a/py-polars/src/polars/functions/eager.py b/py-polars/src/polars/functions/eager.py index fa71c89d9d70..7b7ee1c55118 100644 --- a/py-polars/src/polars/functions/eager.py +++ b/py-polars/src/polars/functions/eager.py @@ -604,7 +604,12 @@ def join_fn(x: pl.LazyFrame, y: pl.LazyFrame) -> pl.LazyFrame: @unstable() -def merge_sorted(items: Iterable[PolarsType], key: str) -> PolarsType: +def merge_sorted( + items: Iterable[PolarsType], + key: str, + *, + maintain_order: bool = False, +) -> PolarsType: """ Merge multiple sorted DataFrames or LazyFrames by the sorted key. @@ -623,6 +628,10 @@ def merge_sorted(items: Iterable[PolarsType], key: str) -> PolarsType: DataFrames or LazyFrames to merge. key Key that is sorted. + maintain_order + If ``True``, the output is guaranteed to have left-biased ordering + for equal keys: rows from the left frame appear before rows from + the right frame when their keys are equal. Examples -------- @@ -654,8 +663,8 @@ def merge_sorted(items: Iterable[PolarsType], key: str) -> PolarsType: Notes ----- - No guarantee is given over the output row order when the key is equal - between dataframes. + Unless ``maintain_order=True``, no guarantee is given over the output + row order when the key is equal between dataframes. The key must be sorted in ascending order. """ @@ -674,7 +683,7 @@ def merge_sorted(items: Iterable[PolarsType], key: str) -> PolarsType: frames = [df.lazy() for df in elems] def reduce_fn(x: pl.LazyFrame, y: pl.LazyFrame) -> pl.LazyFrame: - return x.merge_sorted(y, key=key) + return x.merge_sorted(y, key=key, maintain_order=maintain_order) lf = reduce_balanced(reduce_fn, frames) eager = isinstance(elems[0], pl.DataFrame) diff --git a/py-polars/src/polars/lazyframe/frame.py b/py-polars/src/polars/lazyframe/frame.py index a60ae9071c3b..e2d456e54dae 100644 --- a/py-polars/src/polars/lazyframe/frame.py +++ b/py-polars/src/polars/lazyframe/frame.py @@ -8765,7 +8765,13 @@ def unnest( return self._from_pyldf(self._ldf.unnest(subset._pyselector, separator)) - def merge_sorted(self, other: LazyFrame, key: str) -> LazyFrame: + def merge_sorted( + self, + other: LazyFrame, + key: str, + *, + maintain_order: bool = False, + ) -> LazyFrame: """ Take two sorted DataFrames and merge them by the sorted key. @@ -8782,6 +8788,10 @@ def merge_sorted(self, other: LazyFrame, key: str) -> LazyFrame: Other DataFrame that must be merged key Key that is sorted. + maintain_order + If ``True``, the output is guaranteed to have left-biased ordering + for equal keys: rows from the left frame appear before rows from + the right frame when their keys are equal. Examples -------- @@ -8832,13 +8842,13 @@ def merge_sorted(self, other: LazyFrame, key: str) -> LazyFrame: Notes ----- - No guarantee is given over the output row order when the key is equal - between the both dataframes. + Unless ``maintain_order=True``, no guarantee is given over the output + row order when the key is equal between the both dataframes. The key must be sorted in ascending order. """ require_same_type(self, other) - return self._from_pyldf(self._ldf.merge_sorted(other._ldf, key)) + return self._from_pyldf(self._ldf.merge_sorted(other._ldf, key, maintain_order)) def set_sorted( self, diff --git a/py-polars/tests/unit/operations/test_merge_sorted.py b/py-polars/tests/unit/operations/test_merge_sorted.py index c2d1d6b7ce01..cf53ddada358 100644 --- a/py-polars/tests/unit/operations/test_merge_sorted.py +++ b/py-polars/tests/unit/operations/test_merge_sorted.py @@ -345,3 +345,62 @@ def test_merge_sorted_multiple_associativity(n_dfs: int, lazy: bool) -> None: df_chained_from_right = df.merge_sorted(df_chained_from_right, key="n") assert_frame_equal(df_chained_from_right, df_full) + + +@pytest.mark.parametrize("streaming", [False, True]) +def test_merge_sorted_maintain_order(streaming: bool) -> None: + """Test that maintain_order=True guarantees left-biased ordering for equal keys.""" + left = pl.DataFrame({"src": ["L1", "L2", "L3"], "key": [1, 2, 3]}) + right = pl.DataFrame({"src": ["R1", "R2", "R3"], "key": [2, 3, 4]}) + + result = ( + left.lazy() + .merge_sorted(right.lazy(), key="key", maintain_order=True) + .collect(engine="streaming" if streaming else "in-memory") + ) + + expected = pl.DataFrame( + { + "src": ["L1", "L2", "R1", "L3", "R2", "R3"], + "key": [1, 2, 2, 3, 3, 4], + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("streaming", [False, True]) +def test_merge_sorted_maintain_order_all_equal(streaming: bool) -> None: + """Test maintain_order when all keys are equal.""" + left = pl.DataFrame({"src": ["L1", "L2"], "key": [1, 1]}) + right = pl.DataFrame({"src": ["R1", "R2"], "key": [1, 1]}) + + result = ( + left.lazy() + .merge_sorted(right.lazy(), key="key", maintain_order=True) + .collect(engine="streaming" if streaming else "in-memory") + ) + + expected = pl.DataFrame( + { + "src": ["L1", "L2", "R1", "R2"], + "key": [1, 1, 1, 1], + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("streaming", [False, True]) +def test_merge_sorted_maintain_order_dataframe(streaming: bool) -> None: + """Test maintain_order via the DataFrame.merge_sorted API.""" + left = pl.DataFrame({"src": ["L1", "L2"], "key": [1, 2]}) + right = pl.DataFrame({"src": ["R1", "R2"], "key": [1, 2]}) + + result = left.merge_sorted(right, key="key", maintain_order=True) + + expected = pl.DataFrame( + { + "src": ["L1", "R1", "L2", "R2"], + "key": [1, 1, 2, 2], + } + ) + assert_frame_equal(result, expected)