Skip to content

Commit 92e5ff9

Browse files
committed
Support NumpyNDArray, Numpy1DArray in record_batches_to_struct/struct_to_record_batches
Signed-off-by: Arham Chopra <[email protected]>
1 parent e5523eb commit 92e5ff9

File tree

9 files changed

+650
-315
lines changed

9 files changed

+650
-315
lines changed

cpp/csp/adapters/parquet/ParquetOutputAdapter.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,10 @@ ListColumnParquetOutputHandler::ListColumnParquetOutputHandler( Engine *engine,
127127
listWriterInterface ) )
128128
{
129129
m_valueHandler = std::make_unique<ValueHandler>(
130-
[ this ]( const TimeSeriesProvider *input )
130+
[ this ]( const DialectGenericType& input )
131131
{
132132
static_cast<ListColumnArrayBuilder *>(this -> m_columnArrayBuilder.get())
133-
-> setValue( input -> lastValueTyped<DialectGenericType>() );
133+
-> setValue( input );
134134
} );
135135
}
136136

@@ -176,7 +176,7 @@ std::shared_ptr<::arrow::ArrayBuilder> ListColumnParquetOutputHandler::createVal
176176

177177
void ListColumnParquetOutputAdapter::executeImpl()
178178
{
179-
( *m_valueHandler )( input() );
179+
( *m_valueHandler )( input() -> lastValueTyped<DialectGenericType>() );
180180
m_parquetWriter.scheduleEndCycleEvent();
181181
}
182182

cpp/csp/adapters/parquet/ParquetOutputAdapter.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,23 @@ class ListColumnParquetOutputHandler : public ParquetOutputHandler
9393
static_cast<ColumnBuilderType *>(this -> m_columnArrayBuilder.get()) -> setValue( value );
9494
}
9595

96-
void writeValueFromTs( const TimeSeriesProvider *input ) override final
96+
void writeValueFromArgs( const DialectGenericType& input )
9797
{
9898
( *m_valueHandler )( input );
9999
}
100100

101+
void writeValueFromTs( const TimeSeriesProvider *input ) override final
102+
{
103+
( *m_valueHandler )( input -> lastValueTyped<DialectGenericType>() );
104+
}
105+
101106
private:
102107
std::shared_ptr<arrow::ArrayBuilder> createValueBuilder( const CspTypePtr &elemType,
103108
DialectGenericListWriterInterface::Ptr &listWriterInterface );
104109

105110

106111
protected :
107-
using ValueHandler = std::function<void( const TimeSeriesProvider * )>;
112+
using ValueHandler = std::function<void( const DialectGenericType& )>;
108113

109114
std::unique_ptr<ValueHandler> m_valueHandler;
110115
std::shared_ptr<ListColumnArrayBuilder> m_columnArrayBuilder;

cpp/csp/adapters/parquet/ParquetReaderColumnAdapter.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,24 @@ void BaseListColumnAdapter<ArrowListArrayType, ValueArrayType, ValueType>::addSu
749749
<< " in file " << m_parquetReader.getCurFileOrTableName() );
750750
}
751751

