Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ out
coverage.xml
.idea/
.vscode/
.qoder/
bazel-bin
bazel-out
bazel-testlogs
Expand Down
46 changes: 37 additions & 9 deletions plugins/wasm-go/extensions/ai-security-guard/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,43 @@ description: 阿里云内容安全检测
| `consumerResponseCheckService` | map | optional | - | 为不同消费者指定特定的响应检测服务 |
| `consumerRiskLevel` | map | optional | - | 为不同消费者指定各维度的拦截风险等级 |

补充说明一下 `denyMessage`,对非法请求的处理逻辑为:
- 如果配置了 `denyMessage`,返回内容为 `denyMessage` 配置内容,格式为openai格式的流式/非流式响应
- 如果没有配置 `denyMessage`,优先返回阿里云内容安全的建议回答,格式为openai格式的流式/非流式响应
- 如果阿里云内容安全未返回建议的回答,返回内容为内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`,格式为openai格式的流式/非流式响应

如果用户使用了非openai格式的协议,此时对非法请求的处理逻辑为:
- 如果配置了 `denyMessage`,返回用户配置的 `denyMessage` 内容,非流式响应
- 如果没有配置 `denyMessage`,优先返回阿里云内容安全的建议回答,非流式响应
- 如果阿里云内容安全未返回建议回答,返回内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`,非流式响应
### 拒绝响应结构

内容被拦截时,插件(`MultiModalGuard` action)统一返回以下结构化 JSON 对象,各协议的承载位置如下:

```json
{
"blockedDetails": [
{
"Type": "contentModeration",
"Level": "high",
"Suggestion": "block"
}
],
"requestId": "AAAAAA-BBBB-CCCC-DDDD-EEEEEEE****",
"guardCode": 200
}
```

字段说明:

| 字段 | 类型 | 说明 |
| --- | --- | --- |
| `blockedDetails` | array | 命中拦截的维度明细;若安全服务未返回明细,则根据顶层风险信号自动合成 |
| `blockedDetails[].Type` | string | 风险类型:`contentModeration` / `promptAttack` / `sensitiveData` / `maliciousUrl` / `modelHallucination` |
| `blockedDetails[].Level` | string | 风险等级:`high` / `medium` / `low` 等 |
| `blockedDetails[].Suggestion` | string | 安全服务建议操作,通常为 `block` |
| `requestId` | string | 安全服务的请求 ID,用于追踪 |
| `guardCode` | int | 安全服务返回的业务码(非 HTTP 状态码,成功检测时为 `200`) |

各协议承载位置:

- **`text_generation`(OpenAI 非流式)**:上述结构体序列化为 JSON 字符串后放入 `choices[0].message.content`
- **`text_generation`(OpenAI 流式 SSE)**:同上,放入首个 chunk 的 `delta.content`
- **`text_generation`(`protocol=original`)**:上述结构体直接作为 JSON 响应 body 返回
- **`image_generation`**:上述结构体直接作为 JSON 响应 body 返回(HTTP 403)
- **`mcp`(JSON-RPC)**:上述结构体序列化为 JSON 字符串后放入 `error.message`
- **`mcp`(SSE)**:同上,通过 SSE 事件返回

补充说明一下内容合规检测、提示词攻击检测、敏感内容检测三种风险的四个等级:

Expand Down
37 changes: 37 additions & 0 deletions plugins/wasm-go/extensions/ai-security-guard/README_EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,43 @@ Plugin Priority: `300`
| `consumerResponseCheckService` | map | optional | - | Specify specific response detection services for different consumers |
| `consumerRiskLevel` | map | optional | - | Specify interception risk levels for different consumers in different dimensions |

### Deny Response Body

When content is blocked, the plugin (`MultiModalGuard` action) returns the following structured JSON object. The location in the response depends on the protocol:

```json
{
"blockedDetails": [
{
"Type": "contentModeration",
"Level": "high",
"Suggestion": "block"
}
],
"requestId": "AAAAAA-BBBB-CCCC-DDDD-EEEEEEE****",
"guardCode": 200
}
```

Field descriptions:

