Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions packages/ai/src/providers/openai-completions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ export const streamOpenAICompletions: StreamFunction<"openai-completions", OpenA
partial: output,
});
} else if (block.type === "toolCall") {
if (!block.id || !block.name) {
throw new Error(`Malformed tool call from provider: missing ${!block.id ? "id" : "name"}`);
}
block.arguments = parseStreamingJson(block.partialArgs);
// Finalize in-place and strip the scratch buffers so replay only
// carries parsed arguments.
Expand Down Expand Up @@ -406,6 +409,7 @@ export const streamOpenAICompletions: StreamFunction<"openai-completions", OpenA
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
output.content = output.content.filter((block) => block.type !== "toolCall" || (block.id && block.name));
for (const block of output.content) {
delete (block as { index?: number }).index;
// Streaming scratch buffers are only used during parsing; never persist them.
Expand Down
119 changes: 119 additions & 0 deletions packages/ai/test/openai-completions-malformed-tool-call.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import { beforeEach, describe, expect, it, vi } from "vitest";
import { streamOpenAICompletions } from "../src/providers/openai-completions.ts";
import type { Context, Model } from "../src/types.ts";

const mockState = vi.hoisted(() => ({
chunks: [] as Array<{
id: string;
choices: Array<{
index: number;
delta: Record<string, unknown>;
finish_reason: string | null;
}>;
} | null>,
}));

vi.mock("openai", () => {
class FakeOpenAI {
chat = {
completions: {
create: () => {
const stream = {
async *[Symbol.asyncIterator]() {
for (const chunk of mockState.chunks) {
yield chunk;
}
},
};
const promise = Promise.resolve(stream) as Promise<typeof stream> & {
withResponse: () => Promise<{
data: typeof stream;
response: { status: number; headers: Headers };
}>;
};
promise.withResponse = async () => ({
data: stream,
response: { status: 200, headers: new Headers() },
});
return promise;
},
},
};
}
return { default: FakeOpenAI };
});

const model: Model<"openai-completions"> = {
id: "test-model",
name: "Test Model",
api: "openai-completions",
provider: "test-provider",
baseUrl: "https://example.invalid",
reasoning: false,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 1000,
maxTokens: 100,
};

const context: Context = {
systemPrompt: "",
messages: [{ role: "user", content: [{ type: "text", text: "hi" }], timestamp: 0 }],
tools: [],
};

beforeEach(() => {
mockState.chunks = [];
});

describe("openai-completions malformed tool call handling", () => {
it.each([
{
name: "id",
toolCall: {
index: 0,
type: "function",
function: { name: "read", arguments: '{"path":"README.md"}' },
},
},
{
name: "name",
toolCall: {
index: 0,
id: "call_test",
type: "function",
function: { arguments: '{"path":"README.md"}' },
},
},
])("fails when the provider emits a tool call without $name", async ({ name, toolCall }) => {
mockState.chunks = [
{
id: "chatcmpl-test",
choices: [
{
index: 0,
delta: {
tool_calls: [toolCall],
},
finish_reason: "tool_calls",
},
],
},
];

const stream = streamOpenAICompletions(model, context, { apiKey: "test" });
let errorMessage: string | undefined;
for await (const event of stream) {
if (event.type === "error") {
errorMessage = event.error.errorMessage;
}
}

const expectedError = `Malformed tool call from provider: missing ${name}`;
const result = await stream.result();
expect(errorMessage).toBe(expectedError);
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toBe(expectedError);
expect(result.content).toEqual([]);
});
});
Loading