From 821fe48376e433824315106ee0ad500f324d4131 Mon Sep 17 00:00:00 2001 From: Jayant Singh Date: Tue, 30 Sep 2025 13:42:04 +0530 Subject: [PATCH] Implement lazy loading for inline Arrow results This PR introduces lazy loading support for inline Arrow results to improve memory efficiency when handling large result sets. Previously, InlineChunkProvider would eagerly fetch all arrow batches upfront when results had hasMoreRows = true, which could lead to memory issues with large datasets. This change splits the handling into two separate paths: 1. Lazy path (new): For Thrift-based inline Arrow results (when ARROW_BASED_SET is returned), we now use LazyThriftInlineArrowResult which fetches arrow batches on-demand as the client iterates through rows. This is similar to how LazyThriftResult works for columnar data. 2. Remote path (existing): For URL-based Arrow results (URL_BASED_SET), we continue using ArrowStreamResult with RemoteChunkProvider which downloads chunks from cloud storage. The InlineChunkProvider is now only used for SEA results with JSON_ARRAY format and INLINE disposition (contain all data inline {no hasMoreRows flag set}). This should reduce memory consumption and improve performance when dealing with large inline Arrow result sets. --- .../jdbc/api/impl/ExecutionResultFactory.java | 5 +- .../api/impl/arrow/ArrowStreamResult.java | 92 ++-- .../api/impl/arrow/InlineChunkProvider.java | 122 ----- .../arrow/LazyThriftInlineArrowResult.java | 425 ++++++++++++++++++ .../api/impl/ExecutionResultFactoryTest.java | 7 +- .../api/impl/arrow/ArrowStreamResultTest.java | 21 +- .../impl/arrow/InlineChunkProviderTest.java | 42 -- .../LazyThriftInlineArrowResultTest.java | 285 ++++++++++++ 8 files changed, 771 insertions(+), 228 deletions(-) create mode 100644 src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java create mode 100644 src/test/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResultTest.java diff --git a/src/main/java/com/databricks/jdbc/api/impl/ExecutionResultFactory.java b/src/main/java/com/databricks/jdbc/api/impl/ExecutionResultFactory.java index 4c719731d..ba6d7acb7 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/ExecutionResultFactory.java +++ b/src/main/java/com/databricks/jdbc/api/impl/ExecutionResultFactory.java @@ -1,6 +1,7 @@ package com.databricks.jdbc.api.impl; import com.databricks.jdbc.api.impl.arrow.ArrowStreamResult; +import com.databricks.jdbc.api.impl.arrow.LazyThriftInlineArrowResult; import com.databricks.jdbc.api.impl.volume.VolumeOperationResult; import com.databricks.jdbc.api.internal.IDatabricksSession; import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; @@ -96,9 +97,9 @@ private static IExecutionResult getResultHandler( case COLUMN_BASED_SET: return new LazyThriftResult(resultsResp, parentStatement, session); case ARROW_BASED_SET: - return new ArrowStreamResult(resultsResp, true, parentStatement, session); + return new LazyThriftInlineArrowResult(resultsResp, parentStatement, session); case URL_BASED_SET: - return new ArrowStreamResult(resultsResp, false, parentStatement, session); + return new ArrowStreamResult(resultsResp, parentStatement, session); case ROW_BASED_SET: throw new DatabricksSQLFeatureNotSupportedException( "Invalid state - row based set cannot be received"); diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java index 29a88fd6b..4e011301e 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java @@ -85,13 +85,11 @@ public ArrowStreamResult( public ArrowStreamResult( TFetchResultsResp resultsResp, - boolean isInlineArrow, IDatabricksStatementInternal parentStatementId, IDatabricksSession session) throws DatabricksSQLException { this( resultsResp, - isInlineArrow, parentStatementId, session, DatabricksHttpClientFactory.getInstance().getClient(session.getConnectionContext())); @@ -100,27 +98,22 @@ public ArrowStreamResult( @VisibleForTesting ArrowStreamResult( TFetchResultsResp resultsResp, - boolean isInlineArrow, IDatabricksStatementInternal parentStatement, IDatabricksSession session, IDatabricksHttpClient httpClient) throws DatabricksSQLException { this.session = session; setColumnInfo(resultsResp.getResultSetMetadata()); - if (isInlineArrow) { - this.chunkProvider = new InlineChunkProvider(resultsResp, parentStatement, session); - } else { - CompressionCodec compressionCodec = - CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata()); - this.chunkProvider = - new RemoteChunkProvider( - parentStatement, - resultsResp, - session, - httpClient, - session.getConnectionContext().getCloudFetchThreadPoolSize(), - compressionCodec); - } + CompressionCodec compressionCodec = + CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata()); + this.chunkProvider = + new RemoteChunkProvider( + parentStatement, + resultsResp, + session, + httpClient, + session.getConnectionContext().getCloudFetchThreadPoolSize(), + compressionCodec); } public List getArrowMetadata() throws DatabricksSQLException { @@ -133,30 +126,15 @@ public List getArrowMetadata() throws DatabricksSQLException { /** {@inheritDoc} */ @Override public Object getObject(int columnIndex) throws DatabricksSQLException { - ColumnInfoTypeName requiredType = columnInfos.get(columnIndex).getTypeName(); + ColumnInfo columnInfo = columnInfos.get(columnIndex); + ColumnInfoTypeName requiredType = columnInfo.getTypeName(); String arrowMetadata = chunkIterator.getType(columnIndex); if (arrowMetadata == null) { - arrowMetadata = columnInfos.get(columnIndex).getTypeText(); - } - - // Handle complex type conversion when complex datatype support is disabled - boolean isComplexDatatypeSupportEnabled = - this.session.getConnectionContext().isComplexDatatypeSupportEnabled(); - if (!isComplexDatatypeSupportEnabled && isComplexType(requiredType)) { - LOGGER.debug("Complex datatype support is disabled, converting complex type to STRING"); - - Object result = - chunkIterator.getColumnObjectAtCurrentRow( - columnIndex, ColumnInfoTypeName.STRING, "STRING", columnInfos.get(columnIndex)); - if (result == null) { - return null; - } - ComplexDataTypeParser parser = new ComplexDataTypeParser(); - return parser.formatComplexTypeString(result.toString(), requiredType.name(), arrowMetadata); + arrowMetadata = columnInfo.getTypeText(); } - return chunkIterator.getColumnObjectAtCurrentRow( - columnIndex, requiredType, arrowMetadata, columnInfos.get(columnIndex)); + return getObjectWithComplexTypeHandling( + session, chunkIterator, columnIndex, requiredType, arrowMetadata, columnInfo); } /** @@ -237,4 +215,44 @@ private void setColumnInfo(TGetResultSetMetadataResp resultManifest) { columnInfos.add(getColumnInfoFromTColumnDesc(tColumnDesc)); } } + + /** + * Helper method to handle complex type conversion when complex datatype support is disabled. + * + * @param session The databricks session + * @param chunkIterator The chunk iterator + * @param columnIndex The column index + * @param requiredType The required column type + * @param arrowMetadata The arrow metadata + * @param columnInfo The column info + * @return The object value (converted if complex type and support disabled) + * @throws DatabricksSQLException if an error occurs + */ + protected static Object getObjectWithComplexTypeHandling( + IDatabricksSession session, + ArrowResultChunkIterator chunkIterator, + int columnIndex, + ColumnInfoTypeName requiredType, + String arrowMetadata, + ColumnInfo columnInfo) + throws DatabricksSQLException { + boolean isComplexDatatypeSupportEnabled = + session.getConnectionContext().isComplexDatatypeSupportEnabled(); + + if (!isComplexDatatypeSupportEnabled && isComplexType(requiredType)) { + LOGGER.debug("Complex datatype support is disabled, converting complex type to STRING"); + Object result = + chunkIterator.getColumnObjectAtCurrentRow( + columnIndex, ColumnInfoTypeName.STRING, "STRING", columnInfo); + if (result == null) { + return null; + } + ComplexDataTypeParser parser = new ComplexDataTypeParser(); + + return parser.formatComplexTypeString(result.toString(), requiredType.name(), arrowMetadata); + } + + return chunkIterator.getColumnObjectAtCurrentRow( + columnIndex, requiredType, arrowMetadata, columnInfo); + } } diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java index e22d974a4..32f5e1b80 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java @@ -1,31 +1,17 @@ package com.databricks.jdbc.api.impl.arrow; -import static com.databricks.jdbc.common.util.DatabricksTypeUtil.*; import static com.databricks.jdbc.common.util.DecompressionUtil.decompress; -import com.databricks.jdbc.api.internal.IDatabricksSession; -import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; import com.databricks.jdbc.common.CompressionCodec; import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; import com.databricks.jdbc.log.JdbcLogger; import com.databricks.jdbc.log.JdbcLoggerFactory; -import com.databricks.jdbc.model.client.thrift.generated.*; import com.databricks.jdbc.model.core.ResultData; import com.databricks.jdbc.model.core.ResultManifest; import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; import com.google.common.annotations.VisibleForTesting; import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.List; -import org.apache.arrow.vector.types.pojo.ArrowType; -import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.types.pojo.FieldType; -import org.apache.arrow.vector.types.pojo.Schema; -import org.apache.arrow.vector.util.SchemaUtility; /** Class to manage inline Arrow chunks */ public class InlineChunkProvider implements ChunkProvider { @@ -37,23 +23,6 @@ public class InlineChunkProvider implements ChunkProvider { private final ArrowResultChunk arrowResultChunk; // There is only one packet of data in case of inline arrow - InlineChunkProvider( - TFetchResultsResp resultsResp, - IDatabricksStatementInternal parentStatement, - IDatabricksSession session) - throws DatabricksParsingException { - this.currentChunkIndex = -1; - this.totalRows = 0; - ByteArrayInputStream byteStream = initializeByteStream(resultsResp, session, parentStatement); - ArrowResultChunk.Builder builder = - ArrowResultChunk.builder().withInputStream(byteStream, totalRows); - - if (parentStatement != null) { - builder.withStatementId(parentStatement.getStatementId()); - } - arrowResultChunk = builder.build(); - } - /** * Constructor for inline arrow chunk provider from {@link ResultData} and {@link ResultManifest}. * @@ -123,97 +92,6 @@ public boolean isClosed() { return isClosed; } - private ByteArrayInputStream initializeByteStream( - TFetchResultsResp resultsResp, - IDatabricksSession session, - IDatabricksStatementInternal parentStatement) - throws DatabricksParsingException { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - CompressionCodec compressionType = - CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata()); - try { - byte[] serializedSchema = getSerializedSchema(resultsResp.getResultSetMetadata()); - if (serializedSchema != null) { - baos.write(serializedSchema); - } - writeToByteOutputStream( - compressionType, parentStatement, resultsResp.getResults().getArrowBatches(), baos); - while (resultsResp.hasMoreRows) { - resultsResp = session.getDatabricksClient().getMoreResults(parentStatement); - writeToByteOutputStream( - compressionType, parentStatement, resultsResp.getResults().getArrowBatches(), baos); - } - return new ByteArrayInputStream(baos.toByteArray()); - } catch (DatabricksSQLException | IOException e) { - handleError(e); - } - return null; - } - - private void writeToByteOutputStream( - CompressionCodec compressionCodec, - IDatabricksStatementInternal parentStatement, - List arrowBatchList, - ByteArrayOutputStream baos) - throws DatabricksSQLException, IOException { - for (TSparkArrowBatch arrowBatch : arrowBatchList) { - byte[] decompressedBytes = - decompress( - arrowBatch.getBatch(), - compressionCodec, - String.format( - "Data fetch for inline arrow batch [%d] and statement [%s] with decompression algorithm : [%s]", - arrowBatch.getRowCount(), parentStatement, compressionCodec)); - totalRows += arrowBatch.getRowCount(); - baos.write(decompressedBytes); - } - } - - private byte[] getSerializedSchema(TGetResultSetMetadataResp metadata) - throws DatabricksSQLException { - if (metadata.getArrowSchema() != null) { - return metadata.getArrowSchema(); - } - Schema arrowSchema = hiveSchemaToArrowSchema(metadata.getSchema()); - try { - return SchemaUtility.serialize(arrowSchema); - } catch (IOException e) { - handleError(e); - } - // should never reach here; - return null; - } - - private Schema hiveSchemaToArrowSchema(TTableSchema hiveSchema) - throws DatabricksParsingException { - List fields = new ArrayList<>(); - if (hiveSchema == null) { - return new Schema(fields); - } - try { - hiveSchema - .getColumns() - .forEach( - columnDesc -> { - try { - fields.add(getArrowField(columnDesc)); - } catch (SQLException e) { - throw new RuntimeException(e); - } - }); - } catch (RuntimeException e) { - handleError(e); - } - return new Schema(fields); - } - - private Field getArrowField(TColumnDesc columnDesc) throws SQLException { - TPrimitiveTypeEntry primitiveTypeEntry = getTPrimitiveTypeOrDefault(columnDesc.getTypeDesc()); - ArrowType arrowType = mapThriftToArrowType(primitiveTypeEntry.getType()); - FieldType fieldType = new FieldType(true, arrowType, null); - return new Field(columnDesc.getColumnName(), fieldType, null); - } - @VisibleForTesting void handleError(Exception e) throws DatabricksParsingException { String errorMessage = diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java new file mode 100644 index 000000000..08950339c --- /dev/null +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java @@ -0,0 +1,425 @@ +package com.databricks.jdbc.api.impl.arrow; + +import static com.databricks.jdbc.common.EnvironmentVariables.DEFAULT_RESULT_ROW_LIMIT; +import static com.databricks.jdbc.common.util.DatabricksTypeUtil.*; +import static com.databricks.jdbc.common.util.DecompressionUtil.decompress; + +import com.databricks.jdbc.api.impl.IExecutionResult; +import com.databricks.jdbc.api.internal.IDatabricksSession; +import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; +import com.databricks.jdbc.common.CompressionCodec; +import com.databricks.jdbc.exception.DatabricksParsingException; +import com.databricks.jdbc.exception.DatabricksSQLException; +import com.databricks.jdbc.log.JdbcLogger; +import com.databricks.jdbc.log.JdbcLoggerFactory; +import com.databricks.jdbc.model.client.thrift.generated.*; +import com.databricks.jdbc.model.core.ColumnInfo; +import com.databricks.jdbc.model.core.ColumnInfoTypeName; +import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; +import com.google.common.annotations.VisibleForTesting; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.SchemaUtility; + +/** + * Lazy implementation for thrift-based inline Arrow results that fetches arrow batches on demand. + * Similar to LazyThriftResult but processes Arrow data instead of columnar thrift data. + */ +public class LazyThriftInlineArrowResult implements IExecutionResult { + private static final JdbcLogger LOGGER = + JdbcLoggerFactory.getLogger(LazyThriftInlineArrowResult.class); + + private TFetchResultsResp currentResponse; + private ArrowResultChunk currentChunk; + private ArrowResultChunkIterator currentChunkIterator; + private long globalRowIndex; + private final IDatabricksSession session; + private final IDatabricksStatementInternal statement; + private final int maxRows; + private boolean hasReachedEnd; + private boolean isClosed; + private long totalRowsFetched; + private List columnInfos; + + /** + * Creates a new LazyThriftInlineArrowResult that lazily fetches arrow data on demand. + * + * @param initialResponse the initial response from the server + * @param statement the statement that generated this result + * @param session the session to use for fetching additional data + * @throws DatabricksSQLException if the initial response cannot be processed + */ + public LazyThriftInlineArrowResult( + TFetchResultsResp initialResponse, + IDatabricksStatementInternal statement, + IDatabricksSession session) + throws DatabricksSQLException { + this.currentResponse = initialResponse; + this.statement = statement; + this.session = session; + this.maxRows = statement != null ? statement.getMaxRows() : DEFAULT_RESULT_ROW_LIMIT; + this.globalRowIndex = -1; + this.hasReachedEnd = false; + this.isClosed = false; + this.totalRowsFetched = 0; + + // Initialize column info from metadata + setColumnInfo(initialResponse.getResultSetMetadata()); + + // Load initial chunk + loadCurrentChunk(); + LOGGER.debug( + "LazyThriftInlineArrowResult initialized with {} rows in first chunk, hasMoreRows: {}", + currentChunk.numRows, + currentResponse.hasMoreRows); + } + + /** + * Gets the value at the specified column index for the current row. + * + * @param columnIndex the zero-based column index + * @return the value at the specified column + * @throws DatabricksSQLException if the result is closed, cursor is invalid, or column index is + * out of bounds + */ + @Override + public Object getObject(int columnIndex) throws DatabricksSQLException { + if (isClosed) { + throw new DatabricksSQLException( + "Result is already closed", DatabricksDriverErrorCode.STATEMENT_CLOSED); + } + if (globalRowIndex == -1) { + throw new DatabricksSQLException( + "Cursor is before first row", DatabricksDriverErrorCode.INVALID_STATE); + } + if (currentChunkIterator == null) { + throw new DatabricksSQLException( + "No current chunk available", DatabricksDriverErrorCode.INVALID_STATE); + } + if (columnIndex < 0 || columnIndex >= columnInfos.size()) { + throw new DatabricksSQLException( + "Column index out of bounds " + columnIndex, DatabricksDriverErrorCode.INVALID_STATE); + } + + ColumnInfo columnInfo = columnInfos.get(columnIndex); + ColumnInfoTypeName requiredType = columnInfo.getTypeName(); + String arrowMetadata = currentChunkIterator.getType(columnIndex); + if (arrowMetadata == null) { + arrowMetadata = columnInfo.getTypeText(); + } + + return ArrowStreamResult.getObjectWithComplexTypeHandling( + session, currentChunkIterator, columnIndex, requiredType, arrowMetadata, columnInfo); + } + + /** + * Gets the current row index (0-based). Returns -1 if before the first row. + * + * @return the current row index + */ + @Override + public long getCurrentRow() { + return globalRowIndex; + } + + /** + * Moves the cursor to the next row. Fetches additional data from server if needed. + * + * @return true if there is a next row, false if at the end + * @throws DatabricksSQLException if an error occurs while fetching data + */ + @Override + public boolean next() throws DatabricksSQLException { + if (isClosed || hasReachedEnd) { + return false; + } + + if (!hasNext()) { + return false; + } + + // Check if we've reached the maxRows limit + boolean hasRowLimit = maxRows > 0; + if (hasRowLimit && globalRowIndex + 1 >= maxRows) { + hasReachedEnd = true; + return false; + } + + // Try to advance in current chunk + if (currentChunkIterator != null && currentChunkIterator.hasNextRow()) { + boolean advanced = currentChunkIterator.nextRow(); + if (advanced) { + globalRowIndex++; + return true; + } + } + + // Need to fetch next chunk + while (currentResponse.hasMoreRows) { + fetchNextChunk(); + + // If we got a chunk with data, advance to first row + if (currentChunkIterator != null && currentChunkIterator.hasNextRow()) { + boolean advanced = currentChunkIterator.nextRow(); + if (advanced) { + globalRowIndex++; + return true; + } + } + } + + // No more data available + hasReachedEnd = true; + return false; + } + + /** + * Checks if there are more rows available without advancing the cursor. + * + * @return true if there are more rows, false otherwise + */ + @Override + public boolean hasNext() { + if (isClosed || hasReachedEnd) { + return false; + } + + // Check maxRows limit + boolean hasRowLimit = maxRows > 0; + if (hasRowLimit && globalRowIndex + 1 >= maxRows) { + return false; + } + + // Check if there are more rows in current chunk + if (currentChunkIterator != null && currentChunkIterator.hasNextRow()) { + return true; + } + + // Check if there are more chunks to fetch + return currentResponse.hasMoreRows; + } + + /** Closes this result and releases associated resources. */ + @Override + public void close() { + this.isClosed = true; + if (currentChunk != null) { + currentChunk.releaseChunk(); + } + this.currentChunk = null; + this.currentChunkIterator = null; + this.currentResponse = null; + LOGGER.debug( + "LazyThriftInlineArrowResult closed after fetching {} total rows", totalRowsFetched); + } + + /** + * Gets the number of rows in the current chunk. + * + * @return the number of rows in the current chunk + */ + @Override + public long getRowCount() { + return currentChunk != null ? currentChunk.numRows : 0; + } + + /** + * Gets the chunk count. Always returns 0 for lazy thrift inline arrow results. + * + * @return 0 (lazy results don't use chunks in the same sense as buffered results) + */ + @Override + public long getChunkCount() { + return 0; + } + + private void loadCurrentChunk() throws DatabricksSQLException { + try { + ByteArrayInputStream byteStream = createArrowByteStream(currentResponse); + long rowCount = getTotalRowsInResponse(currentResponse); + + ArrowResultChunk.Builder builder = + ArrowResultChunk.builder().withInputStream(byteStream, rowCount); + + if (statement != null) { + builder.withStatementId(statement.getStatementId()); + } + + currentChunk = builder.build(); + currentChunkIterator = currentChunk.getChunkIterator(); + totalRowsFetched += rowCount; + + LOGGER.debug( + "Loaded arrow chunk with {} rows, total fetched: {}", rowCount, totalRowsFetched); + } catch (DatabricksParsingException e) { + LOGGER.error("Failed to load current chunk: {}", e.getMessage()); + hasReachedEnd = true; + throw new DatabricksSQLException( + "Failed to process arrow data", DatabricksDriverErrorCode.INLINE_CHUNK_PARSING_ERROR); + } + } + + /** + * Fetches the next chunk of data from the server and creates arrow chunks. + * + * @throws DatabricksSQLException if the fetch operation fails + */ + private void fetchNextChunk() throws DatabricksSQLException { + try { + LOGGER.debug("Fetching next arrow chunk, current total rows fetched: {}", totalRowsFetched); + currentResponse = session.getDatabricksClient().getMoreResults(statement); + + // Release previous chunk to free memory + if (currentChunk != null) { + currentChunk.releaseChunk(); + } + + loadCurrentChunk(); + + LOGGER.debug( + "Fetched arrow chunk with {} rows, hasMoreRows: {}", + currentChunk.numRows, + currentResponse.hasMoreRows); + } catch (DatabricksSQLException e) { + LOGGER.error("Failed to fetch next arrow chunk: {}", e.getMessage()); + hasReachedEnd = true; + throw e; + } + } + + private ByteArrayInputStream createArrowByteStream(TFetchResultsResp resultsResp) + throws DatabricksParsingException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + CompressionCodec compressionType = + CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata()); + try { + byte[] serializedSchema = getSerializedSchema(resultsResp.getResultSetMetadata()); + if (serializedSchema != null) { + baos.write(serializedSchema); + } + writeArrowBatchesToStream(compressionType, resultsResp.getResults().getArrowBatches(), baos); + return new ByteArrayInputStream(baos.toByteArray()); + } catch (DatabricksSQLException | IOException e) { + handleError(e); + } + return null; + } + + private void writeArrowBatchesToStream( + CompressionCodec compressionCodec, + List arrowBatchList, + ByteArrayOutputStream baos) + throws DatabricksSQLException, IOException { + for (TSparkArrowBatch arrowBatch : arrowBatchList) { + byte[] decompressedBytes = + decompress( + arrowBatch.getBatch(), + compressionCodec, + String.format( + "Data fetch for lazy inline arrow batch [%d] and statement [%s] with decompression algorithm : [%s]", + arrowBatch.getRowCount(), statement, compressionCodec)); + baos.write(decompressedBytes); + } + } + + private long getTotalRowsInResponse(TFetchResultsResp resultsResp) { + long totalRows = 0; + if (resultsResp.getResults() != null && resultsResp.getResults().getArrowBatches() != null) { + for (TSparkArrowBatch arrowBatch : resultsResp.getResults().getArrowBatches()) { + totalRows += arrowBatch.getRowCount(); + } + } + return totalRows; + } + + private byte[] getSerializedSchema(TGetResultSetMetadataResp metadata) + throws DatabricksSQLException { + if (metadata.getArrowSchema() != null) { + return metadata.getArrowSchema(); + } + Schema arrowSchema = hiveSchemaToArrowSchema(metadata.getSchema()); + try { + return SchemaUtility.serialize(arrowSchema); + } catch (IOException e) { + handleError(e); + } + return null; + } + + private Schema hiveSchemaToArrowSchema(TTableSchema hiveSchema) + throws DatabricksParsingException { + List fields = new ArrayList<>(); + if (hiveSchema == null) { + return new Schema(fields); + } + try { + hiveSchema + .getColumns() + .forEach( + columnDesc -> { + try { + fields.add(getArrowField(columnDesc)); + } catch (SQLException e) { + throw new RuntimeException(e); + } + }); + } catch (RuntimeException e) { + handleError(e); + } + return new Schema(fields); + } + + private Field getArrowField(TColumnDesc columnDesc) throws SQLException { + TPrimitiveTypeEntry primitiveTypeEntry = getTPrimitiveTypeOrDefault(columnDesc.getTypeDesc()); + ArrowType arrowType = mapThriftToArrowType(primitiveTypeEntry.getType()); + FieldType fieldType = new FieldType(true, arrowType, null); + return new Field(columnDesc.getColumnName(), fieldType, null); + } + + private void setColumnInfo(TGetResultSetMetadataResp resultManifest) { + columnInfos = new ArrayList<>(); + if (resultManifest.getSchema() == null) { + return; + } + for (TColumnDesc tColumnDesc : resultManifest.getSchema().getColumns()) { + columnInfos.add( + com.databricks.jdbc.common.util.DatabricksThriftUtil.getColumnInfoFromTColumnDesc( + tColumnDesc)); + } + } + + @VisibleForTesting + void handleError(Exception e) throws DatabricksParsingException { + String errorMessage = + String.format("Cannot process lazy thrift inline arrow format. Error: %s", e.getMessage()); + LOGGER.error(errorMessage); + throw new DatabricksParsingException( + errorMessage, e, DatabricksDriverErrorCode.INLINE_CHUNK_PARSING_ERROR); + } + + /** + * Gets the total number of rows fetched from the server so far. + * + * @return the total number of rows fetched from the server + */ + public long getTotalRowsFetched() { + return totalRowsFetched; + } + + /** + * Checks if all data has been fetched from the server. + * + * @return true if all data has been fetched (either reached end or maxRows limit) + */ + public boolean isCompletelyFetched() { + return hasReachedEnd || !currentResponse.hasMoreRows; + } +} diff --git a/src/test/java/com/databricks/jdbc/api/impl/ExecutionResultFactoryTest.java b/src/test/java/com/databricks/jdbc/api/impl/ExecutionResultFactoryTest.java index 2efb1e33a..1e1461592 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/ExecutionResultFactoryTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/ExecutionResultFactoryTest.java @@ -1,10 +1,10 @@ package com.databricks.jdbc.api.impl; -import static com.databricks.jdbc.TestConstants.ARROW_BATCH_LIST; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.when; import com.databricks.jdbc.api.impl.arrow.ArrowStreamResult; +import com.databricks.jdbc.api.impl.arrow.LazyThriftInlineArrowResult; import com.databricks.jdbc.api.impl.volume.VolumeOperationResult; import com.databricks.jdbc.api.internal.IDatabricksConnectionContext; import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; @@ -128,14 +128,11 @@ public void testGetResultSet_thriftURL() throws SQLException { @Test public void testGetResultSet_thriftInlineArrow() throws SQLException { - when(connectionContext.getConnectionUuid()).thenReturn("sample-uuid"); when(resultSetMetadataResp.getResultFormat()).thenReturn(TSparkRowSetType.ARROW_BASED_SET); when(fetchResultsResp.getResultSetMetadata()).thenReturn(resultSetMetadataResp); when(fetchResultsResp.getResults()).thenReturn(tRowSet); - when(session.getConnectionContext()).thenReturn(connectionContext); - when(tRowSet.getArrowBatches()).thenReturn(ARROW_BATCH_LIST); IExecutionResult result = ExecutionResultFactory.getResultSet(fetchResultsResp, session, parentStatement); - assertInstanceOf(ArrowStreamResult.class, result); + assertInstanceOf(LazyThriftInlineArrowResult.class, result); } } diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java index 5f42fbdf1..9f2eb213a 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java @@ -133,25 +133,6 @@ public void testIteration() throws Exception { assertFalse(result.next()); } - @Test - public void testInlineArrow() throws DatabricksSQLException { - IDatabricksConnectionContext connectionContext = - DatabricksConnectionContextFactory.create(JDBC_URL, new Properties()); - when(session.getConnectionContext()).thenReturn(connectionContext); - when(metadataResp.getSchema()).thenReturn(TEST_TABLE_SCHEMA); - when(fetchResultsResp.getResults()).thenReturn(resultData); - when(fetchResultsResp.getResultSetMetadata()).thenReturn(metadataResp); - ArrowStreamResult result = - new ArrowStreamResult(fetchResultsResp, true, parentStatement, session); - assertEquals(-1, result.getCurrentRow()); - assertTrue(result.hasNext()); - assertFalse(result.next()); - assertEquals(0, result.getCurrentRow()); - assertFalse(result.hasNext()); - assertDoesNotThrow(result::close); - assertFalse(result.hasNext()); - } - @Test public void testCloudFetchArrow() throws Exception { IDatabricksConnectionContext connectionContext = @@ -164,7 +145,7 @@ public void testCloudFetchArrow() throws Exception { when(fetchResultsResp.getResultSetMetadata()).thenReturn(metadataResp); when(parentStatement.getStatementId()).thenReturn(STATEMENT_ID); ArrowStreamResult result = - new ArrowStreamResult(fetchResultsResp, false, parentStatement, session, mockHttpClient); + new ArrowStreamResult(fetchResultsResp, parentStatement, session, mockHttpClient); assertEquals(-1, result.getCurrentRow()); assertTrue(result.hasNext()); assertDoesNotThrow(result::close); diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProviderTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProviderTest.java index 86be512d4..8392daf68 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProviderTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProviderTest.java @@ -1,27 +1,17 @@ package com.databricks.jdbc.api.impl.arrow; -import static com.databricks.jdbc.TestConstants.ARROW_BATCH_LIST; -import static com.databricks.jdbc.TestConstants.TEST_TABLE_SCHEMA; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import com.databricks.jdbc.api.internal.IDatabricksSession; -import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; import com.databricks.jdbc.common.CompressionCodec; -import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; -import com.databricks.jdbc.model.client.thrift.generated.TFetchResultsResp; -import com.databricks.jdbc.model.client.thrift.generated.TGetResultSetMetadataResp; -import com.databricks.jdbc.model.client.thrift.generated.TRowSet; -import com.databricks.jdbc.model.client.thrift.generated.TSparkArrowBatch; import com.databricks.jdbc.model.core.ColumnInfo; import com.databricks.jdbc.model.core.ColumnInfoTypeName; import com.databricks.jdbc.model.core.ResultData; import com.databricks.jdbc.model.core.ResultManifest; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.util.Collections; import net.jpountz.lz4.LZ4FrameOutputStream; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -37,41 +27,9 @@ public class InlineChunkProviderTest { private static final long TOTAL_ROWS = 2L; - @Mock TGetResultSetMetadataResp metadata; - @Mock TFetchResultsResp fetchResultsResp; - @Mock IDatabricksStatementInternal parentStatement; - @Mock IDatabricksSession session; @Mock private ResultData mockResultData; @Mock private ResultManifest mockResultManifest; - @Test - void testInitialisation() throws DatabricksParsingException { - when(fetchResultsResp.getResultSetMetadata()).thenReturn(metadata); - when(metadata.getArrowSchema()).thenReturn(null); - when(metadata.getSchema()).thenReturn(TEST_TABLE_SCHEMA); - when(fetchResultsResp.getResults()).thenReturn(new TRowSet().setArrowBatches(ARROW_BATCH_LIST)); - when(metadata.isSetLz4Compressed()).thenReturn(false); - InlineChunkProvider inlineChunkProvider = - new InlineChunkProvider(fetchResultsResp, parentStatement, session); - assertTrue(inlineChunkProvider.hasNextChunk()); - assertTrue(inlineChunkProvider.next()); - assertFalse(inlineChunkProvider.next()); - } - - @Test - void handleErrorTest() throws DatabricksParsingException { - TSparkArrowBatch arrowBatch = - new TSparkArrowBatch().setRowCount(0).setBatch(new byte[] {65, 66, 67}); - when(fetchResultsResp.getResultSetMetadata()).thenReturn(metadata); - when(fetchResultsResp.getResults()) - .thenReturn(new TRowSet().setArrowBatches(Collections.singletonList(arrowBatch))); - InlineChunkProvider inlineChunkProvider = - new InlineChunkProvider(fetchResultsResp, parentStatement, session); - assertThrows( - DatabricksParsingException.class, - () -> inlineChunkProvider.handleError(new RuntimeException())); - } - @Test void testConstructorSuccessfulCreation() throws DatabricksSQLException, IOException { // Create valid Arrow data with two rows and one column: [1, 2] diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResultTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResultTest.java new file mode 100644 index 000000000..9c43d3e78 --- /dev/null +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResultTest.java @@ -0,0 +1,285 @@ +package com.databricks.jdbc.api.impl.arrow; + +import static com.databricks.jdbc.TestConstants.TEST_TABLE_SCHEMA; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +import com.databricks.jdbc.api.internal.IDatabricksSession; +import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; +import com.databricks.jdbc.dbclient.impl.common.StatementId; +import com.databricks.jdbc.exception.DatabricksParsingException; +import com.databricks.jdbc.exception.DatabricksSQLException; +import com.databricks.jdbc.model.client.thrift.generated.*; +import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; +import java.io.IOException; +import java.util.Collections; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +public class LazyThriftInlineArrowResultTest { + + @Mock private IDatabricksSession session; + @Mock private IDatabricksStatementInternal statement; + private static final StatementId STATEMENT_ID = new StatementId("test_statement_id"); + private static final byte[] DUMMY_ARROW_BYTES = new byte[] {65, 66, 67}; + + private TFetchResultsResp createFetchResultsResp(int rowCount, boolean hasMoreRows) { + TSparkArrowBatch arrowBatch = + new TSparkArrowBatch().setRowCount(rowCount).setBatch(DUMMY_ARROW_BYTES); + TRowSet rowSet = new TRowSet().setArrowBatches(Collections.singletonList(arrowBatch)); + + TGetResultSetMetadataResp metadata = + new TGetResultSetMetadataResp().setSchema(TEST_TABLE_SCHEMA); + + TFetchResultsResp response = + new TFetchResultsResp().setResultSetMetadata(metadata).setResults(rowSet); + response.hasMoreRows = hasMoreRows; + + return response; + } + + @Test + void testConstructorInitializesCorrectly() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertEquals(-1, result.getCurrentRow()); + assertEquals(0, result.getRowCount()); + assertEquals(0, result.getTotalRowsFetched()); + assertFalse(result.hasNext()); + assertTrue(result.isCompletelyFetched()); + } + + @Test + void testGetObjectThrowsWhenClosed() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + result.close(); + + DatabricksSQLException exception = + assertThrows(DatabricksSQLException.class, () -> result.getObject(0)); + assertEquals("Result is already closed", exception.getMessage()); + assertEquals(DatabricksDriverErrorCode.STATEMENT_CLOSED.name(), exception.getSQLState()); + } + + @Test + void testGetObjectThrowsWhenBeforeFirstRow() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + DatabricksSQLException exception = + assertThrows(DatabricksSQLException.class, () -> result.getObject(0)); + assertEquals("Cursor is before first row", exception.getMessage()); + assertEquals(DatabricksDriverErrorCode.INVALID_STATE.name(), exception.getSQLState()); + } + + @Test + void testCloseReleasesResources() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + result.close(); + + assertFalse(result.hasNext()); + assertFalse(result.next()); + } + + @Test + void testIsCompletelyFetchedWhenNoMoreRows() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertTrue(result.isCompletelyFetched()); + } + + @Test + void testIsCompletelyFetchedWithMoreRows() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, true); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertFalse(result.isCompletelyFetched()); + } + + @Test + void testGetChunkCount() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertEquals(0, result.getChunkCount()); + } + + @Test + void testHandleErrorThrowsParsingException() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + Exception testException = new IOException("Test error"); + DatabricksParsingException exception = + assertThrows(DatabricksParsingException.class, () -> result.handleError(testException)); + assertTrue(exception.getMessage().contains("Cannot process lazy thrift inline arrow format")); + assertEquals( + DatabricksDriverErrorCode.INLINE_CHUNK_PARSING_ERROR.name(), exception.getSQLState()); + } + + @Test + void testEmptyResultSet() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertEquals(-1, result.getCurrentRow()); + assertFalse(result.hasNext()); + assertFalse(result.next()); + assertEquals(0, result.getRowCount()); + assertTrue(result.isCompletelyFetched()); + } + + @Test + void testNullStatement() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, null, session); + + assertEquals(-1, result.getCurrentRow()); + assertEquals(0, result.getRowCount()); + } + + @Test + void testGetCurrentRowBeforeNext() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertEquals(-1, result.getCurrentRow()); + } + + @Test + void testGetTotalRowsFetched() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertEquals(0, result.getTotalRowsFetched()); + } + + @Test + void testNextReturnsFalseOnEmptyResultSet() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertFalse(result.next()); + } + + @Test + void testHasNextReturnsFalseOnEmptyResultSet() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertFalse(result.hasNext()); + } + + @Test + void testNextReturnsFalseAfterClose() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + result.close(); + + assertFalse(result.next()); + } + + @Test + void testHasNextReturnsFalseAfterClose() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + result.close(); + + assertFalse(result.hasNext()); + } + + @Test + void testConstructorWithNullStatementUsesDefaultMaxRows() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, null, session); + + assertNotNull(result); + assertEquals(-1, result.getCurrentRow()); + } +}