@@ -113,9 +113,10 @@ void MergeJoin::initialize() {
113
113
isSemiFilterJoin (joinType_)) {
114
114
joinTracker_ = JoinTracker (outputBatchSize_, pool ());
115
115
}
116
- } else if (joinNode_->isAntiJoin ()) {
116
+ } else if (joinNode_->isAntiJoin () || joinNode_-> isFullJoin () ) {
117
117
// Anti join needs to track the left side rows that have no match on the
118
- // right.
118
+ // right. Full outer join needs to track the right side rows that have no
119
+ // match on the left.
119
120
joinTracker_ = JoinTracker (outputBatchSize_, pool ());
120
121
}
121
122
@@ -392,7 +393,8 @@ bool MergeJoin::tryAddOutputRow(
392
393
const RowVectorPtr& leftBatch,
393
394
vector_size_t leftRow,
394
395
const RowVectorPtr& rightBatch,
395
- vector_size_t rightRow) {
396
+ vector_size_t rightRow,
397
+ bool isRightJoinForFullOuter) {
396
398
if (outputSize_ == outputBatchSize_) {
397
399
return false ;
398
400
}
@@ -426,12 +428,15 @@ bool MergeJoin::tryAddOutputRow(
426
428
filterRightInputProjections_);
427
429
428
430
if (joinTracker_) {
429
- if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_)) {
431
+ if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_) ||
432
+ (isFullJoin (joinType_) && isRightJoinForFullOuter)) {
430
433
// Record right-side row with a match on the left-side.
431
- joinTracker_->addMatch (rightBatch, rightRow, outputSize_);
434
+ joinTracker_->addMatch (
435
+ rightBatch, rightRow, outputSize_, isRightJoinForFullOuter);
432
436
} else {
433
437
// Record left-side row with a match on the right-side.
434
- joinTracker_->addMatch (leftBatch, leftRow, outputSize_);
438
+ joinTracker_->addMatch (
439
+ leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
435
440
}
436
441
}
437
442
}
@@ -441,7 +446,8 @@ bool MergeJoin::tryAddOutputRow(
441
446
if (isAntiJoin (joinType_)) {
442
447
VELOX_CHECK (joinTracker_.has_value ());
443
448
// Record left-side row with a match on the right-side.
444
- joinTracker_->addMatch (leftBatch, leftRow, outputSize_);
449
+ joinTracker_->addMatch (
450
+ leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
445
451
}
446
452
447
453
++outputSize_;
@@ -460,14 +466,15 @@ bool MergeJoin::prepareOutput(
460
466
return true ;
461
467
}
462
468
463
- if (isRightJoin (joinType_) && right != currentRight_) {
464
- return true ;
465
- }
466
-
467
469
// If there is a new right, we need to flatten the dictionary.
468
470
if (!isRightFlattened_ && right && currentRight_ != right) {
469
471
flattenRightProjections ();
470
472
}
473
+
474
+ if (right != currentRight_) {
475
+ return true ;
476
+ }
477
+
471
478
return false ;
472
479
}
473
480
@@ -490,11 +497,10 @@ bool MergeJoin::prepareOutput(
490
497
}
491
498
} else {
492
499
for (const auto & projection : leftProjections_) {
500
+ auto column = left->childAt (projection.inputChannel );
501
+ column->clearContainingLazyAndWrapped ();
493
502
localColumns[projection.outputChannel ] = BaseVector::wrapInDictionary (
494
- {},
495
- leftOutputIndices_,
496
- outputBatchSize_,
497
- left->childAt (projection.inputChannel ));
503
+ {}, leftOutputIndices_, outputBatchSize_, column);
498
504
}
499
505
}
500
506
currentLeft_ = left;
@@ -510,11 +516,10 @@ bool MergeJoin::prepareOutput(
510
516
isRightFlattened_ = true ;
511
517
} else {
512
518
for (const auto & projection : rightProjections_) {
519
+ auto column = right->childAt (projection.inputChannel );
520
+ column->clearContainingLazyAndWrapped ();
513
521
localColumns[projection.outputChannel ] = BaseVector::wrapInDictionary (
514
- {},
515
- rightOutputIndices_,
516
- outputBatchSize_,
517
- right->childAt (projection.inputChannel ));
522
+ {}, rightOutputIndices_, outputBatchSize_, column);
518
523
}
519
524
isRightFlattened_ = false ;
520
525
}
@@ -579,6 +584,39 @@ bool MergeJoin::prepareOutput(
579
584
bool MergeJoin::addToOutput () {
580
585
if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_)) {
581
586
return addToOutputForRightJoin ();
587
+ } else if (isFullJoin (joinType_) && filter_) {
588
+ if (!leftForRightJoinMatch_) {
589
+ leftForRightJoinMatch_ = leftMatch_;
590
+ rightForRightJoinMatch_ = rightMatch_;
591
+ }
592
+
593
+ if (leftMatch_ && rightMatch_ && !leftJoinForFullFinished_) {
594
+ auto left = addToOutputForLeftJoin ();
595
+ if (!leftMatch_) {
596
+ leftJoinForFullFinished_ = true ;
597
+ }
598
+ if (left) {
599
+ if (!leftMatch_) {
600
+ leftMatch_ = leftForRightJoinMatch_;
601
+ rightMatch_ = rightForRightJoinMatch_;
602
+ }
603
+
604
+ return true ;
605
+ }
606
+ }
607
+
608
+ if (!leftMatch_ && !rightJoinForFullFinished_) {
609
+ leftMatch_ = leftForRightJoinMatch_;
610
+ rightMatch_ = rightForRightJoinMatch_;
611
+ rightJoinForFullFinished_ = true ;
612
+ }
613
+
614
+ auto right = addToOutputForRightJoin ();
615
+
616
+ leftForRightJoinMatch_ = leftMatch_;
617
+ rightForRightJoinMatch_ = rightMatch_;
618
+
619
+ return right;
582
620
} else {
583
621
return addToOutputForLeftJoin ();
584
622
}
@@ -727,7 +765,9 @@ bool MergeJoin::addToOutputForRightJoin() {
727
765
}
728
766
729
767
for (auto j = leftStartRow; j < leftEndRow; ++j) {
730
- if (!tryAddOutputRow (leftBatch, j, rightBatch, i)) {
768
+ const auto isRightJoinForFullOuter = isFullJoin (joinType_);
769
+ if (!tryAddOutputRow (
770
+ leftBatch, j, rightBatch, i, isRightJoinForFullOuter)) {
731
771
// If we run out of space in the current output_, we will need to
732
772
// produce a buffer and continue processing left later. In this
733
773
// case, we cannot leave left as a lazy vector, since we cannot have
@@ -1141,7 +1181,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
1141
1181
isFullJoin (joinType_)) {
1142
1182
// If output_ is currently wrapping a different buffer, return it
1143
1183
// first.
1144
- if (prepareOutput (input_, nullptr )) {
1184
+ if (prepareOutput (input_, rightInput_ )) {
1145
1185
output_->resize (outputSize_);
1146
1186
return std::move (output_);
1147
1187
}
@@ -1166,7 +1206,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
1166
1206
if (isRightJoin (joinType_) || isFullJoin (joinType_)) {
1167
1207
// If output_ is currently wrapping a different buffer, return it
1168
1208
// first.
1169
- if (prepareOutput (nullptr , rightInput_)) {
1209
+ if (prepareOutput (input_ , rightInput_)) {
1170
1210
output_->resize (outputSize_);
1171
1211
return std::move (output_);
1172
1212
}
@@ -1218,6 +1258,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
1218
1258
endRightRow < rightInput_->size (),
1219
1259
std::nullopt };
1220
1260
1261
+ leftJoinForFullFinished_ = false ;
1262
+ rightJoinForFullFinished_ = false ;
1221
1263
if (!leftMatch_->complete || !rightMatch_->complete ) {
1222
1264
if (!leftMatch_->complete ) {
1223
1265
// Need to continue looking for the end of match.
@@ -1262,8 +1304,6 @@ RowVectorPtr MergeJoin::doGetOutput() {
1262
1304
RowVectorPtr MergeJoin::applyFilter (const RowVectorPtr& output) {
1263
1305
const auto numRows = output->size ();
1264
1306
1265
- RowVectorPtr fullOuterOutput = nullptr ;
1266
-
1267
1307
BufferPtr indices = allocateIndices (numRows, pool ());
1268
1308
auto * rawIndices = indices->asMutable <vector_size_t >();
1269
1309
vector_size_t numPassed = 0 ;
@@ -1280,84 +1320,41 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
1280
1320
1281
1321
// If all matches for a given left-side row fail the filter, add a row to
1282
1322
// the output with nulls for the right-side columns.
1283
- const auto onMiss = [&](auto row) {
1323
+ const auto onMiss = [&](auto row, bool isRightJoinForFullOuter ) {
1284
1324
if (isSemiFilterJoin (joinType_)) {
1285
1325
return ;
1286
1326
}
1287
1327
rawIndices[numPassed++] = row;
1288
1328
1289
- if (isFullJoin (joinType_)) {
1290
- // For filtered rows, it is necessary to insert additional data
1291
- // to ensure the result set is complete. Specifically, we
1292
- // need to generate two records: one record containing the
1293
- // columns from the left table along with nulls for the
1294
- // right table, and another record containing the columns
1295
- // from the right table along with nulls for the left table.
1296
- // For instance, the current output is filtered based on the condition
1297
- // t > 1.
1298
-
1299
- // 1, 1
1300
- // 2, 2
1301
- // 3, 3
1302
-
1303
- // In this scenario, we need to additionally insert a record 1, 1.
1304
- // Subsequently, we will set the values of the columns on the left to
1305
- // null and the values of the columns on the right to null as well. By
1306
- // doing so, we will obtain the final result set.
1307
-
1308
- // 1, null
1309
- // null, 1
1310
- // 2, 2
1311
- // 3, 3
1312
- fullOuterOutput = BaseVector::create<RowVector>(
1313
- output->type (), output->size () + 1 , pool ());
1314
-
1315
- for (auto i = 0 ; i < row + 1 ; ++i) {
1316
- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1317
- fullOuterOutput->childAt (j)->copy (
1318
- output->childAt (j).get (), i, i, 1 );
1329
+ if (!isRightJoin (joinType_)) {
1330
+ if (isFullJoin (joinType_) && isRightJoinForFullOuter) {
1331
+ for (auto & projection : leftProjections_) {
1332
+ auto target = output->childAt (projection.outputChannel );
1333
+ target->setNull (row, true );
1319
1334
}
1320
- }
1321
-
1322
- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1323
- fullOuterOutput->childAt (j)->copy (
1324
- output->childAt (j).get (), row + 1 , row, 1 );
1325
- }
1326
-
1327
- for (auto i = row + 1 ; i < output->size (); ++i) {
1328
- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1329
- fullOuterOutput->childAt (j)->copy (
1330
- output->childAt (j).get (), i + 1 , i, 1 );
1335
+ } else {
1336
+ for (auto & projection : rightProjections_) {
1337
+ auto target = output->childAt (projection.outputChannel );
1338
+ target->setNull (row, true );
1331
1339
}
1332
1340
}
1333
-
1334
- for (auto & projection : leftProjections_) {
1335
- auto & target = fullOuterOutput->childAt (projection.outputChannel );
1336
- target->setNull (row, true );
1337
- }
1338
-
1339
- for (auto & projection : rightProjections_) {
1340
- auto & target = fullOuterOutput->childAt (projection.outputChannel );
1341
- target->setNull (row + 1 , true );
1342
- }
1343
- } else if (!isRightJoin (joinType_)) {
1344
- for (auto & projection : rightProjections_) {
1345
- auto & target = output->childAt (projection.outputChannel );
1346
- target->setNull (row, true );
1347
- }
1348
1341
} else {
1349
1342
for (auto & projection : leftProjections_) {
1350
- auto & target = output->childAt (projection.outputChannel );
1343
+ auto target = output->childAt (projection.outputChannel );
1351
1344
target->setNull (row, true );
1352
1345
}
1353
1346
}
1354
1347
};
1355
1348
1356
1349
auto onMatch = [&](auto row, bool firstMatch) {
1357
- const bool isNonSemiAntiJoin =
1358
- !isSemiFilterJoin (joinType_) && !isAntiJoin (joinType_);
1350
+ const bool isFullLeftJoin =
1351
+ isFullJoin (joinType_) && !joinTracker_->isRightJoinForFullOuter (row);
1352
+
1353
+ const bool isNonSemiAntiFullJoin = !isSemiFilterJoin (joinType_) &&
1354
+ !isAntiJoin (joinType_) && !isFullJoin (joinType_);
1359
1355
1360
- if ((isSemiFilterJoin (joinType_) && firstMatch) || isNonSemiAntiJoin) {
1356
+ if ((isSemiFilterJoin (joinType_) && firstMatch) ||
1357
+ isNonSemiAntiFullJoin || isFullLeftJoin) {
1361
1358
rawIndices[numPassed++] = row;
1362
1359
}
1363
1360
};
@@ -1418,17 +1415,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
1418
1415
1419
1416
if (numPassed == numRows) {
1420
1417
// All rows passed.
1421
- if (fullOuterOutput) {
1422
- return fullOuterOutput;
1423
- }
1424
1418
return output;
1425
1419
}
1426
1420
1427
1421
// Some, but not all rows passed.
1428
- if (fullOuterOutput) {
1429
- return wrap (numPassed, indices, fullOuterOutput);
1430
- }
1431
-
1432
1422
return wrap (numPassed, indices, output);
1433
1423
}
1434
1424
0 commit comments