Skip to content

Commit f88a618

Browse files
committed
Fix Arrow batch row counts for CloudFetch results
Respect server-declared row counts when decoding Arrow IPC result streams so CloudFetch payload padding is not exposed as extra rows. Signed-off-by: Madhavendra Rathore <madhavendra.rathore@databricks.com>
1 parent fd65633 commit f88a618

3 files changed

Lines changed: 230 additions & 20 deletions

File tree

internal/rows/arrowbased/arrowRecordIterator.go

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,29 +38,67 @@ type arrowRecordIterator struct {
3838
isFinished bool
3939
arrowSchemaBytes []byte
4040
arrowSchema *arrow.Schema
41+
nextRowNumber int64
42+
hasNextRowNumber bool
4143
}
4244

4345
var _ rows.ArrowBatchIterator = (*arrowRecordIterator)(nil)
4446

4547
// Retrieve the next arrow record
4648
func (ri *arrowRecordIterator) Next() (arrow.Record, error) {
47-
if !ri.HasNext() {
48-
// returning EOF indicates that there are no more records to iterate
49-
return nil, io.EOF
49+
for {
50+
if !ri.HasNext() {
51+
// returning EOF indicates that there are no more records to iterate
52+
return nil, io.EOF
53+
}
54+
55+
// make sure we have the current batch
56+
err := ri.getCurrentBatch()
57+
if err != nil {
58+
return nil, err
59+
}
60+
61+
// return next record in current batch
62+
r, err := ri.currentBatch.Next()
63+
if err != nil {
64+
ri.checkFinished()
65+
return nil, err
66+
}
67+
68+
r2 := ri.skipReturnedRows(r)
69+
ri.checkFinished()
70+
if r2 == nil {
71+
continue
72+
}
73+
74+
return r2, nil
5075
}
76+
}
5177

52-
// make sure we have the current batch
53-
err := ri.getCurrentBatch()
54-
if err != nil {
55-
return nil, err
78+
func (ri *arrowRecordIterator) skipReturnedRows(r SparkArrowRecord) arrow.Record {
79+
if !ri.hasNextRowNumber {
80+
ri.nextRowNumber = r.Start()
81+
ri.hasNextRowNumber = true
5682
}
5783

58-
// return next record in current batch
59-
r, err := ri.currentBatch.Next()
84+
if r.End() < ri.nextRowNumber {
85+
r.Release()
86+
return nil
87+
}
6088

61-
ri.checkFinished()
89+
if r.Start() < ri.nextRowNumber {
90+
start := ri.nextRowNumber - r.Start()
91+
sliced := r.NewSlice(start, r.NumRows())
92+
r.Release()
93+
if sliced == nil {
94+
return nil
95+
}
96+
ri.nextRowNumber += sliced.NumRows()
97+
return sliced
98+
}
6299

63-
return r, err
100+
ri.nextRowNumber = r.End() + 1
101+
return r
64102
}
65103

66104
// Indicate whether there are any more records available

internal/rows/arrowbased/arrowRecordIterator_test.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"os"
99
"testing"
1010

11+
"github.com/apache/arrow/go/v12/arrow"
1112
"github.com/databricks/databricks-sql-go/driverctx"
1213
"github.com/databricks/databricks-sql-go/internal/cli_service"
1314
"github.com/databricks/databricks-sql-go/internal/client"
@@ -193,6 +194,114 @@ func TestArrowRecordIterator(t *testing.T) {
193194
})
194195
}
195196

