Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions internal/auth/apikey.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
package auth

import (
"crypto/rand"
"encoding/hex"
"fmt"

"github.com/github/gh-aw-mcpg/internal/strutil"
)

// GenerateRandomAPIKey generates a cryptographically random API key.
// Per spec §7.3, the gateway SHOULD generate a random API key on startup
// if none is provided. Returns a 32-byte hex-encoded string (64 chars).
func GenerateRandomAPIKey() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
key, err := strutil.RandomHex(32)
if err != nil {
return "", fmt.Errorf("failed to generate random API key: %w", err)
}
return hex.EncodeToString(bytes), nil
return key, nil
}
52 changes: 28 additions & 24 deletions internal/auth/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,24 @@ var (
ErrInvalidAuthHeader = errors.New("invalid Authorization header format")
)

// supportedAuthSchemes lists the recognized Authorization header scheme prefixes.
// Each entry includes the trailing space that separates the scheme from the value.
var supportedAuthSchemes = []string{"Bearer ", "Agent "}

// stripAuthScheme extracts the value from a scheme-prefixed Authorization header.
// Recognizes "Bearer " and "Agent " formats.
// Returns (scheme, value, true) on match, or ("", authHeader, false) for plain values.
func stripAuthScheme(authHeader string) (scheme, value string, matched bool) {
for _, prefix := range supportedAuthSchemes {
if strings.HasPrefix(authHeader, prefix) {
scheme = strings.TrimSuffix(prefix, " ")
value = strings.TrimPrefix(authHeader, prefix)
return scheme, value, true
}
}
return "", authHeader, false
}

// ParseAuthHeader parses the Authorization header and extracts the API key and agent ID.
// Per MCP spec 7.1, the Authorization header should contain the API key directly
// without any Bearer prefix or other scheme.
Expand All @@ -72,18 +90,9 @@ func ParseAuthHeader(authHeader string) (apiKey string, agentID string, error er
return "", "", ErrMissingAuthHeader
}

// Handle "Bearer <token>" format (backward compatibility)
if strings.HasPrefix(authHeader, "Bearer ") {
log.Print("Detected Bearer token format (backward compatibility)")
token := strings.TrimPrefix(authHeader, "Bearer ")
return token, token, nil
}

// Handle "Agent <agent-id>" format
if strings.HasPrefix(authHeader, "Agent ") {
log.Print("Detected Agent ID format")
agentIDValue := strings.TrimPrefix(authHeader, "Agent ")
return agentIDValue, agentIDValue, nil
if scheme, value, matched := stripAuthScheme(authHeader); matched {
log.Printf("Detected %s format", scheme)
return value, value, nil
}

// Per MCP spec 7.1: Authorization header contains API key directly
Expand Down Expand Up @@ -148,18 +157,13 @@ func ExtractSessionID(authHeader string) string {
return ""
}

// Handle "Bearer <token>" format (backward compatibility)
// Trim spaces for backward compatibility with older clients
if strings.HasPrefix(authHeader, "Bearer ") {
log.Print("Detected Bearer format, trimming spaces for backward compatibility")
sessionID := strings.TrimPrefix(authHeader, "Bearer ")
return strings.TrimSpace(sessionID)
}

// Handle "Agent <agent-id>" format
if strings.HasPrefix(authHeader, "Agent ") {
log.Print("Detected Agent format")
return strings.TrimPrefix(authHeader, "Agent ")
if scheme, value, matched := stripAuthScheme(authHeader); matched {
log.Printf("Detected %s format", scheme)
if scheme == "Bearer" {
// Trim spaces for backward compatibility with older clients
return strings.TrimSpace(value)
}
return value
}

// Plain format (per spec 7.1 - API key is session ID)
Expand Down
64 changes: 64 additions & 0 deletions internal/auth/header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,70 @@ func TestExtractSessionID(t *testing.T) {
}
}

func TestStripAuthScheme(t *testing.T) {
assert := assert.New(t)

tests := []struct {
name string
authHeader string
wantScheme string
wantValue string
wantMatched bool
}{
{
name: "Bearer prefix",
authHeader: "Bearer my-token",
wantScheme: "Bearer",
wantValue: "my-token",
wantMatched: true,
},
{
name: "Agent prefix",
authHeader: "Agent agent-123",
wantScheme: "Agent",
wantValue: "agent-123",
wantMatched: true,
},
{
name: "Plain value (no scheme)",
authHeader: "my-plain-key",
wantScheme: "",
wantValue: "my-plain-key",
wantMatched: false,
},
{
name: "Lowercase bearer (not recognized)",
authHeader: "bearer my-token",
wantScheme: "",
wantValue: "bearer my-token",
wantMatched: false,
},
{
name: "Bearer with extra spaces",
authHeader: "Bearer my-token",
wantScheme: "Bearer",
wantValue: " my-token",
wantMatched: true,
},
{
name: "Empty string",
authHeader: "",
wantScheme: "",
wantValue: "",
wantMatched: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
scheme, value, matched := stripAuthScheme(tt.authHeader)
assert.Equal(tt.wantScheme, scheme)
assert.Equal(tt.wantValue, value)
assert.Equal(tt.wantMatched, matched)
})
}
}

func TestTruncateSessionID(t *testing.T) {
assert := assert.New(t)

Expand Down
9 changes: 4 additions & 5 deletions internal/middleware/jqschema.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package middleware

import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"os"
Expand All @@ -13,6 +11,7 @@ import (
"unicode/utf8"

"github.com/github/gh-aw-mcpg/internal/logger"
"github.com/github/gh-aw-mcpg/internal/strutil"
"github.com/itchyny/gojq"
sdk "github.com/modelcontextprotocol/go-sdk/mcp"
)
Expand Down Expand Up @@ -114,12 +113,12 @@ func init() {

// generateRandomID generates a random ID for payload storage
func generateRandomID() string {
bytes := make([]byte, 16)
if _, err := rand.Read(bytes); err != nil {
id, err := strutil.RandomHex(16)
if err != nil {
// Fallback to timestamp-based ID if random fails
return fmt.Sprintf("fallback-%d", os.Getpid())
}
return hex.EncodeToString(bytes)
return id
}

// applyJqSchema applies the jq schema transformation to JSON data
Expand Down
17 changes: 17 additions & 0 deletions internal/strutil/random_hex.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package strutil

import (
"crypto/rand"
"encoding/hex"
"fmt"
)

// RandomHex returns a hex-encoded string of n cryptographically random bytes.
// The returned string has length 2*n.
func RandomHex(n int) (string, error) {
b := make([]byte, n)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("failed to generate %d random bytes: %w", n, err)
}
return hex.EncodeToString(b), nil
}
51 changes: 51 additions & 0 deletions internal/strutil/random_hex_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package strutil

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestRandomHex(t *testing.T) {
tests := []struct {
name string
n int
wantLen int
}{
{
name: "16 bytes produces 32 hex chars",
n: 16,
wantLen: 32,
},
{
name: "32 bytes produces 64 hex chars",
n: 32,
wantLen: 64,
},
{
name: "1 byte produces 2 hex chars",
n: 1,
wantLen: 2,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := RandomHex(tt.n)
require.NoError(t, err)
assert.Len(t, result, tt.wantLen)
})
}
}

func TestRandomHex_Uniqueness(t *testing.T) {
seen := make(map[string]bool)
for i := 0; i < 100; i++ {
id, err := RandomHex(16)
require.NoError(t, err)
assert.NotEmpty(t, id)
assert.False(t, seen[id], "RandomHex should produce unique values")
seen[id] = true
}
}
Loading