Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 88 additions & 3 deletions pkg/epp/requestcontrol/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"math/rand"
"net"
"strings"
"sync"
"time"

"go.opentelemetry.io/otel"
Expand Down Expand Up @@ -89,6 +90,22 @@ func NewDirectorWithConfig(
}
}

// responseBodyWork represents a unit of work to be processed by the async response body queue.
type responseBodyWork struct {
ctx context.Context
request *fwksched.LLMRequest
response *fwk.Response
targetEndpoint *fwkdl.EndpointMetadata
}

// responseBodyQueue is a per-request async queue for processing response body plugin calls.
// It ensures chunks are processed in order via a channel while keeping plugin execution
// off the critical streaming path.
type responseBodyQueue struct {
ch chan responseBodyWork
done chan struct{} // closed when the processing goroutine exits
}

// Director orchestrates the request handling flow after initial parsing by the handler.
// Its responsibilities include:
// - Retrieving request metadata and relevant objectives.
Expand All @@ -109,6 +126,11 @@ type Director struct {
// and value types cannot be nil
defaultPriority int
parser fwkrh.Parser

// responseBodyQueues maps request IDs to their async processing channels.
// Each request gets a dedicated channel and goroutine to ensure chunks are
// processed in order while not blocking the streaming response path.
responseBodyQueues sync.Map
}

// getInferenceObjective fetches the inferenceObjective from the datastore otherwise creates a new one based on reqCtx.
Expand Down Expand Up @@ -341,6 +363,8 @@ func (d *Director) toSchedulerEndpoints(endpoints []fwkdl.Endpoint) []fwksched.E
}

