Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions api_key.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package main

import "time"

type APIKey struct {
ID int64
Name string
KeyHash string
Active bool
LastUsedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
}
60 changes: 60 additions & 0 deletions api_key_auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package main

import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"log/slog"
"net/http"

"github.com/jackc/pgx/v5"
)

type APIKeyStore interface {
GetAPIKeyByHash(ctx context.Context, keyHash string) (*APIKey, error)
UpdateAPIKeyLastUsed(ctx context.Context, id int64) error
}

// hashAPIKey hashes the raw API key using SHA-256 and returns the hex-encoded string.
func hashAPIKey(raw string) string {
sum := sha256.Sum256([]byte(raw))
return hex.EncodeToString(sum[:])
}

// requireAPIKey is middleware that checks for a valid API key in the X-API-Key header.
func requireAPIKey(store APIKeyStore) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rawKey := r.Header.Get("X-API-Key")
if rawKey == "" {
writeJSON(w, http.StatusUnauthorized, map[string]string{"error": "missing API key"})
return
}

keyHash := hashAPIKey(rawKey)

apiKey, err := store.GetAPIKeyByHash(r.Context(), keyHash)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
writeJSON(w, http.StatusUnauthorized, map[string]string{"error": "invalid API key"})
return
}
slog.Error("failed to fetch api key", "error", err)
writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "internal server error"})
return
}

if !apiKey.Active {
writeJSON(w, http.StatusUnauthorized, map[string]string{"error": "inactive API key"})
return
}

if err := store.UpdateAPIKeyLastUsed(r.Context(), apiKey.ID); err != nil {
slog.Error("failed to update api key last_used_at", "api_key_id", apiKey.ID, "error", err)
}

next.ServeHTTP(w, r)
})
}
}
173 changes: 173 additions & 0 deletions api_key_auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
package main

import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type mockAPIKeyStore struct {
apiKey *APIKey
getErr error
lastUsedCalled bool
lastUsedID int64
updateErr error
}

// GetAPIKeyByHash returns the API key if getErr is nil, otherwise returns getErr.
func (m *mockAPIKeyStore) GetAPIKeyByHash(ctx context.Context, keyHash string) (*APIKey, error) {
if m.getErr != nil {
return nil, m.getErr
}
return m.apiKey, nil
}

// UpdateAPIKeyLastUsed updates the last used timestamp for the API key.
func (m *mockAPIKeyStore) UpdateAPIKeyLastUsed(ctx context.Context, id int64) error {
if m.updateErr != nil {
return m.updateErr
}
m.lastUsedCalled = true
m.lastUsedID = id
return nil
}

// TestHashAPIKey verifies that the hashAPIKey function produces consistent and correctly sized hashes.
func TestHashAPIKey(t *testing.T) {
h1 := hashAPIKey("abc123")
h2 := hashAPIKey("abc123")
h3 := hashAPIKey("different")

assert.Equal(t, h1, h2)
assert.NotEqual(t, h1, h3)
assert.Len(t, h1, 64)
}

// TestRequireAPIKey tests the requireAPIKey middleware with various scenarios, including missing header, invalid key, inactive key, store failure, update failure, and valid key.
func TestRequireAPIKey_MissingHeader(t *testing.T) {
store := &mockAPIKeyStore{}
handler := requireAPIKey(store)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))

req := httptest.NewRequest("GET", "/gtfs-rt/vehicle-positions", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)

assert.Equal(t, http.StatusUnauthorized, w.Code)
}

// Invalid API key should result in 401 Unauthorized
func TestRequireAPIKey_InvalidKey(t *testing.T) {
store := &mockAPIKeyStore{getErr: pgx.ErrNoRows}
handler := requireAPIKey(store)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))

req := httptest.NewRequest("GET", "/gtfs-rt/vehicle-positions", nil)
req.Header.Set("X-API-Key", "bad-key")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)

assert.Equal(t, http.StatusUnauthorized, w.Code)
}

// Inactive API key should result in 401 Unauthorized
func TestRequireAPIKey_InactiveKey(t *testing.T) {
store := &mockAPIKeyStore{
apiKey: &APIKey{
ID: 1,
Name: "test",
Active: false,
},
}
handler := requireAPIKey(store)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))

req := httptest.NewRequest("GET", "/gtfs-rt/vehicle-positions", nil)
req.Header.Set("X-API-Key", "inactive-key")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)

assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.False(t, store.lastUsedCalled)
}

// Store failure should result in 500 Internal Server Error
func TestRequireAPIKey_StoreFailure(t *testing.T) {
store := &mockAPIKeyStore{getErr: errors.New("db down")}
handler := requireAPIKey(store)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))

req := httptest.NewRequest("GET", "/gtfs-rt/vehicle-positions", nil)
req.Header.Set("X-API-Key", "some-key")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)

assert.Equal(t, http.StatusInternalServerError, w.Code)
}

// Update last_used_at failure should be logged but must not block feed access.
func TestRequireAPIKey_UpdateLastUsedFailure_DoesNotBlockRequest(t *testing.T) {
store := &mockAPIKeyStore{
apiKey: &APIKey{
ID: 7,
Name: "feed consumer",
Active: true,
},
updateErr: errors.New("update failed"),
}

called := false
handler := requireAPIKey(store)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))

req := httptest.NewRequest("GET", "/gtfs-rt/vehicle-positions", nil)
req.Header.Set("X-API-Key", "valid-key")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)

assert.True(t, called)
assert.Equal(t, http.StatusOK, w.Code)
}

// Valid API key should call next handler and update last used timestamp
func TestRequireAPIKey_ValidKey(t *testing.T) {
store := &mockAPIKeyStore{
apiKey: &APIKey{
ID: 42,
Name: "consumer",
KeyHash: hashAPIKey("valid-key"),
Active: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
},
}

called := false
handler := requireAPIKey(store)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))

req := httptest.NewRequest("GET", "/gtfs-rt/vehicle-positions", nil)
req.Header.Set("X-API-Key", "valid-key")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)

require.True(t, called)
assert.Equal(t, http.StatusOK, w.Code)
assert.True(t, store.lastUsedCalled)
assert.Equal(t, int64(42), store.lastUsedID)
}
10 changes: 10 additions & 0 deletions db/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions db/query.sql
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,18 @@ RETURNING id, label, agency_tag, active, created_at, updated_at;
UPDATE vehicles
SET active = false, updated_at = NOW()
WHERE id = $1;

-- name: GetAPIKeyByHash :one
SELECT id, name, key_hash, active, last_used_at, created_at, updated_at
FROM api_keys
WHERE key_hash = $1;

-- name: UpdateAPIKeyLastUsed :exec
UPDATE api_keys
SET last_used_at = NOW()
WHERE id = $1;

-- name: CreateAPIKey :one
INSERT INTO api_keys (name, key_hash, active)
VALUES ($1, $2, $3)
RETURNING id, name, key_hash, active, last_used_at, created_at, updated_at;
59 changes: 59 additions & 0 deletions db/query.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ services:
PORT: "8080"
DATABASE_URL: "postgres://postgres:postgres@db:5432/vehicle_positions?sslmode=disable"
STALENESS_THRESHOLD: "5m"
JWT_SECRET: "this-is-a-local-dev-secret-1234567890"
depends_on:
db:
condition: service_healthy
Expand Down
Loading
Loading