Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,9 @@ func TestOpenRouter(t *testing.T) {
func TestZhipuAI(t *testing.T) {
test.RunZhipuAIClaudeAutoConversionTests(t)
}

func TestCooldown(t *testing.T) {
test.RunCooldownParseConfigTests(t)
test.RunCooldownOnHttpResponseHeadersTests(t)
test.RunCooldownRecoveryTests(t)
}
159 changes: 136 additions & 23 deletions plugins/wasm-go/extensions/ai-proxy/provider/failover.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ type failover struct {
healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"`
// @Title zh-CN 健康检测使用的模型
healthCheckModel string `required:"false" yaml:"healthCheckModel" json:"healthCheckModel"`
// @Title zh-CN apiToken 不可用后的冷却恢复时间,单位毫秒,配置后无需健康检测即可自动恢复
cooldownDuration int64 `required:"false" yaml:"cooldownDuration" json:"cooldownDuration"`
// @Title zh-CN 需要进行 failover 的原始请求的状态码,支持正则表达式匹配
failoverOnStatus []string `required:"false" yaml:"failoverOnStatus" json:"failoverOnStatus"`
// @Title zh-CN 本次请求使用的 apiToken
Expand All @@ -49,6 +51,8 @@ type failover struct {
ctxHealthCheckEndpoint string
// @Title zh-CN 健康检测选主,只有选到主的 Wasm VM 才执行健康检测
ctxVmLease string
// @Title zh-CN 记录 apiToken 被标记为不可用的时间戳,用于冷却恢复
ctxApiTokenUnavailableSince string
}

type Lease struct {
Expand Down Expand Up @@ -96,6 +100,7 @@ func (f *failover) FromJson(json gjson.Result) {
f.healthCheckTimeout = 5000
}
f.healthCheckModel = json.Get("healthCheckModel").String()
f.cooldownDuration = json.Get("cooldownDuration").Int()

for _, status := range json.Get("failoverOnStatus").Array() {
f.failoverOnStatus = append(f.failoverOnStatus, status.String())
Expand All @@ -107,8 +112,8 @@ func (f *failover) FromJson(json gjson.Result) {
}

func (f *failover) Validate() error {
if f.healthCheckModel == "" {
return errors.New("missing healthCheckModel in failover config")
if f.healthCheckModel == "" && f.cooldownDuration <= 0 {
return errors.New("either healthCheckModel or cooldownDuration must be configured in failover config")
}
return nil
}
Expand All @@ -124,6 +129,7 @@ func (c *ProviderConfig) initVariable() {
c.failover.ctxUnavailableApiTokens = provider + "-" + id + "-unavailableApiTokens"
c.failover.ctxHealthCheckEndpoint = provider + "-" + id + "-requestHostAndPath"
c.failover.ctxVmLease = provider + "-" + id + "-vmLease"
c.failover.ctxApiTokenUnavailableSince = provider + "-" + id + "-apiTokenUnavailableSince"
}

func parseConfig(json gjson.Result, config *any) error {
Expand All @@ -132,7 +138,8 @@ func parseConfig(json gjson.Result, config *any) error {

func (c *ProviderConfig) SetApiTokensFailover(activeProvider Provider) error {
c.initVariable()
// Reset shared data in case plugin configuration is updated
// Reset failover shared data on config updates so stale cooldown/health-check
// state from the previous config does not leak into the new one.
log.Debugf("ai-proxy plugin configuration is updated, reset shared data")
c.resetSharedData()

Expand All @@ -156,29 +163,57 @@ func (c *ProviderConfig) SetApiTokensFailover(activeProvider Provider) error {
return
}
if len(unavailableTokens) > 0 {
for _, apiToken := range unavailableTokens {
log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", "))
healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody()
healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{
Cluster: healthCheckEndpoint.Cluster,
})

ctx := createHttpContext()
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)

modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body)
// Cooldown recovery: restore tokens whose cooldown period has elapsed
if c.failover.cooldownDuration > 0 {
timestamps, _, err := getApiTokenUnavailableSince(c.failover.ctxApiTokenUnavailableSince)
if err != nil {
log.Errorf("Failed to transform request headers and body: %v", err)
log.Errorf("Failed to get apiToken unavailable timestamps: %v", err)
} else {
now := time.Now().UnixMilli()
var recoveredTokens []string
for _, apiToken := range unavailableTokens {
if since, ok := timestamps[apiToken]; ok && now-since >= c.failover.cooldownDuration {
log.Infof("cooldown recovery: apiToken %s has cooled down for %dms, restoring to available list", apiToken, now-since)
removeApiToken(c.failover.ctxUnavailableApiTokens, apiToken)
addApiToken(c.failover.ctxApiTokens, apiToken)
removeApiTokenUnavailableSince(c.failover.ctxApiTokenUnavailableSince, apiToken)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken)
recoveredTokens = append(recoveredTokens, apiToken)
}
}
// Remove recovered tokens from the list to skip health check for them
for _, token := range recoveredTokens {
unavailableTokens = removeElement(unavailableTokens, token)
}
}
}

// The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion
err = healthCheckClient.Post(generateUrl(modifiedHeaders), util.HeaderToSlice(modifiedHeaders), modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode == 200 {
c.handleAvailableApiToken(apiToken)
// Health check: probe remaining unavailable tokens with a real request
if c.failover.healthCheckModel != "" && len(unavailableTokens) > 0 {
for _, apiToken := range unavailableTokens {
log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", "))
healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody()
healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{
Cluster: healthCheckEndpoint.Cluster,
})

ctx := createHttpContext()
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)

modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body)
if err != nil {
log.Errorf("Failed to transform request headers and body: %v", err)
}

// The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion
err = healthCheckClient.Post(generateUrl(modifiedHeaders), util.HeaderToSlice(modifiedHeaders), modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode == 200 {
c.handleAvailableApiToken(apiToken)
}
}, uint32(c.failover.healthCheckTimeout))
if err != nil {
log.Errorf("Failed to perform health check request: %v", err)
}
}, uint32(c.failover.healthCheckTimeout))
if err != nil {
log.Errorf("Failed to perform health check request: %v", err)
}
}
}
Expand Down Expand Up @@ -355,6 +390,10 @@ func (c *ProviderConfig) handleUnavailableApiToken(ctx wrapper.HttpContext, apiT
removeApiToken(c.failover.ctxApiTokens, apiToken)
addApiToken(c.failover.ctxUnavailableApiTokens, apiToken)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken)
// Record the time when the apiToken becomes unavailable for cooldown recovery
if c.failover.cooldownDuration > 0 {
setApiTokenUnavailableSince(c.failover.ctxApiTokenUnavailableSince, apiToken, time.Now().UnixMilli())
}
// Set the request host and path to shared data in case they are needed in apiToken health check
c.setHealthCheckEndpoint(ctx)
} else {
Expand Down Expand Up @@ -527,7 +566,76 @@ func modifyApiTokenRequestCount(key, apiToken string, op string) {
}

func (c *ProviderConfig) initApiTokens() error {
return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, 0)
_, cas, _ := getApiTokens(c.failover.ctxApiTokens)
return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, cas)
}

