@@ -21,6 +21,7 @@ import { AI } from '../lib/public-types';
2121import { VertexAIBackend } from '../lib/backend' ;
2222import { TemplateGenerativeModel } from '../lib/models/template-generative-model' ;
2323import * as generateContentMethods from '../lib/methods/generate-content' ;
24+ import { EnhancedGenerateContentResponse , GenerateContentStreamResult } from '../lib/types' ;
2425
2526const fakeAI : AI = {
2627 app : {
@@ -39,6 +40,15 @@ const fakeAI: AI = {
3940const TEMPLATE_ID = 'my-template' ;
4041const TEMPLATE_VARS = { a : 1 , b : '2' } ;
4142
43+ function streamResult ( response : EnhancedGenerateContentResponse ) : GenerateContentStreamResult {
44+ return {
45+ stream : ( async function * ( ) {
46+ yield response ;
47+ } ) ( ) ,
48+ response : Promise . resolve ( response ) ,
49+ } ;
50+ }
51+
4252describe ( 'TemplateGenerativeModel' , function ( ) {
4353 afterEach ( function ( ) {
4454 jest . restoreAllMocks ( ) ;
@@ -202,6 +212,134 @@ describe('TemplateGenerativeModel', function () {
202212 ] ) ;
203213 } ) ;
204214
215+ it ( 'automatically calls functionReference from template chat function calls' , async function ( ) {
216+ const getWeather = jest . fn < ( args : object ) => object > ( ) . mockReturnValue ( { temperature : 72 } ) ;
217+ const functionCallResponse = {
218+ candidates : [
219+ {
220+ content : {
221+ role : 'model' ,
222+ parts : [
223+ {
224+ functionCall : {
225+ name : 'getWeather' ,
226+ args : { city : 'London' } ,
227+ } ,
228+ } ,
229+ ] ,
230+ } ,
231+ } ,
232+ ] ,
233+ functionCalls : ( ) => [ { name : 'getWeather' , args : { city : 'London' } } ] ,
234+ } as EnhancedGenerateContentResponse ;
235+ const finalResponse = {
236+ candidates : [
237+ {
238+ content : {
239+ role : 'model' ,
240+ parts : [ { text : 'It is 72 degrees.' } ] ,
241+ } ,
242+ } ,
243+ ] ,
244+ text : ( ) => 'It is 72 degrees.' ,
245+ functionCalls : ( ) => undefined ,
246+ } as EnhancedGenerateContentResponse ;
247+ const templateGenerateContentSpy = jest
248+ . spyOn ( generateContentMethods , 'templateGenerateContent' )
249+ . mockResolvedValueOnce ( { response : functionCallResponse } )
250+ . mockResolvedValueOnce ( { response : finalResponse } ) ;
251+ const model = new TemplateGenerativeModel ( fakeAI , { timeout : 5000 } ) ;
252+ const chat = model . startChat ( {
253+ templateId : TEMPLATE_ID ,
254+ templateVariables : TEMPLATE_VARS ,
255+ tools : [
256+ {
257+ functionDeclarations : [
258+ {
259+ name : 'getWeather' ,
260+ functionReference : getWeather ,
261+ } ,
262+ ] ,
263+ } ,
264+ ] ,
265+ } ) ;
266+
267+ const result = await chat . sendMessage ( 'weather in London' ) ;
268+ const history = await chat . getHistory ( ) ;
269+
270+ expect ( result . response . text ( ) ) . toBe ( 'It is 72 degrees.' ) ;
271+ expect ( getWeather ) . toHaveBeenCalledWith ( { city : 'London' } ) ;
272+ expect ( templateGenerateContentSpy ) . toHaveBeenCalledTimes ( 2 ) ;
273+ expect ( templateGenerateContentSpy ) . toHaveBeenLastCalledWith (
274+ model . _apiSettings ,
275+ TEMPLATE_ID ,
276+ expect . objectContaining ( {
277+ inputs : TEMPLATE_VARS ,
278+ contents : [
279+ {
280+ role : 'user' ,
281+ parts : [ { text : 'weather in London' } ] ,
282+ } ,
283+ {
284+ role : 'model' ,
285+ parts : [
286+ {
287+ functionCall : {
288+ name : 'getWeather' ,
289+ args : { city : 'London' } ,
290+ } ,
291+ } ,
292+ ] ,
293+ } ,
294+ {
295+ role : 'function' ,
296+ parts : [
297+ {
298+ functionResponse : {
299+ name : 'getWeather' ,
300+ response : { temperature : 72 } ,
301+ } ,
302+ } ,
303+ ] ,
304+ } ,
305+ ] ,
306+ } ) ,
307+ { timeout : 5000 } ,
308+ ) ;
309+ expect ( history ) . toEqual ( [
310+ {
311+ role : 'user' ,
312+ parts : [ { text : 'weather in London' } ] ,
313+ } ,
314+ {
315+ role : 'model' ,
316+ parts : [
317+ {
318+ functionCall : {
319+ name : 'getWeather' ,
320+ args : { city : 'London' } ,
321+ } ,
322+ } ,
323+ ] ,
324+ } ,
325+ {
326+ role : 'function' ,
327+ parts : [
328+ {
329+ functionResponse : {
330+ name : 'getWeather' ,
331+ response : { temperature : 72 } ,
332+ } ,
333+ } ,
334+ ] ,
335+ } ,
336+ {
337+ role : 'model' ,
338+ parts : [ { text : 'It is 72 degrees.' } ] ,
339+ } ,
340+ ] ) ;
341+ } ) ;
342+
205343 it ( 'should call templateGenerateContentStream with template chat parameters' , async function ( ) {
206344 const templateGenerateContentStreamSpy = jest
207345 . spyOn ( generateContentMethods , 'templateGenerateContentStream' )
@@ -233,5 +371,133 @@ describe('TemplateGenerativeModel', function () {
233371 { role : 'model' , parts : [ { text : 'stream back' } ] } ,
234372 ] ) ;
235373 } ) ;
374+
375+ it ( 'automatically calls functionReference from streaming template chat function calls' , async function ( ) {
376+ const getWeather = jest . fn < ( args : object ) => object > ( ) . mockReturnValue ( { temperature : 72 } ) ;
377+ const functionCallResponse = {
378+ candidates : [
379+ {
380+ content : {
381+ role : 'model' ,
382+ parts : [
383+ {
384+ functionCall : {
385+ name : 'getWeather' ,
386+ args : { city : 'London' } ,
387+ } ,
388+ } ,
389+ ] ,
390+ } ,
391+ } ,
392+ ] ,
393+ functionCalls : ( ) => [ { name : 'getWeather' , args : { city : 'London' } } ] ,
394+ } as EnhancedGenerateContentResponse ;
395+ const finalResponse = {
396+ candidates : [
397+ {
398+ content : {
399+ role : 'model' ,
400+ parts : [ { text : 'It is 72 degrees.' } ] ,
401+ } ,
402+ } ,
403+ ] ,
404+ text : ( ) => 'It is 72 degrees.' ,
405+ functionCalls : ( ) => undefined ,
406+ } as EnhancedGenerateContentResponse ;
407+ const templateGenerateContentStreamSpy = jest
408+ . spyOn ( generateContentMethods , 'templateGenerateContentStream' )
409+ . mockResolvedValueOnce ( streamResult ( functionCallResponse ) )
410+ . mockResolvedValueOnce ( streamResult ( finalResponse ) ) ;
411+ const model = new TemplateGenerativeModel ( fakeAI , { timeout : 5000 } ) ;
412+ const chat = model . startChat ( {
413+ templateId : TEMPLATE_ID ,
414+ templateVariables : TEMPLATE_VARS ,
415+ tools : [
416+ {
417+ functionDeclarations : [
418+ {
419+ name : 'getWeather' ,
420+ functionReference : getWeather ,
421+ } ,
422+ ] ,
423+ } ,
424+ ] ,
425+ } ) ;
426+
427+ const result = await chat . sendMessageStream ( 'weather in London' ) ;
428+ await result . response ;
429+ const history = await chat . getHistory ( ) ;
430+
431+ expect ( getWeather ) . toHaveBeenCalledWith ( { city : 'London' } ) ;
432+ expect ( templateGenerateContentStreamSpy ) . toHaveBeenCalledTimes ( 2 ) ;
433+ expect ( templateGenerateContentStreamSpy ) . toHaveBeenLastCalledWith (
434+ model . _apiSettings ,
435+ TEMPLATE_ID ,
436+ expect . objectContaining ( {
437+ inputs : TEMPLATE_VARS ,
438+ contents : [
439+ {
440+ role : 'user' ,
441+ parts : [ { text : 'weather in London' } ] ,
442+ } ,
443+ {
444+ role : 'model' ,
445+ parts : [
446+ {
447+ functionCall : {
448+ name : 'getWeather' ,
449+ args : { city : 'London' } ,
450+ } ,
451+ } ,
452+ ] ,
453+ } ,
454+ {
455+ role : 'function' ,
456+ parts : [
457+ {
458+ functionResponse : {
459+ name : 'getWeather' ,
460+ response : { temperature : 72 } ,
461+ } ,
462+ } ,
463+ ] ,
464+ } ,
465+ ] ,
466+ } ) ,
467+ { timeout : 5000 } ,
468+ ) ;
469+ expect ( history ) . toEqual ( [
470+ {
471+ role : 'user' ,
472+ parts : [ { text : 'weather in London' } ] ,
473+ } ,
474+ {
475+ role : 'model' ,
476+ parts : [
477+ {
478+ functionCall : {
479+ name : 'getWeather' ,
480+ args : { city : 'London' } ,
481+ } ,
482+ } ,
483+ ] ,
484+ } ,
485+ {
486+ role : 'function' ,
487+ parts : [
488+ {
489+ functionResponse : {
490+ name : 'getWeather' ,
491+ response : { temperature : 72 } ,
492+ } ,
493+ } ,
494+ ] ,
495+ } ,
496+ {
497+ role : 'model' ,
498+ parts : [ { text : 'It is 72 degrees.' } ] ,
499+ } ,
500+ ] ) ;
501+ } ) ;
236502 } ) ;
237503} ) ;
0 commit comments