Skip to content

Commit b14e2c8

Browse files
committed
Add node to convert record batches into csp.Structs and fix arrow adapter
Signed-off-by: Arham Chopra <[email protected]>
1 parent f70df59 commit b14e2c8

File tree

6 files changed

+244
-106
lines changed

6 files changed

+244
-106
lines changed

cpp/csp/adapters/parquet/ParquetReader.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,10 +382,13 @@ void SingleFileParquetReader::clear()
382382
}
383383

384384
InMemoryTableParquetReader::InMemoryTableParquetReader( GeneratorPtr generatorPtr, std::vector<std::string> columns,
385-
bool allowMissingColumns, std::optional<std::string> symbolColumnName )
385+
bool allowMissingColumns, std::optional<std::string> symbolColumnName, bool call_init )
386386
: SingleTableParquetReader( columns, true, allowMissingColumns, symbolColumnName ), m_generatorPtr( generatorPtr )
387387
{
388-
init();
388+
if( call_init )
389+
{
390+
init();
391+
}
389392
}
390393

391394
bool InMemoryTableParquetReader::openNextFile()

cpp/csp/adapters/parquet/ParquetReader.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -424,21 +424,28 @@ class SingleFileParquetReader final : public SingleTableParquetReader
424424
bool m_allowMissingFiles;
425425
};
426426

427-
class InMemoryTableParquetReader final : public SingleTableParquetReader
427+
class InMemoryTableParquetReader : public SingleTableParquetReader
428428
{
429429
public:
430430
using GeneratorPtr = csp::Generator<std::shared_ptr<arrow::Table>, csp::DateTime, csp::DateTime>::Ptr;
431431

432432
InMemoryTableParquetReader( GeneratorPtr generatorPtr, std::vector<std::string> columns,
433433
bool allowMissingColumns,
434-
std::optional<std::string> symbolColumnName = {} );
434+
std::optional<std::string> symbolColumnName = {},
435+
bool call_init = true);
435436
std::string getCurFileOrTableName() const override{ return "IN_MEMORY_TABLE"; }
436437

437438
protected:
438-
bool openNextFile() override;
439-
bool readNextRowGroup() override;
439+
virtual bool openNextFile() override;
440+
virtual bool readNextRowGroup() override;
441+
void setTable( std::shared_ptr<arrow::Table> table )
442+
{
443+
m_fullTable = table;
444+
m_nextChunkIndex = 0;
445+
m_curTable = nullptr;
446+
}
440447

441-
void clear() override;
448+
virtual void clear() override;
442449

443450
private:
444451
GeneratorPtr m_generatorPtr;