197+
func TestArrowRecordIterator_SkipsOverlappingReturnedRows(t *testing.T) {
198+
var releasedOverlappingRecord bool
199+
var slicedStart int64
200+
var slicedEnd int64
201+
202+
overlappingRecord := fakeRecord{
203+
fnNumRows: func() int64 { return 10 },
204+
fnRelease: func() {
205+
releasedOverlappingRecord = true
206+
},
207+
fnNewSlice: func(i, j int64) arrow.Record {
208+
slicedStart = i
209+
slicedEnd = j
210+
return fakeRecord{fnNumRows: func() int64 { return j - i }}
211+
},
212+
}
213+
214+
rs := &arrowRecordIterator{
215+
batchIterator: &fakeBatchIterator{
216+
index: -1,
217+
batches: []SparkArrowBatch{
218+
&sparkArrowBatch{
219+
Delimiter: rowscanner.NewDelimiter(0, 5),
220+
arrowRecords: []SparkArrowRecord{
221+
&sparkArrowRecord{
222+
Delimiter: rowscanner.NewDelimiter(0, 5),
223+
Record: fakeRecord{fnNumRows: func() int64 { return 5 }},
224+
},
225+
},
226+
},
227+
&sparkArrowBatch{
228+
Delimiter: rowscanner.NewDelimiter(0, 10),
229+
arrowRecords: []SparkArrowRecord{
230+
&sparkArrowRecord{
231+
Delimiter: rowscanner.NewDelimiter(0, 10),
232+
Record: overlappingRecord,
233+
},
234+
},
235+
},
236+
},
237+
},
238+
}
239+
defer rs.Close()
240+
241+
r1, err := rs.Next()
242+
assert.NoError(t, err)
243+
assert.Equal(t, int64(5), r1.NumRows())
244+
r1.Release()
245+
246+
r2, err := rs.Next()
247+
assert.NoError(t, err)
248+
assert.Equal(t, int64(5), r2.NumRows())
249+
r2.Release()
250+
251+
assert.True(t, releasedOverlappingRecord)
252+
assert.Equal(t, int64(5), slicedStart)
253+
assert.Equal(t, int64(10), slicedEnd)
254+
255+
r3, err := rs.Next()
256+
assert.Nil(t, r3)
257+
assert.ErrorIs(t, err, io.EOF)
258+
}
259+
260+
func TestLimitArrowRecordsUsesExpectedRowCount(t *testing.T) {
261+
var releasedOriginal bool
262+
var releasedExtra bool
263+
var slicedStart int64
264+
var slicedEnd int64
265+
266+
records := []SparkArrowRecord{
267+
&sparkArrowRecord{
268+
Delimiter: rowscanner.NewDelimiter(10, 5),
269+
Record: fakeRecord{
270+
fnNumRows: func() int64 { return 5 },
271+
fnRelease: func() {
272+
releasedOriginal = true
273+
},
274+
fnNewSlice: func(i, j int64) arrow.Record {
275+
slicedStart = i
276+
slicedEnd = j
277+
return fakeRecord{fnNumRows: func() int64 { return j - i }}
278+
},
279+
},
280+
},
281+
&sparkArrowRecord{
282+
Delimiter: rowscanner.NewDelimiter(15, 5),
283+
Record: fakeRecord{
284+
fnNumRows: func() int64 { return 5 },
285+
fnRelease: func() {
286+
releasedExtra = true
287+
},
288+
},
289+
},
290+
}
291+
292+
limited := limitArrowRecords(records, 3)
293+
defer limited[0].Release()
294+
295+
assert.Len(t, limited, 1)
296+
assert.Equal(t, int64(10), limited[0].Start())
297+
assert.Equal(t, int64(3), limited[0].Count())
298+
assert.Equal(t, int64(3), limited[0].NumRows())
299+
assert.True(t, releasedOriginal)
300+
assert.True(t, releasedExtra)
301+
assert.Equal(t, int64(0), slicedStart)
302+
assert.Equal(t, int64(3), slicedEnd)
303+
}
304+
196305
func TestArrowRecordIteratorSchema(t *testing.T) {
197306
// Test with arrowSchemaBytes available
198307
t.Run("schema with initial schema bytes", func(t *testing.T) {

internal/rows/arrowbased/batchloader.go

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ type IPCStreamIterator interface {
3131
Close()
3232
}
3333

34+
type positionedIPCStreamIterator interface {
35+
NextWithMetadata() (io.Reader, int64, int64, error)
36+
}
37+
3438
func NewCloudIPCStreamIterator(
3539
ctx context.Context,
3640
files []*cli_service.TSparkArrowResultLink,
@@ -136,21 +140,30 @@ type localIPCStreamIterator struct {
136140
var _ IPCStreamIterator = (*localIPCStreamIterator)(nil)
137141

138142
func (bi *localIPCStreamIterator) Next() (io.Reader, error) {
143+
reader, _, _, err := bi.NextWithMetadata()
144+
return reader, err
145+
}
146+
147+
func (bi *localIPCStreamIterator) NextWithMetadata() (io.Reader, int64, int64, error) {
139148
cnt := len(bi.batches)
140149
bi.index++
141150
if bi.index < cnt {
142151
ab := bi.batches[bi.index]
152+
startRowOffset := bi.startRowOffset
153+
for i := 0; i < bi.index; i++ {
154+
startRowOffset += bi.batches[i].RowCount
155+
}
143156

144157
reader := io.MultiReader(
145158
bytes.NewReader(bi.arrowSchemaBytes),
146159
getReader(bytes.NewReader(ab.Batch), bi.cfg.UseLz4Compression),
147160
)
148161

149-
return reader, nil
162+
return reader, startRowOffset, ab.RowCount, nil
150163
}
151164

152165
bi.index = cnt
153-
return nil, io.EOF
166+
return nil, 0, 0, io.EOF
154167
}
155168

156169
func (bi *localIPCStreamIterator) HasNext() bool {
@@ -176,6 +189,11 @@ type cloudIPCStreamIterator struct {
176189
var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil)
177190

178191
func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) {
192+
reader, _, _, err := bi.NextWithMetadata()
193+
return reader, err
194+
}
195+
196+
func (bi *cloudIPCStreamIterator) NextWithMetadata() (io.Reader, int64, int64, error) {
179197
for (bi.downloadTasks.Len() < bi.cfg.MaxDownloadThreads) && (bi.pendingLinks.Len() > 0) {
180198
link := bi.pendingLinks.Dequeue()
181199
logger.Debug().Msgf(
@@ -204,15 +222,15 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) {
204222

205223
task := bi.downloadTasks.Dequeue()
206224
if task == nil {
207-
return nil, io.EOF
225+
return nil, 0, 0, io.EOF
208226
}
209227

210228
data, downloadMs, err := task.GetResult()
211229

212230
// once we've got an errored out task - cancel the remaining ones
213231
if err != nil {
214232
bi.Close()
215-
return nil, err
233+
return nil, 0, 0, err
216234
}
217235

218236
// explicitly call cancel function on successfully completed task to avoid context leak
@@ -226,7 +244,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) {
226244
bi.onFileDownloaded(downloadMs)
227245
}
228246

229-
return data, nil
247+
return data, task.link.StartRowOffset, task.link.RowCount, nil
230248
}
231249

232250
func (bi *cloudIPCStreamIterator) HasNext() bool {
@@ -558,15 +576,26 @@ func NewBatchIterator(ipcIterator IPCStreamIterator, startRowOffset int64) Batch
558576
}
559577

560578
func (bi *batchIterator) Next() (SparkArrowBatch, error) {
561-
reader, err := bi.ipcIterator.Next()
579+
startRowOffset := bi.startRowOffset
580+
expectedRows := int64(-1)
581+
var reader io.Reader
582+
var err error
583+
if positionedIterator, ok := bi.ipcIterator.(positionedIPCStreamIterator); ok {
584+
reader, startRowOffset, expectedRows, err = positionedIterator.NextWithMetadata()
585+
} else {
586+
reader, err = bi.ipcIterator.Next()
587+
}
562588
if err != nil {
563589
return nil, err
564590
}
565591

566-
records, err := getArrowRecords(reader, bi.startRowOffset)
592+
records, err := getArrowRecords(reader, startRowOffset)
567593
if err != nil {
568594
return nil, err
569595
}
596+
if expectedRows >= 0 {
597+
records = limitArrowRecords(records, expectedRows)
598+
}
570599

571600
// When using CloudFetch, cached Arrow IPC files may contain stale column
572601
// names from a previous query. Replace the embedded schema with the
@@ -593,14 +622,48 @@ func (bi *batchIterator) Next() (SparkArrowBatch, error) {
593622
}
594623

595624
batch := &sparkArrowBatch{
596-
Delimiter: rowscanner.NewDelimiter(bi.startRowOffset, totalRows),
625+
Delimiter: rowscanner.NewDelimiter(startRowOffset, totalRows),
597626
arrowRecords: records,
598627
}
599628

600-
bi.startRowOffset += totalRows
629+
bi.startRowOffset = startRowOffset + totalRows
601630
return batch, nil
602631
}
603632

633+
func limitArrowRecords(records []SparkArrowRecord, expectedRows int64) []SparkArrowRecord {
634+
if expectedRows < 0 {
635+
return records
636+
}
637+
638+
remaining := expectedRows
639+
limited := records[:0]
640+
for _, record := range records {
641+
if remaining <= 0 {
642+
record.Release()
643+
continue
644+
}
645+
646+
if record.NumRows() <= remaining {
647+
limited = append(limited, record)
648+
remaining -= record.NumRows()
649+
continue
650+
}
651+
652+
start := record.Start()
653+
sliced := record.NewSlice(0, remaining)
654+
record.Release()
655+
if sliced != nil {
656+
limited = append(limited, &sparkArrowRecord{
657+
Delimiter: rowscanner.NewDelimiter(start, sliced.NumRows()),
658+
Record: sliced,
659+
})
660+
}
661+
remaining = 0
662+
}
663+
664+
return limited
665+
}
666+
604667
func (bi *batchIterator) HasNext() bool {
605668
return bi.ipcIterator.HasNext()
606669
}

0 commit comments

Comments
 (0)