Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
// 仅 /v1/chat/completions 和 /v1/completions 接口支持 stream_options 参数
// generic provider 不做能力映射,不添加 stream_options
if providerConfig.IsOpenAIProtocol() && !providerConfig.IsGeneric() && (apiName == provider.ApiNameChatCompletion || apiName == provider.ApiNameCompletion) {
newBody = normalizeOpenAiRequestBody(newBody)
newBody = normalizeOpenAiRequestBody(newBody, providerConfig.IsStreamUsageStatsDisabled())
}
log.Debugf("[onHttpRequestBody] newBody=%s", newBody)
body = newBody
Expand Down Expand Up @@ -626,7 +626,10 @@ func convertResponseBodyToClaude(ctx wrapper.HttpContext, body []byte) ([]byte,
return convertedBody, nil
}

func normalizeOpenAiRequestBody(body []byte) []byte {
func normalizeOpenAiRequestBody(body []byte, disableStreamUsageStats bool) []byte {
if disableStreamUsageStats {
return body
}
var err error
// Default setting include_usage.
if gjson.GetBytes(body, "stream").Bool() && (!gjson.GetBytes(body, "stream_options").Exists() || !gjson.GetBytes(body, "stream_options.include_usage").Exists()) {
Expand Down
53 changes: 53 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/test"
"github.com/tidwall/gjson"
)

func Test_getApiName(t *testing.T) {
Expand Down Expand Up @@ -117,6 +118,58 @@ func Test_isSupportedRequestContentType(t *testing.T) {
}
}

func Test_normalizeOpenAiRequestBody(t *testing.T) {
tests := []struct {
name string
body string
disableStreamUsageStats bool
wantIncludeUsage bool
wantExists bool
}{
{
name: "stream enabled, stats enabled",
body: `{"stream":true,"messages":[]}`,
disableStreamUsageStats: false,
wantExists: true,
wantIncludeUsage: true,
},
{
name: "stream enabled, stats disabled",
body: `{"stream":true,"messages":[]}`,
disableStreamUsageStats: true,
wantExists: false,
wantIncludeUsage: false,
},
{
name: "stream disabled, stats enabled",
body: `{"stream":false,"messages":[]}`,
disableStreamUsageStats: false,
wantExists: false,
wantIncludeUsage: false,
},
{
name: "stream_options already set, stats enabled",
body: `{"stream":true,"stream_options":{"include_usage":false}}`,
disableStreamUsageStats: false,
wantExists: true,
wantIncludeUsage: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeOpenAiRequestBody([]byte(tt.body), tt.disableStreamUsageStats)
parsed := gjson.ParseBytes(got)
exists := parsed.Get("stream_options.include_usage").Exists()
if exists != tt.wantExists {
t.Errorf("stream_options.include_usage exists=%v, want %v", exists, tt.wantExists)
}
if exists && parsed.Get("stream_options.include_usage").Bool() != tt.wantIncludeUsage {
t.Errorf("stream_options.include_usage=%v, want %v", parsed.Get("stream_options.include_usage").Bool(), tt.wantIncludeUsage)
}
})
}
}

func TestAi360(t *testing.T) {
test.RunAi360ParseConfigTests(t)
test.RunAi360OnHttpRequestHeadersTests(t)
Expand Down
8 changes: 8 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,9 @@ type ProviderConfig struct {
// @Title zh-CN Provider 基础路径
// @Description zh-CN 当配置了此值时,各个 Provider 在改写请求路径时会将其添加到路径前面,例如配置"/api/ai"后,请求路径"/v1/chat/completions"会被改写为"/api/ai/v1/chat/completions"
providerBasePath string `required:"false" yaml:"providerBasePath" json:"providerBasePath"`
// @Title zh-CN 禁用Stream Usage统计
// @Description zh-CN 开启后,stream请求不会自动添加stream_options.include_usage参数,用于兼容不支持该参数的旧版推理引擎。
disableStreamUsageStats bool `required:"false" yaml:"disableStreamUsageStats" json:"disableStreamUsageStats"`
}

func (c *ProviderConfig) GetId() string {
Expand Down Expand Up @@ -520,6 +523,10 @@ func (c *ProviderConfig) IsOpenAIProtocol() bool {
return c.protocol == protocolOpenAI
}

func (c *ProviderConfig) IsStreamUsageStatsDisabled() bool {
return c.disableStreamUsageStats
}

func (c *ProviderConfig) FromJson(json gjson.Result) {
c.id = json.Get("id").String()
c.typ = json.Get("type").String()
Expand Down Expand Up @@ -723,6 +730,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.promoteThinkingOnEmpty = true
}
c.providerBasePath = json.Get("providerBasePath").String()
c.disableStreamUsageStats = json.Get("disableStreamUsageStats").Bool()
}

func (c *ProviderConfig) Validate() error {
Expand Down
Loading