Skip to content

Commit 9c6ce5f

Browse files
zhouyuanLakehouse Engine Bot
authored andcommitted
[11771] Fix smj result mismatch issue
Signed-off-by: Yuan <[email protected]> Alchemy-item: [[11771] Fix smj result mismatch issue](#27 (comment)) commit 1/1 - 8c77615
1 parent 503f2ed commit 9c6ce5f

File tree

3 files changed

+146
-97
lines changed

3 files changed

+146
-97
lines changed

velox/exec/MergeJoin.cpp

Lines changed: 82 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,10 @@ void MergeJoin::initialize() {
113113
isSemiFilterJoin(joinType_)) {
114114
joinTracker_ = JoinTracker(outputBatchSize_, pool());
115115
}
116-
} else if (joinNode_->isAntiJoin()) {
116+
} else if (joinNode_->isAntiJoin() || joinNode_->isFullJoin()) {
117117
// Anti join needs to track the left side rows that have no match on the
118-
// right.
118+
// right. Full outer join needs to track the right side rows that have no
119+
// match on the left.
119120
joinTracker_ = JoinTracker(outputBatchSize_, pool());
120121
}
121122

@@ -392,7 +393,8 @@ bool MergeJoin::tryAddOutputRow(
392393
const RowVectorPtr& leftBatch,
393394
vector_size_t leftRow,
394395
const RowVectorPtr& rightBatch,
395-
vector_size_t rightRow) {
396+
vector_size_t rightRow,
397+
bool isRightJoinForFullOuter) {
396398
if (outputSize_ == outputBatchSize_) {
397399
return false;
398400
}
@@ -426,12 +428,15 @@ bool MergeJoin::tryAddOutputRow(
426428
filterRightInputProjections_);
427429

428430
if (joinTracker_) {
429-
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
431+
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_) ||
432+
(isFullJoin(joinType_) && isRightJoinForFullOuter)) {
430433
// Record right-side row with a match on the left-side.
431-
joinTracker_->addMatch(rightBatch, rightRow, outputSize_);
434+
joinTracker_->addMatch(
435+
rightBatch, rightRow, outputSize_, isRightJoinForFullOuter);
432436
} else {
433437
// Record left-side row with a match on the right-side.
434-
joinTracker_->addMatch(leftBatch, leftRow, outputSize_);
438+
joinTracker_->addMatch(
439+
leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
435440
}
436441
}
437442
}
@@ -441,7 +446,8 @@ bool MergeJoin::tryAddOutputRow(
441446
if (isAntiJoin(joinType_)) {
442447
VELOX_CHECK(joinTracker_.has_value());
443448
// Record left-side row with a match on the right-side.
444-
joinTracker_->addMatch(leftBatch, leftRow, outputSize_);
449+
joinTracker_->addMatch(
450+
leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
445451
}
446452

447453
++outputSize_;
@@ -460,14 +466,15 @@ bool MergeJoin::prepareOutput(
460466
return true;
461467
}
462468

463-
if (isRightJoin(joinType_) && right != currentRight_) {
464-
return true;
465-
}
466-
467469
// If there is a new right, we need to flatten the dictionary.
468470
if (!isRightFlattened_ && right && currentRight_ != right) {
469471
flattenRightProjections();
470472
}
473+
474+
if (right != currentRight_) {
475+
return true;
476+
}
477+
471478
return false;
472479
}
473480

@@ -490,11 +497,10 @@ bool MergeJoin::prepareOutput(
490497
}
491498
} else {
492499
for (const auto& projection : leftProjections_) {
500+
auto column = left->childAt(projection.inputChannel);
501+
column->clearContainingLazyAndWrapped();
493502
localColumns[projection.outputChannel] = BaseVector::wrapInDictionary(
494-
{},
495-
leftOutputIndices_,
496-
outputBatchSize_,
497-
left->childAt(projection.inputChannel));
503+
{}, leftOutputIndices_, outputBatchSize_, column);
498504
}
499505
}
500506
currentLeft_ = left;
@@ -510,11 +516,10 @@ bool MergeJoin::prepareOutput(
510516
isRightFlattened_ = true;
511517
} else {
512518
for (const auto& projection : rightProjections_) {
519+
auto column = right->childAt(projection.inputChannel);
520+
column->clearContainingLazyAndWrapped();
513521
localColumns[projection.outputChannel] = BaseVector::wrapInDictionary(
514-
{},
515-
rightOutputIndices_,
516-
outputBatchSize_,
517-
right->childAt(projection.inputChannel));
522+
{}, rightOutputIndices_, outputBatchSize_, column);
518523
}
519524
isRightFlattened_ = false;
520525
}
@@ -579,6 +584,39 @@ bool MergeJoin::prepareOutput(
579584
bool MergeJoin::addToOutput() {
580585
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
581586
return addToOutputForRightJoin();
587+
} else if (isFullJoin(joinType_) && filter_) {
588+
if (!leftForRightJoinMatch_) {
589+
leftForRightJoinMatch_ = leftMatch_;
590+
rightForRightJoinMatch_ = rightMatch_;
591+
}
592+
593+
if (leftMatch_ && rightMatch_ && !leftJoinForFullFinished_) {
594+
auto left = addToOutputForLeftJoin();
595+
if (!leftMatch_) {
596+
leftJoinForFullFinished_ = true;
597+
}
598+
if (left) {
599+
if (!leftMatch_) {
600+
leftMatch_ = leftForRightJoinMatch_;
601+
rightMatch_ = rightForRightJoinMatch_;
602+
}
603+
604+
return true;
605+
}
606+
}
607+
608+
if (!leftMatch_ && !rightJoinForFullFinished_) {
609+
leftMatch_ = leftForRightJoinMatch_;
610+
rightMatch_ = rightForRightJoinMatch_;
611+
rightJoinForFullFinished_ = true;
612+
}
613+
614+
auto right = addToOutputForRightJoin();
615+
616+
leftForRightJoinMatch_ = leftMatch_;
617+
rightForRightJoinMatch_ = rightMatch_;
618+
619+
return right;
582620
} else {
583621
return addToOutputForLeftJoin();
584622
}
@@ -727,7 +765,9 @@ bool MergeJoin::addToOutputForRightJoin() {
727765
}
728766

729767
for (auto j = leftStartRow; j < leftEndRow; ++j) {
730-
if (!tryAddOutputRow(leftBatch, j, rightBatch, i)) {
768+
const auto isRightJoinForFullOuter = isFullJoin(joinType_);
769+
if (!tryAddOutputRow(
770+
leftBatch, j, rightBatch, i, isRightJoinForFullOuter)) {
731771
// If we run out of space in the current output_, we will need to
732772
// produce a buffer and continue processing left later. In this
733773
// case, we cannot leave left as a lazy vector, since we cannot have
@@ -1141,7 +1181,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
11411181
isFullJoin(joinType_)) {
11421182
// If output_ is currently wrapping a different buffer, return it
11431183
// first.
1144-
if (prepareOutput(input_, nullptr)) {
1184+
if (prepareOutput(input_, rightInput_)) {
11451185
output_->resize(outputSize_);
11461186
return std::move(output_);
11471187
}
@@ -1166,7 +1206,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
11661206
if (isRightJoin(joinType_) || isFullJoin(joinType_)) {
11671207
// If output_ is currently wrapping a different buffer, return it
11681208
// first.
1169-
if (prepareOutput(nullptr, rightInput_)) {
1209+
if (prepareOutput(input_, rightInput_)) {
11701210
output_->resize(outputSize_);
11711211
return std::move(output_);
11721212
}
@@ -1218,6 +1258,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
12181258
endRightRow < rightInput_->size(),
12191259
std::nullopt};
12201260

1261+
leftJoinForFullFinished_ = false;
1262+
rightJoinForFullFinished_ = false;
12211263
if (!leftMatch_->complete || !rightMatch_->complete) {
12221264
if (!leftMatch_->complete) {
12231265
// Need to continue looking for the end of match.
@@ -1262,8 +1304,6 @@ RowVectorPtr MergeJoin::doGetOutput() {
12621304
RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
12631305
const auto numRows = output->size();
12641306

1265-
RowVectorPtr fullOuterOutput = nullptr;
1266-
12671307
BufferPtr indices = allocateIndices(numRows, pool());
12681308
auto* rawIndices = indices->asMutable<vector_size_t>();
12691309
vector_size_t numPassed = 0;
@@ -1280,84 +1320,41 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
12801320

12811321
// If all matches for a given left-side row fail the filter, add a row to
12821322
// the output with nulls for the right-side columns.
1283-
const auto onMiss = [&](auto row) {
1323+
const auto onMiss = [&](auto row, bool isRightJoinForFullOuter) {
12841324
if (isSemiFilterJoin(joinType_)) {
12851325
return;
12861326
}
12871327
rawIndices[numPassed++] = row;
12881328

1289-
if (isFullJoin(joinType_)) {
1290-
// For filtered rows, it is necessary to insert additional data
1291-
// to ensure the result set is complete. Specifically, we
1292-
// need to generate two records: one record containing the
1293-
// columns from the left table along with nulls for the
1294-
// right table, and another record containing the columns
1295-
// from the right table along with nulls for the left table.
1296-
// For instance, the current output is filtered based on the condition
1297-
// t > 1.
1298-
1299-
// 1, 1
1300-
// 2, 2
1301-
// 3, 3
1302-
1303-
// In this scenario, we need to additionally insert a record 1, 1.
1304-
// Subsequently, we will set the values of the columns on the left to
1305-
// null and the values of the columns on the right to null as well. By
1306-
// doing so, we will obtain the final result set.
1307-
1308-
// 1, null
1309-
// null, 1
1310-
// 2, 2
1311-
// 3, 3
1312-
fullOuterOutput = BaseVector::create<RowVector>(
1313-
output->type(), output->size() + 1, pool());
1314-
1315-
for (auto i = 0; i < row + 1; ++i) {
1316-
for (auto j = 0; j < output->type()->size(); ++j) {
1317-
fullOuterOutput->childAt(j)->copy(
1318-
output->childAt(j).get(), i, i, 1);
1329+
if (!isRightJoin(joinType_)) {
1330+
if (isFullJoin(joinType_) && isRightJoinForFullOuter) {
1331+
for (auto& projection : leftProjections_) {
1332+
auto target = output->childAt(projection.outputChannel);
1333+
target->setNull(row, true);
13191334
}
1320-
}
1321-
1322-
for (auto j = 0; j < output->type()->size(); ++j) {
1323-
fullOuterOutput->childAt(j)->copy(
1324-
output->childAt(j).get(), row + 1, row, 1);
1325-
}
1326-
1327-
for (auto i = row + 1; i < output->size(); ++i) {
1328-
for (auto j = 0; j < output->type()->size(); ++j) {
1329-
fullOuterOutput->childAt(j)->copy(
1330-
output->childAt(j).get(), i + 1, i, 1);
1335+
} else {
1336+
for (auto& projection : rightProjections_) {
1337+
auto target = output->childAt(projection.outputChannel);
1338+
target->setNull(row, true);
13311339
}
13321340
}
1333-
1334-
for (auto& projection : leftProjections_) {
1335-
auto& target = fullOuterOutput->childAt(projection.outputChannel);
1336-
target->setNull(row, true);
1337-
}
1338-
1339-
for (auto& projection : rightProjections_) {
1340-
auto& target = fullOuterOutput->childAt(projection.outputChannel);
1341-
target->setNull(row + 1, true);
1342-
}
1343-
} else if (!isRightJoin(joinType_)) {
1344-
for (auto& projection : rightProjections_) {
1345-
auto& target = output->childAt(projection.outputChannel);
1346-
target->setNull(row, true);
1347-
}
13481341
} else {
13491342
for (auto& projection : leftProjections_) {
1350-
auto& target = output->childAt(projection.outputChannel);
1343+
auto target = output->childAt(projection.outputChannel);
13511344
target->setNull(row, true);
13521345
}
13531346
}
13541347
};
13551348

13561349
auto onMatch = [&](auto row, bool firstMatch) {
1357-
const bool isNonSemiAntiJoin =
1358-
!isSemiFilterJoin(joinType_) && !isAntiJoin(joinType_);
1350+
const bool isFullLeftJoin =
1351+
isFullJoin(joinType_) && !joinTracker_->isRightJoinForFullOuter(row);
1352+
1353+
const bool isNonSemiAntiFullJoin = !isSemiFilterJoin(joinType_) &&
1354+
!isAntiJoin(joinType_) && !isFullJoin(joinType_);
13591355

1360-
if ((isSemiFilterJoin(joinType_) && firstMatch) || isNonSemiAntiJoin) {
1356+
if ((isSemiFilterJoin(joinType_) && firstMatch) ||
1357+
isNonSemiAntiFullJoin || isFullLeftJoin) {
13611358
rawIndices[numPassed++] = row;
13621359
}
13631360
};
@@ -1418,17 +1415,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
14181415

14191416
if (numPassed == numRows) {
14201417
// All rows passed.
1421-
if (fullOuterOutput) {
1422-
return fullOuterOutput;
1423-
}
14241418
return output;
14251419
}
14261420

14271421
// Some, but not all rows passed.
1428-
if (fullOuterOutput) {
1429-
return wrap(numPassed, indices, fullOuterOutput);
1430-
}
1431-
14321422
return wrap(numPassed, indices, output);
14331423
}
14341424

0 commit comments

Comments
 (0)