Skip to content

Commit eed794f

Browse files
[Internal] Add internal/retrier package (#5746)
## Summary The repository has several ways to implement simple retry loops. This PR adds a new `internal/retrier` package that captures the common denominator across all of them. This PR already opportunistically migrates `common/retry.go` to use this new package internally. Subsequent PRs will migrate existing call sites to it, removing unnecessary middleware. Design choices: - **Value-aware predicate.** `Retrier[V].IsRetriable(V, error)` sees both the polled value and any error, so state-driven polling (the dominant pattern across the provider's waiters) is supported natively without sentinel-error wrapping. - **Factory-per-Run.** `Run` takes a `func() Retrier[V]` and invokes it once per call. Each invocation gets its own retrier instance with independent backoff state, making `Run` safe to use from multiple goroutines with the same factory without locking. - **Deterministic exponential backoff, no jitter.** The Databricks SDK already absorbs transient errors, and state-polling waiters here do not need cross-client decorrelation. The package is intended to sit on top of the SDK for longer polls (e.g. resource state transitions); jitter adds complexity without a corresponding benefit in that scenario. This is not a one-way door; it can be revisited if a use case calls for it. ## Test plan - [x] `go test ./internal/retrier/...` passes (12 cases across BackoffPolicy initialization and exponential growth, `Run` retry/value-aware/ctx-cancellation paths, `RunErr`, `RetryIf` wiring and factory independence, `RetryIfErr` forwarding, `sleep` cancellation). - [x] `go tool staticcheck ./internal/retrier/...` clean. - [x] `make fmt lint ws` clean. - [x] CI green.
1 parent 4004910 commit eed794f

5 files changed

Lines changed: 820 additions & 150 deletions

File tree

NEXT_CHANGELOG.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,5 @@
1616

1717
### Internal Changes
1818

19+
* Add `internal/retrier` package for unified retry and backoff handling ([#5746](https://github.com/databricks/terraform-provider-databricks/pull/5746)).
1920
* Pass `excludedAttributes=entitlements` on SCIM `/Me` requests ([#5725](https://github.com/databricks/terraform-provider-databricks/pull/5725)).
20-
21-
The provider only needs identity fields (`userName`, `id`, `externalId`) from `/Me`, never entitlements. Skipping the entitlement computation avoids an expensive `getEffectivePermissions` traversal on the SCIM backend, which has caused incidents on workspaces with large grant counts.

common/retry.go

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,40 +4,47 @@ import (
44
"context"
55
"errors"
66
"regexp"
7+
"time"
78

89
"github.com/databricks/databricks-sdk-go/apierr"
9-
"github.com/databricks/databricks-sdk-go/logger"
10-
"github.com/databricks/databricks-sdk-go/retries"
10+
"github.com/databricks/terraform-provider-databricks/internal/retrier"
1111
)
1212

13-
var timeoutRegex = regexp.MustCompile(`request timed out after .* of inactivity`)
14-
13+
// RetryOnTimeout retries f while it returns an SDK inactivity-timeout error.
14+
//
15+
// TODO: Deprecate this function in favor of retrier.Run.
1516
func RetryOnTimeout[T any](ctx context.Context, f func(context.Context) (*T, error)) (*T, error) {
16-
r := retries.New[T](retries.WithRetryFunc(func(err error) bool {
17-
msg := err.Error()
18-
isTimeout := timeoutRegex.MatchString(msg)
19-
if isTimeout {
20-
logger.Debugf(ctx, "Retrying due to timeout: %s", msg)
21-
}
22-
return isTimeout
23-
}))
24-
return r.Run(ctx, func(ctx context.Context) (*T, error) {
25-
return f(ctx)
26-
})
17+
return retrier.Run(ctx, retryOnErr[*T](isTimeoutError), f)
2718
}
2819

29-
// RetryOn504 returns a [retries.Retrier] that calls the given method
30-
// until it either succeeds or returns an error that is different from
31-
// [apierr.ErrDeadlineExceeded].
20+
// RetryOn504 retries f while it returns an error wrapping [apierr.ErrDeadlineExceeded].
21+
//
22+
// TODO: Deprecate this function in favor of retrier.Run.
3223
func RetryOn504[T any](ctx context.Context, f func(context.Context) (*T, error)) (*T, error) {
33-
r := retries.New[T](retries.WithTimeout(-1), retries.WithRetryFunc(func(err error) bool {
34-
if !errors.Is(err, apierr.ErrDeadlineExceeded) {
35-
return false
36-
}
37-
logger.Debugf(ctx, "Retrying on error 504")
38-
return true
39-
}))
40-
return r.Run(ctx, func(ctx context.Context) (*T, error) {
41-
return f(ctx)
24+
return retrier.Run(ctx, retryOnErr[*T](is504Error), f)
25+
}
26+
27+
// retryOnErr adapts an error-only predicate to a value-aware retrier factory
28+
// using the transient backoff policy.
29+
func retryOnErr[V any](isRetriable func(error) bool) func() retrier.Retrier[V] {
30+
return retrier.RetryIf(transientBackoff(), func(_ V, err error) bool {
31+
return isRetriable(err)
4232
})
4333
}
34+
35+
// transientBackoff returns a fresh backoff tuned for transient errors:
36+
// start fast, cap quickly.
37+
func transientBackoff() retrier.BackoffPolicy {
38+
return retrier.BackoffPolicy{Initial: time.Second, Maximum: 30 * time.Second}
39+
}
40+
41+
// TODO: Replace the regex check with type-aware error inspection.
42+
var timeoutRegex = regexp.MustCompile(`request timed out after .* of inactivity`)
43+
44+
func isTimeoutError(err error) bool {
45+
return err != nil && timeoutRegex.MatchString(err.Error())
46+
}
47+
48+
func is504Error(err error) bool {
49+
return errors.Is(err, apierr.ErrDeadlineExceeded)
50+
}

common/retry_test.go

Lines changed: 105 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -3,132 +3,116 @@ package common
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"testing"
78

89
"github.com/databricks/databricks-sdk-go/apierr"
9-
"github.com/databricks/databricks-sdk-go/experimental/mocks"
10-
"github.com/databricks/databricks-sdk-go/service/workspace"
11-
"github.com/stretchr/testify/assert"
12-
"github.com/stretchr/testify/mock"
1310
)
1411

15-
func TestRetryOnTimeout_NoError(t *testing.T) {
16-
w := mocks.NewMockWorkspaceClient(t)
17-
expected := &workspace.ObjectInfo{}
18-
api := w.GetMockWorkspaceAPI().EXPECT()
19-
api.GetStatusByPath(mock.Anything, mock.Anything).Return(expected, nil)
20-
res, err := RetryOnTimeout(context.Background(), func(ctx context.Context) (*workspace.ObjectInfo, error) {
21-
return w.WorkspaceClient.Workspace.GetStatusByPath(ctx, "path")
22-
})
23-
assert.NoError(t, err)
24-
assert.Equal(t, expected, res)
12+
func TestRetryOnTimeout(t *testing.T) {
13+
timeoutErr := errors.New("request failed: request timed out after 1m0s of inactivity")
14+
otherErr := errors.New("non-retriable")
15+
16+
testCases := []struct {
17+
name string
18+
callErrs []error
19+
wantErr error
20+
wantCalls int
21+
}{
22+
{
23+
name: "success on first call",
24+
callErrs: []error{nil},
25+
wantCalls: 1,
26+
},
27+
{
28+
name: "timeout then succeed",
29+
callErrs: []error{timeoutErr, nil},
30+
wantCalls: 2,
31+
},
32+
{
33+
name: "non-timeout halts",
34+
callErrs: []error{otherErr},
35+
wantErr: otherErr,
36+
wantCalls: 1,
37+
},
38+
{
39+
name: "timeout then non-timeout halts",
40+
callErrs: []error{timeoutErr, otherErr},
41+
wantErr: otherErr,
42+
wantCalls: 2,
43+
},
44+
}
45+
46+
for _, tc := range testCases {
47+
t.Run(tc.name, func(t *testing.T) {
48+
calls := 0
49+
_, err := RetryOnTimeout(context.Background(), func(ctx context.Context) (*struct{}, error) {
50+
e := tc.callErrs[calls]
51+
calls++
52+
return nil, e
53+
})
54+
if calls != tc.wantCalls {
55+
t.Errorf("call count = %d, want %d", calls, tc.wantCalls)
56+
}
57+
if !errors.Is(err, tc.wantErr) {
58+
t.Errorf("err = %v, want %v", err, tc.wantErr)
59+
}
60+
})
61+
}
2562
}
2663

27-
func TestRetryOnTimeout_OneError(t *testing.T) {
28-
w := mocks.NewMockWorkspaceClient(t)
29-
expected := &workspace.ObjectInfo{}
30-
api := w.GetMockWorkspaceAPI().EXPECT()
31-
call1 := api.GetStatusByPath(mock.Anything, mock.Anything).Return(nil, errors.New("request failed: request timed out after 1m0s of inactivity"))
32-
call1.Repeatability = 1
33-
api.GetStatusByPath(mock.Anything, mock.Anything).Return(expected, nil)
34-
res, err := RetryOnTimeout(context.Background(), func(ctx context.Context) (*workspace.ObjectInfo, error) {
35-
return w.WorkspaceClient.Workspace.GetStatusByPath(ctx, "path")
36-
})
37-
assert.NoError(t, err)
38-
assert.Equal(t, expected, res)
39-
}
40-
41-
func TestRetryOnTimeout_NonRetriableError(t *testing.T) {
42-
w := mocks.NewMockWorkspaceClient(t)
43-
expected := errors.New("request failed: non-retriable error")
44-
api := w.GetMockWorkspaceAPI().EXPECT()
45-
api.GetStatusByPath(mock.Anything, mock.Anything).Return(nil, expected)
46-
_, err := RetryOnTimeout(context.Background(), func(ctx context.Context) (*workspace.ObjectInfo, error) {
47-
return w.WorkspaceClient.Workspace.GetStatusByPath(ctx, "path")
48-
})
49-
assert.ErrorIs(t, err, expected)
50-
}
51-
52-
func TestRetryOn504_noError(t *testing.T) {
53-
wantErr := error(nil)
54-
wantRes := (*workspace.ObjectInfo)(nil)
55-
wantCalls := 1
56-
57-
w := mocks.NewMockWorkspaceClient(t)
58-
api := w.GetMockWorkspaceAPI().EXPECT()
59-
api.GetStatusByPath(mock.Anything, mock.Anything).Return(wantRes, wantErr)
60-
61-
gotCalls := 0
62-
gotRes, gotErr := RetryOn504(context.Background(), func(ctx context.Context) (*workspace.ObjectInfo, error) {
63-
gotCalls += 1
64-
return w.WorkspaceClient.Workspace.GetStatusByPath(ctx, "path")
65-
})
66-
67-
assert.ErrorIs(t, gotErr, wantErr)
68-
assert.Equal(t, gotRes, wantRes)
69-
assert.Equal(t, gotCalls, wantCalls)
70-
}
71-
72-
func TestRetryOn504_errorNot504(t *testing.T) {
73-
wantErr := errors.New("test error")
74-
wantRes := (*workspace.ObjectInfo)(nil)
75-
wantCalls := 1
76-
77-
w := mocks.NewMockWorkspaceClient(t)
78-
api := w.GetMockWorkspaceAPI().EXPECT()
79-
api.GetStatusByPath(mock.Anything, mock.Anything).Return(wantRes, wantErr)
80-
81-
gotCalls := 0
82-
gotRes, gotErr := RetryOn504(context.Background(), func(ctx context.Context) (*workspace.ObjectInfo, error) {
83-
gotCalls += 1
84-
return w.WorkspaceClient.Workspace.GetStatusByPath(ctx, "path")
85-
})
86-
87-
assert.ErrorIs(t, gotErr, wantErr)
88-
assert.Equal(t, gotRes, wantRes)
89-
assert.Equal(t, gotCalls, wantCalls)
90-
}
91-
92-
func TestRetryOn504_error504ThenFail(t *testing.T) {
93-
wantErr := errors.New("test error")
94-
wantRes := (*workspace.ObjectInfo)(nil)
95-
wantCalls := 2
96-
97-
w := mocks.NewMockWorkspaceClient(t)
98-
api := w.GetMockWorkspaceAPI().EXPECT()
99-
call := api.GetStatusByPath(mock.Anything, mock.Anything).Return(nil, apierr.ErrDeadlineExceeded)
100-
call.Repeatability = 1
101-
api.GetStatusByPath(mock.Anything, mock.Anything).Return(wantRes, wantErr)
102-
103-
gotCalls := 0
104-
gotRes, gotErr := RetryOn504(context.Background(), func(ctx context.Context) (*workspace.ObjectInfo, error) {
105-
gotCalls++
106-
return w.WorkspaceClient.Workspace.GetStatusByPath(ctx, "path")
107-
})
108-
109-
assert.ErrorIs(t, gotErr, wantErr)
110-
assert.Equal(t, gotRes, wantRes)
111-
assert.Equal(t, gotCalls, wantCalls)
112-
}
113-
114-
func TestRetryOn504_error504ThenSuccess(t *testing.T) {
115-
wantErr := error(nil)
116-
wantRes := &workspace.ObjectInfo{}
117-
wantCalls := 2
118-
119-
w := mocks.NewMockWorkspaceClient(t)
120-
api := w.GetMockWorkspaceAPI().EXPECT()
121-
call := api.GetStatusByPath(mock.Anything, mock.Anything).Return(nil, apierr.ErrDeadlineExceeded)
122-
call.Repeatability = 1
123-
api.GetStatusByPath(mock.Anything, mock.Anything).Return(wantRes, wantErr)
124-
125-
gotCalls := 0
126-
gotRes, gotErr := RetryOn504(context.Background(), func(ctx context.Context) (*workspace.ObjectInfo, error) {
127-
gotCalls++
128-
return w.WorkspaceClient.Workspace.GetStatusByPath(ctx, "path")
129-
})
130-
131-
assert.ErrorIs(t, gotErr, wantErr)
132-
assert.Equal(t, gotRes, wantRes)
133-
assert.Equal(t, gotCalls, wantCalls)
64+
func TestRetryOn504(t *testing.T) {
65+
otherErr := errors.New("not 504")
66+
67+
testCases := []struct {
68+
name string
69+
callErrs []error
70+
wantErr error
71+
wantCalls int
72+
}{
73+
{
74+
name: "success on first call",
75+
callErrs: []error{nil},
76+
wantCalls: 1,
77+
},
78+
{
79+
name: "504 then succeed",
80+
callErrs: []error{apierr.ErrDeadlineExceeded, nil},
81+
wantCalls: 2,
82+
},
83+
{
84+
name: "wrapped 504 then succeed",
85+
callErrs: []error{fmt.Errorf("got 504: %w", apierr.ErrDeadlineExceeded), nil},
86+
wantCalls: 2,
87+
},
88+
{
89+
name: "non-504 halts",
90+
callErrs: []error{otherErr},
91+
wantErr: otherErr,
92+
wantCalls: 1,
93+
},
94+
{
95+
name: "504 then non-504 halts",
96+
callErrs: []error{apierr.ErrDeadlineExceeded, otherErr},
97+
wantErr: otherErr,
98+
wantCalls: 2,
99+
},
100+
}
101+
102+
for _, tc := range testCases {
103+
t.Run(tc.name, func(t *testing.T) {
104+
calls := 0
105+
_, err := RetryOn504(context.Background(), func(ctx context.Context) (*struct{}, error) {
106+
e := tc.callErrs[calls]
107+
calls++
108+
return nil, e
109+
})
110+
if calls != tc.wantCalls {
111+
t.Errorf("call count = %d, want %d", calls, tc.wantCalls)
112+
}
113+
if !errors.Is(err, tc.wantErr) {
114+
t.Errorf("err = %v, want %v", err, tc.wantErr)
115+
}
116+
})
117+
}
134118
}

0 commit comments

Comments
 (0)