Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions src/lib/tools/BetaToolRunner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ export class BetaToolRunner<Stream extends boolean> {
#consumed = false;
/** Whether parameters have been mutated since the last API call */
#mutated = false;
/** Whether the messages array was explicitly replaced since the last API call */
#messagesMutated = false;
/** Current state containing the request parameters */
#state: { params: BetaToolRunnerParams };
#options: BetaToolRunnerRequestOptions;
Expand Down Expand Up @@ -185,6 +187,7 @@ export class BetaToolRunner<Stream extends boolean> {
}

this.#mutated = false;
this.#messagesMutated = false;
this.#toolResponse = undefined;
this.#iterationCount++;
this.#message = undefined;
Expand All @@ -205,12 +208,20 @@ export class BetaToolRunner<Stream extends boolean> {

const isCompacted = await this.#checkAndCompact();
if (!isCompacted) {
if (!this.#mutated) {
const { role, content } = await this.#message;
const message = await this.#message;

if (message.container?.id != null && this.#state.params.container == null) {
this.#state.params.container = message.container.id;
}

if (!this.#messagesMutated) {
const { role, content } = message;
this.#state.params.messages.push({ role, content });
}

const toolMessage = await this.#generateToolResponse(this.#state.params.messages.at(-1)!);
const toolMessage = await this.#generateToolResponse(
this.#messagesMutated ? this.#state.params.messages.at(-1)! : message,
);
if (toolMessage) {
this.#state.params.messages.push(toolMessage);
} else if (!this.#mutated) {
Expand Down Expand Up @@ -263,10 +274,16 @@ export class BetaToolRunner<Stream extends boolean> {
setMessagesParams(
paramsOrMutator: BetaToolRunnerParams | ((prevParams: BetaToolRunnerParams) => BetaToolRunnerParams),
) {
const previousParams = this.#state.params;
const nextParams =
typeof paramsOrMutator === 'function' ? paramsOrMutator(previousParams) : paramsOrMutator;

this.#messagesMutated = nextParams.messages !== previousParams.messages;

if (typeof paramsOrMutator === 'function') {
this.#state.params = paramsOrMutator(this.#state.params);
this.#state.params = nextParams;
} else {
this.#state.params = paramsOrMutator;
this.#state.params = nextParams;
}
this.#mutated = true;
// Invalidate cached tool response since parameters changed
Expand Down
131 changes: 131 additions & 0 deletions tests/lib/tools/ToolRunner.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,38 @@ function betaMessageToStreamEvents(message: BetaMessage): BetaRawMessageStreamEv
return events;
}

function createAssistantMessage(
content: BetaContentBlock[],
overrides: Partial<BetaMessage> = {},
): BetaMessage {
const hasToolUse = content.some((block) => block.type === 'tool_use' || block.type === 'server_tool_use');

return {
id: overrides.id ?? 'msg_custom',
type: 'message',
role: 'assistant',
content,
model: 'claude-3-5-sonnet-latest',
stop_reason: hasToolUse ? 'tool_use' : 'end_turn',
stop_sequence: null,
container: null,
context_management: null,
usage: {
input_tokens: 10,
output_tokens: 20,
cache_creation: null,
cache_creation_input_tokens: null,
cache_read_input_tokens: null,
server_tool_use: null,
service_tier: null,
inference_geo: null,
iterations: null,
speed: null,
},
...overrides,
};
}

// Overloaded setupTest function for both streaming and non-streaming
interface SetupTestResult<Stream extends boolean> {
runner: Anthropic.Beta.Messages.BetaToolRunner<Stream>;
Expand Down Expand Up @@ -768,6 +800,105 @@ describe('ToolRunner', () => {
});
await expectDone(iterator);
});

it('preserves assistant tool messages when only non-message params change', async () => {
const { runner, handleRequest } = setupTest({
messages: [{ role: 'user', content: 'Get weather' }],
});

const requestBodies: Array<Record<string, unknown>> = [];
const firstMessage = createAssistantMessage([getWeatherToolUse('Paris')], { id: 'msg_0' });
const finalMessage = createAssistantMessage([getTextContent('Done')], { id: 'msg_1' });

handleRequest(async (_req, init) => {
requestBodies.push(JSON.parse(String(init?.body ?? '{}')));
return new Response(JSON.stringify(firstMessage), {
status: 200,
headers: { 'content-type': 'application/json' },
});
});

handleRequest(async (_req, init) => {
requestBodies.push(JSON.parse(String(init?.body ?? '{}')));
return new Response(JSON.stringify(finalMessage), {
status: 200,
headers: { 'content-type': 'application/json' },
});
});

const iterator = runner[Symbol.asyncIterator]();

await expectEvent(iterator, (message) => {
expect(message.content).toMatchObject([getWeatherToolUse('Paris')]);
});

runner.setMessagesParams((params) => ({
...params,
max_tokens: 200,
}));

await expectEvent(iterator, (message) => {
expect(message.content).toMatchObject([getTextContent('Done')]);
});

expect(requestBodies).toHaveLength(2);
expect(requestBodies[1]?.max_tokens).toBe(200);
expect(requestBodies[1]?.messages).toMatchObject([
{ role: 'user', content: 'Get weather' },
{ role: 'assistant', content: [getWeatherToolUse('Paris')] },
{ role: 'user', content: [getWeatherToolResult('Paris')] },
]);

await expectDone(iterator);
});

it('forwards container.id from the assistant response to the next iteration', async () => {
const { runner, handleRequest } = setupTest({
messages: [{ role: 'user', content: 'Get weather' }],
});

const requestBodies: Array<Record<string, unknown>> = [];
const firstMessage = createAssistantMessage([getWeatherToolUse('Paris')], {
id: 'msg_0',
container: {
id: 'container_123',
expires_at: '2026-03-31T00:00:00Z',
skills: null,
},
});
const finalMessage = createAssistantMessage([getTextContent('Done')], { id: 'msg_1' });

handleRequest(async (_req, init) => {
requestBodies.push(JSON.parse(String(init?.body ?? '{}')));
return new Response(JSON.stringify(firstMessage), {
status: 200,
headers: { 'content-type': 'application/json' },
});
});

handleRequest(async (_req, init) => {
requestBodies.push(JSON.parse(String(init?.body ?? '{}')));
return new Response(JSON.stringify(finalMessage), {
status: 200,
headers: { 'content-type': 'application/json' },
});
});

const iterator = runner[Symbol.asyncIterator]();

await expectEvent(iterator, (message) => {
expect(message.container).toMatchObject({ id: 'container_123' });
});

await expectEvent(iterator, (message) => {
expect(message.content).toMatchObject([getTextContent('Done')]);
});

expect(requestBodies).toHaveLength(2);
expect(requestBodies[1]?.container).toBe('container_123');

await expectDone(iterator);
});
});

describe('.runUntilDone()', () => {
Expand Down