1- import { createTestServer } from '@ai-sdk/test-server/with-vitest' ;
21import { DownloadError } from '@ai-sdk/provider-utils' ;
32import { download } from './download' ;
4- import { describe , it , expect , vi , afterEach } from 'vitest' ;
5-
6- const server = createTestServer ( {
7- 'http://example.com/file' : { } ,
8- 'http://example.com/large' : { } ,
9- } ) ;
3+ import { describe , it , expect , vi , afterEach , beforeEach } from 'vitest' ;
104
115describe ( 'download SSRF protection' , ( ) => {
126 it ( 'should reject private IPv4 addresses' , async ( ) => {
@@ -100,16 +94,32 @@ describe('download SSRF redirect protection', () => {
10094} ) ;
10195
10296describe ( 'download' , ( ) => {
97+ const originalFetch = globalThis . fetch ;
98+
99+ beforeEach ( ( ) => {
100+ vi . resetAllMocks ( ) ;
101+ } ) ;
102+
103+ afterEach ( ( ) => {
104+ globalThis . fetch = originalFetch ;
105+ } ) ;
106+
103107 it ( 'should download data successfully and match expected bytes' , async ( ) => {
104108 const expectedBytes = new Uint8Array ( [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ] ) ;
105109
106- server . urls [ 'http://example.com/file' ] . response = {
107- type : 'binary' ,
108- headers : {
110+ globalThis . fetch = vi . fn ( ) . mockResolvedValue ( {
111+ ok : true ,
112+ status : 200 ,
113+ headers : new Headers ( {
109114 'content-type' : 'application/octet-stream' ,
110- } ,
111- body : Buffer . from ( expectedBytes ) ,
112- } ;
115+ } ) ,
116+ body : new ReadableStream ( {
117+ start ( controller ) {
118+ controller . enqueue ( expectedBytes ) ;
119+ controller . close ( ) ;
120+ } ,
121+ } ) ,
122+ } as unknown as Response ) ;
113123
114124 const result = await download ( {
115125 url : new URL ( 'http://example.com/file' ) ,
@@ -119,16 +129,21 @@ describe('download', () => {
119129 expect ( result ! . data ) . toEqual ( expectedBytes ) ;
120130 expect ( result ! . mediaType ) . toBe ( 'application/octet-stream' ) ;
121131
122- // UA header assertion
123- expect ( server . calls [ 0 ] . requestUserAgent ) . toContain ( 'ai-sdk/' ) ;
132+ expect ( fetch ) . toHaveBeenCalledWith (
133+ 'http://example.com/file' ,
134+ expect . objectContaining ( {
135+ headers : expect . any ( Object ) ,
136+ } ) ,
137+ ) ;
124138 } ) ;
125139
126140 it ( 'should throw DownloadError when response is not ok' , async ( ) => {
127- server . urls [ 'http://example.com/file' ] . response = {
128- type : 'error' ,
141+ globalThis . fetch = vi . fn ( ) . mockResolvedValue ( {
142+ ok : false ,
129143 status : 404 ,
130- body : 'Not Found' ,
131- } ;
144+ statusText : 'Not Found' ,
145+ headers : new Headers ( ) ,
146+ } as unknown as Response ) ;
132147
133148 try {
134149 await download ( {
@@ -143,11 +158,7 @@ describe('download', () => {
143158 } ) ;
144159
145160 it ( 'should throw DownloadError when fetch throws an error' , async ( ) => {
146- server . urls [ 'http://example.com/file' ] . response = {
147- type : 'error' ,
148- status : 500 ,
149- body : 'Network error' ,
150- } ;
161+ globalThis . fetch = vi . fn ( ) . mockRejectedValue ( new Error ( 'Network error' ) ) ;
151162
152163 try {
153164 await download ( {
@@ -160,15 +171,20 @@ describe('download', () => {
160171 } ) ;
161172
162173 it ( 'should abort when response exceeds default size limit' , async ( ) => {
163- // Create a response that claims to be larger than 2 GiB
164- server . urls [ 'http://example.com/large' ] . response = {
165- type : 'binary' ,
166- headers : {
174+ globalThis . fetch = vi . fn ( ) . mockResolvedValue ( {
175+ ok : true ,
176+ status : 200 ,
177+ headers : new Headers ( {
167178 'content-type' : 'application/octet-stream' ,
168179 'content-length' : `${ 3 * 1024 * 1024 * 1024 } ` ,
169- } ,
170- body : Buffer . from ( new Uint8Array ( 10 ) ) ,
171- } ;
180+ } ) ,
181+ body : new ReadableStream ( {
182+ start ( controller ) {
183+ controller . enqueue ( new Uint8Array ( 10 ) ) ;
184+ controller . close ( ) ;
185+ } ,
186+ } ) ,
187+ } as unknown as Response ) ;
172188
173189 try {
174190 await download ( {
@@ -187,13 +203,11 @@ describe('download', () => {
187203 const controller = new AbortController ( ) ;
188204 controller . abort ( ) ;
189205
190- server . urls [ 'http://example.com/file' ] . response = {
191- type : 'binary' ,
192- headers : {
193- 'content-type' : 'application/octet-stream' ,
194- } ,
195- body : Buffer . from ( new Uint8Array ( [ 1 , 2 , 3 ] ) ) ,
196- } ;
206+ globalThis . fetch = vi
207+ . fn ( )
208+ . mockRejectedValue (
209+ new DOMException ( 'The operation was aborted.' , 'AbortError' ) ,
210+ ) ;
197211
198212 try {
199213 await download ( {
@@ -202,8 +216,14 @@ describe('download', () => {
202216 } ) ;
203217 expect . fail ( 'Expected download to throw' ) ;
204218 } catch ( error : unknown ) {
205- // The fetch should be aborted, resulting in a DownloadError wrapping an AbortError
206219 expect ( DownloadError . isInstance ( error ) ) . toBe ( true ) ;
207220 }
221+
222+ expect ( fetch ) . toHaveBeenCalledWith (
223+ 'http://example.com/file' ,
224+ expect . objectContaining ( {
225+ signal : controller . signal ,
226+ } ) ,
227+ ) ;
208228 } ) ;
209229} ) ;
0 commit comments