Skip to content

Commit 05db00c

Browse files
committed
feat: add tests for auth token refresh on retry and during pagination
1 parent d0ffb16 commit 05db00c

1 file changed

Lines changed: 276 additions & 0 deletions

File tree

core/dbio/api/api_test.go

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3558,3 +3558,279 @@ endpoints:
35583558
})
35593559
}
35603560
}
3561+
3562+
func TestAuthTokenRefreshOnRetry(t *testing.T) {
3563+
// This test reproduces an issue where DoRequest's retry path (goto retry)
3564+
// re-sends the same httpReq with stale auth headers instead of refreshing
3565+
// them from conn.State.Auth.Headers.
3566+
//
3567+
// Scenario: A user's API token expires server-side. The server returns 401.
3568+
// A retry rule fires, and meanwhile Auth.Headers has been updated with a
3569+
// new token (by EnsureAuthenticated or equivalent). But the retry re-sends
3570+
// the same httpReq with the OLD token baked in at MakeRequest() time.
3571+
//
3572+
// The mock server:
3573+
// - On the FIRST request: rotates the valid token from v1→v2, updates
3574+
// Auth.Headers on the APIConnection, then returns 401 (old token rejected)
3575+
// - On the SECOND request (retry): only accepts token_v2
3576+
//
3577+
// This guarantees deterministic ordering without relying on timing/sleeps.
3578+
3579+
var mu sync.Mutex
3580+
currentToken := "token_v1"
3581+
requestCount := 0
3582+
var ac *APIConnection // set after creation, before server receives requests
3583+
3584+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
3585+
mu.Lock()
3586+
requestCount++
3587+
reqNum := requestCount
3588+
mu.Unlock()
3589+
3590+
w.Header().Set("Content-Type", "application/json")
3591+
3592+
authHeader := r.Header.Get("Authorization")
3593+
3594+
if reqNum == 1 {
3595+
// First request: simulate server-side token expiration.
3596+
// Rotate the valid token to v2 and update Auth.Headers
3597+
// (simulating what EnsureAuthenticated does after re-auth).
3598+
mu.Lock()
3599+
currentToken = "token_v2"
3600+
mu.Unlock()
3601+
3602+
// Update the APIConnection's auth headers to the new token
3603+
ac.State.Auth.Mutex.Lock()
3604+
ac.State.Auth.Headers["Authorization"] = "Bearer token_v2"
3605+
ac.State.Auth.Mutex.Unlock()
3606+
3607+
// Reject this request — it carries the old token
3608+
w.WriteHeader(http.StatusUnauthorized)
3609+
json.NewEncoder(w).Encode(map[string]any{
3610+
"error": fmt.Sprintf("token expired (request #%d): got %q, want %q",
3611+
reqNum, authHeader, "Bearer token_v2"),
3612+
})
3613+
return
3614+
}
3615+
3616+
// Subsequent requests: validate against current token
3617+
mu.Lock()
3618+
validToken := "Bearer " + currentToken
3619+
mu.Unlock()
3620+
3621+
if authHeader != validToken {
3622+
w.WriteHeader(http.StatusUnauthorized)
3623+
json.NewEncoder(w).Encode(map[string]any{
3624+
"error": fmt.Sprintf("invalid token (request #%d): got %q, want %q",
3625+
reqNum, authHeader, validToken),
3626+
})
3627+
return
3628+
}
3629+
3630+
w.WriteHeader(http.StatusOK)
3631+
json.NewEncoder(w).Encode(map[string]any{
3632+
"data": []map[string]any{
3633+
{"id": 1, "name": "Test"},
3634+
{"id": 2, "name": "Data"},
3635+
},
3636+
})
3637+
}))
3638+
defer server.Close()
3639+
3640+
specYAML := fmt.Sprintf(`
3641+
name: test_auth_retry_refresh
3642+
authentication:
3643+
type: static
3644+
headers:
3645+
Authorization: "Bearer token_v1"
3646+
endpoints:
3647+
test_endpoint:
3648+
request:
3649+
url: %s/data
3650+
method: GET
3651+
response:
3652+
records:
3653+
jmespath: data
3654+
rules:
3655+
- condition: "response.status == 401"
3656+
action: retry
3657+
max_attempts: 3
3658+
`, server.URL)
3659+
3660+
spec, err := LoadSpec(specYAML)
3661+
if !assert.NoError(t, err) {
3662+
return
3663+
}
3664+
3665+
ac, err = NewAPIConnection(context.Background(), spec, map[string]any{
3666+
"state": map[string]any{},
3667+
"secrets": map[string]any{},
3668+
})
3669+
assert.NoError(t, err)
3670+
3671+
err = ac.Authenticate()
3672+
assert.NoError(t, err)
3673+
assert.Equal(t, "Bearer token_v1", ac.State.Auth.Headers["Authorization"])
3674+
3675+
// ReadDataflow → MakeRequest bakes "Bearer token_v1" into httpReq.Header.
3676+
// First request hits server → server rotates to token_v2, updates Auth.Headers,
3677+
// returns 401. Retry rule fires → goto retry → PerformRequest re-sends httpReq.
3678+
//
3679+
// BUG: retry uses the SAME httpReq with old "Bearer token_v1" headers.
3680+
// FIX: retry should refresh httpReq.Header from conn.State.Auth.Headers.
3681+
df, err := ac.ReadDataflow("test_endpoint", APIStreamConfig{
3682+
Limit: 10,
3683+
})
3684+
if err == nil {
3685+
_, err = df.Collect()
3686+
}
3687+
3688+
mu.Lock()
3689+
totalRequests := requestCount
3690+
mu.Unlock()
3691+
3692+
// After fix: retry refreshes auth headers from conn.State.Auth.Headers,
3693+
// so the second attempt uses "Bearer token_v2" and succeeds.
3694+
assert.NoError(t, err, "retry should succeed after refreshing auth headers")
3695+
assert.Equal(t, 2, totalRequests, "should take exactly 2 requests (1 failed + 1 retry)")
3696+
t.Logf("collected records after %d HTTP requests (retry refreshed headers)", totalRequests)
3697+
}
3698+
3699+
func TestAuthTokenRefreshDuringPagination(t *testing.T) {
3700+
// This test reproduces an issue where paginated requests use stale auth
3701+
// headers after token rotation. During long-running paginated extractions,
3702+
// EnsureAuthenticated() may refresh the token (updating Auth.Headers),
3703+
// but subsequent requests still carry the OLD headers baked in at
3704+
// MakeRequest() time.
3705+
//
3706+
// The mock server:
3707+
// - Page 1: succeeds with token_v1, returns has_more=true
3708+
// - Page 2: server rotates token to v2, updates Auth.Headers, returns 401
3709+
// - Page 2 retry / page 2 re-request: should use token_v2
3710+
//
3711+
// This simulates a real API whose token expires mid-extraction.
3712+
3713+
var mu sync.Mutex
3714+
currentToken := "token_v1"
3715+
requestCount := 0
3716+
var ac *APIConnection
3717+
3718+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
3719+
mu.Lock()
3720+
requestCount++
3721+
reqNum := requestCount
3722+
mu.Unlock()
3723+
3724+
w.Header().Set("Content-Type", "application/json")
3725+
authHeader := r.Header.Get("Authorization")
3726+
3727+
// On the 2nd request (page 2), simulate token expiration
3728+
if reqNum == 2 {
3729+
mu.Lock()
3730+
currentToken = "token_v2"
3731+
mu.Unlock()
3732+
3733+
ac.State.Auth.Mutex.Lock()
3734+
ac.State.Auth.Headers["Authorization"] = "Bearer token_v2"
3735+
ac.State.Auth.Mutex.Unlock()
3736+
3737+
w.WriteHeader(http.StatusUnauthorized)
3738+
json.NewEncoder(w).Encode(map[string]any{
3739+
"error": fmt.Sprintf("token expired (request #%d): got %q", reqNum, authHeader),
3740+
})
3741+
return
3742+
}
3743+
3744+
// All other requests: validate against current token
3745+
mu.Lock()
3746+
validToken := "Bearer " + currentToken
3747+
mu.Unlock()
3748+
3749+
if authHeader != validToken {
3750+
w.WriteHeader(http.StatusUnauthorized)
3751+
json.NewEncoder(w).Encode(map[string]any{
3752+
"error": fmt.Sprintf("invalid token (request #%d): got %q, want %q",
3753+
reqNum, authHeader, validToken),
3754+
})
3755+
return
3756+
}
3757+
3758+
page := r.URL.Query().Get("page")
3759+
hasMore := true
3760+
nextPage := "2"
3761+
if page == "2" {
3762+
hasMore = false
3763+
nextPage = ""
3764+
}
3765+
3766+
w.WriteHeader(http.StatusOK)
3767+
json.NewEncoder(w).Encode(map[string]any{
3768+
"data": []map[string]any{
3769+
{"id": reqNum, "value": fmt.Sprintf("page_%s", page)},
3770+
},
3771+
"next_page": nextPage,
3772+
"has_more": hasMore,
3773+
})
3774+
}))
3775+
defer server.Close()
3776+
3777+
specYAML := fmt.Sprintf(`
3778+
name: test_auth_pagination_refresh
3779+
authentication:
3780+
type: static
3781+
headers:
3782+
Authorization: "Bearer token_v1"
3783+
endpoints:
3784+
test_endpoint:
3785+
request:
3786+
url: %s/data
3787+
method: GET
3788+
parameters:
3789+
page: "{state.next_page}"
3790+
response:
3791+
records:
3792+
jmespath: data
3793+
processors:
3794+
- expression: response.json.next_page
3795+
output: state.next_page
3796+
aggregation: last
3797+
rules:
3798+
- condition: "response.status == 401"
3799+
action: retry
3800+
max_attempts: 3
3801+
pagination:
3802+
stop_condition: "response.json.has_more == false"
3803+
next_state:
3804+
next_page: "{state.next_page}"
3805+
`, server.URL)
3806+
3807+
spec, err := LoadSpec(specYAML)
3808+
if !assert.NoError(t, err) {
3809+
return
3810+
}
3811+
3812+
ac, err = NewAPIConnection(context.Background(), spec, map[string]any{
3813+
"state": map[string]any{"next_page": "1"},
3814+
"secrets": map[string]any{},
3815+
})
3816+
assert.NoError(t, err)
3817+
3818+
err = ac.Authenticate()
3819+
assert.NoError(t, err)
3820+
3821+
df, err := ac.ReadDataflow("test_endpoint", APIStreamConfig{
3822+
Limit: 100,
3823+
})
3824+
if err == nil {
3825+
_, err = df.Collect()
3826+
}
3827+
3828+
mu.Lock()
3829+
totalRequests := requestCount
3830+
mu.Unlock()
3831+
3832+
// After fix: page 2 retry refreshes auth headers and uses "Bearer token_v2".
3833+
assert.NoError(t, err, "paginated request should succeed after auth header refresh")
3834+
assert.Equal(t, 3, totalRequests, "should take 3 requests (page1 ok + page2 fail + page2 retry ok)")
3835+
t.Logf("collected records from both pages across %d HTTP requests", totalRequests)
3836+
}

0 commit comments

Comments
 (0)