|
| 1 | +package org.jetbrains.kotlinx.dataframe.io.db |
| 2 | + |
| 3 | +import org.duckdb.DuckDBColumnType |
| 4 | +import org.duckdb.DuckDBColumnType.ARRAY |
| 5 | +import org.duckdb.DuckDBColumnType.BIGINT |
| 6 | +import org.duckdb.DuckDBColumnType.BIT |
| 7 | +import org.duckdb.DuckDBColumnType.BLOB |
| 8 | +import org.duckdb.DuckDBColumnType.BOOLEAN |
| 9 | +import org.duckdb.DuckDBColumnType.DATE |
| 10 | +import org.duckdb.DuckDBColumnType.DECIMAL |
| 11 | +import org.duckdb.DuckDBColumnType.DOUBLE |
| 12 | +import org.duckdb.DuckDBColumnType.ENUM |
| 13 | +import org.duckdb.DuckDBColumnType.FLOAT |
| 14 | +import org.duckdb.DuckDBColumnType.HUGEINT |
| 15 | +import org.duckdb.DuckDBColumnType.INTEGER |
| 16 | +import org.duckdb.DuckDBColumnType.INTERVAL |
| 17 | +import org.duckdb.DuckDBColumnType.JSON |
| 18 | +import org.duckdb.DuckDBColumnType.LIST |
| 19 | +import org.duckdb.DuckDBColumnType.MAP |
| 20 | +import org.duckdb.DuckDBColumnType.SMALLINT |
| 21 | +import org.duckdb.DuckDBColumnType.STRUCT |
| 22 | +import org.duckdb.DuckDBColumnType.TIME |
| 23 | +import org.duckdb.DuckDBColumnType.TIMESTAMP |
| 24 | +import org.duckdb.DuckDBColumnType.TIMESTAMP_MS |
| 25 | +import org.duckdb.DuckDBColumnType.TIMESTAMP_NS |
| 26 | +import org.duckdb.DuckDBColumnType.TIMESTAMP_S |
| 27 | +import org.duckdb.DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE |
| 28 | +import org.duckdb.DuckDBColumnType.TIME_WITH_TIME_ZONE |
| 29 | +import org.duckdb.DuckDBColumnType.TINYINT |
| 30 | +import org.duckdb.DuckDBColumnType.UBIGINT |
| 31 | +import org.duckdb.DuckDBColumnType.UHUGEINT |
| 32 | +import org.duckdb.DuckDBColumnType.UINTEGER |
| 33 | +import org.duckdb.DuckDBColumnType.UNION |
| 34 | +import org.duckdb.DuckDBColumnType.UNKNOWN |
| 35 | +import org.duckdb.DuckDBColumnType.USMALLINT |
| 36 | +import org.duckdb.DuckDBColumnType.UTINYINT |
| 37 | +import org.duckdb.DuckDBColumnType.UUID |
| 38 | +import org.duckdb.DuckDBColumnType.VARCHAR |
| 39 | +import org.duckdb.DuckDBResultSetMetaData |
| 40 | +import org.duckdb.JsonNode |
| 41 | +import org.jetbrains.kotlinx.dataframe.DataFrame |
| 42 | +import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata |
| 43 | +import org.jetbrains.kotlinx.dataframe.io.TableMetadata |
| 44 | +import org.jetbrains.kotlinx.dataframe.io.db.DuckDb.convertSqlTypeToKType |
| 45 | +import org.jetbrains.kotlinx.dataframe.io.getSchemaForAllSqlTables |
| 46 | +import org.jetbrains.kotlinx.dataframe.io.readAllSqlTables |
| 47 | +import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema |
| 48 | +import java.math.BigDecimal |
| 49 | +import java.math.BigInteger |
| 50 | +import java.sql.Array |
| 51 | +import java.sql.Blob |
| 52 | +import java.sql.DatabaseMetaData |
| 53 | +import java.sql.ResultSet |
| 54 | +import java.sql.Struct |
| 55 | +import java.sql.Timestamp |
| 56 | +import java.time.LocalDate |
| 57 | +import java.time.LocalTime |
| 58 | +import java.time.OffsetDateTime |
| 59 | +import java.time.OffsetTime |
| 60 | +import java.util.UUID |
| 61 | +import kotlin.reflect.KType |
| 62 | +import kotlin.reflect.KTypeProjection |
| 63 | +import kotlin.reflect.full.createType |
| 64 | +import kotlin.reflect.full.withNullability |
| 65 | +import kotlin.reflect.typeOf |
| 66 | + |
| 67 | +/** |
| 68 | + * Represents the [DuckDB](http://duckdb.org/) database type. |
| 69 | + * |
| 70 | + * This class provides methods to convert data from a [ResultSet] to the appropriate type for DuckDB, |
| 71 | + * and to generate the corresponding [column schema][ColumnSchema]. |
| 72 | + */ |
| 73 | +public object DuckDb : DbType("duckdb") { |
| 74 | + |
| 75 | + /** the name of the class of the DuckDB JDBC driver */ |
| 76 | + override val driverClassName: String = "org.duckdb.DuckDBDriver" |
| 77 | + |
| 78 | + /** |
| 79 | + * How a column type from JDBC, [tableColumnMetadata], is read in Java/Kotlin. |
| 80 | + * The returned type must exactly follow [ResultSet.getObject] of your specific database's JDBC driver. |
| 81 | + * Returning `null` defer the implementation to the default one (which may not always be correct). |
| 82 | + * |
| 83 | + * Following [org.duckdb.DuckDBVector.getObject]. |
| 84 | + */ |
| 85 | + override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType = |
| 86 | + tableColumnMetadata.sqlTypeName.toKType(tableColumnMetadata.isNullable) |
| 87 | + |
| 88 | + /** |
| 89 | + * How a column from JDBC should be represented as DataFrame (value) column |
| 90 | + * See [convertSqlTypeToKType]. |
| 91 | + */ |
| 92 | + override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema { |
| 93 | + val type = convertSqlTypeToKType(tableColumnMetadata) |
| 94 | + return ColumnSchema.Value(type) |
| 95 | + } |
| 96 | + |
| 97 | + /** |
| 98 | + * Follows exactly [org.duckdb.DuckDBVector.getObject]. |
| 99 | + * |
| 100 | + * "// dataframe-jdbc" is added for all types that are covered correctly by |
| 101 | + * [org.jetbrains.kotlinx.dataframe.io.makeCommonSqlToKTypeMapping] at the moment, however, to cover |
| 102 | + * all nested types, we'll use a full type-map for all [DuckDB types][DuckDBColumnType] exactly. |
| 103 | + */ |
| 104 | + @Suppress("ktlint:standard:blank-line-between-when-conditions") |
| 105 | + internal fun String.toKType(isNullable: Boolean): KType { |
| 106 | + val sqlTypeName = this |
| 107 | + return when (DuckDBResultSetMetaData.TypeNameToType(sqlTypeName)) { |
| 108 | + BOOLEAN -> typeOf<Boolean>() // dataframe-jdbc |
| 109 | + TINYINT -> typeOf<Byte>() |
| 110 | + SMALLINT -> typeOf<Short>() |
| 111 | + INTEGER -> typeOf<Int>() // dataframe-jdbc |
| 112 | + BIGINT -> typeOf<Long>() // dataframe-jdbc |
| 113 | + HUGEINT -> typeOf<BigInteger>() |
| 114 | + UHUGEINT -> typeOf<BigInteger>() |
| 115 | + UTINYINT -> typeOf<Short>() |
| 116 | + USMALLINT -> typeOf<Int>() |
| 117 | + UINTEGER -> typeOf<Long>() |
| 118 | + UBIGINT -> typeOf<BigInteger>() |
| 119 | + FLOAT -> typeOf<Float>() // dataframe-jdbc |
| 120 | + DOUBLE -> typeOf<Double>() // dataframe-jdbc |
| 121 | + DECIMAL -> typeOf<BigDecimal>() // dataframe-jdbc |
| 122 | + TIME -> typeOf<LocalTime>() |
| 123 | + TIME_WITH_TIME_ZONE -> typeOf<OffsetTime>() // dataframe-jdbc |
| 124 | + DATE -> typeOf<LocalDate>() |
| 125 | + TIMESTAMP, TIMESTAMP_MS, TIMESTAMP_NS, TIMESTAMP_S -> typeOf<Timestamp>() // dataframe-jdbc |
| 126 | + TIMESTAMP_WITH_TIME_ZONE -> typeOf<OffsetDateTime>() // dataframe-jdbc |
| 127 | + JSON -> typeOf<JsonNode>() |
| 128 | + BLOB -> typeOf<Blob>() |
| 129 | + UUID -> typeOf<UUID>() |
| 130 | + MAP -> { |
| 131 | + val (key, value) = parseMapTypes(sqlTypeName) |
| 132 | + Map::class.createType( |
| 133 | + listOf( |
| 134 | + KTypeProjection.invariant(key.toKType(false)), |
| 135 | + KTypeProjection.covariant(value.toKType(true)), |
| 136 | + ), |
| 137 | + ) |
| 138 | + } |
| 139 | + |
| 140 | + LIST, ARRAY -> { |
| 141 | + // TODO requires #1266 and #1273 for specific types |
| 142 | + // val listType = parseListType(sqlTypeName) |
| 143 | + // Array::class.createType( |
| 144 | + // listOf(KTypeProjection.covariant(listType.toKType(true))), |
| 145 | + // ) |
| 146 | + typeOf<Array>() |
| 147 | + } |
| 148 | + |
| 149 | + STRUCT -> typeOf<Struct>() // TODO requires #1266 for specific types |
| 150 | + UNION -> typeOf<Any>() // Cannot handle this in Kotlin |
| 151 | + VARCHAR -> typeOf<String>() |
| 152 | + UNKNOWN, BIT, INTERVAL, ENUM -> typeOf<String>() |
| 153 | + }.withNullability(isNullable) |
| 154 | + } |
| 155 | + |
| 156 | + /** Parses "MAP(X, Y)" into "X" and "Y", taking parentheses into account */ |
| 157 | + internal fun parseMapTypes(typeString: String): Pair<String, String> { |
| 158 | + if (!typeString.startsWith("MAP(") || !typeString.endsWith(")")) { |
| 159 | + error("invalid MAP type: $typeString") |
| 160 | + } |
| 161 | + |
| 162 | + val content = typeString.removeSurrounding("MAP(", ")") |
| 163 | + |
| 164 | + // Find the comma that separates key and value types |
| 165 | + var parenCount = 0 |
| 166 | + var commaIndex = -1 |
| 167 | + for (i in content.indices) { |
| 168 | + when (content[i]) { |
| 169 | + '(' -> parenCount++ |
| 170 | + |
| 171 | + ')' -> parenCount-- |
| 172 | + |
| 173 | + ',' -> if (parenCount == 0) { |
| 174 | + commaIndex = i |
| 175 | + break |
| 176 | + } |
| 177 | + } |
| 178 | + } |
| 179 | + |
| 180 | + if (commaIndex == -1) error("invalid MAP type: $typeString") |
| 181 | + val keyType = content.take(commaIndex).trim() |
| 182 | + val valueType = content.substring(commaIndex + 1).trim() |
| 183 | + return Pair(keyType, valueType) |
| 184 | + } |
| 185 | + |
| 186 | + /** Parses "X[]" and "X[123]" into "X", and "X[][]" into "X[]" */ |
| 187 | + internal fun parseListType(typeString: String): String { |
| 188 | + if (!typeString.endsWith("]")) { |
| 189 | + error("invalid LIST/ARRAY type: $typeString") |
| 190 | + } |
| 191 | + |
| 192 | + return typeString.take(typeString.indexOfLast { it == '[' }) |
| 193 | + } |
| 194 | + |
| 195 | + /** |
| 196 | + * How to filter out system tables from user-created ones when using |
| 197 | + * [DataFrame.readAllSqlTables][DataFrame.Companion.readAllSqlTables] and |
| 198 | + * [DataFrame.getSchemaForAllSqlTables][DataFrame.Companion.getSchemaForAllSqlTables]. |
| 199 | + * |
| 200 | + * The names of these can sometimes be found in the specific JDBC integration. |
| 201 | + */ |
| 202 | + override fun isSystemTable(tableMetadata: TableMetadata): Boolean = |
| 203 | + tableMetadata.schemaName?.lowercase()?.contains("information_schema") == true || |
| 204 | + tableMetadata.schemaName?.lowercase()?.contains("system") == true || |
| 205 | + tableMetadata.name.lowercase().contains("system_") |
| 206 | + |
| 207 | + /** |
| 208 | + * How to retrieve the correct table metadata when using |
| 209 | + * [DataFrame.readAllSqlTables][DataFrame.Companion.readAllSqlTables] and |
| 210 | + * [DataFrame.getSchemaForAllSqlTables][DataFrame.Companion.getSchemaForAllSqlTables]. |
| 211 | + * The names of these can be found in the [DatabaseMetaData] implementation of the DuckDB JDBC integration. |
| 212 | + */ |
| 213 | + override fun buildTableMetadata(tables: ResultSet): TableMetadata = |
| 214 | + TableMetadata( |
| 215 | + tables.getString("TABLE_NAME"), |
| 216 | + tables.getString("TABLE_SCHEM"), |
| 217 | + tables.getString("TABLE_CAT"), |
| 218 | + ) |
| 219 | +} |
0 commit comments