33#include < csp/adapters/parquet/ParquetOutputAdapterManager.h>
44#include < csp/adapters/parquet/ParquetDictBasketOutputWriter.h>
55#include < csp/adapters/parquet/ParquetStatusUtils.h>
6- #include < csp/core/Generator.h>
76#include < csp/engine/PushInputAdapter.h>
87#include < csp/python/Conversions.h>
98#include < csp/python/Exception.h>
@@ -81,30 +80,50 @@ REGISTER_CPPNODE( csp::cppnodes, parquet_dict_basket_writer );
8180namespace
8281{
8382
84- // Generator that wraps a Python generator yielding (RecordBatch, dict, bool) tuples.
85- // The dict maps basket names to RecordBatches; the bool indicates schema change.
86- class RecordBatchGenerator : public csp ::Generator<RecordBatchWithFlag, csp::DateTime, csp::DateTime>
83+ // Generator that wraps a Python "stream factory" callable.
84+ // The factory signature: factory(starttime, endtime, needed_columns) -> iterator of (reader, basket_dict)
85+ // Each reader is a pyarrow.RecordBatchReader; basket_dict maps basket names to readers.
86+ // Readers are imported via ArrowArrayStream for GIL-free batch pulling in C++.
87+ class PyRecordBatchStreamSource : public csp ::adapters::parquet::RecordBatchStreamSource
8788{
8889public:
89- RecordBatchGenerator ( PyObject *wrappedGenerator )
90- : m_wrappedGenerator ( csp::python::PyObjectPtr::incref( wrappedGenerator ) )
90+ PyRecordBatchStreamSource ( PyObject *factory )
91+ : m_factory ( csp::python::PyObjectPtr::incref( factory ) )
9192 {
9293 }
9394
94- void init ( csp::DateTime start, csp::DateTime end ) override
95+ void init ( csp::DateTime start, csp::DateTime end,
96+ const std::set<std::string> & neededColumns ) override
9597 {
96- auto tp = csp::python::PyObjectPtr::own ( PyTuple_New ( 2 ) );
98+ auto tp = csp::python::PyObjectPtr::own ( PyTuple_New ( 3 ) );
9799 if ( !tp.get () )
98100 CSP_THROW ( csp::python::PythonPassthrough, " " );
99101
100102 PyTuple_SET_ITEM ( tp.get (), 0 , csp::python::toPython ( start ) );
101103 PyTuple_SET_ITEM ( tp.get (), 1 , csp::python::toPython ( end ) );
102- m_iter = csp::python::PyObjectPtr::check ( PyObject_Call ( m_wrappedGenerator.ptr (), tp.get (), nullptr ) );
104+
105+ if ( neededColumns.empty () )
106+ {
107+ Py_INCREF ( Py_None );
108+ PyTuple_SET_ITEM ( tp.get (), 2 , Py_None );
109+ }
110+ else
111+ {
112+ auto pyList = csp::python::PyObjectPtr::own ( PyList_New ( neededColumns.size () ) );
113+ if ( !pyList.get () )
114+ CSP_THROW ( csp::python::PythonPassthrough, " " );
115+ Py_ssize_t idx = 0 ;
116+ for ( auto & col : neededColumns )
117+ PyList_SET_ITEM ( pyList.get (), idx++, PyUnicode_FromStringAndSize ( col.c_str (), col.size () ) );
118+ PyTuple_SET_ITEM ( tp.get (), 2 , pyList.release () );
119+ }
120+
121+ m_iter = csp::python::PyObjectPtr::check ( PyObject_Call ( m_factory.ptr (), tp.get (), nullptr ) );
103122 CSP_TRUE_OR_THROW ( PyIter_Check ( m_iter.ptr () ), csp::TypeError,
104- " RecordBatch generator expected to return iterator" );
123+ " Stream factory expected to return iterator" );
105124 }
106125
107- bool next ( RecordBatchWithFlag &value ) override
126+ bool nextStream ( ) override
108127 {
109128 if ( m_iter.ptr () == nullptr )
110129 return false ;
@@ -115,21 +134,18 @@ class RecordBatchGenerator : public csp::Generator<RecordBatchWithFlag, csp::Dat
115134 if ( nextVal.get () == nullptr )
116135 return false ;
117136
118- // Expect a tuple of (RecordBatch , dict, bool )
119- CSP_TRUE_OR_THROW ( PyTuple_Check ( nextVal.get () ) && PyTuple_GET_SIZE ( nextVal.get () ) == 3 ,
120- csp::TypeError, " RecordBatch generator expected to yield (batch, basket_batches, schema_changed ) tuples" );
137+ // Expect a tuple of (RecordBatchReader , dict)
138+ CSP_TRUE_OR_THROW ( PyTuple_Check ( nextVal.get () ) && PyTuple_GET_SIZE ( nextVal.get () ) == 2 ,
139+ csp::TypeError, " Stream factory expected to yield (reader, basket_dict ) tuples" );
121140
122- PyObject *pyBatch = PyTuple_GET_ITEM ( nextVal.get (), 0 );
123- PyObject *pyBasketDict = PyTuple_GET_ITEM ( nextVal.get (), 1 );
124- PyObject *pySchemaChanged = PyTuple_GET_ITEM ( nextVal.get (), 2 );
141+ PyObject *pyReader = PyTuple_GET_ITEM ( nextVal.get (), 0 );
142+ PyObject *pyBasketDict = PyTuple_GET_ITEM ( nextVal.get (), 1 );
125143
126- value.schemaChanged = PyObject_IsTrue ( pySchemaChanged );
144+ // Import main reader via ArrowArrayStream
145+ m_mainReader = importRecordBatchReader ( pyReader );
127146
128- // Import main RecordBatch via PyCapsule C Data Interface
129- value.batch = importRecordBatch ( pyBatch );
130-
131- // Import basket RecordBatches
132- value.basketBatches .clear ();
147+ // Import basket readers and read their single batch
148+ m_basketBatches.clear ();
133149 if ( PyDict_Check ( pyBasketDict ) )
134150 {
135151 PyObject *key, *val;
@@ -139,43 +155,55 @@ class RecordBatchGenerator : public csp::Generator<RecordBatchWithFlag, csp::Dat
139155 const char *basketName = PyUnicode_AsUTF8 ( key );
140156 if ( !basketName )
141157 CSP_THROW ( csp::python::PythonPassthrough, " " );
142- value.basketBatches [ basketName ] = importRecordBatch ( val );
158+
159+ auto basketReader = importRecordBatchReader ( val );
160+ std::shared_ptr<::arrow::RecordBatch> batch;
161+ auto status = basketReader -> ReadNext ( &batch );
162+ if ( !status.ok () )
163+ CSP_THROW ( csp::ValueError, " Failed to read basket batch: " << status.ToString () );
164+ if ( batch )
165+ m_basketBatches[ basketName ] = batch;
143166 }
144167 }
145168
146169 return true ;
147170 }
148171
149- private:
150- static std::shared_ptr<::arrow::RecordBatch> importRecordBatch ( PyObject *pyBatch )
172+ std::shared_ptr<::arrow::RecordBatchReader> mainReader () override
151173 {
152- auto exportResult = csp::python::PyObjectPtr::own (
153- PyObject_CallMethod ( pyBatch, " __arrow_c_array__" , nullptr ) );
154- if ( !exportResult.get () || PyErr_Occurred () )
155- CSP_THROW ( csp::python::PythonPassthrough, " " );
156-
157- CSP_TRUE_OR_THROW ( PyTuple_Check ( exportResult.get () ) && PyTuple_GET_SIZE ( exportResult.get () ) == 2 ,
158- csp::TypeError, " __arrow_c_array__ expected to return (schema_capsule, array_capsule)" );
174+ return m_mainReader;
175+ }
159176
160- PyObject *pySchemaCapsule = PyTuple_GET_ITEM ( exportResult.get (), 0 );
161- PyObject *pyArrayCapsule = PyTuple_GET_ITEM ( exportResult.get (), 1 );
177+ const std::unordered_map<std::string, std::shared_ptr<::arrow::RecordBatch>> & basketBatches () const override
178+ {
179+ return m_basketBatches;
180+ }
162181
163- auto *c_schema = reinterpret_cast <struct ArrowSchema *>( PyCapsule_GetPointer ( pySchemaCapsule, " arrow_schema" ) );
164- auto *c_array = reinterpret_cast <struct ArrowArray *>( PyCapsule_GetPointer ( pyArrayCapsule, " arrow_array" ) );
182+ private:
183+ static std::shared_ptr<::arrow::RecordBatchReader> importRecordBatchReader ( PyObject *pyReader )
184+ {
185+ // Call __arrow_c_stream__() to export as ArrowArrayStream PyCapsule
186+ auto capsule = csp::python::PyObjectPtr::own (
187+ PyObject_CallMethod ( pyReader, " __arrow_c_stream__" , nullptr ) );
188+ if ( !capsule.get () || PyErr_Occurred () )
189+ CSP_THROW ( csp::python::PythonPassthrough, " " );
165190
166- auto schemaResult = arrow::ImportSchema ( c_schema );
167- if ( !schemaResult.ok () )
168- CSP_THROW ( csp::ValueError, " Failed to import RecordBatch schema: " << schemaResult.status ().ToString () );
191+ auto *stream = reinterpret_cast <struct ArrowArrayStream *>(
192+ PyCapsule_GetPointer ( capsule.get (), " arrow_array_stream" ) );
193+ if ( !stream )
194+ CSP_THROW ( csp::ValueError, " Failed to get ArrowArrayStream from PyCapsule" );
169195
170- auto batchResult = arrow::ImportRecordBatch ( c_array, schemaResult. ValueUnsafe () );
171- if ( !batchResult .ok () )
172- CSP_THROW ( csp::ValueError, " Failed to import RecordBatch : " << batchResult .status ().ToString () );
196+ auto result = :: arrow::ImportRecordBatchReader ( stream );
197+ if ( !result .ok () )
198+ CSP_THROW ( csp::ValueError, " Failed to import RecordBatchReader : " << result .status ().ToString () );
173199
174- return std::move ( batchResult .ValueUnsafe () );
200+ return result .ValueUnsafe ();
175201 }
176202
177- csp::python::PyObjectPtr m_wrappedGenerator ;
203+ csp::python::PyObjectPtr m_factory ;
178204 csp::python::PyObjectPtr m_iter;
205+ std::shared_ptr<::arrow::RecordBatchReader> m_mainReader;
206+ std::unordered_map<std::string, std::shared_ptr<::arrow::RecordBatch>> m_basketBatches;
179207};
180208
181209}
@@ -185,9 +213,9 @@ namespace csp::python
185213
186214// AdapterManager
187215csp::AdapterManager *create_parquet_input_adapter_manager_impl ( PyEngine *engine, const Dictionary &properties,
188- RecordBatchGenerator::Ptr rbGenerator )
216+ ParquetInputAdapterManager::RecordBatchStreamSourcePtr streamSource )
189217{
190- auto res = engine -> engine () -> createOwnedObject<ParquetInputAdapterManager>( properties, rbGenerator );
218+ auto res = engine -> engine () -> createOwnedObject<ParquetInputAdapterManager>( properties, streamSource );
191219 return res;
192220}
193221
@@ -443,9 +471,9 @@ static PyObject *create_parquet_input_adapter_manager( PyObject *args )
443471 &PyFunction_Type, &pyFileGenerator ) )
444472 CSP_THROW ( PythonPassthrough, " " );
445473
446- auto rbGenerator = std::make_shared<RecordBatchGenerator >( pyFileGenerator );
474+ auto streamSource = std::make_shared<PyRecordBatchStreamSource >( pyFileGenerator );
447475 auto *adapterMgr = create_parquet_input_adapter_manager_impl ( pyEngine, fromPython<Dictionary>( pyProperties ),
448- rbGenerator );
476+ streamSource );
449477 auto res = PyCapsule_New ( adapterMgr, " adapterMgr" , nullptr );
450478 return res;
451479 CSP_RETURN_NULL ;
0 commit comments