Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/pretty-laws-dress.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@ai-sdk/gateway': patch
---

fix(gateway): add error handling for oidc refresh
94 changes: 67 additions & 27 deletions packages/gateway/src/gateway-provider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -655,26 +655,24 @@ describe('GatewayProvider', () => {
// Test successful cases
const result = await getGatewayAuthToken(options);

expect(result).not.toBeNull();
expect(result?.authMethod).toBe(testCase.expectedAuthMethod);
expect(result.authMethod).toBe(testCase.expectedAuthMethod);

if (testCase.expectedAuthMethod === 'api-key') {
const expectedToken =
testCase.optionsApiKey || testCase.envApiKey;
expect(result?.token).toBe(expectedToken);
expect(result.token).toBe(expectedToken);

// If we used options API key, OIDC should not be called
if (testCase.optionsApiKey) {
expect(getVercelOidcToken).not.toHaveBeenCalled();
}
} else if (testCase.expectedAuthMethod === 'oidc') {
expect(result?.token).toBe(testCase.oidcTokenMock);
expect(result.token).toBe(testCase.oidcTokenMock);
expect(getVercelOidcToken).toHaveBeenCalled();
}
} else {
// Test failure cases
const result = await getGatewayAuthToken(options);
expect(result).toBeNull();
// Test failure cases - should throw when OIDC fails
await expect(getGatewayAuthToken(options)).rejects.toThrow();
}
});
});
Expand Down Expand Up @@ -771,15 +769,13 @@ describe('GatewayProvider', () => {
AI_GATEWAY_API_KEY: '',
};

vi.mocked(getVercelOidcToken).mockRejectedValue(
new GatewayAuthenticationError({
message: 'OIDC token not available',
statusCode: 401,
}),
);
const oidcError = new GatewayAuthenticationError({
message: 'OIDC token not available',
statusCode: 401,
});
vi.mocked(getVercelOidcToken).mockRejectedValue(oidcError);

const result = await getGatewayAuthToken({});
expect(result).toBeNull();
await expect(getGatewayAuthToken({})).rejects.toThrow(oidcError);
});

it('should handle whitespace-only environment variables', async () => {
Expand All @@ -791,9 +787,8 @@ describe('GatewayProvider', () => {

// The whitespace API key should still be used (it's treated as a valid value)
const result = await getGatewayAuthToken({});
expect(result).not.toBeNull();
expect(result?.authMethod).toBe('api-key');
expect(result?.token).toBe('\t\n ');
expect(result.authMethod).toBe('api-key');
expect(result.token).toBe('\t\n ');
});

it('should prioritize options.apiKey over all environment variables', async () => {
Expand All @@ -806,11 +801,56 @@ describe('GatewayProvider', () => {
const optionsApiKey = 'options-api-key';
const result = await getGatewayAuthToken({ apiKey: optionsApiKey });

expect(result).not.toBeNull();
expect(result?.authMethod).toBe('api-key');
expect(result?.token).toBe(optionsApiKey);
expect(result.authMethod).toBe('api-key');
expect(result.token).toBe(optionsApiKey);
expect(getVercelOidcToken).not.toHaveBeenCalled();
});

it('should surface OIDC error as cause when authentication fails', async () => {
process.env = {
...originalEnv,
VERCEL_OIDC_TOKEN: '',
AI_GATEWAY_API_KEY: '',
};

delete process.env.AI_GATEWAY_API_KEY;

const oidcError = new Error(
'OIDC token generation failed: project not linked',
);
vi.mocked(getVercelOidcToken).mockRejectedValue(oidcError);

vi.mocked(GatewayFetchMetadata).mockImplementation(
(config: any) =>
({
getAvailableModels: async () => {
if (config.headers && typeof config.headers === 'function') {
await config.headers();
}
return mockGetAvailableModels();
},
getCredits: async () => {
if (config.headers && typeof config.headers === 'function') {
await config.headers();
}
return mockGetCredits();
},
}) as any,
);

const provider = createGatewayProvider();

try {
await provider.getAvailableModels();
fail('Expected an error to be thrown');
} catch (error) {
expect(GatewayAuthenticationError.isInstance(error)).toBe(true);
if (GatewayAuthenticationError.isInstance(error)) {
expect(error.cause).toBe(oidcError);
expect(error.message).toContain('No authentication provided');
}
}
});
});

describe('Authentication precedence', () => {
Expand All @@ -823,8 +863,8 @@ describe('GatewayProvider', () => {
const optionsApiKey = 'options-api-key';
const result = await getGatewayAuthToken({ apiKey: optionsApiKey });

expect(result?.authMethod).toBe('api-key');
expect(result?.token).toBe(optionsApiKey);
expect(result.authMethod).toBe('api-key');
expect(result.token).toBe(optionsApiKey);
expect(getVercelOidcToken).not.toHaveBeenCalled();
});

Expand All @@ -837,8 +877,8 @@ describe('GatewayProvider', () => {

const result = await getGatewayAuthToken({});

expect(result?.authMethod).toBe('api-key');
expect(result?.token).toBe('env-api-key');
expect(result.authMethod).toBe('api-key');
expect(result.token).toBe('env-api-key');
expect(getVercelOidcToken).not.toHaveBeenCalled();
});

Expand All @@ -852,8 +892,8 @@ describe('GatewayProvider', () => {

const result = await getGatewayAuthToken({});

expect(result?.authMethod).toBe('oidc');
expect(result?.token).toBe('oidc-token');
expect(result.authMethod).toBe('oidc');
expect(result.token).toBe('oidc-token');
expect(getVercelOidcToken).toHaveBeenCalled();
});
});
Expand Down
36 changes: 15 additions & 21 deletions packages/gateway/src/gateway-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ export function createGatewayProvider(
'https://ai-gateway.vercel.sh/v3/ai';

const getHeaders = async () => {
const auth = await getGatewayAuthToken(options);
if (auth) {
try {
const auth = await getGatewayAuthToken(options);
return withUserAgentSuffix(
{
Authorization: `Bearer ${auth.token}`,
Expand All @@ -128,13 +128,14 @@ export function createGatewayProvider(
},
`ai-sdk/gateway/${VERSION}`,
);
} catch (error) {
throw GatewayAuthenticationError.createContextualError({
apiKeyProvided: false,
oidcTokenProvided: false,
statusCode: 401,
cause: error,
});
}

throw GatewayAuthenticationError.createContextualError({
apiKeyProvided: false,
oidcTokenProvided: false,
statusCode: 401,
});
};

const createO11yHeaders = () => {
Expand Down Expand Up @@ -255,10 +256,7 @@ export const gateway = createGatewayProvider();

export async function getGatewayAuthToken(
options: GatewayProviderSettings,
): Promise<{
token: string;
authMethod: 'api-key' | 'oidc';
} | null> {
): Promise<{ token: string; authMethod: 'api-key' | 'oidc' }> {
const apiKey = loadOptionalSetting({
settingValue: options.apiKey,
environmentVariableName: 'AI_GATEWAY_API_KEY',
Expand All @@ -271,13 +269,9 @@ export async function getGatewayAuthToken(
};
}

try {
const oidcToken = await getVercelOidcToken();
return {
token: oidcToken,
authMethod: 'oidc',
};
} catch {
return null;
}
const oidcToken = await getVercelOidcToken();
return {
token: oidcToken,
authMethod: 'oidc',
};
}
Loading