Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changeset/fix-ssrf-redirect-bypass.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
'@ai-sdk/provider-utils': patch
'ai': patch
---

fix(security): validate redirect targets in download functions to prevent SSRF bypass

Both `downloadBlob` and `download` now validate the final URL after following HTTP redirects, preventing attackers from bypassing SSRF protections via open redirects to internal/private addresses.
169 changes: 130 additions & 39 deletions packages/ai/src/util/download/download.test.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import { createTestServer } from '@ai-sdk/test-server/with-vitest';
import { DownloadError } from '@ai-sdk/provider-utils';
import { download } from './download';
import { describe, it, expect, vi } from 'vitest';

const server = createTestServer({
'http://example.com/file': {},
'http://example.com/large': {},
});
import { describe, it, expect, vi, afterEach, beforeEach } from 'vitest';

describe('download SSRF protection', () => {
it('should reject private IPv4 addresses', async () => {
Expand All @@ -28,17 +22,104 @@ describe('download SSRF protection', () => {
});
});

describe('download SSRF redirect protection', () => {
const originalFetch = globalThis.fetch;

afterEach(() => {
globalThis.fetch = originalFetch;
});

it('should reject redirects to private IP addresses', async () => {
globalThis.fetch = vi.fn().mockResolvedValue({
ok: true,
status: 200,
redirected: true,
url: 'http://169.254.169.254/latest/meta-data/',
headers: new Headers({ 'content-type': 'text/plain' }),
body: new ReadableStream({
start(controller) {
controller.enqueue(new TextEncoder().encode('secret'));
controller.close();
},
}),
} as unknown as Response);

await expect(
download({ url: new URL('https://evil.com/redirect') }),
).rejects.toThrow(DownloadError);
});

it('should reject redirects to localhost', async () => {
globalThis.fetch = vi.fn().mockResolvedValue({
ok: true,
status: 200,
redirected: true,
url: 'http://localhost:8080/admin',
headers: new Headers({ 'content-type': 'text/plain' }),
body: new ReadableStream({
start(controller) {
controller.enqueue(new TextEncoder().encode('secret'));
controller.close();
},
}),
} as unknown as Response);

await expect(
download({ url: new URL('https://evil.com/redirect') }),
).rejects.toThrow(DownloadError);
});

it('should allow redirects to safe URLs', async () => {
const content = new Uint8Array([1, 2, 3]);
globalThis.fetch = vi.fn().mockResolvedValue({
ok: true,
status: 200,
redirected: true,
url: 'https://cdn.example.com/image.png',
headers: new Headers({ 'content-type': 'image/png' }),
body: new ReadableStream({
start(controller) {
controller.enqueue(content);
controller.close();
},
}),
} as unknown as Response);

const result = await download({
url: new URL('https://example.com/image.png'),
});
expect(result.data).toEqual(content);
expect(result.mediaType).toBe('image/png');
});
});

describe('download', () => {
const originalFetch = globalThis.fetch;

beforeEach(() => {
vi.resetAllMocks();
});

afterEach(() => {
globalThis.fetch = originalFetch;
});

it('should download data successfully and match expected bytes', async () => {
const expectedBytes = new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8]);

server.urls['http://example.com/file'].response = {
type: 'binary',
headers: {
globalThis.fetch = vi.fn().mockResolvedValue({
ok: true,
status: 200,
headers: new Headers({
'content-type': 'application/octet-stream',
},
body: Buffer.from(expectedBytes),
};
}),
body: new ReadableStream({
start(controller) {
controller.enqueue(expectedBytes);
controller.close();
},
}),
} as unknown as Response);

const result = await download({
url: new URL('http://example.com/file'),
Expand All @@ -48,16 +129,21 @@ describe('download', () => {
expect(result!.data).toEqual(expectedBytes);
expect(result!.mediaType).toBe('application/octet-stream');

// UA header assertion
expect(server.calls[0].requestUserAgent).toContain('ai-sdk/');
expect(fetch).toHaveBeenCalledWith(
'http://example.com/file',
expect.objectContaining({
headers: expect.any(Object),
}),
);
});

it('should throw DownloadError when response is not ok', async () => {
server.urls['http://example.com/file'].response = {
type: 'error',
globalThis.fetch = vi.fn().mockResolvedValue({
ok: false,
status: 404,
body: 'Not Found',
};
statusText: 'Not Found',
headers: new Headers(),
} as unknown as Response);

try {
await download({
Expand All @@ -72,11 +158,7 @@ describe('download', () => {
});

it('should throw DownloadError when fetch throws an error', async () => {
server.urls['http://example.com/file'].response = {
type: 'error',
status: 500,
body: 'Network error',
};
globalThis.fetch = vi.fn().mockRejectedValue(new Error('Network error'));

try {
await download({
Expand All @@ -89,15 +171,20 @@ describe('download', () => {
});

it('should abort when response exceeds default size limit', async () => {
// Create a response that claims to be larger than 2 GiB
server.urls['http://example.com/large'].response = {
type: 'binary',
headers: {
globalThis.fetch = vi.fn().mockResolvedValue({
ok: true,
status: 200,
headers: new Headers({
'content-type': 'application/octet-stream',
'content-length': `${3 * 1024 * 1024 * 1024}`,
},
body: Buffer.from(new Uint8Array(10)),
};
}),
body: new ReadableStream({
start(controller) {
controller.enqueue(new Uint8Array(10));
controller.close();
},
}),
} as unknown as Response);

try {
await download({
Expand All @@ -116,13 +203,11 @@ describe('download', () => {
const controller = new AbortController();
controller.abort();

server.urls['http://example.com/file'].response = {
type: 'binary',
headers: {
'content-type': 'application/octet-stream',
},
body: Buffer.from(new Uint8Array([1, 2, 3])),
};
globalThis.fetch = vi
.fn()
.mockRejectedValue(
new DOMException('The operation was aborted.', 'AbortError'),
);

try {
await download({
Expand All @@ -131,8 +216,14 @@ describe('download', () => {
});
expect.fail('Expected download to throw');
} catch (error: unknown) {
// The fetch should be aborted, resulting in a DownloadError wrapping an AbortError
expect(DownloadError.isInstance(error)).toBe(true);
}

expect(fetch).toHaveBeenCalledWith(
'http://example.com/file',
expect.objectContaining({
signal: controller.signal,
}),
);
});
});
5 changes: 5 additions & 0 deletions packages/ai/src/util/download/download.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ export const download = async ({
signal: abortSignal,
});

// Validate final URL after redirects to prevent SSRF via open redirect
if (response.redirected) {
validateDownloadUrl(response.url);
}

if (!response.ok) {
throw new DownloadError({
url: urlText,
Expand Down
76 changes: 76 additions & 0 deletions packages/provider-utils/src/download-blob.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,82 @@ describe('downloadBlob() SSRF protection', () => {
DownloadError,
);
});

it('should reject redirects to private IP addresses', async () => {
const originalFetch = globalThis.fetch;
globalThis.fetch = vi.fn().mockResolvedValue({
ok: true,
status: 200,
redirected: true,
url: 'http://169.254.169.254/latest/meta-data/',
headers: new Headers({ 'content-type': 'text/plain' }),
body: new ReadableStream({
start(controller) {
controller.enqueue(new TextEncoder().encode('secret'));
controller.close();
},
}),
} as unknown as Response);

try {
await expect(downloadBlob('https://evil.com/redirect')).rejects.toThrow(
DownloadError,
);
} finally {
globalThis.fetch = originalFetch;
}
});

it('should reject redirects to localhost', async () => {
const originalFetch = globalThis.fetch;
globalThis.fetch = vi.fn().mockResolvedValue({
ok: true,
status: 200,
redirected: true,
url: 'http://localhost:8080/admin',
headers: new Headers({ 'content-type': 'text/plain' }),
body: new ReadableStream({
start(controller) {
controller.enqueue(new TextEncoder().encode('secret'));
controller.close();
},
}),
} as unknown as Response);

try {
await expect(downloadBlob('https://evil.com/redirect')).rejects.toThrow(
DownloadError,
);
} finally {
globalThis.fetch = originalFetch;
}
});

it('should allow redirects to safe URLs', async () => {
const originalFetch = globalThis.fetch;
const content = new TextEncoder().encode('safe content');
globalThis.fetch = vi.fn().mockResolvedValue({
ok: true,
status: 200,
redirected: true,
url: 'https://cdn.example.com/image.png',
headers: new Headers({ 'content-type': 'image/png' }),
body: new ReadableStream({
start(controller) {
controller.enqueue(content);
controller.close();
},
}),
} as unknown as Response);

try {
const result = await downloadBlob('https://example.com/image.png');
expect(result).toBeInstanceOf(Blob);
expect(result.type).toBe('image/png');
} finally {
globalThis.fetch = originalFetch;
}
});
});

describe('DownloadError', () => {
Expand Down
5 changes: 5 additions & 0 deletions packages/provider-utils/src/download-blob.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ export async function downloadBlob(
signal: options?.abortSignal,
});

// Validate final URL after redirects to prevent SSRF via open redirect
if (response.redirected) {
validateDownloadUrl(response.url);
}

if (!response.ok) {
throw new DownloadError({
url,
Expand Down
Loading