Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugins/wasm-go/extensions/ai-proxy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ description: AI 代理插件配置参考
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| ---------------------- | ---------------------- | -------- | ------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `type` | string | 必填 | - | AI 服务提供商名称 |
| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件默认会随机选择;对于 Responses、Files、Batches、Fine-tuning 等有状态 API,会优先使用 `x-mse-consumer` 做稳定路由;如果未提供 `x-mse-consumer`,插件会在首个成功响应里自动下发 `higress-ai-affinity` cookie,后续请求基于该 cookie 保持稳定映射。部分服务提供商只支持配置一个 token。 |
| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟。此项配置目前仅用于获取上下文信息,并不影响实际转发大模型请求。 |
| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-\*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "\*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。<br/>4. 支持以 `~` 前缀使用正则匹配。例如用 "~gpt(.\*)" 匹配所有以 "gpt" 开头的模型并支持在目标模型中使用 capture group 引用匹配到的内容。示例: "~gpt(.\*): openai/gpt\$1" |
| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) |
Expand Down
2 changes: 1 addition & 1 deletion plugins/wasm-go/extensions/ai-proxy/README_EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Plugin execution priority: `100`
| Name | Data Type | Requirement | Default | Description |
| -------------- | --------------- | -------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `type` | string | Required | - | Name of the AI service provider |
| `apiTokens` | array of string | Optional | - | Tokens used for authentication when accessing AI services. If multiple tokens are configured, the plugin randomly selects one for each request. Some service providers only support configuring a single token. |
| `apiTokens` | array of string | Optional | - | Tokens used for authentication when accessing AI services. When multiple tokens are configured, the plugin uses random selection by default. For stateful APIs such as Responses, Files, Batches, and Fine-tuning, it first tries stable routing with `x-mse-consumer`; if that header is absent, the plugin automatically returns a `higress-ai-affinity` cookie on the first successful response and keeps subsequent requests pinned by that cookie. Some service providers only support configuring a single token. |
| `timeout` | number | Optional | - | Timeout for accessing AI services, in milliseconds. The default value is 120000, which equals 2 minutes. Only used when retrieving context data. Won't affect the request forwarded to the LLM upstream. |
| `modelMapping` | map of string | Optional | - | Mapping table for AI models, used to map model names in requests to names supported by the service provider.<br/>1. Supports prefix matching. For example, "gpt-3-\*" matches all model names starting with “gpt-3-”;<br/>2. Supports using "\*" as a key for a general fallback mapping;<br/>3. If the mapped target name is an empty string "", the original model name is preserved. |
| `protocol` | string | Optional | - | API contract provided by the plugin. Currently supports the following values: openai (default, uses OpenAI's interface contract), original (uses the raw interface contract of the target service provider). **Note: Auto protocol detection is now supported, no need to configure this field to support both OpenAI and Claude protocols** |
Expand Down
1 change: 1 addition & 0 deletions plugins/wasm-go/extensions/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse)

