Skip to content

Commit 7637150

Browse files
author
Raphael Eidus
committed
feat(amazon-bedrock): add guardrail content support for text and image parts
1 parent 7943a4b commit 7637150

File tree

5 files changed

+342
-9
lines changed

5 files changed

+342
-9
lines changed

.changeset/wild-coins-fly.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
"@ai-sdk/amazon-bedrock": patch
3+
---
4+
5+
Adds support for specifying which parts of a message to guard using GuardrailConverseContentBlock
6+
7+
ref: https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails-use-converse-api.html

packages/amazon-bedrock/src/bedrock-api-types.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,16 @@ export interface BedrockDocumentBlock {
158158
};
159159
}
160160

161+
export interface BedrockGuardrailTextBlock extends BedrockTextBlock {
162+
qualifiers?: Array<'grounding_source' | 'query' | 'guard_content'>;
163+
}
164+
161165
export interface BedrockGuardrailConverseContentBlock {
162-
guardContent: unknown;
166+
guardContent:
167+
| {
168+
text: BedrockGuardrailTextBlock;
169+
}
170+
| BedrockImageBlock;
163171
}
164172

165173
export interface BedrockImageBlock {

packages/amazon-bedrock/src/bedrock-chat-options.ts

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,34 @@ export type BedrockFilePartProviderOptions = z.infer<
9898
typeof bedrockFilePartProviderOptions
9999
>;
100100

101+
/**
102+
* Bedrock text part provider options for guardrail content.
103+
* These options apply to individual text parts.
104+
*/
105+
export const bedrockTextPartProviderOptions = z.object({
106+
guardContent: z.boolean().optional(),
107+
guardContentQualifiers: z
108+
.array(z.enum(['grounding_source', 'query', 'guard_content']))
109+
.optional(),
110+
});
111+
112+
export type BedrockTextPartProviderOptions = z.infer<
113+
typeof bedrockTextPartProviderOptions
114+
>;
115+
116+
/**
117+
* Bedrock image part provider options for guardrail content.
118+
* These options apply to individual image parts.
119+
*/
120+
121+
export const bedrockImagePartProviderOptions = z.object({
122+
guardContent: z.boolean().optional(),
123+
});
124+
125+
export type BedrockImagePartProviderOptions = z.infer<
126+
typeof bedrockImagePartProviderOptions
127+
>;
128+
101129
export const amazonBedrockLanguageModelOptions = z.object({
102130
/**
103131
* Additional inference parameters that the model supports,

packages/amazon-bedrock/src/convert-to-bedrock-chat-messages.test.ts

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,259 @@ describe('user messages', () => {
396396
system: [],
397397
});
398398
});
399+
400+
it('should convert text part to guardContent when guardContent provider option is true', async () => {
401+
const { messages } = await convertToBedrockChatMessages([
402+
{
403+
role: 'user',
404+
content: [
405+
{
406+
type: 'text',
407+
text: 'Grounding text',
408+
providerOptions: {
409+
bedrock: {
410+
guardContent: true,
411+
},
412+
},
413+
},
414+
],
415+
},
416+
]);
417+
418+
expect(messages).toEqual([
419+
{
420+
role: 'user',
421+
content: [
422+
{
423+
guardContent: {
424+
text: {
425+
text: 'Grounding text',
426+
},
427+
},
428+
},
429+
],
430+
},
431+
]);
432+
});
433+
434+
it('should convert text part to guardContent with qualifiers', async () => {
435+
const { messages } = await convertToBedrockChatMessages([
436+
{
437+
role: 'user',
438+
content: [
439+
{
440+
type: 'text',
441+
text: 'Grounding text',
442+
providerOptions: {
443+
bedrock: {
444+
guardContent: true,
445+
guardContentQualifiers: ['grounding_source'],
446+
},
447+
},
448+
},
449+
],
450+
},
451+
]);
452+
453+
expect(messages).toEqual([
454+
{
455+
role: 'user',
456+
content: [
457+
{
458+
guardContent: {
459+
text: {
460+
text: 'Grounding text',
461+
qualifiers: ['grounding_source'],
462+
},
463+
},
464+
},
465+
],
466+
},
467+
]);
468+
});
469+
470+
it('should convert text part to guardContent with multiple qualifiers', async () => {
471+
const { messages } = await convertToBedrockChatMessages([
472+
{
473+
role: 'user',
474+
content: [
475+
{
476+
type: 'text',
477+
text: 'Query text',
478+
providerOptions: {
479+
bedrock: {
480+
guardContent: true,
481+
guardContentQualifiers: ['grounding_source', 'query'],
482+
},
483+
},
484+
},
485+
],
486+
},
487+
]);
488+
489+
expect(messages).toEqual([
490+
{
491+
role: 'user',
492+
content: [
493+
{
494+
guardContent: {
495+
text: {
496+
text: 'Query text',
497+
qualifiers: ['grounding_source', 'query'],
498+
},
499+
},
500+
},
501+
],
502+
},
503+
]);
504+
});
505+
506+
it('should convert text part as normal text when guardContent is false', async () => {
507+
const { messages } = await convertToBedrockChatMessages([
508+
{
509+
role: 'user',
510+
content: [
511+
{
512+
type: 'text',
513+
text: 'Normal text',
514+
providerOptions: {
515+
bedrock: {
516+
guardContent: false,
517+
},
518+
},
519+
},
520+
],
521+
},
522+
]);
523+
524+
expect(messages).toEqual([
525+
{
526+
role: 'user',
527+
content: [{ text: 'Normal text' }],
528+
},
529+
]);
530+
});
531+
532+
it('should convert text part as normal text when no provider options', async () => {
533+
const { messages } = await convertToBedrockChatMessages([
534+
{
535+
role: 'user',
536+
content: [{ type: 'text', text: 'Normal text' }],
537+
},
538+
]);
539+
540+
expect(messages).toEqual([
541+
{
542+
role: 'user',
543+
content: [{ text: 'Normal text' }],
544+
},
545+
]);
546+
});
547+
548+
it('should convert image part to guardContent when guardContent provider option is true', async () => {
549+
const imageData = new Uint8Array([0, 1, 2, 3]);
550+
551+
const { messages } = await convertToBedrockChatMessages([
552+
{
553+
role: 'user',
554+
content: [
555+
{
556+
type: 'file',
557+
data: Buffer.from(imageData).toString('base64'),
558+
mediaType: 'image/png',
559+
providerOptions: {
560+
bedrock: {
561+
guardContent: true,
562+
},
563+
},
564+
},
565+
],
566+
},
567+
]);
568+
569+
expect(messages).toEqual([
570+
{
571+
role: 'user',
572+
content: [
573+
{
574+
guardContent: {
575+
image: {
576+
format: 'png',
577+
source: { bytes: 'AAECAw==' },
578+
},
579+
},
580+
},
581+
],
582+
},
583+
]);
584+
});
585+
586+
it('should convert image part as normal image when guardContent is false', async () => {
587+
const imageData = new Uint8Array([0, 1, 2, 3]);
588+
589+
const { messages } = await convertToBedrockChatMessages([
590+
{
591+
role: 'user',
592+
content: [
593+
{
594+
type: 'file',
595+
data: Buffer.from(imageData).toString('base64'),
596+
mediaType: 'image/png',
597+
providerOptions: {
598+
bedrock: {
599+
guardContent: false,
600+
},
601+
},
602+
},
603+
],
604+
},
605+
]);
606+
607+
expect(messages).toEqual([
608+
{
609+
role: 'user',
610+
content: [
611+
{
612+
image: {
613+
format: 'png',
614+
source: { bytes: 'AAECAw==' },
615+
},
616+
},
617+
],
618+
},
619+
]);
620+
});
621+
622+
it('should convert image part as normal image when no provider options', async () => {
623+
const imageData = new Uint8Array([0, 1, 2, 3]);
624+
625+
const { messages } = await convertToBedrockChatMessages([
626+
{
627+
role: 'user',
628+
content: [
629+
{
630+
type: 'file',
631+
data: Buffer.from(imageData).toString('base64'),
632+
mediaType: 'image/png',
633+
},
634+
],
635+
},
636+
]);
637+
638+
expect(messages).toEqual([
639+
{
640+
role: 'user',
641+
content: [
642+
{
643+
image: {
644+
format: 'png',
645+
source: { bytes: 'AAECAw==' },
646+
},
647+
},
648+
],
649+
},
650+
]);
651+
});
399652
});
400653

401654
describe('assistant messages', () => {

packages/amazon-bedrock/src/convert-to-bedrock-chat-messages.ts

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ import {
2525
BedrockUserMessage,
2626
} from './bedrock-api-types';
2727
import { bedrockReasoningMetadataSchema } from './bedrock-chat-language-model';
28-
import { bedrockFilePartProviderOptions } from './bedrock-chat-options';
28+
import {
29+
bedrockFilePartProviderOptions,
30+
bedrockTextPartProviderOptions,
31+
bedrockImagePartProviderOptions,
32+
} from './bedrock-chat-options';
2933
import { normalizeToolCallId } from './normalize-tool-call-id';
3034

3135
function getCachePoint(
@@ -106,9 +110,26 @@ export async function convertToBedrockChatMessages(
106110

107111
switch (part.type) {
108112
case 'text': {
109-
bedrockContent.push({
110-
text: part.text,
113+
const textOptions = await parseProviderOptions({
114+
provider: 'bedrock',
115+
providerOptions: part.providerOptions,
116+
schema: bedrockTextPartProviderOptions,
111117
});
118+
119+
if (textOptions?.guardContent) {
120+
bedrockContent.push({
121+
guardContent: {
122+
text: {
123+
text: part.text,
124+
qualifiers: textOptions.guardContentQualifiers,
125+
},
126+
},
127+
});
128+
} else {
129+
bedrockContent.push({
130+
text: part.text,
131+
});
132+
}
112133
break;
113134
}
114135

@@ -127,12 +148,28 @@ export async function convertToBedrockChatMessages(
127148
}
128149

129150
if (part.mediaType.startsWith('image/')) {
130-
bedrockContent.push({
131-
image: {
132-
format: getBedrockImageFormat(part.mediaType),
133-
source: { bytes: convertToBase64(part.data) },
134-
},
151+
const imageOptions = await parseProviderOptions({
152+
provider: 'bedrock',
153+
providerOptions: part.providerOptions,
154+
schema: bedrockImagePartProviderOptions,
135155
});
156+
if (imageOptions?.guardContent) {
157+
bedrockContent.push({
158+
guardContent: {
159+
image: {
160+
format: getBedrockImageFormat(part.mediaType),
161+
source: { bytes: convertToBase64(part.data) },
162+
},
163+
},
164+
});
165+
} else {
166+
bedrockContent.push({
167+
image: {
168+
format: getBedrockImageFormat(part.mediaType),
169+
source: { bytes: convertToBase64(part.data) },
170+
},
171+
});
172+
}
136173
} else {
137174
if (!part.mediaType) {
138175
throw new UnsupportedFunctionalityError({

0 commit comments

Comments
 (0)