// HandleResponseHeader is called when the response headers are received.
// Response header plugins are run asynchronously since they do not produce data
// that is needed by the caller before the next processing step.
func (d *Director) HandleResponseHeader(ctx context.Context, reqCtx *handlers.RequestContext) *handlers.RequestContext {
response := &fwk.Response{
RequestId: reqCtx.Request.Headers[reqcommon.RequestIdHeaderKey],
Expand All @@ -349,23 +373,66 @@ func (d *Director) HandleResponseHeader(ctx context.Context, reqCtx *handlers.Re
}
// TODO: to extend fallback functionality, handle cases where target pod is unavailable
// https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/1224
d.runResponseHeaderPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
d.runResponseHeaderPluginsAsync(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
return reqCtx
}

// HandleResponseBody is invoked by the director for every chunk received in a streaming
// response, or exactly once for a non-streaming response.
//
// For intermediate streaming chunks (endOfStream=false), the work is sent to a per-request
// async queue (channel + goroutine) so plugins run off the critical path while preserving
// chunk ordering. For the final chunk (endOfStream=true), the queue is drained first, then
// plugins run synchronously because they may produce DynamicMetadata that must be attached
// to the ext_proc response sent back to Envoy.
func (d *Director) HandleResponseBody(ctx context.Context, reqCtx *handlers.RequestContext, endOfStream bool) *handlers.RequestContext {
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk")
logger.V(logutil.TRACE).Info("Entering HandleResponseBodyChunk")

if len(d.requestControlPlugins.responseStreamingPlugins) == 0 {
logger.V(logutil.TRACE).Info("Exiting HandleResponseBodyChunk")
return reqCtx
}

response := &fwk.Response{
RequestId: reqCtx.Request.Headers[reqcommon.RequestIdHeaderKey],
Headers: reqCtx.Response.Headers,
EndOfStream: endOfStream,
Usage: reqCtx.Usage,
}
d.runResponseBodyPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
reqCtx.Response.DynamicMetadata = response.DynamicMetadata
requestId := reqCtx.Request.Headers[reqcommon.RequestIdHeaderKey]

if endOfStream {
// Drain the async queue: close the channel and wait for the goroutine to finish
// processing all previously queued chunks before running the final chunk synchronously.
if val, ok := d.responseBodyQueues.LoadAndDelete(requestId); ok {
q := val.(*responseBodyQueue)
close(q.ch)
<-q.done // wait for all queued chunks to be processed
}
// Run the final chunk synchronously so DynamicMetadata is available for the response.
d.runResponseBodyPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
reqCtx.Response.DynamicMetadata = response.DynamicMetadata
} else {
// Get or create the async queue for this request.
work := responseBodyWork{
ctx: ctx,
request: reqCtx.SchedulingRequest,
response: response,
targetEndpoint: reqCtx.TargetPod,
}
if val, ok := d.responseBodyQueues.Load(requestId); ok {
val.(*responseBodyQueue).ch <- work
} else {
q := &responseBodyQueue{
ch: make(chan responseBodyWork, 100),
done: make(chan struct{}),
}
d.responseBodyQueues.Store(requestId, q)
go d.processResponseBodyQueue(q)
q.ch <- work
}
}
logger.V(logutil.TRACE).Info("Exiting HandleResponseBodyChunk")
return reqCtx
}
Expand Down Expand Up @@ -425,6 +492,14 @@ func (d *Director) runResponseHeaderPlugins(ctx context.Context, request *fwksch
}
}

// runResponseHeaderPluginsAsync runs all response header plugins in a goroutine.
func (d *Director) runResponseHeaderPluginsAsync(ctx context.Context, request *fwksched.LLMRequest, response *fwk.Response, targetEndpoint *fwkdl.EndpointMetadata) {
if len(d.requestControlPlugins.responseReceivedPlugins) == 0 {
return
}
go d.runResponseHeaderPlugins(ctx, request, response, targetEndpoint)
Copy link
Copy Markdown
Contributor

@LukeAVanDrie LukeAVanDrie Mar 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if we don't block, don't we want to ensure that the handling for chunks A-->B-->C for request, R, are still processed in that order?

If we fire these off as routines and chunks arrive quickly, couldn't this result in out of order execution?

Copy link
Copy Markdown
Contributor Author

@gyliu513 gyliu513 Mar 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good catch, thanks @LukeAVanDrie. You're right. The PredictedLatency plugin's ResponseBody is stateful across chunks (it tracks TTFT on the first chunk, accumulates TPOT across subsequent chunks via predictedLatencyCtx). Firing goroutines per chunk could result in out-of-order execution and corrupt that state.

How about this approach to keep the order: instead of firing a goroutine per chunk, each request gets a dedicated channel + goroutine queue. The flow would be:

  • Intermediate chunks (endOfStream=false): sent to a per-request buffered channel. A single goroutine reads from the channel and runs plugins sequentially — so chunks A→B→C are always processed in that order.
  • Final chunk (endOfStream=true): the channel is closed, we wait for the goroutine to drain all queued chunks, then run plugins synchronously on the main goroutine. This ensures DynamicMetadata written by plugins (e.g., requestattributereporter) is available before the Envoy response is generated.
  • Response header plugins: remain fire-and-forget async since they are stateless and only called once per request.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds good to me!

}

func (d *Director) runResponseBodyPlugins(ctx context.Context, request *fwksched.LLMRequest, response *fwk.Response, targetEndpoint *fwkdl.EndpointMetadata) {
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
for _, plugin := range d.requestControlPlugins.responseStreamingPlugins {
Expand All @@ -435,3 +510,13 @@ func (d *Director) runResponseBodyPlugins(ctx context.Context, request *fwksched
loggerTrace.Info("Completed running ResponseStreaming plugin successfully", "plugin", plugin.TypedName())
}
}

// processResponseBodyQueue reads work items from the queue channel and runs response body
// plugins for each one sequentially. It exits when the channel is closed and signals
// completion by closing q.done.
func (d *Director) processResponseBodyQueue(q *responseBodyQueue) {
defer close(q.done)
for work := range q.ch {
d.runResponseBodyPlugins(work.ctx, work.request, work.response, work.targetEndpoint)
}
}
127 changes: 121 additions & 6 deletions pkg/epp/requestcontrol/director_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ import (
"errors"
"fmt"
"sort"
"sync"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
Expand Down Expand Up @@ -1093,13 +1095,25 @@ func TestDirector_HandleResponseReceived(t *testing.T) {

director.HandleResponseHeader(ctx, reqCtx)

if diff := cmp.Diff("test-req-id-for-response", pr1.lastRespOnResponse.RequestId); diff != "" {
// HandleResponseHeader runs plugins asynchronously, so wait for completion.
require.Eventually(t, func() bool {
pr1.mu.Lock()
defer pr1.mu.Unlock()
return pr1.lastRespOnResponse != nil
}, time.Second, 10*time.Millisecond, "response header plugin should have been called")

pr1.mu.Lock()
lastResp := pr1.lastRespOnResponse
lastTargetPod := pr1.lastTargetPodOnResponse
pr1.mu.Unlock()

if diff := cmp.Diff("test-req-id-for-response", lastResp.RequestId); diff != "" {
t.Errorf("Scheduler.OnResponse RequestId mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(reqCtx.Response.Headers, pr1.lastRespOnResponse.Headers); diff != "" {
if diff := cmp.Diff(reqCtx.Response.Headers, lastResp.Headers); diff != "" {
t.Errorf("Scheduler.OnResponse Headers mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff("namespace1/test-pod-name", pr1.lastTargetPodOnResponse); diff != "" {
if diff := cmp.Diff("namespace1/test-pod-name", lastTargetPod); diff != "" {
t.Errorf("Scheduler.OnResponse TargetPodName mismatch (-want +got):\n%s", diff)
}
}
Expand Down Expand Up @@ -1127,14 +1141,30 @@ func TestDirector_HandleResponseBody(t *testing.T) {

director.HandleResponseBody(ctx, reqCtx, false)
director.HandleResponseBody(ctx, reqCtx, false)

// Intermediate chunks (endOfStream=false) run asynchronously, wait for them.
require.Eventually(t, func() bool {
ps1.mu.Lock()
defer ps1.mu.Unlock()
return len(ps1.respsOnStreaming) >= 2
}, time.Second, 10*time.Millisecond, "async response body plugins should have been called for intermediate chunks")

// Final chunk (endOfStream=true) runs synchronously (drains queue first).
director.HandleResponseBody(ctx, reqCtx, true)

assert.Equal(t, 3, len(ps1.respsOnStreaming), "Should have received 3 streaming calls")
ps1.mu.Lock()
resps := make([]*fwk.Response, len(ps1.respsOnStreaming))
copy(resps, ps1.respsOnStreaming)
targetPods := make([]string, len(ps1.targetPodsOnStreaming))
copy(targetPods, ps1.targetPodsOnStreaming)
ps1.mu.Unlock()

assert.Equal(t, 3, len(resps), "Should have received 3 streaming calls")

for i, resp := range ps1.respsOnStreaming {
for i, resp := range resps {
assert.Equal(t, "test-req-id-for-streaming", resp.RequestId)
assert.Equal(t, reqCtx.Response.Headers, resp.Headers)
assert.Equal(t, "namespace1/test-pod-name", ps1.targetPodsOnStreaming[i])
assert.Equal(t, "namespace1/test-pod-name", targetPods[i])
if i < 2 {
assert.False(t, resp.EndOfStream, "EndOfStream should be false for chunk %d", i)
} else {
Expand All @@ -1143,19 +1173,100 @@ func TestDirector_HandleResponseBody(t *testing.T) {
}
}

func TestDirector_HandleResponseBody_ChunkOrdering(t *testing.T) {
// orderTrackingPlugin records the RequestId of each chunk it processes.
// Since we set a unique RequestId per chunk, the recorded order lets us
// verify that chunks are processed in the exact order they were sent,
// even though they go through the async queue.
plugin := &orderTrackingPlugin{
typedName: fwkplugin.TypedName{Type: "order-tracker", Name: "order-tracker"},
}

ctx := logutil.NewTestLoggerIntoContext(context.Background())
ds := datastore.NewDatastore(t.Context(), nil, 0)
locator := NewCachedPodLocator(context.Background(), NewDatastorePodLocator(ds), time.Minute)
director := NewDirectorWithConfig(ds, &mockScheduler{}, nil, nil, locator, NewConfig().WithResponseStreamingPlugins(plugin))

const numChunks = 50

for i := range numChunks {
reqCtx := &handlers.RequestContext{
Request: &handlers.Request{
Headers: map[string]string{
// All chunks share the same request ID so they go through the same queue.
reqcommon.RequestIdHeaderKey: "ordering-test-request",
},
},
Response: &handlers.Response{
Headers: map[string]string{},
},
TargetPod: &fwkdl.EndpointMetadata{},
Usage: fwk.Usage{CompletionTokens: i},
}
director.HandleResponseBody(ctx, reqCtx, false)
}

// Send final chunk to drain the queue.
finalReqCtx := &handlers.RequestContext{
Request: &handlers.Request{
Headers: map[string]string{
reqcommon.RequestIdHeaderKey: "ordering-test-request",
},
},
Response: &handlers.Response{
Headers: map[string]string{},
},
TargetPod: &fwkdl.EndpointMetadata{},
Usage: fwk.Usage{CompletionTokens: numChunks},
}
director.HandleResponseBody(ctx, finalReqCtx, true)

// Total calls: numChunks async + 1 sync final.
plugin.mu.Lock()
tokenCounts := make([]int, len(plugin.observedTokenCounts))
copy(tokenCounts, plugin.observedTokenCounts)
plugin.mu.Unlock()

require.Equal(t, numChunks+1, len(tokenCounts), "should have received all chunk calls")

// Verify ordering: each chunk's CompletionTokens should appear in the order 0, 1, 2, ..., numChunks.
for i, tokens := range tokenCounts {
assert.Equal(t, i, tokens, "chunk %d was processed out of order", i)
}
}

// orderTrackingPlugin records the CompletionTokens from each ResponseBody call to verify ordering.
type orderTrackingPlugin struct {
mu sync.Mutex
typedName fwkplugin.TypedName
observedTokenCounts []int
}

func (p *orderTrackingPlugin) TypedName() fwkplugin.TypedName {
return p.typedName
}

func (p *orderTrackingPlugin) ResponseBody(_ context.Context, _ *fwksched.LLMRequest, response *fwk.Response, _ *fwkdl.EndpointMetadata) {
p.mu.Lock()
defer p.mu.Unlock()
p.observedTokenCounts = append(p.observedTokenCounts, response.Usage.CompletionTokens)
}

const (
testResponseReceivedType = "test-response-received"
testPostStreamingType = "test-response-streaming"
testPostCompleteType = "test-response-complete"
)

type testResponseReceived struct {
mu sync.Mutex
typedName fwkplugin.TypedName
lastRespOnResponse *fwk.Response
lastTargetPodOnResponse string
}

type testResponseStreaming struct {
mu sync.Mutex
typedName fwkplugin.TypedName
respsOnStreaming []*fwk.Response
targetPodsOnStreaming []string
Expand Down Expand Up @@ -1186,11 +1297,15 @@ func (p *testResponseStreaming) TypedName() fwkplugin.TypedName {
}

func (p *testResponseReceived) ResponseHeader(_ context.Context, _ *fwksched.LLMRequest, response *fwk.Response, targetPod *fwkdl.EndpointMetadata) {
p.mu.Lock()
defer p.mu.Unlock()
p.lastRespOnResponse = response
p.lastTargetPodOnResponse = targetPod.NamespacedName.String()
}

func (p *testResponseStreaming) ResponseBody(_ context.Context, _ *fwksched.LLMRequest, response *fwk.Response, targetPod *fwkdl.EndpointMetadata) {
p.mu.Lock()
defer p.mu.Unlock()
p.respsOnStreaming = append(p.respsOnStreaming, response)
p.targetPodsOnStreaming = append(p.targetPodsOnStreaming, targetPod.NamespacedName.String())

Expand Down
Loading