headers := util.GetResponseHeaders()
providerConfig.ApplyAnonymousAffinityCookie(ctx, headers)
if handler, ok := activeProvider.(provider.TransformResponseHeadersHandler); ok {
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
handler.TransformResponseHeaders(ctx, apiName, headers)
Expand Down
119 changes: 114 additions & 5 deletions plugins/wasm-go/extensions/ai-proxy/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ import (
"path"
"regexp"
"strconv"

"strings"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/google/uuid"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
Expand Down Expand Up @@ -172,6 +172,8 @@ const (
ctxKeyPushedMessage = "pushedMessage"
ctxKeyContentPushed = "contentPushed"
ctxKeyReasoningContentPushed = "reasoningContentPushed"
ctxKeyAnonymousAffinityKey = "anonymousAffinityKey"
ctxKeySetAffinityCookie = "setAffinityCookie"

objectChatCompletion = "chat.completion"
objectChatCompletionChunk = "chat.completion.chunk"
Expand All @@ -184,6 +186,8 @@ const (

defaultTimeout = 2 * 60 * 1000 // ms

AnonymousAffinityCookieName = "higress-ai-affinity"

basePathHandlingRemovePrefix basePathHandling = "removePrefix"
basePathHandlingPrepend basePathHandling = "prepend"
)
Expand Down Expand Up @@ -727,17 +731,17 @@ func (c *ProviderConfig) selectApiToken(ctx wrapper.HttpContext) string {

// For stateful APIs, try to use consumer affinity
if isStatefulAPI(apiName) {
consumer := c.getConsumerFromContext(ctx)
if consumer != "" {
return c.GetTokenWithConsumerAffinity(ctx, consumer)
affinityKey := c.getAffinityKeyFromContext(ctx)
if affinityKey != "" {
return c.GetTokenWithConsumerAffinity(ctx, affinityKey)
}
}

// Fall back to random selection
return c.GetRandomToken()
}

// getConsumerFromContext retrieves the consumer identifier from the request context
// getConsumerFromContext retrieves the explicit consumer identifier from the request headers.
func (c *ProviderConfig) getConsumerFromContext(ctx wrapper.HttpContext) string {
consumer, err := proxywasm.GetHttpRequestHeader("x-mse-consumer")
if err == nil && consumer != "" {
Expand All @@ -746,6 +750,57 @@ func (c *ProviderConfig) getConsumerFromContext(ctx wrapper.HttpContext) string
return ""
}

func (c *ProviderConfig) getAffinityKeyFromContext(ctx wrapper.HttpContext) string {
if consumer := c.getConsumerFromContext(ctx); consumer != "" {
return consumer
}
if affinityKey := getAnonymousAffinityKeyFromCookie(); affinityKey != "" {
return affinityKey
}
if len(c.apiTokens) <= 1 {
return ""
}
affinityKey := uuid.NewString()
ctx.SetContext(ctxKeyAnonymousAffinityKey, affinityKey)
ctx.SetContext(ctxKeySetAffinityCookie, true)
return affinityKey
}

func getAnonymousAffinityKeyFromCookie() string {
cookieHeader, err := proxywasm.GetHttpRequestHeader("cookie")
if err != nil || cookieHeader == "" {
return ""
}
return getCookieValue(cookieHeader, AnonymousAffinityCookieName)
}

func getCookieValue(cookieHeader, cookieName string) string {
prefix := cookieName + "="
for _, cookie := range strings.Split(cookieHeader, ";") {
cookie = strings.TrimSpace(cookie)
if strings.HasPrefix(cookie, prefix) {
return strings.TrimPrefix(cookie, prefix)
}
}
return ""
}

func buildAffinityCookie(cookieValue string) string {
return fmt.Sprintf("%s=%s; Path=/; HttpOnly", AnonymousAffinityCookieName, cookieValue)
}

func (c *ProviderConfig) ApplyAnonymousAffinityCookie(ctx wrapper.HttpContext, headers http.Header) {
shouldSetCookie, _ := ctx.GetContext(ctxKeySetAffinityCookie).(bool)
if !shouldSetCookie {
return
}
affinityKey, _ := ctx.GetContext(ctxKeyAnonymousAffinityKey).(string)
if affinityKey == "" {
return
}
headers.Add("Set-Cookie", buildAffinityCookie(affinityKey))
}

func (c *ProviderConfig) GetRandomToken() string {
apiTokens := c.apiTokens
count := len(apiTokens)
Expand Down Expand Up @@ -1052,6 +1107,60 @@ func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte) []StreamEvent
return events
}

func ExtractStreamingDataLines(ctx wrapper.HttpContext, chunk []byte, isLastChunk bool) []string {
body := chunk
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
body = append(bufferedStreamingBody, chunk...)
}
body = bytes.ReplaceAll(body, []byte("\r\n"), []byte("\n"))
body = bytes.ReplaceAll(body, []byte("\r"), []byte("\n"))

lines := make([]string, 0)
start := 0
for start < len(body) {
end := bytes.IndexByte(body[start:], '\n')
if end < 0 {
break
}

line := strings.TrimSpace(string(body[start : start+end]))
if line != "" {
lines = append(lines, line)
}
start += end + 1
}

if isLastChunk {
line := strings.TrimSpace(string(body[start:]))
if line != "" {
lines = append(lines, line)
}
ctx.SetContext(ctxKeyStreamingBody, nil)
return lines
}

if start < len(body) {
ctx.SetContext(ctxKeyStreamingBody, append([]byte(nil), body[start:]...))
} else {
ctx.SetContext(ctxKeyStreamingBody, nil)
}
return lines
}

func ExtractStreamingDataPayload(line string) (string, bool) {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, ":") {
return "", false
}
if strings.HasPrefix(line, ssePrefix) {
line = strings.TrimSpace(line[len(ssePrefix):])
}
if line == "" || line == "[DONE]" {
return "", false
}
return line, true
}

func (c *ProviderConfig) isSupportedAPI(apiName ApiName) bool {
_, exist := c.capabilities[string(apiName)]
return exist
Expand Down
49 changes: 44 additions & 5 deletions plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ func TestIsStatefulAPI(t *testing.T) {

func TestGetTokenWithConsumerAffinity(t *testing.T) {
tests := []struct {
name string
apiTokens []string
consumer string
wantEmpty bool
wantToken string // If not empty, expected specific token (for single token case)
name string
apiTokens []string
consumer string
wantEmpty bool
wantToken string // If not empty, expected specific token (for single token case)
}{
{
name: "no_tokens_returns_empty",
Expand Down Expand Up @@ -273,3 +273,42 @@ func TestGetTokenWithConsumerAffinity_HashDistribution(t *testing.T) {
})
}
}

func TestGetCookieValue(t *testing.T) {
tests := []struct {
name string
cookieHeader string
cookieName string
expected string
}{
{
name: "find affinity cookie",
cookieHeader: "foo=bar; higress-ai-affinity=session-123; other=baz",
cookieName: AnonymousAffinityCookieName,
expected: "session-123",
},
{
name: "cookie missing",
cookieHeader: "foo=bar; other=baz",
cookieName: AnonymousAffinityCookieName,
expected: "",
},
{
name: "empty header",
cookieHeader: "",
cookieName: AnonymousAffinityCookieName,
expected: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, getCookieValue(tt.cookieHeader, tt.cookieName))
})
}
}

func TestBuildAffinityCookie(t *testing.T) {
cookie := buildAffinityCookie("session-123")
assert.Equal(t, "higress-ai-affinity=session-123; Path=/; HttpOnly", cookie)
}
21 changes: 10 additions & 11 deletions plugins/wasm-go/extensions/ai-proxy/provider/vertex.go
Original file line number Diff line number Diff line change
Expand Up @@ -621,23 +621,16 @@ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
return v.claude.OnStreamingResponseBody(ctx, name, chunk, isLastChunk)
}
log.Infof("[vertexProvider] receive chunk body: %s", string(chunk))
if isLastChunk {
return []byte(ssePrefix + "[DONE]\n\n"), nil
}
if len(chunk) == 0 {
return nil, nil
}
if name != ApiNameChatCompletion {
return chunk, nil
}

responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n")
for _, data := range lines {
if len(data) < 6 {
// ignore blank line or wrong format
for _, line := range ExtractStreamingDataLines(ctx, chunk, isLastChunk) {
data, ok := ExtractStreamingDataPayload(line)
if !ok {
continue
}
data = data[6:]
var vertexResp vertexChatResponse
if err := json.Unmarshal([]byte(data), &vertexResp); err != nil {
log.Errorf("unable to unmarshal vertex response: %v", err)
Expand All @@ -651,8 +644,14 @@ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
}
v.appendResponse(responseBuilder, string(responseBody))
}
if isLastChunk {
responseBuilder.WriteString(ssePrefix + "[DONE]\n\n")
}
modifiedResponseChunk := responseBuilder.String()
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
if modifiedResponseChunk == "" {
return nil, nil
}
return []byte(modifiedResponseChunk), nil
}

Expand Down
Loading
Loading