| Field | Type | Description |
| --- | --- | --- |
| `blockedDetails` | array | Details of the triggered blocking dimensions. Synthesised from top-level risk signals when the security service returns no detail entries. |
| `blockedDetails[].Type` | string | Risk type: `contentModeration` / `promptAttack` / `sensitiveData` / `maliciousUrl` / `modelHallucination` |
| `blockedDetails[].Level` | string | Risk level: `high` / `medium` / `low` etc. |
| `blockedDetails[].Suggestion` | string | Action recommended by the security service, usually `block` |
| `requestId` | string | Request ID from the security service, for tracing |
| `guardCode` | int | Business code returned by the security service (not an HTTP status code; `200` indicates a successful check that detected a risk) |

How the body is embedded per protocol:

- **`text_generation` (OpenAI non-streaming)**: serialised as a JSON string and placed in `choices[0].message.content`
- **`text_generation` (OpenAI streaming SSE)**: same, placed in `delta.content` of the first chunk
- **`text_generation` (`protocol=original`)**: returned directly as the JSON response body
- **`image_generation`**: returned directly as the JSON response body (HTTP 403)
- **`mcp` (JSON-RPC)**: serialised as a JSON string and placed in `error.message`
- **`mcp` (SSE)**: same, returned via SSE event

## Examples of configuration
### Check if the input is legal
Expand Down
161 changes: 161 additions & 0 deletions plugins/wasm-go/extensions/ai-security-guard/config/config.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package config

import (
"encoding/json"
"errors"
"fmt"
"regexp"
"strconv"
"strings"

"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
Expand Down Expand Up @@ -61,6 +64,8 @@ const (

DefaultTextModerationPlusTextInputCheckService = "llm_query_moderation"
DefaultTextModerationPlusTextOutputCheckService = "llm_response_moderation"

DefaultCheckRecordTTL = 172800 // 2 days in seconds
)

// api types
Expand Down Expand Up @@ -166,6 +171,10 @@ type AISecurityConfig struct {
ApiType string
// openai, qwen, comfyui, etc.
ProviderType string
// Redis-based message dedup
RedisClient wrapper.RedisClient
CheckAllMessages bool
CheckRecordTTL int // TTL in seconds for check records, default 172800 (2 days)
}

func (config *AISecurityConfig) Parse(json gjson.Result) error {
Expand Down Expand Up @@ -337,6 +346,68 @@ func (config *AISecurityConfig) Parse(json gjson.Result) error {
Host: serviceHost,
})
config.Metrics = make(map[string]proxywasm.MetricCounter)
// Parse checkAllMessages dedup config
config.CheckAllMessages = json.Get("checkAllMessages").Bool()
log.Infof("[config] checkAllMessages=%v, action=%s, apiType=%s, requestContentJsonPath=%s", config.CheckAllMessages, config.Action, config.ApiType, config.RequestContentJsonPath)
if obj := json.Get("checkRecordTTL"); obj.Exists() {
if obj.Int() <= 0 && config.CheckAllMessages {
return errors.New("checkRecordTTL must be greater than 0")
}
config.CheckRecordTTL = int(obj.Int())
}
// Initialize Redis client for dedup if checkAllMessages is enabled
if config.CheckAllMessages {
log.Infof("[config] initializing redis client for dedup...")
if err := config.initRedisClient(json); err != nil {
log.Warnf("failed to init redis for dedup, checkAllMessages will be disabled: %v", err)
config.CheckAllMessages = false
} else {
log.Infof("[config] redis client initialized, ready=%v", config.RedisClient.Ready())
}
}
return nil
}

func (config *AISecurityConfig) initRedisClient(json gjson.Result) error {
redisConfig := json.Get("redis")
if !redisConfig.Exists() {
return errors.New("missing redis config for checkAllMessages")
}
serviceName := redisConfig.Get("service_name").String()
if serviceName == "" {
return errors.New("redis service_name must not be empty")
}
servicePort := int(redisConfig.Get("service_port").Int())
if servicePort == 0 {
if strings.HasSuffix(serviceName, ".static") {
servicePort = 80
} else {
servicePort = 6379
}
}
username := redisConfig.Get("username").String()
password := redisConfig.Get("password").String()
timeout := int(redisConfig.Get("timeout").Int())
if timeout == 0 {
timeout = 1000
}
database := int(redisConfig.Get("database").Int())

cluster := wrapper.FQDNCluster{
FQDN: serviceName,
Port: int64(servicePort),
}
log.Infof("[redis-dedup] cluster=%s, hasPassword=%v, username=%q, database=%d, timeout=%d",
cluster.ClusterName(), password != "", username, database, timeout)

config.RedisClient = wrapper.NewRedisClusterClient(cluster)
// Note: Init() always returns nil in the current SDK; Ready() reflects actual init status
config.RedisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(database))
if config.RedisClient.Ready() {
log.Info("[redis-dedup] redis client init successfully")
} else {
log.Warn("[redis-dedup] redis client init pending, will retry on first command")
}
return nil
}

