Skip to content

Commit 2366458

Browse files
committed
Cap CloudFetch Arrow batches to server-declared RowCount
CloudFetch Arrow IPC files can contain padding rows beyond the RowCount declared on each result link, which the driver surfaced as extra rows (e.g. 301,407 rows returned for SELECT ... LIMIT 300000). Cap each decoded CloudFetch batch to its link's RowCount and anchor batch offsets to the link's StartRowOffset. This matches the official JDBC driver, whose ArrowResultChunkIterator stops iterating once rowsReadByIterator >= numRows (the server-declared TSparkArrowResultLink.RowCount), silently ignoring any padding rows in the Arrow file. The cap is scoped to the CloudFetch path only, via a new positionedIPCStreamIterator interface implemented solely by cloudIPCStreamIterator. The inline/local Arrow path is intentionally left uncapped: those batches are returned verbatim with no padding and their per-batch RowCount has historically been untrusted, so capping there could silently drop rows. A RowCount <= 0 is treated as "unknown" and never drops rows. Adds a unit test driving batchIterator -> limitArrowRecords (cap-down, exact boundary, over-count, RowCount==0 safety, inline-never-capped) and an env-gated CloudFetch e2e asserting an exact 2,000,000-row drain. Co-authored-by: Isaac
1 parent 94154ac commit 2366458

4 files changed

Lines changed: 283 additions & 7 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Release History
22