752+
template< typename ArrowListArrayType, typename ValueArrayType, typename ValueType>
753+
void BaseListColumnAdapter<ArrowListArrayType, ValueArrayType, ValueType>::addSubscriber( csp::adapters::utils::ValueDispatcher<const DialectGenericType&>::SubscriberType subscriber, std::optional<utils::Symbol> symbol, const DialectGenericListReaderInterface::Ptr &listReader )
754+
{
755+
CSP_TRUE_OR_THROW_RUNTIME( m_listReader == nullptr,
756+
"Trying to subscribe list column in parquet reader more than once, this is not supported" );
757+
CSP_TRUE_OR_THROW_RUNTIME( listReader != nullptr,
758+
"Trying to subscribe list column in parquet reader with null listReader" );
759+
m_listReader = std::dynamic_pointer_cast<TypedDialectGenericListReaderInterface<ValueType>>( listReader );
760+
CSP_TRUE_OR_THROW_RUNTIME( m_listReader != nullptr,
761+
"Subscribed to parquet column " << getColumnName() << " with type "
762+
<< "NumpyArray[" << listReader -> getValueType() -> type().asString()
763+
<< "] while "
764+
<< " column type in file is NumpyArray["
765+
<< getContainerValueType() -> type().asString() << "]"
766+
<< " in file " << m_parquetReader.getCurFileOrTableName() );
767+
m_dispatcher.addSubscriber( subscriber, symbol );
768+
}
769+
752770
template< typename ArrowListArrayType, typename ValueArrayType, typename ValueType >
753771
void NativeListColumnAdapter<ArrowListArrayType, ValueArrayType, ValueType>::readCurValue()
754772
{

cpp/csp/adapters/parquet/ParquetReaderColumnAdapter.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class ParquetColumnAdapter
4848
virtual void addSubscriber( ManagedSimInputAdapter *inputAdapter, std::optional<utils::Symbol> symbol = {} ) = 0;
4949
// NOTE: This API is only defined for ListType Column Adapters
5050
virtual void addSubscriber( ManagedSimInputAdapter *inputAdapter, std::optional<utils::Symbol> symbol, const DialectGenericListReaderInterface::Ptr &listReader ) = 0;
51+
// NOTE: This API is only used to add subscriber for ListType column adapters in cases where there is no ManagedSimInputAdapter
52+
virtual void addSubscriber( csp::adapters::utils::ValueDispatcher<const DialectGenericType&>::SubscriberType subscriber, std::optional<utils::Symbol> symbol, const DialectGenericListReaderInterface::Ptr &listReader ) = 0;
5153

5254
virtual void dispatchValue( const utils::Symbol *symbol ) = 0;
5355

@@ -124,6 +126,10 @@ class MissingColumnAdapter : public ParquetColumnAdapter
124126

125127
virtual void addSubscriber( ManagedSimInputAdapter *inputAdapter, std::optional<utils::Symbol> symbol = {} ) override {};
126128
virtual void addSubscriber( ManagedSimInputAdapter *inputAdapter, std::optional<utils::Symbol> symbol, const DialectGenericListReaderInterface::Ptr &listReader ) override {};
129+
virtual void addSubscriber( csp::adapters::utils::ValueDispatcher<const DialectGenericType&>::SubscriberType subscriber, std::optional<utils::Symbol> symbol, const DialectGenericListReaderInterface::Ptr &listReader ) override
130+
{
131+
CSP_THROW(TypeError, "Trying to add DIALECT_GENERIC subscriber on non container type");
132+
}
127133

128134
virtual void dispatchValue( const utils::Symbol *symbol ) override {};
129135

@@ -172,6 +178,10 @@ class BaseTypedColumnAdapter : public ParquetColumnAdapter
172178
void addSubscriber( ManagedSimInputAdapter *inputAdapter, std::optional<utils::Symbol> symbol = {} ) override;
173179
void addSubscriber( ManagedSimInputAdapter *inputAdapter,
174180
std::optional<utils::Symbol> symbol, const DialectGenericListReaderInterface::Ptr &listReader ) override;
181+
virtual void addSubscriber( csp::adapters::utils::ValueDispatcher<const DialectGenericType&>::SubscriberType subscriber, std::optional<utils::Symbol> symbol, const DialectGenericListReaderInterface::Ptr &listReader ) override
182+
{
183+
CSP_THROW(TypeError, "Trying to add DIALECT_GENERIC subscriber on non container type");
184+
}
175185
void dispatchValue( const utils::Symbol *symbol ) override;
176186

177187
void ensureType( CspType::Ptr cspType ) override;
@@ -288,9 +298,11 @@ class BaseListColumnAdapter : public BaseTypedColumnAdapter<DialectGenericType,
288298
public:
289299
using BaseTypedColumnAdapter<DialectGenericType, ArrowListArrayType>::BaseTypedColumnAdapter;
290300
using BaseTypedColumnAdapter<DialectGenericType, ArrowListArrayType>::getColumnName;
301+
using BaseTypedColumnAdapter<DialectGenericType, ArrowListArrayType>::m_dispatcher;
291302
void addSubscriber( ManagedSimInputAdapter *inputAdapter, std::optional<utils::Symbol> symbol = {} ) override;
292303
void addSubscriber( ManagedSimInputAdapter *inputAdapter,
293304
std::optional<utils::Symbol> symbol, const DialectGenericListReaderInterface::Ptr &listReader ) override;
305+
void addSubscriber( csp::adapters::utils::ValueDispatcher<const DialectGenericType&>::SubscriberType subscriber, std::optional<utils::Symbol> symbol, const DialectGenericListReaderInterface::Ptr &listReader ) override;
294306
CspTypePtr getNativeCspType() const override {return nullptr;}
295307
bool isListType() const override{ return true; };
296308
CspTypePtr getContainerValueType() const override{ return CspType::fromCType<ValueType>::type(); }

cpp/csp/python/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ set(CSPIMPL_PUBLIC_HEADERS
3636
PyBasketInputProxy.h
3737
PyBasketOutputProxy.h
3838
PyCppNode.h
39+
PyDialectGenericListsInterface.h
3940
PyEngine.h
4041
PyInputAdapterWrapper.h
4142
PyInputProxy.h

0 commit comments

Comments
 (0)