Skip to content

Commit 7f31d9c

Browse files
committed
feat(ai): support template chat function auto-calling
1 parent 6d65d5f commit 7f31d9c

3 files changed

Lines changed: 424 additions & 15 deletions

File tree

packages/ai/__tests__/template-generative-model.test.ts

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import { AI } from '../lib/public-types';
2121
import { VertexAIBackend } from '../lib/backend';
2222
import { TemplateGenerativeModel } from '../lib/models/template-generative-model';
2323
import * as generateContentMethods from '../lib/methods/generate-content';
24+
import { EnhancedGenerateContentResponse, GenerateContentStreamResult } from '../lib/types';
2425

2526
const fakeAI: AI = {
2627
app: {
@@ -39,6 +40,15 @@ const fakeAI: AI = {
3940
const TEMPLATE_ID = 'my-template';
4041
const TEMPLATE_VARS = { a: 1, b: '2' };
4142

43+
function streamResult(response: EnhancedGenerateContentResponse): GenerateContentStreamResult {
44+
return {
45+
stream: (async function* () {
46+
yield response;
47+
})(),
48+
response: Promise.resolve(response),
49+
};
50+
}
51+
4252
describe('TemplateGenerativeModel', function () {
4353
afterEach(function () {
4454
jest.restoreAllMocks();
@@ -202,6 +212,134 @@ describe('TemplateGenerativeModel', function () {
202212
]);
203213
});
204214

215+
it('automatically calls functionReference from template chat function calls', async function () {
216+
const getWeather = jest.fn<(args: object) => object>().mockReturnValue({ temperature: 72 });
217+
const functionCallResponse = {
218+
candidates: [
219+
{
220+
content: {
221+
role: 'model',
222+
parts: [
223+
{
224+
functionCall: {
225+
name: 'getWeather',
226+
args: { city: 'London' },
227+
},
228+
},
229+
],
230+
},
231+
},
232+
],
233+
functionCalls: () => [{ name: 'getWeather', args: { city: 'London' } }],
234+
} as EnhancedGenerateContentResponse;
235+
const finalResponse = {
236+
candidates: [
237+
{
238+
content: {
239+
role: 'model',
240+
parts: [{ text: 'It is 72 degrees.' }],
241+
},
242+
},
243+
],
244+
text: () => 'It is 72 degrees.',
245+
functionCalls: () => undefined,
246+
} as EnhancedGenerateContentResponse;
247+
const templateGenerateContentSpy = jest
248+
.spyOn(generateContentMethods, 'templateGenerateContent')
249+
.mockResolvedValueOnce({ response: functionCallResponse })
250+
.mockResolvedValueOnce({ response: finalResponse });
251+
const model = new TemplateGenerativeModel(fakeAI, { timeout: 5000 });
252+
const chat = model.startChat({
253+
templateId: TEMPLATE_ID,
254+
templateVariables: TEMPLATE_VARS,
255+
tools: [
256+
{
257+
functionDeclarations: [
258+
{
259+
name: 'getWeather',
260+
functionReference: getWeather,
261+
},
262+
],
263+
},
264+
],
265+
});
266+
267+
const result = await chat.sendMessage('weather in London');
268+
const history = await chat.getHistory();
269+
270+
expect(result.response.text()).toBe('It is 72 degrees.');
271+
expect(getWeather).toHaveBeenCalledWith({ city: 'London' });
272+
expect(templateGenerateContentSpy).toHaveBeenCalledTimes(2);
273+
expect(templateGenerateContentSpy).toHaveBeenLastCalledWith(
274+
model._apiSettings,
275+
TEMPLATE_ID,
276+
expect.objectContaining({
277+
inputs: TEMPLATE_VARS,
278+
contents: [
279+
{
280+
role: 'user',
281+
parts: [{ text: 'weather in London' }],
282+
},
283+
{
284+
role: 'model',
285+
parts: [
286+
{
287+
functionCall: {
288+
name: 'getWeather',
289+
args: { city: 'London' },
290+
},
291+
},
292+
],
293+
},
294+
{
295+
role: 'function',
296+
parts: [
297+
{
298+
functionResponse: {
299+
name: 'getWeather',
300+
response: { temperature: 72 },
301+
},
302+
},
303+
],
304+
},
305+
],
306+
}),
307+
{ timeout: 5000 },
308+
);
309+
expect(history).toEqual([
310+
{
311+
role: 'user',
312+
parts: [{ text: 'weather in London' }],
313+
},
314+
{
315+
role: 'model',
316+
parts: [
317+
{
318+
functionCall: {
319+
name: 'getWeather',
320+
args: { city: 'London' },
321+
},
322+
},
323+
],
324+
},
325+
{
326+
role: 'function',
327+
parts: [
328+
{
329+
functionResponse: {
330+
name: 'getWeather',
331+
response: { temperature: 72 },
332+
},
333+
},
334+
],
335+
},
336+
{
337+
role: 'model',
338+
parts: [{ text: 'It is 72 degrees.' }],
339+
},
340+
]);
341+
});
342+
205343
it('should call templateGenerateContentStream with template chat parameters', async function () {
206344
const templateGenerateContentStreamSpy = jest
207345
.spyOn(generateContentMethods, 'templateGenerateContentStream')
@@ -233,5 +371,133 @@ describe('TemplateGenerativeModel', function () {
233371
{ role: 'model', parts: [{ text: 'stream back' }] },
234372
]);
235373
});
374+
375+
it('automatically calls functionReference from streaming template chat function calls', async function () {
376+
const getWeather = jest.fn<(args: object) => object>().mockReturnValue({ temperature: 72 });
377+
const functionCallResponse = {
378+
candidates: [
379+
{
380+
content: {
381+
role: 'model',
382+
parts: [
383+
{
384+
functionCall: {
385+
name: 'getWeather',
386+
args: { city: 'London' },
387+
},
388+
},
389+
],
390+
},
391+
},
392+
],
393+
functionCalls: () => [{ name: 'getWeather', args: { city: 'London' } }],
394+
} as EnhancedGenerateContentResponse;
395+
const finalResponse = {
396+
candidates: [
397+
{
398+
content: {
399+
role: 'model',
400+
parts: [{ text: 'It is 72 degrees.' }],
401+
},
402+
},
403+
],
404+
text: () => 'It is 72 degrees.',
405+
functionCalls: () => undefined,
406+
} as EnhancedGenerateContentResponse;
407+
const templateGenerateContentStreamSpy = jest
408+
.spyOn(generateContentMethods, 'templateGenerateContentStream')
409+
.mockResolvedValueOnce(streamResult(functionCallResponse))
410+
.mockResolvedValueOnce(streamResult(finalResponse));
411+
const model = new TemplateGenerativeModel(fakeAI, { timeout: 5000 });
412+
const chat = model.startChat({
413+
templateId: TEMPLATE_ID,
414+
templateVariables: TEMPLATE_VARS,
415+
tools: [
416+
{
417+
functionDeclarations: [
418+
{
419+
name: 'getWeather',
420+
functionReference: getWeather,
421+
},
422+
],
423+
},
424+
],
425+
});
426+
427+
const result = await chat.sendMessageStream('weather in London');
428+
await result.response;
429+
const history = await chat.getHistory();
430+
431+
expect(getWeather).toHaveBeenCalledWith({ city: 'London' });
432+
expect(templateGenerateContentStreamSpy).toHaveBeenCalledTimes(2);
433+
expect(templateGenerateContentStreamSpy).toHaveBeenLastCalledWith(
434+
model._apiSettings,
435+
TEMPLATE_ID,
436+
expect.objectContaining({
437+
inputs: TEMPLATE_VARS,
438+
contents: [
439+
{
440+
role: 'user',
441+
parts: [{ text: 'weather in London' }],
442+
},
443+
{
444+
role: 'model',
445+
parts: [
446+
{
447+
functionCall: {
448+
name: 'getWeather',
449+
args: { city: 'London' },
450+
},
451+
},
452+
],
453+
},
454+
{
455+
role: 'function',
456+
parts: [
457+
{
458+
functionResponse: {
459+
name: 'getWeather',
460+
response: { temperature: 72 },
461+
},
462+
},
463+
],
464+
},
465+
],
466+
}),
467+
{ timeout: 5000 },
468+
);
469+
expect(history).toEqual([
470+
{
471+
role: 'user',
472+
parts: [{ text: 'weather in London' }],
473+
},
474+
{
475+
role: 'model',
476+
parts: [
477+
{
478+
functionCall: {
479+
name: 'getWeather',
480+
args: { city: 'London' },
481+
},
482+
},
483+
],
484+
},
485+
{
486+
role: 'function',
487+
parts: [
488+
{
489+
functionResponse: {
490+
name: 'getWeather',
491+
response: { temperature: 72 },
492+
},
493+
},
494+
],
495+
},
496+
{
497+
role: 'model',
498+
parts: [{ text: 'It is 72 degrees.' }],
499+
},
500+
]);
501+
});
236502
});
237503
});

0 commit comments

Comments
 (0)