Expand Down Expand Up @@ -364,6 +435,7 @@ func (config *AISecurityConfig) SetDefaultValues() {
config.BufferLimit = 1000
config.ApiType = ApiTextGeneration
config.ProviderType = ProviderOpenAI
config.CheckRecordTTL = DefaultCheckRecordTTL
}

func (config *AISecurityConfig) IncrementCounter(metricName string, inc uint64) {
Expand Down Expand Up @@ -584,3 +656,92 @@ func IsRiskLevelAcceptable(action string, data Data, config AISecurityConfig, co
return LevelToInt(data.RiskLevel) < LevelToInt(config.GetRiskLevelBar(consumer))
}
}

type DenyResponseBody struct {
BlockedDetails []Detail `json:"blockedDetails"`
RequestId string `json:"requestId"`
// GuardCode is the business code returned by the security service (typically 200 when the check
// succeeded and a risk was detected). It is NOT an HTTP status code.
GuardCode int `json:"guardCode"`
}

func BuildDenyResponseBody(response Response, config AISecurityConfig, consumer string) ([]byte, error) {
body := DenyResponseBody{
BlockedDetails: GetUnacceptableDetail(response.Data, config, consumer),
RequestId: response.RequestId,
GuardCode: response.Code,
}
return json.Marshal(body)
}

func GetUnacceptableDetail(data Data, config AISecurityConfig, consumer string) []Detail {
result := []Detail{}
for _, detail := range data.Detail {
switch detail.Type {
case ContentModerationType:
if LevelToInt(detail.Level) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) {
result = append(result, detail)
}
case PromptAttackType:
if LevelToInt(detail.Level) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) {
result = append(result, detail)
}
case SensitiveDataType:
if LevelToInt(detail.Level) >= LevelToInt(config.GetSensitiveDataLevelBar(consumer)) {
result = append(result, detail)
}
case MaliciousUrlDataType:
if LevelToInt(detail.Level) >= LevelToInt(config.GetMaliciousUrlLevelBar(consumer)) {
result = append(result, detail)
}
case ModelHallucinationDataType:
if LevelToInt(detail.Level) >= LevelToInt(config.GetModelHallucinationLevelBar(consumer)) {
result = append(result, detail)
}
}
}
// Fallback: when the security service returns a top-level risk signal but no Detail entries,
// synthesise detail items from RiskLevel/AttackLevel so blockedDetails is never empty on a
// real block event.
if len(result) == 0 {
if LevelToInt(data.RiskLevel) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) {
result = append(result, Detail{
Type: ContentModerationType,
Level: data.RiskLevel,
Suggestion: "block",
})
}
if LevelToInt(data.AttackLevel) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) {
result = append(result, Detail{
Type: PromptAttackType,
Level: data.AttackLevel,
Suggestion: "block",
})
}
}
return result
}

// BuildPolicyFingerprint generates a fingerprint string that encodes all policy dimensions
// affecting the security check result. This ensures Redis dedup keys are invalidated when
// any relevant policy parameter changes, preventing cross-policy cache pollution.
func (config *AISecurityConfig) BuildPolicyFingerprint(consumer string) string {
if config.Action == MultiModalGuard {
return fmt.Sprintf("%s:%s:%s:%s:%s:%s:%s:%s:%s",
config.Action,
config.GetRequestCheckService(consumer),
config.GetRequestImageCheckService(consumer),
config.GetContentModerationLevelBar(consumer),
config.GetPromptAttackLevelBar(consumer),
config.GetSensitiveDataLevelBar(consumer),
config.GetMaliciousUrlLevelBar(consumer),
config.GetModelHallucinationLevelBar(consumer),
strconv.FormatBool(config.CheckRequestImage),
)
}
return fmt.Sprintf("%s:%s:%s",
config.Action,
config.GetRequestCheckService(consumer),
config.GetRiskLevelBar(consumer),
)
}
Loading