Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions extension/background.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
const API_URL = 'https://api.tryhavril.com';

async function getConfig() {
const session = await new Promise((r) => chrome.storage.session.get(['token'], r));
const session = await new Promise((r) =>
chrome.storage.session.get(['token', 'refreshToken'], r),
);
return {
token: session.token || '',
serverUrl: API_URL,
Expand All @@ -31,6 +33,10 @@ async function apiFetch(path, options = {}) {

if (!response.ok) {
const body = await response.json().catch(() => ({}));
if (body.code === 'expired_token') {
await rotateAccessToken();
return apiFetch(path, options);
}
throw new Error(body.error || `HTTP ${response.status}`);
}

Expand All @@ -49,6 +55,26 @@ async function submitConversation(conversation, sourceModel) {
});
}

async function rotateAccessToken() {
const { refreshToken, serverUrl } = await getConfig();
if (!refreshToken) {
throw new Error('Not logged in please open the extension to log in');
}

const response = await fetch(`${serverUrl}${`/v1/auth/refresh`}`, {
method: 'POST',
body: JSON.stringify({ refresh_token: refreshToken }),
});

if (!response.ok) {
const body = await response.json().catch(() => ({}));
throw new Error(body.error || `HTTP ${response.status}`);
}

const { token } = await response.json();
await chrome.storage.session.set({ token });
}

// ── OAuth tab handling ────────────────────────────────────────────────────────

// Watch for the extension OAuth callback URL.
Expand All @@ -61,12 +87,13 @@ chrome.tabs.onUpdated.addListener(async (tabId, changeInfo, tab) => {

const url = new URL(tab.url);
const token = url.searchParams.get('token');
const refreshToken = url.searchParams.get('refresh_token');
const userName = url.searchParams.get('name') || '';
const userEmail = url.searchParams.get('email') || '';
const userAvatar = url.searchParams.get('avatar') || '';
if (!token) return;

await chrome.storage.session.set({ token });
await chrome.storage.session.set({ token, refreshToken });
await chrome.storage.sync.set({ userName, userEmail, userAvatar });

await chrome.storage.session.remove(['authTabId']);
Expand Down Expand Up @@ -111,7 +138,7 @@ chrome.runtime.onMessage.addListener((message, _sender, sendResponse) => {
return { started: true };
}
case 'LOGOUT': {
await chrome.storage.session.remove(['token']);
await chrome.storage.session.remove(['token', 'refreshToken']);
await chrome.storage.sync.remove([
'userName',
'userEmail',
Expand Down
21 changes: 20 additions & 1 deletion havril/internal/api/handlers/auth_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (h *AuthHandler) Callback(w http.ResponseWriter, r *http.Request) {
return
}

user, rawToken, err := h.users.HandleOAuthCallback(
user, rawToken, refreshRawToken, err := h.users.HandleOAuthCallback(
r.Context(),
gothUser.Provider,
gothUser.UserID,
Expand Down Expand Up @@ -74,6 +74,7 @@ func (h *AuthHandler) Callback(w http.ResponseWriter, r *http.Request) {
SameSite: http.SameSiteLaxMode,
})
q.Set("token", rawToken)
q.Set("refresh_token", refreshRawToken)
}
// Returning user: no token param — background.js will update profile only.
http.Redirect(w, r, "/v1/auth/ext/done?"+q.Encode(), http.StatusFound)
Expand Down Expand Up @@ -112,3 +113,21 @@ func (h *AuthHandler) NewMcpToken(w http.ResponseWriter, r *http.Request) {

writeJSON(w, map[string]string{"mcp_token": mcpToken}, http.StatusOK)
}

func (h *AuthHandler) NewRefreshAccessToken(w http.ResponseWriter, r *http.Request) {
var body struct {
RefreshToken string `json:"refresh_token"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil || body.RefreshToken == "" {
writeError(w, "invalid_request", "malformed JSON body", http.StatusBadRequest)
return
}

_, accessToken, err := h.users.RefreshAccessToken(r.Context(), body.RefreshToken)
if err != nil {
writeError(w, "failed to generate access token", "internal_error", http.StatusUnauthorized)
return
}

writeJSON(w, map[string]string{"token": accessToken, "expires_in": "900"}, http.StatusOK)
}
5 changes: 5 additions & 0 deletions havril/internal/api/middleware/auth_mid.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/hex"
"net/http"
"strings"
"time"

userSvc "github.com/freedisch/havril/internal/user"
"github.com/google/uuid"
Expand Down Expand Up @@ -39,6 +40,10 @@ func (m *AuthMiddleware) Authenticate(next http.Handler) http.Handler {
http.Error(w, `{"error": "invalid token", "code": "invalid_token"}`, http.StatusUnauthorized)
return
}
if user.TokenExpiresAt != nil && time.Now().After(*user.TokenExpiresAt) {
http.Error(w, `{"error": "token expired", "code": "expired_token"}`, http.StatusUnauthorized)
return
}

go m.users.TouchLastSeen(context.Background(), user.ID)

Expand Down
33 changes: 31 additions & 2 deletions havril/internal/user/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package user
import (
"context"
"fmt"
"time"

"github.com/freedisch/havril/pkg/models"
"github.com/freedisch/havril/pkg/utils"
"gorm.io/gorm"
)

Expand Down Expand Up @@ -39,12 +41,21 @@ func (r *Repository) FindOrCreate(ctx context.Context, provider, oauthID, email,

func (r *Repository) SaveToken(ctx context.Context, userID, tokenHash, tokenPrefix string) error {
result := r.db.WithContext(ctx).Model(&models.User{}).Where("id = ?", userID).Updates(map[string]any{
"token_hash": tokenHash,
"token_prefix": tokenPrefix,
"token_hash": tokenHash,
"token_prefix": tokenPrefix,
"token_expires_at": *utils.TimePtr(time.Now().Add(15 * time.Minute)),
})
return result.Error
}

func (r *Repository) SaveRefreshToken(ctx context.Context, userID, refreshTokenHash, refreshTokenPrefix string) error {
result := r.db.WithContext(ctx).Model(&models.User{}).Where("id = ?", userID).Updates(map[string]any{
"refresh_token_hash": refreshTokenHash,
"refresh_token_prefix": refreshTokenPrefix,
"refresh_token_expires_at": *utils.TimePtr(time.Now().Add(30 * 24 * time.Hour)),
})
return result.Error
}
func (r *Repository) SaveMcpToken(ctx context.Context, userID, tokenMcpHash, tokenMcpPrefix string) error {
result := r.db.WithContext(ctx).Model(&models.User{}).Where("id = ?", userID).Updates(map[string]any{
"mcp_token_hash": tokenMcpHash,
Expand All @@ -62,6 +73,24 @@ func (r *Repository) GetByTokenHash(ctx context.Context, tokenHash string) (*mod

return &user, nil
}
func (r *Repository) GetUserByID(ctx context.Context, userID string) (*models.User, error) {
var user models.User
result := r.db.WithContext(ctx).Where("id = ?", userID).First(&user)
if result.Error != nil {
return nil, fmt.Errorf("get user by ID: %w", result.Error)
}
return &user, nil
}

func (r *Repository) GetByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (*models.User, error) {
var user models.User
result := r.db.WithContext(ctx).Where("refresh_token_hash = ?", refreshTokenHash).First(&user)
if result.Error != nil {
return nil, fmt.Errorf("get user by token: %w", result.Error)
}

return &user, nil
}

func (r *Repository) GetByMcpTokenHash(ctx context.Context, tokenMcpHash string) (*models.User, error) {
var user models.User
Expand Down
62 changes: 57 additions & 5 deletions havril/internal/user/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import (
"crypto/sha256"
"encoding/hex"
"fmt"
"time"

"github.com/freedisch/havril/pkg/models"
"github.com/freedisch/havril/pkg/utils"
"github.com/google/uuid"
)

Expand All @@ -19,30 +21,67 @@ func NewService(repo *Repository) *Service {
return &Service{repo: repo}
}

func (s *Service) HandleOAuthCallback(ctx context.Context, provider, oauthID, email, displayName, avatarURL string) (*models.User, string, error) {
func (s *Service) HandleOAuthCallback(ctx context.Context, provider, oauthID, email, displayName, avatarURL string) (*models.User, string, string, error) {
user, err := s.repo.FindOrCreate(ctx, provider, oauthID, email, displayName, avatarURL)
if err != nil {
return nil, "", err
return nil, "", "", err
}

rawToken, tokenHash, tokenPrefix, err := generateToken()
if err != nil {
return nil, "", fmt.Errorf("generate token: %w", err)
return nil, "", "", fmt.Errorf("generate token: %w", err)
}

refreshRawToken, refreshTokenHash, refreshTokenPrefix, err := generateToken()
if err != nil {
return nil, "", "", fmt.Errorf("generate token: %w", err)
}

if err := s.repo.SaveToken(ctx, user.ID, tokenHash, tokenPrefix); err != nil {
return nil, "", fmt.Errorf("save token: %w", err)
return nil, "", "", fmt.Errorf("save token: %w", err)
}
if err := s.repo.SaveRefreshToken(ctx, user.ID, refreshTokenHash, refreshTokenPrefix); err != nil {
return nil, "", "", fmt.Errorf("save token: %w", err)
}

user.TokenHash = tokenHash
user.TokenPrefix = tokenPrefix
return user, rawToken, nil
user.TokenExpiresAt = utils.TimePtr(time.Now().Add(15 * time.Minute))
user.RefreshTokenHash = refreshTokenHash
user.RefreshTokenPrefix = refreshTokenPrefix
user.RefreshTokenExpiresAt = utils.TimePtr(time.Now().Add(30 * 24 * time.Hour))
return user, rawToken, refreshRawToken, nil
}

func (s *Service) RefreshAccessToken(ctx context.Context, rawRefreshToken string) (*models.User, string, error) {
sum := sha256.Sum256([]byte(rawRefreshToken))
hashStr := hex.EncodeToString(sum[:])
user, err := s.GetByRefreshTokenHash(ctx, hashStr)

if err != nil {
return nil, "", fmt.Errorf("invalid refresh token: %w", err)

}
if user.RefreshTokenExpiresAt == nil || time.Now().After(*user.RefreshTokenExpiresAt) {
return nil, "", fmt.Errorf("refresh token expired")
}
rawToken, tokenHash, tokenPrefix, err := generateToken()
if err != nil {
return nil, "", fmt.Errorf("generate token: %w", err)
}
if err := s.repo.SaveToken(ctx, user.ID, tokenHash, tokenPrefix); err != nil {
return nil, "", fmt.Errorf("save token: %w", err)
}

return user, rawToken, nil
}
func (s *Service) GetByTokenHash(ctx context.Context, hash string) (*models.User, error) {
return s.repo.GetByTokenHash(ctx, hash)
}

func (s *Service) GetByRefreshTokenHash(ctx context.Context, hash string) (*models.User, error) {
return s.repo.GetByRefreshTokenHash(ctx, hash)
}
func (s *Service) GetByMcpTokenHash(ctx context.Context, hash string) (*models.User, error) {
return s.repo.GetByMcpTokenHash(ctx, hash)
}
Expand Down Expand Up @@ -90,3 +129,16 @@ func (s *Service) GenerateMcpToken(ctx context.Context, userID string) (string,

return mcpToken, nil
}

// func (s * Service) GenerateNewAccessToken(ctx context.Context, userID string) (string, error) {
// accessToken, accessTokenHash, accessTokenPrefix, err := generateToken()
// if err != nil{
// return "", fmt.Errorf("generate access token: %w", err)
// }

// if err := s.repo.SaveToken(ctx, userID, accessTokenHash, accessTokenPrefix); err != nil {
// return "", fmt.Errorf("save token: %w", err)
// }

// return accessToken, nil
// }
1 change: 1 addition & 0 deletions havril/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ func main() {
r.Get("/v1/auth/{provider}", authHandler.Begin)
r.Get("/v1/auth/{provider}/callback", authHandler.Callback)
r.Get("/v1/auth/ext/done", authHandler.ExtDone)
r.Post("/v1/auth/refresh", authHandler.NewRefreshAccessToken)

// OAuth discovery stubs — required by MCP clients before
// they will attempt a connection; access_denied from /authorize causes
Expand Down
30 changes: 17 additions & 13 deletions havril/pkg/models/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,23 @@ import (

type User struct {
gorm.Model
ID string `gorm:"type:uuid;default:gen_random_uuid();primaryKey"`
Email string `gorm:"uniqueIndex;not null"`
OAuthProvider string `gorm:"not null;uniqueIndex:idx_users_oauth"`
OAuthID string `gorm:"not null;uniqueIndex:idx_users_oauth"`
DisplayName string
AvatarURL string
TokenHash string `gorm:"uniqueIndex"`
TokenPrefix string
McpTokenHash string
McpTokenPrefix string
LastSeenAt *time.Time
ConnectedModels []ConnectedModel `gorm:"foreignKey:UserID"`
Memories []Memory `gorm:"foreignKey:UserID"`
ID string `gorm:"type:uuid;default:gen_random_uuid();primaryKey"`
Email string `gorm:"uniqueIndex;not null"`
OAuthProvider string `gorm:"not null;uniqueIndex:idx_users_oauth"`
OAuthID string `gorm:"not null;uniqueIndex:idx_users_oauth"`
DisplayName string
AvatarURL string
TokenHash string `gorm:"uniqueIndex"`
TokenPrefix string
TokenExpiresAt *time.Time
RefreshTokenHash string
RefreshTokenPrefix string
RefreshTokenExpiresAt *time.Time
McpTokenHash string
McpTokenPrefix string
LastSeenAt *time.Time
ConnectedModels []ConnectedModel `gorm:"foreignKey:UserID"`
Memories []Memory `gorm:"foreignKey:UserID"`
}

func (User) TableName() string {
Expand Down
5 changes: 5 additions & 0 deletions havril/pkg/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package utils
import (
"log"
"os"
"time"
)

func GetEnv(key, fallback string) string {
Expand All @@ -19,3 +20,7 @@ func MustEnv(key string) string {
}
return v
}

func TimePtr(t time.Time) *time.Time {
return &t
}
Loading