Skip to content

Commit ffb50e8

Browse files
authored
Merge pull request #46 from Keyfactor/feat/AB#82257/oauth-token-cache
keyfactor-auth-client-go v1.3.1: Cache Access Token Provider across HttpClient instances
2 parents 6ee6956 + 4d3657d commit ffb50e8

3 files changed

Lines changed: 69 additions & 3 deletions

File tree

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# v1.3.1
2+
## Fixes
3+
- Reuse OAuth2 token source to prevent unnecessary token fetches for each request.
4+
15
# v1.3.0
26

37
## Features

auth_providers/auth_oauth.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2024 Keyfactor
1+
// Copyright 2026 Keyfactor
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -22,6 +22,7 @@ import (
2222
"net/http"
2323
"os"
2424
"strings"
25+
"sync"
2526
"time"
2627

2728
"golang.org/x/oauth2"
@@ -115,6 +116,10 @@ type CommandConfigOauth struct {
115116

116117
// TokenURL is the token URL for OAuth authentication
117118
TokenURL string `json:"token_url,omitempty" yaml:"token_url,omitempty"`
119+
120+
// unexported: lazily initialized, shared across GetHttpClient() calls
121+
tokenSource oauth2.TokenSource
122+
tsMu sync.Mutex
118123
}
119124

120125
// NewOAuthAuthenticatorBuilder creates a new CommandConfigOauth instance.
@@ -222,7 +227,15 @@ func (b *CommandConfigOauth) GetHttpClient() (*http.Client, error) {
222227
}
223228

224229
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Transport: baseTransport})
225-
tokenSource := config.TokenSource(ctx)
230+
231+
// Lazily initialize the token source and cache it
232+
b.tsMu.Lock()
233+
if b.tokenSource == nil {
234+
log.Printf("[DEBUG] Initializing OAuth2 token source for client ID: %s", b.ClientID)
235+
b.tokenSource = config.TokenSource(ctx)
236+
}
237+
tokenSource := b.tokenSource
238+
b.tsMu.Unlock()
226239

227240
client = http.Client{
228241
Transport: &oauth2Transport{

auth_providers/auth_oauth_test.go

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2024 Keyfactor
1+
// Copyright 2026 Keyfactor
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -16,13 +16,16 @@ package auth_providers_test
1616

1717
import (
1818
"crypto/tls"
19+
"encoding/json"
1920
"encoding/pem"
2021
"fmt"
2122
"net/http"
23+
"net/http/httptest"
2224
"net/url"
2325
"os"
2426
"path/filepath"
2527
"strings"
28+
"sync/atomic"
2629
"testing"
2730

2831
"github.com/Keyfactor/keyfactor-auth-client-go/auth_providers"
@@ -568,3 +571,49 @@ func DownloadCertificate(input string, outputPath string) error {
568571
fmt.Printf("Certificate chain saved to: %s\n", outputFile)
569572
return nil
570573
}
574+
575+
func TestCommandConfigOauth_TokenSourceIsReused(t *testing.T) {
576+
var tokenRequestCount atomic.Int32
577+
578+
// Fake IdP token endpoint
579+
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
580+
tokenRequestCount.Add(1)
581+
w.Header().Set("Content-Type", "application/json")
582+
json.NewEncoder(w).Encode(map[string]interface{}{
583+
"access_token": "shared-test-token",
584+
"token_type": "Bearer",
585+
"expires_in": 3600,
586+
})
587+
}))
588+
defer tokenServer.Close()
589+
590+
// Fake API endpoint (just needs to accept requests)
591+
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
592+
w.WriteHeader(http.StatusOK)
593+
}))
594+
defer apiServer.Close()
595+
596+
config := &auth_providers.CommandConfigOauth{
597+
ClientID: "test-client-id",
598+
ClientSecret: "test-client-secret",
599+
TokenURL: tokenServer.URL + "/token",
600+
}
601+
602+
// Get multiple clients from the same config
603+
const numClients = 3
604+
for i := 0; i < numClients; i++ {
605+
client, err := config.GetHttpClient()
606+
if err != nil {
607+
t.Fatalf("GetHttpClient() call %d failed: %v", i+1, err)
608+
}
609+
resp, err := client.Get(apiServer.URL)
610+
if err != nil {
611+
t.Fatalf("request %d failed: %v", i+1, err)
612+
}
613+
resp.Body.Close()
614+
}
615+
616+
if tokenRequestCount.Load() != 1 {
617+
t.Errorf("expected token endpoint to be called once, got %d — token source is not being reused across GetHttpClient() calls", tokenRequestCount.Load())
618+
}
619+
}

0 commit comments

Comments
 (0)