3+
## Unreleased
4+
- Fix CloudFetch results over-reporting rows: Arrow IPC files can carry padding rows beyond a result link's server-declared `RowCount`, which surfaced as extra rows (e.g. 301,407 returned for a `LIMIT 300000`). Decoded CloudFetch batches are now capped to the server-declared `RowCount`, matching the JDBC driver's behavior. The inline/local Arrow path is unchanged (databricks/databricks-sql-go#XXX)
5+
36
## v1.12.0 (2026-05-25)
47
- Retry transient S3 errors in CloudFetch downloads and staging PUT/GET/REMOVE operations (databricks/databricks-sql-go#355, #361)
58
- Telemetry: normalize host key for per-host client + breaker registries; stop retrying into 429s, honour Retry-After, fix userAgent (databricks/databricks-sql-go#354, #364)

driver_e2e_test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ package dbsql
33
import (
44
"context"
55
"database/sql"
6+
"database/sql/driver"
67
"encoding/json"
78
"fmt"
9+
"io"
810
"net/http/httptest"
911
"net/url"
1012
"os"
@@ -17,6 +19,7 @@ import (
1719
"github.com/databricks/databricks-sql-go/internal/cli_service"
1820
"github.com/databricks/databricks-sql-go/internal/client"
1921
"github.com/databricks/databricks-sql-go/logger"
22+
dbsqlrows "github.com/databricks/databricks-sql-go/rows"
2023
"github.com/pkg/errors"
2124
"github.com/stretchr/testify/assert"
2225
"github.com/stretchr/testify/require"
@@ -592,3 +595,71 @@ func getServer(state *callState) *httptest.Server {
592595
},
593596
})
594597
}
598+
599+
// TestE2ECloudFetchExactRowCount validates that a large CloudFetch result drains
600+
// the EXACT number of rows requested. CloudFetch Arrow IPC files can carry padding
601+
// rows beyond a link's server-declared RowCount; without capping to RowCount the
602+
// driver over-reports (e.g. 301,407 rows for a LIMIT 300000). This is the
603+
// regression guard for the row-count cap. Skipped in -short mode because it
604+
// drains a multi-million-row result over several CloudFetch link pages.
605+
func TestE2ECloudFetchExactRowCount(t *testing.T) {
606+
if testing.Short() {
607+
t.Skip("skipping large CloudFetch drain in -short mode")
608+
}
609+
host := os.Getenv("DATABRICKS_PECOTESTING_SERVER_HOSTNAME")
610+
httpPath := os.Getenv("DATABRICKS_PECOTESTING_HTTP_PATH2")
611+
token := os.Getenv("DATABRICKS_PECOTESTING_TOKEN")
612+
if token == "" {
613+
token = os.Getenv("DATABRICKS_PECOTESTING_TOKEN_PERSONAL")
614+
}
615+
if host == "" || httpPath == "" || token == "" {
616+
t.Skip("set DATABRICKS_PECOTESTING_SERVER_HOSTNAME, DATABRICKS_PECOTESTING_HTTP_PATH2, and DATABRICKS_PECOTESTING_TOKEN to run")
617+
}
618+
619+
const wantRows = 2000000
620+
621+
connector, err := NewConnector(
622+
WithServerHostname(host),
623+
WithPort(443),
624+
WithHTTPPath(httpPath),
625+
WithAccessToken(token),
626+
WithMaxRows(500000),
627+
)
628+
require.NoError(t, err)
629+
630+
db := sql.OpenDB(connector)
631+
defer db.Close() //nolint:errcheck
632+
633+
conn, err := db.Conn(context.Background())
634+
require.NoError(t, err)
635+
defer conn.Close() //nolint:errcheck
636+
637+
// A wide-ish row (id + 64-byte pad) over 2M rows forces a multi-page
638+
// CloudFetch (URL-based) result rather than inline Arrow.
639+
query := fmt.Sprintf("SELECT id, repeat('x', 64) AS pad FROM range(%d)", wantRows)
640+
var driverRows driver.Rows
641+
err = conn.Raw(func(d any) error {
642+
var queryErr error
643+
driverRows, queryErr = d.(driver.QueryerContext).QueryContext(context.Background(), query, nil)
644+
return queryErr
645+
})
646+
require.NoError(t, err)
647+
defer driverRows.Close() //nolint:errcheck
648+
649+
batches, err := driverRows.(dbsqlrows.Rows).GetArrowBatches(context.Background())
650+
require.NoError(t, err)
651+
defer batches.Close()
652+
653+
var rowCount int64
654+
for {
655+
record, nextErr := batches.Next()
656+
if nextErr == io.EOF {
657+
break
658+
}
659+
require.NoError(t, nextErr)
660+
rowCount += record.NumRows()
661+
record.Release()
662+
}
663+
664+
require.Equal(t, int64(wantRows), rowCount, "CloudFetch must surface exactly the requested rows, with no Arrow padding")
665+
}

internal/rows/arrowbased/batchloader.go

Lines changed: 107 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,27 @@ type IPCStreamIterator interface {
3131
Close()
3232
}
3333

34+
// positionedIPCStreamIterator is an optional extension of IPCStreamIterator for
35+
// streams that carry server-declared positioning metadata alongside each IPC
36+
// payload. It is implemented ONLY by the CloudFetch iterator: CloudFetch result
37+
// links carry an authoritative StartRowOffset and RowCount, and the Arrow IPC
38+
// files they point at may be padded with extra rows beyond RowCount. batchIterator
39+
// uses this metadata to (a) anchor each batch at its true stream offset and
40+
// (b) cap the decoded records to RowCount so padding rows are not surfaced as
41+
// real data (see limitArrowRecords).
42+
//
43+
// The inline/local Arrow path intentionally does NOT implement this: those
44+
// batches are returned verbatim by the server with no padding, and their
45+
// per-batch RowCount has historically been untrusted, so capping there would
46+
// risk silently dropping rows. NextWithMetadata returns expectedRows < 0 to mean
47+
// "row count unknown — do not cap".
48+
type positionedIPCStreamIterator interface {
49+
// NextWithMetadata returns the next IPC payload along with its absolute
50+
// stream start offset and the server-declared row count (expectedRows). An
51+
// expectedRows < 0 means the count is unknown and no capping should occur.
52+
NextWithMetadata() (reader io.Reader, startRowOffset int64, expectedRows int64, err error)
53+
}
54+
3455
func NewCloudIPCStreamIterator(
3556
ctx context.Context,
3657
files []*cli_service.TSparkArrowResultLink,
@@ -174,8 +195,17 @@ type cloudIPCStreamIterator struct {
174195
}
175196

176197
var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil)
198+
var _ positionedIPCStreamIterator = (*cloudIPCStreamIterator)(nil)
177199

178200
func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) {
201+
reader, _, _, err := bi.NextWithMetadata()
202+
return reader, err
203+
}
204+
205+
// NextWithMetadata returns the next downloaded CloudFetch IPC payload together
206+
// with the link's authoritative StartRowOffset and RowCount. The Arrow file may
207+
// contain padding rows beyond RowCount; the caller caps to RowCount.
208+
func (bi *cloudIPCStreamIterator) NextWithMetadata() (io.Reader, int64, int64, error) {
179209
for (bi.downloadTasks.Len() < bi.cfg.MaxDownloadThreads) && (bi.pendingLinks.Len() > 0) {
180210
link := bi.pendingLinks.Dequeue()
181211
logger.Debug().Msgf(
@@ -204,15 +234,15 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) {
204234

205235
task := bi.downloadTasks.Dequeue()
206236
if task == nil {
207-
return nil, io.EOF
237+
return nil, 0, 0, io.EOF
208238
}
209239

210240
data, downloadMs, err := task.GetResult()
211241

212242
// once we've got an errored out task - cancel the remaining ones
213243
if err != nil {
214244
bi.Close()
215-
return nil, err
245+
return nil, 0, 0, err
216246
}
217247

218248
// explicitly call cancel function on successfully completed task to avoid context leak
@@ -226,7 +256,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) {
226256
bi.onFileDownloaded(downloadMs)
227257
}
228258

229-
return data, nil
259+
return data, task.link.StartRowOffset, task.link.RowCount, nil
230260
}
231261

232262
func (bi *cloudIPCStreamIterator) HasNext() bool {
@@ -558,16 +588,38 @@ func NewBatchIterator(ipcIterator IPCStreamIterator, startRowOffset int64) Batch
558588
}
559589

560590
func (bi *batchIterator) Next() (SparkArrowBatch, error) {
561-
reader, err := bi.ipcIterator.Next()
591+
// startRowOffset is the absolute offset of this batch within the result
592+
// stream. For positioned (CloudFetch) streams it comes from the server's
593+
// link metadata; otherwise we track it locally by accumulating decoded rows.
594+
startRowOffset := bi.startRowOffset
595+
// expectedRows is the server-declared row count for this batch. A value < 0
596+
// means "unknown" and disables capping (the inline/local path).
597+
expectedRows := int64(-1)
598+
var reader io.Reader
599+
var err error
600+
if positionedIterator, ok := bi.ipcIterator.(positionedIPCStreamIterator); ok {
601+
reader, startRowOffset, expectedRows, err = positionedIterator.NextWithMetadata()
602+
} else {
603+
reader, err = bi.ipcIterator.Next()
604+
}
562605
if err != nil {
563606
return nil, err
564607
}
565608

566-
records, err := getArrowRecords(reader, bi.startRowOffset)
609+
records, err := getArrowRecords(reader, startRowOffset)
567610
if err != nil {
568611
return nil, err
569612
}
570613

614+
// Cap the decoded records to the server-declared row count, dropping the
615+
// padding rows some CloudFetch Arrow files carry beyond their link's
616+
// RowCount. Only cap when the count is strictly positive: expectedRows == 0
617+
// with decoded rows is treated as "untrustworthy, do not cap" rather than
618+
// silently dropping the whole batch (see #371 review F1).
619+
if expectedRows > 0 {
620+
records = limitArrowRecords(records, expectedRows)
621+
}
622+
571623
// When using CloudFetch, cached Arrow IPC files may contain stale column
572624
// names from a previous query. Replace the embedded schema with the
573625
// authoritative schema from GetResultSetMetadata.
@@ -593,14 +645,62 @@ func (bi *batchIterator) Next() (SparkArrowBatch, error) {
593645
}
594646

595647
batch := &sparkArrowBatch{
596-
Delimiter: rowscanner.NewDelimiter(bi.startRowOffset, totalRows),
648+
Delimiter: rowscanner.NewDelimiter(startRowOffset, totalRows),
597649
arrowRecords: records,
598650
}
599651

600-
bi.startRowOffset += totalRows
652+
// Advance the local offset for the next non-positioned batch. Positioned
653+
// streams overwrite startRowOffset from server metadata on the next call.
654+
bi.startRowOffset = startRowOffset + totalRows
601655
return batch, nil
602656
}
603657

658+
// limitArrowRecords caps a decoded batch to expectedRows, releasing any records
659+
// (and the tail of a partially-kept record) that fall beyond the server-declared
660+
// count. It is the mechanism that strips CloudFetch Arrow padding rows.
661+
//
662+
// Contract:
663+
// - Callers must only invoke this when expectedRows is trustworthy and the
664+
// batch may be over-long; expectedRows < 0 is treated as "unknown" and the
665+
// records are returned unchanged.
666+
// - When a record straddles the boundary it is sliced with NewSlice(0, remaining):
667+
// the slice bounds are record-relative (0-based within the record), while the
668+
// Delimiter's start is the ABSOLUTE stream offset of the record. Keep these two
669+
// distinct — do not pass the absolute start as a slice index.
670+
func limitArrowRecords(records []SparkArrowRecord, expectedRows int64) []SparkArrowRecord {
671+
if expectedRows < 0 {
672+
return records
673+
}
674+
675+
remaining := expectedRows
676+
limited := records[:0]
677+
for _, record := range records {
678+
if remaining <= 0 {
679+
record.Release()
680+
continue
681+
}
682+
683+
if record.NumRows() <= remaining {
684+
limited = append(limited, record)
685+
remaining -= record.NumRows()
686+
continue
687+
}
688+
689+
start := record.Start()
690+
sliced := record.NewSlice(0, remaining)
691+
record.Release()
692+
if sliced != nil {
693+
limited = append(limited, &sparkArrowRecord{
694+
Delimiter: rowscanner.NewDelimiter(start, sliced.NumRows()),
695+
Record: sliced,
696+
})
697+
}
698+
remaining = 0
699+
}
700+
701+
return limited
702+
}
703+
604704
func (bi *batchIterator) HasNext() bool {
605705
return bi.ipcIterator.HasNext()
606706
}

internal/rows/arrowbased/batchloader_test.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"context"
66
"fmt"
7+
"io"
78
"net/http"
89
"net/http/httptest"
910
"runtime"
@@ -1160,3 +1161,104 @@ func countDownloadTaskGoroutines() int {
11601161
}
11611162
return strings.Count(string(buf), "cloudFetchDownloadTask).Run")
11621163
}
1164+
1165+
// fakePositionedIPCIterator is a test IPCStreamIterator that also implements
1166+
// positionedIPCStreamIterator, so it exercises the CloudFetch row-count-capping
1167+
// path through NewBatchIterator without real CloudFetch downloads.
1168+
type fakePositionedIPCIterator struct {
1169+
data []byte
1170+
startRowOffset int64
1171+
expectedRows int64
1172+
consumed bool
1173+
}
1174+
1175+
var _ IPCStreamIterator = (*fakePositionedIPCIterator)(nil)
1176+
var _ positionedIPCStreamIterator = (*fakePositionedIPCIterator)(nil)
1177+
1178+
func (f *fakePositionedIPCIterator) Next() (io.Reader, error) {
1179+
r, _, _, err := f.NextWithMetadata()
1180+
return r, err
1181+
}
1182+
func (f *fakePositionedIPCIterator) NextWithMetadata() (io.Reader, int64, int64, error) {
1183+
if f.consumed {
1184+
return nil, 0, 0, io.EOF
1185+
}
1186+
f.consumed = true
1187+
return bytes.NewReader(f.data), f.startRowOffset, f.expectedRows, nil
1188+
}
1189+
func (f *fakePositionedIPCIterator) HasNext() bool { return !f.consumed }
1190+
func (f *fakePositionedIPCIterator) Close() {}
1191+
1192+
// fakePlainIPCIterator implements only IPCStreamIterator (the inline/local
1193+
// shape) so the cap must never apply to it.
1194+
type fakePlainIPCIterator struct {
1195+
data []byte
1196+
consumed bool
1197+
}
1198+
1199+
var _ IPCStreamIterator = (*fakePlainIPCIterator)(nil)
1200+
1201+
func (f *fakePlainIPCIterator) Next() (io.Reader, error) {
1202+
if f.consumed {
1203+
return nil, io.EOF
1204+
}
1205+
f.consumed = true
1206+
return bytes.NewReader(f.data), nil
1207+
}
1208+
func (f *fakePlainIPCIterator) HasNext() bool { return !f.consumed }
1209+
func (f *fakePlainIPCIterator) Close() {}
1210+
1211+
// TestBatchIterator_RowCountCap covers the batchIterator -> limitArrowRecords
1212+
// integration (#371 review F1/F6). generateMockArrowBytes writes the 3-row
1213+
// record twice, so each stream decodes to 6 rows.
1214+
func TestBatchIterator_RowCountCap(t *testing.T) {
1215+
const decoded = 6
1216+
const startOffset int64 = 100
1217+
1218+
t.Run("positioned: caps padding rows down to RowCount", func(t *testing.T) {
1219+
it := &fakePositionedIPCIterator{data: generateMockArrowBytes(generateArrowRecord()), startRowOffset: startOffset, expectedRows: 4}
1220+
bi := NewBatchIterator(it, startOffset)
1221+
batch, err := bi.Next()
1222+
assert.NoError(t, err)
1223+
defer batch.Close()
1224+
assert.Equal(t, int64(4), batch.Count(), "batch should be capped to RowCount")
1225+
assert.Equal(t, startOffset, batch.Start(), "batch must anchor at the server offset")
1226+
})
1227+
1228+
t.Run("positioned: exact boundary keeps all rows", func(t *testing.T) {
1229+
it := &fakePositionedIPCIterator{data: generateMockArrowBytes(generateArrowRecord()), startRowOffset: startOffset, expectedRows: decoded}
1230+
bi := NewBatchIterator(it, startOffset)
1231+
batch, err := bi.Next()
1232+
assert.NoError(t, err)
1233+
defer batch.Close()
1234+
assert.Equal(t, int64(decoded), batch.Count())
1235+
})
1236+
1237+
t.Run("positioned: RowCount larger than decoded keeps all rows", func(t *testing.T) {
1238+
it := &fakePositionedIPCIterator{data: generateMockArrowBytes(generateArrowRecord()), startRowOffset: startOffset, expectedRows: 100}
1239+
bi := NewBatchIterator(it, startOffset)
1240+
batch, err := bi.Next()
1241+
assert.NoError(t, err)
1242+
defer batch.Close()
1243+
assert.Equal(t, int64(decoded), batch.Count())
1244+
})
1245+
1246+
t.Run("positioned: RowCount==0 is NOT trusted, keeps all rows (F1)", func(t *testing.T) {
1247+
it := &fakePositionedIPCIterator{data: generateMockArrowBytes(generateArrowRecord()), startRowOffset: startOffset, expectedRows: 0}
1248+
bi := NewBatchIterator(it, startOffset)
1249+
batch, err := bi.Next()
1250+
assert.NoError(t, err)
1251+
defer batch.Close()
1252+
assert.Equal(t, int64(decoded), batch.Count(), "RowCount==0 must not silently drop the batch")
1253+
})
1254+
1255+
t.Run("plain/inline iterator is never capped", func(t *testing.T) {
1256+
it := &fakePlainIPCIterator{data: generateMockArrowBytes(generateArrowRecord())}
1257+
bi := NewBatchIterator(it, startOffset)
1258+
batch, err := bi.Next()
1259+
assert.NoError(t, err)
1260+
defer batch.Close()
1261+
assert.Equal(t, int64(decoded), batch.Count(), "inline path must return all decoded rows")
1262+
assert.Equal(t, startOffset, batch.Start())
1263+
})
1264+
}

0 commit comments

Comments
 (0)