Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions internal/cmd/auth_add.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,12 @@ func (c *AuthAddCmd) Run(ctx context.Context, flags *RootFlags) error {
if err != nil {
return fmt.Errorf("fetch authorized email: %w", err)
}
if normalizeEmail(authorizedEmail) != normalizeEmail(c.Email) {
return fmt.Errorf("authorized as %s, expected %s", authorizedEmail, c.Email)
// If c.Email looks like a client name (no "@"), skip the email match check —
// the user intentionally authorized with a client alias.
if strings.Contains(normalizeEmail(c.Email), "@") {
if normalizeEmail(authorizedEmail) != normalizeEmail(c.Email) {
return fmt.Errorf("authorized as %s, expected %s", authorizedEmail, c.Email)
}
}

store, err := openSecretsStore()
Expand Down
32 changes: 32 additions & 0 deletions internal/cmd/auth_add_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,38 @@ func TestAuthAddCmd_KeepRejected(t *testing.T) {
}
}

func TestAuthAddCmd_ClientAliasSkipsEmailCheck(t *testing.T) {
// When a bare client name (no "@") is passed, the authorized email check
// should be skipped — the user is intentionally using a client alias.
origAuth := authorizeGoogle
origOpen := openSecretsStore
origKeychain := ensureKeychainAccess
origFetch := fetchAuthorizedEmail
t.Cleanup(func() {
authorizeGoogle = origAuth
openSecretsStore = origOpen
ensureKeychainAccess = origKeychain
fetchAuthorizedEmail = origFetch
})

ensureKeychainAccess = func() error { return nil }
store := newMemSecretsStore()
openSecretsStore = func() (secrets.Store, error) { return store, nil }
authorizeGoogle = func(context.Context, googleauth.AuthorizeOptions) (string, error) {
return "rt", nil
}
fetchAuthorizedEmail = func(context.Context, string, string, []string, time.Duration) (string, error) {
// Authorized email has a different domain than the client alias
return "user@sub.company.ai", nil
}

// "company.ai" has no "@" so it's a client alias — should succeed without mismatch error
err := Execute([]string{"auth", "add", "company.ai", "--services", "gmail"})
if err != nil {
t.Fatalf("expected success with client alias, got: %v", err)
}
}

func TestAuthAddCmd_EmailMismatch(t *testing.T) {
origAuth := authorizeGoogle
origOpen := openSecretsStore
Expand Down
6 changes: 6 additions & 0 deletions internal/config/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ func ResolveClientForAccount(cfg File, email string, override string) (string, e
}

email = strings.ToLower(strings.TrimSpace(email))

// If input has no "@", treat it as a direct client name (e.g. "baher.ai")
if email != "" && !strings.Contains(email, "@") {
return NormalizeClientNameOrDefault(email)
}

if email != "" {
if client, ok := cfg.AccountClients[email]; ok && strings.TrimSpace(client) != "" {
return NormalizeClientNameOrDefault(client)
Expand Down
30 changes: 30 additions & 0 deletions internal/config/clients_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,36 @@ func TestResolveClientForAccount(t *testing.T) {
t.Fatalf("got %q, want %q", got, DefaultClientName)
}
})

t.Run("client name without @ is used as-is", func(t *testing.T) {
// When a bare client name like "company.ai" is passed (no @),
// it should be returned directly as the client name.
cfg := File{}
got, err := ResolveClientForAccount(cfg, "company.ai", "")
if err != nil {
t.Fatalf("ResolveClientForAccount: %v", err)
}
if got != "company.ai" {
t.Fatalf("got %q, want %q", got, "company.ai")
}
})

t.Run("subdomain email routes via account_clients mapping", func(t *testing.T) {
// Emails with subdomain domains (e.g. user@sub.company.ai) should be
// routable via account_clients config to the correct client.
cfg := File{
AccountClients: map[string]string{
"user@sub.company.ai": "company.ai",
},
}
got, err := ResolveClientForAccount(cfg, "user@sub.company.ai", "")
if err != nil {
t.Fatalf("ResolveClientForAccount: %v", err)
}
if got != "company.ai" {
t.Fatalf("got %q, want %q", got, "company.ai")
}
})
}

func TestListClientCredentials(t *testing.T) {
Expand Down