Skip to content

Commit 425c700

Browse files
authored
[ES-1934053] Detach result streaming from QueryContext cancellation (#373)
> Supersedes #371 (which was opened from a fork and therefore could not run the required JFrog-dependent CI checks — forks don't receive the OIDC `id-token`). Same context-only fix, same-repo branch so CI can run. Review history and discussion are on #371. ## Summary Fixes the Sev1 result-streaming truncation in **ES-1934053**. Since #295, the caller `QueryContext` was threaded into `NewRows` → `ResultPageIterator`, so result paging (`FetchResults`) and `CloseOperation` inherited the caller's deadline. A short timeout meant to gate only statement submission + status polling then fired mid-stream and **silently truncated** large CloudFetch results (a query expected to return 29,232,004 rows returned only 2,159,144; the `ArrowBatchIterator` surfaced `io.EOF` rather than the deadline error). This PR is **context-fix-only**. The Arrow row-count cap that was originally bundled with this work is split into #372 (an independent *over-reporting* fix). ## What changed - Detach the result context from the caller's cancellation via `context.WithoutCancel`, preserving its values for auth/logging. - Wire the detached context to a cancel func invoked from `Rows.Close()`, so in-flight `FetchResults` and CloudFetch downloads are never left uncancellable (addresses review finding **F2**). - `GetArrowBatches`/`GetArrowIPCStreams` now build their iterator from the detached context, so CloudFetch S3 downloads also survive the caller's deadline and remain abortable via `Close` (addresses **F3** — previously paging was detached but downloads were not). ## Test plan - `go test ./internal/rows/...` — adds `TestNewRows_DetachesResultRPCContextFromQueryContextCancellation` and `TestNewRows_CloseAbortsDetachedResultContext`. - `go test ./...` — full suite green. - E2E against a real SQL warehouse: reproduced the regression (100,000-row stream truncated to **12,288** rows under a 2s mid-stream deadline) and confirmed the fix drains the full result. The e2e regression passes the **cancelled** `QueryContext` into `GetArrowBatches` (previously `context.Background()`, which masked the download path). This pull request and its description were written by Isaac. Signed-off-by: Madhavendra Rathore <madhavendra.rathore@databricks.com>
1 parent 94154ac commit 425c700

3 files changed

Lines changed: 245 additions & 10 deletions

File tree

driver_e2e_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ package dbsql
33
import (
44
"context"
55
"database/sql"
6+
"database/sql/driver"
67
"encoding/json"
78
"fmt"
9+
"io"
810
"net/http/httptest"
911
"net/url"
1012
"os"
@@ -17,6 +19,7 @@ import (
1719
"github.com/databricks/databricks-sql-go/internal/cli_service"
1820
"github.com/databricks/databricks-sql-go/internal/client"
1921
"github.com/databricks/databricks-sql-go/logger"
22+
dbsqlrows "github.com/databricks/databricks-sql-go/rows"
2023
"github.com/pkg/errors"
2124
"github.com/stretchr/testify/assert"
2225
"github.com/stretchr/testify/require"
@@ -255,6 +258,68 @@ func TestWorkflowExample(t *testing.T) {
255258
}
256259
}
257260

261+
func TestE2EArrowBatchesSurviveQueryContextCancellation(t *testing.T) {
262+
host := os.Getenv("DATABRICKS_PECOTESTING_SERVER_HOSTNAME")
263+
httpPath := os.Getenv("DATABRICKS_PECOTESTING_HTTP_PATH2")
264+
token := os.Getenv("DATABRICKS_PECOTESTING_TOKEN")
265+
if token == "" {
266+
token = os.Getenv("DATABRICKS_PECOTESTING_TOKEN_PERSONAL")
267+
}
268+
if host == "" || httpPath == "" || token == "" {
269+
t.Skip("set DATABRICKS_PECOTESTING_SERVER_HOSTNAME, DATABRICKS_PECOTESTING_HTTP_PATH2, and DATABRICKS_PECOTESTING_TOKEN to run")
270+
}
271+
272+
connector, err := NewConnector(
273+
WithServerHostname(host),
274+
WithPort(443),
275+
WithHTTPPath(httpPath),
276+
WithAccessToken(token),
277+
WithMaxRows(1),
278+
)
279+
require.NoError(t, err)
280+
281+
db := sql.OpenDB(connector)
282+
defer db.Close() //nolint:errcheck
283+
284+
conn, err := db.Conn(context.Background())
285+
require.NoError(t, err)
286+
defer conn.Close() //nolint:errcheck
287+
288+
queryCtx, cancel := context.WithCancel(context.Background())
289+
defer cancel()
290+
291+
var driverRows driver.Rows
292+
err = conn.Raw(func(d any) error {
293+
var queryErr error
294+
driverRows, queryErr = d.(driver.QueryerContext).QueryContext(queryCtx, "SELECT id FROM range(3)", nil)
295+
return queryErr
296+
})
297+
require.NoError(t, err)
298+
defer driverRows.Close() //nolint:errcheck
299+
300+
cancel()
301+
302+
// Pass the already-cancelled queryCtx (not context.Background()) so the test
303+
// exercises the detached-iterator path: result paging AND CloudFetch
304+
// downloads must survive cancellation of the ctx handed to GetArrowBatches.
305+
batches, err := driverRows.(dbsqlrows.Rows).GetArrowBatches(queryCtx)
306+
require.NoError(t, err)
307+
defer batches.Close()
308+
309+
var rowCount int64
310+
for {
311+
record, nextErr := batches.Next()
312+
if nextErr == io.EOF {
313+
break
314+
}
315+
require.NoError(t, nextErr)
316+
rowCount += record.NumRows()
317+
record.Release()
318+
}
319+
320+
require.Equal(t, int64(3), rowCount)
321+
}
322+
258323
func TestContextTimeoutExample(t *testing.T) {
259324

260325
_ = logger.SetLogLevel("debug")

internal/rows/rows.go

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
1414
"github.com/databricks/databricks-sql-go/internal/cli_service"
1515
dbsqlclient "github.com/databricks/databricks-sql-go/internal/client"
16+
context2 "github.com/databricks/databricks-sql-go/internal/compat/context"
1617
"github.com/databricks/databricks-sql-go/internal/config"
1718
dbsqlerr_int "github.com/databricks/databricks-sql-go/internal/errors"
1819
"github.com/databricks/databricks-sql-go/internal/rows/arrowbased"
@@ -57,7 +58,16 @@ type rows struct {
5758

5859
logger_ *dbsqllog.DBSQLLogger
5960

61+
// ctx is the context used for all server-side result RPCs (FetchResults,
62+
// GetResultSetMetadata, CloseOperation) and CloudFetch downloads. It is
63+
// detached from the caller's QueryContext cancellation so that a deadline
64+
// gating statement submission does not truncate result streaming, while
65+
// preserving context values used for auth/logging. It remains abortable via
66+
// Close() through resultsCancel.
6067
ctx context.Context
68+
// resultsCancel aborts in-flight result RPCs/downloads when Close() is
69+
// called, so the detached ctx never leaves an operation uncancellable.
70+
resultsCancel context.CancelFunc
6171

6272
// Telemetry tracking
6373
// telemetryUpdate is called after each chunk is fetched with:
@@ -134,6 +144,15 @@ func NewRows(
134144

135145
logger.Debug().Msgf("databricks: creating Rows, pageSize: %d, location: %v", pageSize, location)
136146

147+
// QueryContext may use a short deadline to gate statement submission and
148+
// status polling (see ES-1934053 / #295 / #371). Result handles can outlive
149+
// that phase, especially for paginated CloudFetch streams, so detach
150+
// server-side result RPCs from the caller's cancellation while preserving
151+
// context values used for auth/logging. The detached context is still wired
152+
// to a cancel func invoked from Close(), so the result handle remains
153+
// abortable (no uncancellable in-flight FetchResults or CloudFetch download).
154+
resultsCtx, resultsCancel := context.WithCancel(context2.WithoutCancel(ctx))
155+
137156
r := &rows{
138157
client: client,
139158
opHandle: opHandle,
@@ -142,7 +161,8 @@ func NewRows(
142161
location: location,
143162
config: config,
144163
logger_: logger,
145-
ctx: ctx,
164+
ctx: resultsCtx,
165+
resultsCancel: resultsCancel,
146166
chunkCount: 0,
147167
bytesDownloaded: 0,
148168
}
@@ -201,7 +221,7 @@ func NewRows(
201221
// the operations.
202222
closedOnServer := directResults != nil && directResults.CloseOperation != nil
203223
r.ResultPageIterator = rowscanner.NewResultPageIterator(
204-
ctx,
224+
resultsCtx,
205225
d,
206226
pageSize,
207227
opHandle,
@@ -244,6 +264,12 @@ func (r *rows) Close() error {
244264
return nil
245265
}
246266

267+
// Release the detached results context after the close RPC runs, aborting
268+
// any in-flight FetchResults/CloudFetch downloads still referencing it.
269+
if r.resultsCancel != nil {
270+
defer r.resultsCancel()
271+
}
272+
247273
if r.RowScanner != nil {
248274
// make sure the row scanner frees up any resources
249275
r.RowScanner.Close()
@@ -635,27 +661,34 @@ func (r *rows) logger() *dbsqllog.DBSQLLogger {
635661
}
636662

637663
func (r *rows) GetArrowBatches(ctx context.Context) (dbsqlrows.ArrowBatchIterator, error) {
638-
// update context with correlationId and connectionId which will be used in logging and errors
639-
ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(ctx, r.connId), r.correlationId)
664+
// Result fetching must outlive the caller's QueryContext deadline: both the
665+
// inter-page FetchResults RPCs (via r.ResultPageIterator) AND the CloudFetch
666+
// S3 downloads created from the iterator context. Build the iterator from the
667+
// detached results context (abortable via Close) rather than the caller ctx,
668+
// so passing a deadline-bound ctx here cannot truncate the stream. Driver
669+
// values for logging/auth are already carried by r.ctx; re-apply the ids
670+
// defensively. See ES-1934053 / #371.
671+
iterCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(r.ctx, r.connId), r.correlationId)
640672

641673
// If a row scanner exists we use it to create the iterator, that way the iterator includes
642674
// data returned as direct results
643675
if r.RowScanner != nil {
644-
return r.RowScanner.GetArrowBatches(ctx, *r.config, r.ResultPageIterator)
676+
return r.RowScanner.GetArrowBatches(iterCtx, *r.config, r.ResultPageIterator)
645677
}
646678

647-
return arrowbased.NewArrowRecordIterator(ctx, r.ResultPageIterator, nil, nil, *r.config), nil
679+
return arrowbased.NewArrowRecordIterator(iterCtx, r.ResultPageIterator, nil, nil, *r.config), nil
648680
}
649681

650682
func (r *rows) GetArrowIPCStreams(ctx context.Context) (dbsqlrows.ArrowIPCStreamIterator, error) {
651-
// update context with correlationId and connectionId which will be used in logging and errors
652-
ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(ctx, r.connId), r.correlationId)
683+
// See GetArrowBatches: result fetching is detached from the caller ctx so a
684+
// submit-gating deadline cannot truncate streaming; it stays abortable via Close.
685+
iterCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(r.ctx, r.connId), r.correlationId)
653686

654687
// If a row scanner exists we use it to create the iterator, that way the iterator includes
655688
// data returned as direct results
656689
if r.RowScanner != nil {
657-
return r.RowScanner.GetArrowIPCStreams(ctx, *r.config, r.ResultPageIterator)
690+
return r.RowScanner.GetArrowIPCStreams(iterCtx, *r.config, r.ResultPageIterator)
658691
}
659692

660-
return arrowbased.NewArrowIPCStreamIterator(ctx, r.ResultPageIterator, nil, nil, *r.config), nil
693+
return arrowbased.NewArrowIPCStreamIterator(iterCtx, r.ResultPageIterator, nil, nil, *r.config), nil
661694
}

internal/rows/rows_test.go

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,6 +1564,143 @@ func TestFetchResultPage_PropagatesGetNextPageError(t *testing.T) {
15641564
assert.ErrorContains(t, actualErr, errorMsg)
15651565
}
15661566

1567+
func TestNewRows_DetachesResultRPCContextFromQueryContextCancellation(t *testing.T) {
1568+
t.Parallel()
1569+
1570+
baseCtx := driverctx.NewContextWithConnId(context.Background(), "connId")
1571+
baseCtx = driverctx.NewContextWithCorrelationId(baseCtx, "corrId")
1572+
queryCtx, cancel := context.WithCancel(baseCtx)
1573+
cancel()
1574+
1575+
assertResultCtx := func(ctx context.Context) {
1576+
assert.NoError(t, ctx.Err(), "result RPC context should not inherit query cancellation")
1577+
assert.Equal(t, "connId", driverctx.ConnIdFromContext(ctx))
1578+
assert.Equal(t, "corrId", driverctx.CorrelationIdFromContext(ctx))
1579+
}
1580+
1581+
metaCalled := false
1582+
metaFn := func(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (*cli_service.TGetResultSetMetadataResp, error) {
1583+
metaCalled = true
1584+
assertResultCtx(ctx)
1585+
return &cli_service.TGetResultSetMetadataResp{
1586+
Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS},
1587+
Schema: &cli_service.TTableSchema{
1588+
Columns: []*cli_service.TColumnDesc{
1589+
{ColumnName: "flag", Position: 0, TypeDesc: &cli_service.TTypeDesc{
1590+
Types: []*cli_service.TTypeEntry{{
1591+
PrimitiveEntry: &cli_service.TPrimitiveTypeEntry{Type: cli_service.TTypeId_BOOLEAN_TYPE},
1592+
}},
1593+
}},
1594+
},
1595+
},
1596+
}, nil
1597+
}
1598+
1599+
noMoreRows := false
1600+
fetchCalled := false
1601+
fetchFn := func(ctx context.Context, req *cli_service.TFetchResultsReq) (*cli_service.TFetchResultsResp, error) {
1602+
fetchCalled = true
1603+
assertResultCtx(ctx)
1604+
return &cli_service.TFetchResultsResp{
1605+
Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS},
1606+
HasMoreRows: &noMoreRows,
1607+
Results: &cli_service.TRowSet{
1608+
StartRowOffset: 0,
1609+
Columns: []*cli_service.TColumn{
1610+
{BoolVal: &cli_service.TBoolColumn{Values: []bool{true}}},
1611+
},
1612+
},
1613+
}, nil
1614+
}
1615+
1616+
closeCalled := false
1617+
closeFn := func(ctx context.Context, req *cli_service.TCloseOperationReq) (*cli_service.TCloseOperationResp, error) {
1618+
closeCalled = true
1619+
assertResultCtx(ctx)
1620+
return &cli_service.TCloseOperationResp{
1621+
Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS},
1622+
}, nil
1623+
}
1624+
1625+
testClient := &client.TestClient{
1626+
FnFetchResults: fetchFn,
1627+
FnGetResultSetMetadata: metaFn,
1628+
FnCloseOperation: closeFn,
1629+
}
1630+
opHandle := &cli_service.TOperationHandle{
1631+
OperationId: &cli_service.THandleIdentifier{GUID: []byte("operation-id")},
1632+
}
1633+
cfg := config.WithDefaults()
1634+
1635+
dr, dbErr := NewRows(queryCtx, opHandle, testClient, cfg, nil, nil)
1636+
assert.Nil(t, dbErr)
1637+
1638+
dest := make([]driver.Value, 1)
1639+
assert.NoError(t, dr.Next(dest))
1640+
assert.Equal(t, true, dest[0])
1641+
assert.True(t, fetchCalled, "FetchResults should use the detached result context")
1642+
assert.True(t, metaCalled, "GetResultSetMetadata should use the detached result context")
1643+
assert.True(t, closeCalled, "CloseOperation should use the detached result context")
1644+
}
1645+
1646+
// TestNewRows_CloseAbortsDetachedResultContext verifies the detachment is not
1647+
// total: the result context survives the caller's QueryContext cancellation
1648+
// (so streaming is not truncated) but is still cancelled by Close(), so an
1649+
// in-flight FetchResults/CloudFetch download can never be left uncancellable.
1650+
func TestNewRows_CloseAbortsDetachedResultContext(t *testing.T) {
1651+
t.Parallel()
1652+
1653+
queryCtx, cancel := context.WithCancel(context.Background())
1654+
1655+
var capturedCtx context.Context
1656+
noMoreRows := false
1657+
fetchFn := func(ctx context.Context, req *cli_service.TFetchResultsReq) (*cli_service.TFetchResultsResp, error) {
1658+
capturedCtx = ctx
1659+
return &cli_service.TFetchResultsResp{
1660+
Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS},
1661+
HasMoreRows: &noMoreRows,
1662+
Results: &cli_service.TRowSet{
1663+
StartRowOffset: 0,
1664+
Columns: []*cli_service.TColumn{{BoolVal: &cli_service.TBoolColumn{Values: []bool{true}}}},
1665+
},
1666+
}, nil
1667+
}
1668+
metaFn := func(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (*cli_service.TGetResultSetMetadataResp, error) {
1669+
return &cli_service.TGetResultSetMetadataResp{
1670+
Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS},
1671+
Schema: &cli_service.TTableSchema{Columns: []*cli_service.TColumnDesc{
1672+
{ColumnName: "flag", Position: 0, TypeDesc: &cli_service.TTypeDesc{Types: []*cli_service.TTypeEntry{{
1673+
PrimitiveEntry: &cli_service.TPrimitiveTypeEntry{Type: cli_service.TTypeId_BOOLEAN_TYPE},
1674+
}}}},
1675+
}},
1676+
}, nil
1677+
}
1678+
closeFn := func(ctx context.Context, req *cli_service.TCloseOperationReq) (*cli_service.TCloseOperationResp, error) {
1679+
// The close RPC itself must still run with a live (un-cancelled) context.
1680+
assert.NoError(t, ctx.Err(), "CloseOperation must run before the result context is cancelled")
1681+
return &cli_service.TCloseOperationResp{Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS}}, nil
1682+
}
1683+
1684+
testClient := &client.TestClient{FnFetchResults: fetchFn, FnGetResultSetMetadata: metaFn, FnCloseOperation: closeFn}
1685+
opHandle := &cli_service.TOperationHandle{OperationId: &cli_service.THandleIdentifier{GUID: []byte("operation-id")}}
1686+
1687+
dr, dbErr := NewRows(queryCtx, opHandle, testClient, config.WithDefaults(), nil, nil)
1688+
assert.Nil(t, dbErr)
1689+
1690+
dest := make([]driver.Value, 1)
1691+
assert.NoError(t, dr.Next(dest))
1692+
assert.NotNil(t, capturedCtx)
1693+
1694+
// Caller cancels the QueryContext: result context must remain alive.
1695+
cancel()
1696+
assert.NoError(t, capturedCtx.Err(), "result context must survive QueryContext cancellation")
1697+
assert.NotNil(t, capturedCtx.Done(), "result context must be abortable (non-nil Done)")
1698+
1699+
// Close() must cancel the detached result context so nothing is left uncancellable.
1700+
assert.NoError(t, dr.Close())
1701+
assert.ErrorIs(t, capturedCtx.Err(), context.Canceled, "Close() should cancel the detached result context")
1702+
}
1703+
15671704
// TestRows_CloseCallback_ReceivesChunkCount verifies that when rows.Close() is called,
15681705
// the closeCallback receives the correct chunkCount reflecting the number of result pages
15691706
// that were fetched during iteration.

0 commit comments

Comments
 (0)