cpp/csp/python/CMakeLists.txt

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,35 @@ target_compile_definitions(cspimpl PUBLIC NPY_NO_DEPRECATED_API=NPY_1_7_API_VERS
9090
target_compile_definitions(cspimpl PRIVATE CSPIMPL_EXPORTS=1)
9191

9292

93+
find_package(Arrow REQUIRED)
94+
find_package(Parquet REQUIRED)
95+
96+
if(WIN32)
97+
if(CSP_USE_VCPKG)
98+
set(ARROW_PACKAGES_TO_LINK Arrow::arrow_static Parquet::parquet_static )
99+
target_compile_definitions(csp_parquet_adapter PUBLIC ARROW_STATIC)
100+
target_compile_definitions(csp_parquet_adapter PUBLIC PARQUET_STATIC)
101+
else()
102+
# use dynamic variants
103+
# Until we manage to get the fix for ws3_32.dll in arrow-16 into conda, manually fix the error here
104+
get_target_property(LINK_LIBS Arrow::arrow_shared INTERFACE_LINK_LIBRARIES)
105+
string(REPLACE "ws2_32.dll" "ws2_32" FIXED_LINK_LIBS "${LINK_LIBS}")
106+
set_target_properties(Arrow::arrow_shared PROPERTIES INTERFACE_LINK_LIBRARIES "${FIXED_LINK_LIBS}")
107+
set(ARROW_PACKAGES_TO_LINK parquet_shared arrow_shared)
108+
endif()
109+
else()
110+
if(CSP_USE_VCPKG)
111+
# use static variants
112+
set(ARROW_PACKAGES_TO_LINK parquet_static arrow_static)
113+
else()
114+
# use dynamic variants
115+
set(ARROW_PACKAGES_TO_LINK parquet arrow)
116+
endif()
117+
endif()
118+
93119
## Baselib c++ module
94120
add_library(cspbaselibimpl SHARED cspbaselibimpl.cpp)
95-
target_link_libraries(cspbaselibimpl cspimpl baselibimpl)
121+
target_link_libraries(cspbaselibimpl cspimpl baselibimpl csp_parquet_adapter ${ARROW_PACKAGES_TO_LINK})
96122

97123
# Include exprtk include directory for exprtk node
98124
target_include_directories(cspbaselibimpl PRIVATE ${EXPRTK_INCLUDE_DIRS})

cpp/csp/python/cspbaselibimpl.cpp

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55
#include <exprtk.hpp>
66
#include <numpy/ndarrayobject.h>
77

8+
#include <arrow/type.h>
9+
#include <arrow/table.h>
10+
#include <arrow/c/abi.h>
11+
#include <arrow/c/bridge.h>
12+
13+
#include <csp/adapters/parquet/ParquetReader.h>
14+
#include <csp/adapters/utils/StructAdapterInfo.h>
15+
816
static void * init_nparray()
917
{
1018
csp::python::AcquireGIL gil;
@@ -325,6 +333,133 @@ DECLARE_CPPNODE( exprtk_impl )
325333

326334
EXPORT_CPPNODE( exprtk_impl );
327335

336+
DECLARE_CPPNODE( record_batches_to_struct )
337+
{
338+
using InMemoryTableParquetReader = csp::adapters::parquet::InMemoryTableParquetReader;
339+
class RecordBatchReader : public InMemoryTableParquetReader
340+
{
341+
public:
342+
RecordBatchReader( std::vector<std::string> columns, std::shared_ptr<arrow::Schema> schema ):
343+
InMemoryTableParquetReader( nullptr, columns, false, {}, false )
344+
{
345+
m_schema = schema;
346+
}
347+
std::string getCurFileOrTableName() const override{ return "IN_RECORD_BATCH"; }
348+
void initialize() { setColumnAdaptersFromCurrentTable(); }
349+
void parseBatches( std::vector<std::shared_ptr<arrow::RecordBatch>> record_batches )
350+
{
351+
auto table_result = arrow::Table::FromRecordBatches( m_schema, record_batches );
352+
if( !table_result.ok() )
353+
CSP_THROW( ValueError, "Failed to load all the record batches into a table: " << table_result.status().ToString() );
354+
355+
setTable( table_result.ValueUnsafe() );
356+
357+
if( !readNextRowGroup() )
358+
CSP_THROW( ValueError, "Unable to read the first row group from table" );
359+
360+
while( readNextRow() )
361+
{
362+
for( auto& adapter: getStructAdapters() )
363+
{
364+
adapter -> dispatchValue( nullptr );
365+
}
366+
}
367+
}
368+
void stop()
369+
{
370+
InMemoryTableParquetReader::clear();
371+
}
372+
protected:
373+
bool openNextFile() override { return false; }
374+
void clear() override { setTable( nullptr ); }
375+
};
376+
377+
SCALAR_INPUT( DialectGenericType, schema_ptr );
378+
SCALAR_INPUT( StructMetaPtr, cls );
379+
SCALAR_INPUT( DictionaryPtr, properties );
380+
TS_INPUT( Generic, data );
381+
382+
TS_OUTPUT( Generic );
383+
384+
std::shared_ptr<RecordBatchReader> reader;
385+
std::vector<StructPtr>* m_structsVecPtr;
386+
387+
INIT_CPPNODE( record_batches_to_struct )
388+
{
389+
auto & input_def = tsinputDef( "data" );
390+
if( input_def.type -> type() != CspType::Type::ARRAY )
391+
CSP_THROW( TypeError, "record_batches_to_struct expected ts array type, got " << input_def.type -> type() );
392+
393+
auto * aType = static_cast<const CspArrayType *>( input_def.type.get() );
394+
CspTypePtr elemType = aType -> elemType();
395+
if( elemType -> type() != CspType::Type::DIALECT_GENERIC )
396+
CSP_THROW( TypeError, "record_batches_to_struct expected ts array of DIALECT_GENERIC type, got " << elemType -> type() );
397+
398+
auto & output_def = tsoutputDef( "" );
399+
if( output_def.type -> type() != CspType::Type::ARRAY )
400+
CSP_THROW( TypeError, "record_batches_to_struct expected ts array type, got " << output_def.type -> type() );
401+
}
402+
403+
START()
404+
{
405+
// Create Adapters for Schema
406+
PyObject* capsule = csp::python::toPythonBorrowed(schema_ptr);
407+
struct ArrowSchema* c_schema = reinterpret_cast<struct ArrowSchema*>( PyCapsule_GetPointer(capsule, "arrow_schema") );
408+
auto result = arrow::ImportSchema(c_schema);
409+
if( !result.ok() )
410+
CSP_THROW( ValueError, "Failed to load the arrow schema: " << result.status().ToString() );
411+
std::shared_ptr<arrow::Schema> schema = result.ValueUnsafe();
412+
std::vector<std::string> columns;
413+
auto field_map = properties.value() -> get<DictionaryPtr>( "field_map" );
414+
for( auto it = field_map -> begin(); it != field_map -> end(); ++it )
415+
{
416+
if( schema -> GetFieldByName( it.key() ) )
417+
columns.push_back(it.key());
418+
else
419+
CSP_THROW( ValueError, "column " << it.key() << " not found in schema" );
420+
}
421+
reader = std::make_shared<RecordBatchReader>( columns, schema );
422+
reader -> initialize();
423+
424+
CspTypePtr outType = std::make_shared<csp::CspStructType>( cls.value() );
425+
csp::adapters::utils::StructAdapterInfo key{ std::move(outType), std::move(field_map) };
426+
auto& struct_adapter = reader -> getStructAdapter( key );
427+
struct_adapter.addSubscriber( [this]( StructPtr * s )
428+
{
429+
if( s ) this -> m_structsVecPtr -> push_back( *s );
430+
else CSP_THROW( ValueError, "Failed to create struct while parsing the record batches" );
431+
}, {} );
432+
}
433+
434+
INVOKE()
435+
{
436+
if( csp.ticked( data ) )
437+
{
438+
auto & py_batches = data.lastValue<std::vector<DialectGenericType>>();
439+
std::vector<std::shared_ptr<arrow::RecordBatch>> batches;
440+
for( auto& py_batch: py_batches )
441+
{
442+
PyObject* py_tuple = csp::python::toPythonBorrowed( py_batch );
443+
PyObject* py_schema = PyTuple_GET_ITEM( py_tuple, 0 );
444+
PyObject* py_array = PyTuple_GET_ITEM( py_tuple, 1 );
445+
struct ArrowSchema* c_schema = reinterpret_cast<struct ArrowSchema*>( PyCapsule_GetPointer( py_schema, "arrow_schema" ) );
446+
struct ArrowArray* c_array = reinterpret_cast<struct ArrowArray*>( PyCapsule_GetPointer( py_array, "arrow_array" ) );
447+
auto result = arrow::ImportRecordBatch(c_array, c_schema);
448+
if( !result.ok() )
449+
CSP_THROW( ValueError, "Failed to load record batches through PyCapsule C Data interface: " << result.status().ToString() );
450+
batches.emplace_back(result.ValueUnsafe());
451+
}
452+
std::vector<StructPtr> & out = unnamed_output().reserveSpace<std::vector<StructPtr>>();
453+
out.clear();
454+
m_structsVecPtr = &out;
455+
reader -> parseBatches( batches );
456+
m_structsVecPtr = nullptr;
457+
}
458+
}
459+
};
460+
461+
EXPORT_CPPNODE( record_batches_to_struct );
462+
328463
}
329464

330465
// Base nodes
@@ -350,6 +485,7 @@ REGISTER_CPPNODE( csp::cppnodes, struct_fromts );
350485
REGISTER_CPPNODE( csp::cppnodes, struct_collectts );
351486

352487
REGISTER_CPPNODE( csp::cppnodes, exprtk_impl );
488+
REGISTER_CPPNODE( csp::cppnodes, record_batches_to_struct );
353489

354490
static PyModuleDef _cspbaselibimpl_module = {
355491
PyModuleDef_HEAD_INIT,

0 commit comments

Comments
 (0)