Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.databricks.jdbc.api.impl;

import com.databricks.jdbc.api.impl.arrow.ArrowStreamResult;
import com.databricks.jdbc.api.impl.arrow.LazyThriftInlineArrowResult;
import com.databricks.jdbc.api.impl.volume.VolumeOperationResult;
import com.databricks.jdbc.api.internal.IDatabricksSession;
import com.databricks.jdbc.api.internal.IDatabricksStatementInternal;
Expand Down Expand Up @@ -96,9 +97,9 @@ private static IExecutionResult getResultHandler(
case COLUMN_BASED_SET:
return new LazyThriftResult(resultsResp, parentStatement, session);
case ARROW_BASED_SET:
return new ArrowStreamResult(resultsResp, true, parentStatement, session);
return new LazyThriftInlineArrowResult(resultsResp, parentStatement, session);
case URL_BASED_SET:
return new ArrowStreamResult(resultsResp, false, parentStatement, session);
return new ArrowStreamResult(resultsResp, parentStatement, session);
case ROW_BASED_SET:
throw new DatabricksSQLFeatureNotSupportedException(
"Invalid state - row based set cannot be received");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,11 @@ public ArrowStreamResult(

public ArrowStreamResult(
TFetchResultsResp resultsResp,
boolean isInlineArrow,
IDatabricksStatementInternal parentStatementId,
IDatabricksSession session)
throws DatabricksSQLException {
this(
resultsResp,
isInlineArrow,
parentStatementId,
session,
DatabricksHttpClientFactory.getInstance().getClient(session.getConnectionContext()));
Expand All @@ -100,27 +98,22 @@ public ArrowStreamResult(
@VisibleForTesting
ArrowStreamResult(
TFetchResultsResp resultsResp,
boolean isInlineArrow,
IDatabricksStatementInternal parentStatement,
IDatabricksSession session,
IDatabricksHttpClient httpClient)
throws DatabricksSQLException {
this.session = session;
setColumnInfo(resultsResp.getResultSetMetadata());
if (isInlineArrow) {
this.chunkProvider = new InlineChunkProvider(resultsResp, parentStatement, session);
} else {
CompressionCodec compressionCodec =
CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata());
this.chunkProvider =
new RemoteChunkProvider(
parentStatement,
resultsResp,
session,
httpClient,
session.getConnectionContext().getCloudFetchThreadPoolSize(),
compressionCodec);
}
CompressionCodec compressionCodec =
CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata());
this.chunkProvider =
new RemoteChunkProvider(
parentStatement,
resultsResp,
session,
httpClient,
session.getConnectionContext().getCloudFetchThreadPoolSize(),
compressionCodec);
}

public List<String> getArrowMetadata() throws DatabricksSQLException {
Expand All @@ -133,30 +126,15 @@ public List<String> getArrowMetadata() throws DatabricksSQLException {
/** {@inheritDoc} */
@Override
public Object getObject(int columnIndex) throws DatabricksSQLException {
ColumnInfoTypeName requiredType = columnInfos.get(columnIndex).getTypeName();
ColumnInfo columnInfo = columnInfos.get(columnIndex);
ColumnInfoTypeName requiredType = columnInfo.getTypeName();
String arrowMetadata = chunkIterator.getType(columnIndex);
if (arrowMetadata == null) {
arrowMetadata = columnInfos.get(columnIndex).getTypeText();
}

// Handle complex type conversion when complex datatype support is disabled
boolean isComplexDatatypeSupportEnabled =
this.session.getConnectionContext().isComplexDatatypeSupportEnabled();
if (!isComplexDatatypeSupportEnabled && isComplexType(requiredType)) {
LOGGER.debug("Complex datatype support is disabled, converting complex type to STRING");

Object result =
chunkIterator.getColumnObjectAtCurrentRow(
columnIndex, ColumnInfoTypeName.STRING, "STRING", columnInfos.get(columnIndex));
if (result == null) {
return null;
}
ComplexDataTypeParser parser = new ComplexDataTypeParser();
return parser.formatComplexTypeString(result.toString(), requiredType.name(), arrowMetadata);
arrowMetadata = columnInfo.getTypeText();
}

return chunkIterator.getColumnObjectAtCurrentRow(
columnIndex, requiredType, arrowMetadata, columnInfos.get(columnIndex));
return getObjectWithComplexTypeHandling(
session, chunkIterator, columnIndex, requiredType, arrowMetadata, columnInfo);
}

/**
Expand Down Expand Up @@ -237,4 +215,44 @@ private void setColumnInfo(TGetResultSetMetadataResp resultManifest) {
columnInfos.add(getColumnInfoFromTColumnDesc(tColumnDesc));
}
}

/**
* Helper method to handle complex type conversion when complex datatype support is disabled.
*
* @param session The databricks session
* @param chunkIterator The chunk iterator
* @param columnIndex The column index
* @param requiredType The required column type
* @param arrowMetadata The arrow metadata
* @param columnInfo The column info
* @return The object value (converted if complex type and support disabled)
* @throws DatabricksSQLException if an error occurs
*/
protected static Object getObjectWithComplexTypeHandling(
IDatabricksSession session,
ArrowResultChunkIterator chunkIterator,
int columnIndex,
ColumnInfoTypeName requiredType,
String arrowMetadata,
ColumnInfo columnInfo)
throws DatabricksSQLException {
boolean isComplexDatatypeSupportEnabled =
session.getConnectionContext().isComplexDatatypeSupportEnabled();

if (!isComplexDatatypeSupportEnabled && isComplexType(requiredType)) {
LOGGER.debug("Complex datatype support is disabled, converting complex type to STRING");
Object result =
chunkIterator.getColumnObjectAtCurrentRow(
columnIndex, ColumnInfoTypeName.STRING, "STRING", columnInfo);
if (result == null) {
return null;
}
ComplexDataTypeParser parser = new ComplexDataTypeParser();

return parser.formatComplexTypeString(result.toString(), requiredType.name(), arrowMetadata);
}

return chunkIterator.getColumnObjectAtCurrentRow(
columnIndex, requiredType, arrowMetadata, columnInfo);
}
}
Original file line number Diff line number Diff line change
@@ -1,31 +1,17 @@
package com.databricks.jdbc.api.impl.arrow;

import static com.databricks.jdbc.common.util.DatabricksTypeUtil.*;
import static com.databricks.jdbc.common.util.DecompressionUtil.decompress;

import com.databricks.jdbc.api.internal.IDatabricksSession;
import com.databricks.jdbc.api.internal.IDatabricksStatementInternal;
import com.databricks.jdbc.common.CompressionCodec;
import com.databricks.jdbc.exception.DatabricksParsingException;
import com.databricks.jdbc.exception.DatabricksSQLException;
import com.databricks.jdbc.log.JdbcLogger;
import com.databricks.jdbc.log.JdbcLoggerFactory;
import com.databricks.jdbc.model.client.thrift.generated.*;
import com.databricks.jdbc.model.core.ResultData;
import com.databricks.jdbc.model.core.ResultManifest;
import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode;
import com.google.common.annotations.VisibleForTesting;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.SchemaUtility;

/** Class to manage inline Arrow chunks */
public class InlineChunkProvider implements ChunkProvider {
Expand All @@ -37,23 +23,6 @@ public class InlineChunkProvider implements ChunkProvider {
private final ArrowResultChunk
arrowResultChunk; // There is only one packet of data in case of inline arrow

InlineChunkProvider(
TFetchResultsResp resultsResp,
IDatabricksStatementInternal parentStatement,
IDatabricksSession session)
throws DatabricksParsingException {
this.currentChunkIndex = -1;
this.totalRows = 0;
ByteArrayInputStream byteStream = initializeByteStream(resultsResp, session, parentStatement);
ArrowResultChunk.Builder builder =
ArrowResultChunk.builder().withInputStream(byteStream, totalRows);

if (parentStatement != null) {
builder.withStatementId(parentStatement.getStatementId());
}
arrowResultChunk = builder.build();
}

/**
* Constructor for inline arrow chunk provider from {@link ResultData} and {@link ResultManifest}.
*
Expand Down Expand Up @@ -123,97 +92,6 @@ public boolean isClosed() {
return isClosed;
}

private ByteArrayInputStream initializeByteStream(
TFetchResultsResp resultsResp,
IDatabricksSession session,
IDatabricksStatementInternal parentStatement)
throws DatabricksParsingException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
CompressionCodec compressionType =
CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata());
try {
byte[] serializedSchema = getSerializedSchema(resultsResp.getResultSetMetadata());
if (serializedSchema != null) {
baos.write(serializedSchema);
}
writeToByteOutputStream(
compressionType, parentStatement, resultsResp.getResults().getArrowBatches(), baos);
while (resultsResp.hasMoreRows) {
resultsResp = session.getDatabricksClient().getMoreResults(parentStatement);
writeToByteOutputStream(
compressionType, parentStatement, resultsResp.getResults().getArrowBatches(), baos);
}
return new ByteArrayInputStream(baos.toByteArray());
} catch (DatabricksSQLException | IOException e) {
handleError(e);
}
return null;
}

private void writeToByteOutputStream(
CompressionCodec compressionCodec,
IDatabricksStatementInternal parentStatement,
List<TSparkArrowBatch> arrowBatchList,
ByteArrayOutputStream baos)
throws DatabricksSQLException, IOException {
for (TSparkArrowBatch arrowBatch : arrowBatchList) {
byte[] decompressedBytes =
decompress(
arrowBatch.getBatch(),
compressionCodec,
String.format(
"Data fetch for inline arrow batch [%d] and statement [%s] with decompression algorithm : [%s]",
arrowBatch.getRowCount(), parentStatement, compressionCodec));
totalRows += arrowBatch.getRowCount();
baos.write(decompressedBytes);
}
}

private byte[] getSerializedSchema(TGetResultSetMetadataResp metadata)
throws DatabricksSQLException {
if (metadata.getArrowSchema() != null) {
return metadata.getArrowSchema();
}
Schema arrowSchema = hiveSchemaToArrowSchema(metadata.getSchema());
try {
return SchemaUtility.serialize(arrowSchema);
} catch (IOException e) {
handleError(e);
}
// should never reach here;
return null;
}

private Schema hiveSchemaToArrowSchema(TTableSchema hiveSchema)
throws DatabricksParsingException {
List<Field> fields = new ArrayList<>();
if (hiveSchema == null) {
return new Schema(fields);
}
try {
hiveSchema
.getColumns()
.forEach(
columnDesc -> {
try {
fields.add(getArrowField(columnDesc));
} catch (SQLException e) {
throw new RuntimeException(e);
}
});
} catch (RuntimeException e) {
handleError(e);
}
return new Schema(fields);
}

private Field getArrowField(TColumnDesc columnDesc) throws SQLException {
TPrimitiveTypeEntry primitiveTypeEntry = getTPrimitiveTypeOrDefault(columnDesc.getTypeDesc());
ArrowType arrowType = mapThriftToArrowType(primitiveTypeEntry.getType());
FieldType fieldType = new FieldType(true, arrowType, null);
return new Field(columnDesc.getColumnName(), fieldType, null);
}

@VisibleForTesting
void handleError(Exception e) throws DatabricksParsingException {
String errorMessage =
Expand Down
Loading
Loading