Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/loud-apples-shop.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@ai-sdk/gateway': patch
---

fix: image generation via Gateway warning schema mismatch
5 changes: 5 additions & 0 deletions .changeset/moody-geese-move.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@ai-sdk/gateway': patch
---

feat: report image generation usage info in Gateway
178 changes: 177 additions & 1 deletion packages/gateway/src/gateway-image-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ describe('GatewayImageModel', () => {
providerMetadata,
}: {
images?: string[];
warnings?: Array<{ type: 'other'; message: string }>;
warnings?: Array<
| { type: 'unsupported'; feature: string; details?: string }
| { type: 'compatibility'; feature: string; details?: string }
| { type: 'other'; message: string }
>;
providerMetadata?: Record<string, unknown>;
} = {}) {
server.urls['https://api.test.com/image-model'].response = {
Expand Down Expand Up @@ -303,6 +307,99 @@ describe('GatewayImageModel', () => {
expect(result.warnings).toEqual(mockWarnings);
});

it('should return unsupported warnings correctly', async () => {
const mockWarnings = [
{
type: 'unsupported' as const,
feature: 'size',
details:
'This model does not support the `size` option. Use `aspectRatio` instead.',
},
];

prepareJsonResponse({
images: ['base64-1'],
warnings: mockWarnings,
});

const result = await createTestModel().doGenerate({
prompt: 'Test prompt',
files: undefined,
mask: undefined,
n: 1,
size: '1024x1024',
aspectRatio: undefined,
seed: undefined,
providerOptions: {},
});

expect(result.warnings).toEqual(mockWarnings);
});

it('should return compatibility warnings correctly', async () => {
const mockWarnings = [
{
type: 'compatibility' as const,
feature: 'seed',
details: 'Seed support is approximate for this model.',
},
];

prepareJsonResponse({
images: ['base64-1'],
warnings: mockWarnings,
});

const result = await createTestModel().doGenerate({
prompt: 'Test prompt',
files: undefined,
mask: undefined,
n: 1,
size: undefined,
aspectRatio: undefined,
seed: 42,
providerOptions: {},
});

expect(result.warnings).toEqual(mockWarnings);
});

it('should handle mixed warning types', async () => {
const mockWarnings = [
{
type: 'unsupported' as const,
feature: 'size',
},
{
type: 'compatibility' as const,
feature: 'seed',
details: 'Approximate seed support.',
},
{
type: 'other' as const,
message: 'Rate limit approaching.',
},
];

prepareJsonResponse({
images: ['base64-1'],
warnings: mockWarnings,
});

const result = await createTestModel().doGenerate({
prompt: 'Test prompt',
files: undefined,
mask: undefined,
n: 1,
size: '1024x1024',
aspectRatio: undefined,
seed: 42,
providerOptions: {},
});

expect(result.warnings).toEqual(mockWarnings);
});

it('should return empty warnings array when not provided', async () => {
prepareJsonResponse({
images: ['base64-1'],
Expand Down Expand Up @@ -343,6 +440,85 @@ describe('GatewayImageModel', () => {
expect(result.response.headers).toBeDefined();
});

it('should return usage when provided', async () => {
server.urls['https://api.test.com/image-model'].response = {
type: 'json-value',
body: {
images: ['base64-1'],
usage: {
inputTokens: 27,
outputTokens: 6240,
totalTokens: 6267,
},
},
};

const result = await createTestModel().doGenerate({
prompt: 'Test prompt',
files: undefined,
mask: undefined,
n: 1,
size: undefined,
aspectRatio: undefined,
seed: undefined,
providerOptions: {},
});

expect(result.usage).toEqual({
inputTokens: 27,
outputTokens: 6240,
totalTokens: 6267,
});
});

it('should return usage with partial token counts', async () => {
server.urls['https://api.test.com/image-model'].response = {
type: 'json-value',
body: {
images: ['base64-1'],
usage: {
inputTokens: 10,
},
},
};

const result = await createTestModel().doGenerate({
prompt: 'Test prompt',
files: undefined,
mask: undefined,
n: 1,
size: undefined,
aspectRatio: undefined,
seed: undefined,
providerOptions: {},
});

expect(result.usage).toEqual({
inputTokens: 10,
outputTokens: undefined,
totalTokens: undefined,
});
});

it('should not include usage when not provided', async () => {
prepareJsonResponse({
images: ['base64-1'],
});

const result = await createTestModel().doGenerate({
prompt: 'Test prompt',
files: undefined,
mask: undefined,
n: 1,
size: undefined,
aspectRatio: undefined,
seed: undefined,
providerOptions: {},
});

expect(result.usage).toBeUndefined();
});

it('should merge custom headers with config headers', async () => {
prepareJsonResponse();

Expand Down
40 changes: 32 additions & 8 deletions packages/gateway/src/gateway-image-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ export class GatewayImageModel implements ImageModelV3 {
modelId: this.modelId,
headers: responseHeaders,
},
...(responseBody.usage != null && {
usage: {
inputTokens: responseBody.usage.inputTokens ?? undefined,
outputTokens: responseBody.usage.outputTokens ?? undefined,
totalTokens: responseBody.usage.totalTokens ?? undefined,
},
}),
};
} catch (error) {
throw asGatewayError(error, await parseAuthMethod(resolvedHeaders));
Expand Down Expand Up @@ -129,17 +136,34 @@ const providerMetadataEntrySchema = z
})
.catchall(z.unknown());

const gatewayImageWarningSchema = z.discriminatedUnion('type', [
z.object({
type: z.literal('unsupported'),
feature: z.string(),
details: z.string().optional(),
}),
z.object({
type: z.literal('compatibility'),
feature: z.string(),
details: z.string().optional(),
}),
z.object({
type: z.literal('other'),
message: z.string(),
}),
]);

const gatewayImageUsageSchema = z.object({
inputTokens: z.number().nullish(),
outputTokens: z.number().nullish(),
totalTokens: z.number().nullish(),
});

const gatewayImageResponseSchema = z.object({
images: z.array(z.string()), // Always base64 strings over the wire
warnings: z
.array(
z.object({
type: z.literal('other'),
message: z.string(),
}),
)
.optional(),
warnings: z.array(gatewayImageWarningSchema).optional(),
providerMetadata: z
.record(z.string(), providerMetadataEntrySchema)
.optional(),
usage: gatewayImageUsageSchema.optional(),
});
Loading