|
32 | 32 | import io.trino.plugin.jdbc.JdbcColumnHandle; |
33 | 33 | import io.trino.plugin.jdbc.JdbcExpression; |
34 | 34 | import io.trino.plugin.jdbc.JdbcJoinCondition; |
| 35 | +import io.trino.plugin.jdbc.JdbcSortItem; |
35 | 36 | import io.trino.plugin.jdbc.JdbcTableHandle; |
36 | 37 | import io.trino.plugin.jdbc.JdbcTypeHandle; |
37 | 38 | import io.trino.plugin.jdbc.LongReadFunction; |
|
101 | 102 | import java.util.Spliterator; |
102 | 103 | import java.util.Spliterators; |
103 | 104 | import java.util.function.BiFunction; |
| 105 | +import java.util.stream.Stream; |
104 | 106 | import java.util.stream.StreamSupport; |
105 | 107 |
|
106 | 108 | import static com.google.common.base.Preconditions.checkArgument; |
| 109 | +import static com.google.common.base.Preconditions.checkState; |
107 | 110 | import static com.google.common.base.Strings.emptyToNull; |
108 | 111 | import static com.google.common.base.Verify.verify; |
109 | 112 | import static io.airlift.slice.Slices.utf8Slice; |
|
164 | 167 | import static java.lang.String.join; |
165 | 168 | import static java.util.Locale.ENGLISH; |
166 | 169 | import static java.util.concurrent.TimeUnit.DAYS; |
| 170 | +import static java.util.stream.Collectors.joining; |
167 | 171 |
|
168 | 172 | public class OracleClient |
169 | 173 | extends BaseJdbcClient |
@@ -222,6 +226,17 @@ public class OracleClient |
222 | 226 | .put(TIMESTAMP_TZ_MILLIS, WriteMapping.longMapping("timestamp(3) with time zone", oracleTimestampWithTimeZoneWriteFunction())) |
223 | 227 | .buildOrThrow(); |
224 | 228 |
|
| 229 | + private static final Set<String> ORACLE_COLLATABLE_TYPES = ImmutableSet.<String>builder() |
| 230 | + .add("char") |
| 231 | + .add("nchar") |
| 232 | + .add("varchar") |
| 233 | + .add("varchar2") |
| 234 | + .add("nvarchar") |
| 235 | + .add("nvarchar2") |
| 236 | + .add("clob") |
| 237 | + .add("nclob") |
| 238 | + .build(); |
| 239 | + |
225 | 240 | private final boolean synonymsEnabled; |
226 | 241 | private final ConnectorExpressionRewriter<ParameterizedExpression> connectorExpressionRewriter; |
227 | 242 | private final ProjectFunctionRewriter<JdbcExpression, ParameterizedExpression> projectFunctionRewriter; |
@@ -629,6 +644,72 @@ public boolean isLimitGuaranteed(ConnectorSession session) |
629 | 644 | return true; |
630 | 645 | } |
631 | 646 |
|
| 647 | + @Override |
| 648 | + public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List<JdbcSortItem> sortOrder) |
| 649 | + { |
| 650 | + for (JdbcSortItem sortItem : sortOrder) { |
| 651 | + if (isCollatableInTrino(sortItem.column()) && !isCollatableInOracle(sortItem.column())) { |
| 652 | + // If it is a Trino text type but its Oracle type is NOT collatable, |
| 653 | + // we cannot guarantee correct sorting and must prevent pushdown. |
| 654 | + return false; |
| 655 | + } |
| 656 | + // Non-textual types (numbers, dates, etc.) are safe for TopN pushdown |
| 657 | + // as their sorting is same in Trino and Oracle |
| 658 | + } |
| 659 | + return true; |
| 660 | + } |
| 661 | + |
| 662 | + @Override |
| 663 | + protected Optional<TopNFunction> topNFunction() |
| 664 | + { |
| 665 | + return Optional.of((query, sortItems, limit) -> { |
| 666 | + String orderBy = sortItems.stream() |
| 667 | + .flatMap(sortItem -> { |
| 668 | + String collation = ""; |
| 669 | + if (isCollatableInTrino(sortItem.column())) { |
| 670 | + checkState( |
| 671 | + isCollatableInOracle(sortItem.column()), |
| 672 | + "Column '%s' is collatable in Trino but not in Oracle. Check database configuration.", |
| 673 | + sortItem.column().getColumnName() |
| 674 | + ); |
| 675 | + // Trino encodes all text as UTF-8, and sorts text as Unicode codepoints. |
| 676 | + // In Oracle BINARY collation provides unsigned byte-by-byte comparison. |
| 677 | + // If the Oracle DB uses UTF8 or AL32UTF8 charset, then BINARY collation will match Trino's sorting. |
| 678 | + // If the Oracle DB uses a non UTF8 charset, like WE8ISO8859P1 or WE8MSWIN1252, |
| 679 | + // then Oracle will not sort the same way as in Trino. |
| 680 | + collation = "COLLATE BINARY"; |
| 681 | + } |
| 682 | + String ordering = sortItem.sortOrder().isAscending() ? "ASC" : "DESC"; |
| 683 | + String nullsHandling = switch (sortItem.sortOrder()) { |
| 684 | + // In Oracle both ASC and DESC imply NULLS LAST, but we'll be explicit |
| 685 | + case ASC_NULLS_FIRST, DESC_NULLS_FIRST -> "NULLS FIRST"; |
| 686 | + case ASC_NULLS_LAST, DESC_NULLS_LAST -> "NULLS LAST"; |
| 687 | + }; |
| 688 | + return Stream.of(format("%s %s %s %s", quoted(sortItem.column().getColumnName()), collation, ordering, nullsHandling)); |
| 689 | + }) |
| 690 | + .collect(joining(", ")); |
| 691 | + return format("%s ORDER BY %s FETCH FIRST %s ROWS ONLY", query, orderBy, limit); |
| 692 | + }); |
| 693 | + } |
| 694 | + |
| 695 | + private boolean isCollatableInTrino(JdbcColumnHandle column) |
| 696 | + { |
| 697 | + return column.getColumnType() instanceof CharType || column.getColumnType() instanceof VarcharType; |
| 698 | + } |
| 699 | + |
| 700 | + private boolean isCollatableInOracle(JdbcColumnHandle column) |
| 701 | + { |
| 702 | + String jdbcTypeName = column.getJdbcTypeHandle().jdbcTypeName() |
| 703 | + .orElseThrow(() -> new TrinoException(JDBC_ERROR, "Type name is missing: " + column.getJdbcTypeHandle())); |
| 704 | + return ORACLE_COLLATABLE_TYPES.contains(jdbcTypeName.toLowerCase(ENGLISH)); |
| 705 | + } |
| 706 | + |
| 707 | + @Override |
| 708 | + public boolean isTopNGuaranteed(ConnectorSession session) |
| 709 | + { |
| 710 | + return true; |
| 711 | + } |
| 712 | + |
632 | 713 | @Override |
633 | 714 | protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCondition joinCondition) |
634 | 715 | { |
|
0 commit comments