func getApiTokenUnavailableSince(key string) (map[string]int64, uint32, error) {
data, cas, err := proxywasm.GetSharedData(key)
if err != nil {
if errors.Is(err, types.ErrorStatusNotFound) {
return make(map[string]int64), cas, nil
}
return nil, 0, err
}
if data == nil {
return make(map[string]int64), cas, nil
}

var timestamps map[string]int64
if err = json.Unmarshal(data, &timestamps); err != nil {
return nil, 0, fmt.Errorf("failed to unmarshal unavailableSince: %v", err)
}
return timestamps, cas, nil
}

func setApiTokenUnavailableSince(key, apiToken string, timestamp int64) {
for attempt := 1; attempt <= casMaxRetries; attempt++ {
timestamps, cas, err := getApiTokenUnavailableSince(key)
if err != nil {
log.Errorf("Failed to get %s: %v", key, err)
continue
}
timestamps[apiToken] = timestamp
data, err := json.Marshal(timestamps)
if err != nil {
log.Errorf("Failed to marshal unavailableSince: %v", err)
return
}
if err := proxywasm.SetSharedData(key, data, cas); err == nil {
return
} else if !errors.Is(err, types.ErrorStatusCasMismatch) {
log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
return
}
log.Errorf("CAS mismatch when setting %s, retrying...", key)
}
}

func removeApiTokenUnavailableSince(key, apiToken string) {
for attempt := 1; attempt <= casMaxRetries; attempt++ {
timestamps, cas, err := getApiTokenUnavailableSince(key)
if err != nil {
log.Errorf("Failed to get %s: %v", key, err)
continue
}
if _, ok := timestamps[apiToken]; !ok {
return
}
delete(timestamps, apiToken)
data, err := json.Marshal(timestamps)
if err != nil {
log.Errorf("Failed to marshal unavailableSince: %v", err)
return
}
if err := proxywasm.SetSharedData(key, data, cas); err == nil {
return
} else if !errors.Is(err, types.ErrorStatusCasMismatch) {
log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
return
}
log.Errorf("CAS mismatch when setting %s, retrying...", key)
}
}

func (c *ProviderConfig) GetGlobalRandomToken() string {
Expand Down Expand Up @@ -571,11 +679,16 @@ func (c *ProviderConfig) isFailoverEnabled() bool {
}

func (c *ProviderConfig) resetSharedData() {
// In the real proxy-wasm host, cas=0 means "ignore CAS and overwrite"
// instead of "match CAS=0". We rely on that behavior here so config updates
// can unconditionally clear previous shared data state.
_ = proxywasm.SetSharedData(c.failover.ctxVmLease, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxApiTokens, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxUnavailableApiTokens, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestSuccessCount, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenUnavailableSince, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxHealthCheckEndpoint, nil, 0)
}

func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, apiTokens []string, status string) types.Action {
Expand Down
44 changes: 44 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,50 @@ func TestFailover_FromJson_Defaults(t *testing.T) {
assert.Equal(t, int64(8000), f.healthCheckTimeout)
assert.Equal(t, "test-model", f.healthCheckModel)
})

t.Run("cooldown_duration_default", func(t *testing.T) {
f := &failover{}
jsonStr := `{"enabled": true}`
f.FromJson(gjson.Parse(jsonStr))
assert.Equal(t, int64(0), f.cooldownDuration)
})

t.Run("cooldown_duration_custom", func(t *testing.T) {
f := &failover{}
jsonStr := `{"enabled": true, "cooldownDuration": 60000}`
f.FromJson(gjson.Parse(jsonStr))
assert.Equal(t, int64(60000), f.cooldownDuration)
})
}

func TestFailover_Validate(t *testing.T) {
t.Run("only_healthCheckModel", func(t *testing.T) {
f := &failover{healthCheckModel: "gpt-3.5-turbo"}
assert.NoError(t, f.Validate())
})

t.Run("only_cooldownDuration", func(t *testing.T) {
f := &failover{cooldownDuration: 60000}
assert.NoError(t, f.Validate())
})

t.Run("both_healthCheckModel_and_cooldownDuration", func(t *testing.T) {
f := &failover{healthCheckModel: "gpt-3.5-turbo", cooldownDuration: 60000}
assert.NoError(t, f.Validate())
})

t.Run("neither_healthCheckModel_nor_cooldownDuration", func(t *testing.T) {
f := &failover{}
err := f.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), "either healthCheckModel or cooldownDuration")
})

t.Run("negative_cooldownDuration", func(t *testing.T) {
f := &failover{cooldownDuration: -1}
err := f.Validate()
assert.Error(t, err)
})
}

func TestFailover_FromJson_FailoverOnStatus(t *testing.T) {
Expand Down
Loading
Loading