Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions streamaccumulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ type ChatCompletionAccumulator struct {

type FinishedChatCompletionToolCall struct {
ChatCompletionMessageFunctionToolCallFunction
Index int
ID string
Index int
ChoiceIndex int
ID string
}

type chatCompletionResponseState struct {
state chatCompletionResponseStateEnum
index int
state chatCompletionResponseStateEnum
choiceIndex int
index int
}

type chatCompletionResponseStateEnum int
Expand Down Expand Up @@ -46,9 +48,14 @@ func (acc *ChatCompletionAccumulator) AddChunk(chunk ChatCompletionChunk) bool {
return true
}

chunkIndex := int(chunk.Choices[0].Index)
acc.choiceChatCompletionStates = expandToFit(acc.choiceChatCompletionStates, chunkIndex)
acc.justFinished = acc.choiceChatCompletionStates[chunkIndex].update(chunk)
for _, choice := range chunk.Choices {
chunkIndex := int(choice.Index)
acc.choiceChatCompletionStates = expandToFit(acc.choiceChatCompletionStates, chunkIndex)
justFinished := acc.choiceChatCompletionStates[chunkIndex].update(choice)
if acc.justFinished.state == emptyResponseState && justFinished.state != emptyResponseState {
acc.justFinished = justFinished
}
}
return true
}

Expand All @@ -58,7 +65,7 @@ func (acc *ChatCompletionAccumulator) AddChunk(chunk ChatCompletionChunk) bool {
// an empty string is returned and the boolean will be false.
func (acc *ChatCompletionAccumulator) JustFinishedContent() (content string, ok bool) {
if acc.justFinished.state == contentResponseState {
return acc.Choices[0].Message.Content, true
return acc.Choices[acc.justFinished.choiceIndex].Message.Content, true
}
return "", false
}
Expand All @@ -69,7 +76,7 @@ func (acc *ChatCompletionAccumulator) JustFinishedContent() (content string, ok
// an empty string is returned and the boolean will be false.
func (acc *ChatCompletionAccumulator) JustFinishedRefusal() (refusal string, ok bool) {
if acc.justFinished.state == refusalResponseState {
return acc.Choices[0].Message.Refusal, true
return acc.Choices[acc.justFinished.choiceIndex].Message.Refusal, true
}
return "", false
}
Expand All @@ -83,11 +90,13 @@ func (acc *ChatCompletionAccumulator) JustFinishedRefusal() (refusal string, ok
// You cannot rely on this with a stream that has ParallelToolCalls enabled.
func (acc *ChatCompletionAccumulator) JustFinishedToolCall() (toolcall FinishedChatCompletionToolCall, ok bool) {
if acc.justFinished.state == toolResponseState {
f := acc.Choices[0].Message.ToolCalls[acc.justFinished.index].Function
id := acc.Choices[0].Message.ToolCalls[acc.justFinished.index].ID
choice := acc.Choices[acc.justFinished.choiceIndex]
tool := choice.Message.ToolCalls[acc.justFinished.index]
f := tool.Function
return FinishedChatCompletionToolCall{
ID: id,
Index: acc.justFinished.index,
ID: tool.ID,
Index: acc.justFinished.index,
ChoiceIndex: acc.justFinished.choiceIndex,
ChatCompletionMessageFunctionToolCallFunction: ChatCompletionMessageFunctionToolCallFunction{
Name: f.Name,
Arguments: f.Arguments,
Expand Down Expand Up @@ -169,9 +178,9 @@ func (cc *ChatCompletion) accumulateDelta(chunk ChatCompletionChunk) bool {

// Updates the internal response state and returns the previous state if
// the state changed. This ensures that JustFinished events only fire once.
func (prev *chatCompletionResponseState) update(chunk ChatCompletionChunk) (justFinished chatCompletionResponseState) {
delta := chunk.Choices[0].Delta
new := chatCompletionResponseState{}
func (prev *chatCompletionResponseState) update(choice ChatCompletionChunkChoice) (justFinished chatCompletionResponseState) {
delta := choice.Delta
new := chatCompletionResponseState{choiceIndex: int(choice.Index)}
switch {
case delta.JSON.Content.Valid():
new.state = contentResponseState
Expand Down
48 changes: 48 additions & 0 deletions streamaccumulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,54 @@ func TestAccumulatorEmptyToolCallsArray(t *testing.T) {
acc.AddChunk(chunk)
}

func TestAccumulatorJustFinishedToolCallUsesChoiceIndex(t *testing.T) {
acc := openai.ChatCompletionAccumulator{}

addChunk := func(raw string) {
t.Helper()

var chunk openai.ChatCompletionChunk
if err := chunk.UnmarshalJSON([]byte(raw)); err != nil {
t.Fatalf("Failed to unmarshal chunk: %v", err)
}
if !acc.AddChunk(chunk) {
t.Fatal("AddChunk returned false")
}
}

addChunk(`{"id":"test","choices":[{"index":0,"delta":{"content":"first"}},{"index":1,"delta":{"tool_calls":[{"id":"call_123","index":0,"type":"function","function":{"name":"lookup","arguments":"{\"city\":"}}]}}]}`)
if _, ok := acc.JustFinishedToolCall(); ok {
t.Fatal("Unexpected finished tool call")
}

addChunk(`{"id":"test","choices":[{"index":0,"delta":{"content":" choice"}},{"index":1,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"Paris\"}"}}]}}]}`)
if _, ok := acc.JustFinishedToolCall(); ok {
t.Fatal("Unexpected finished tool call")
}

addChunk(`{"id":"test","choices":[{"index":0,"delta":{"content":" still streaming"}},{"index":1,"delta":{},"finish_reason":"tool_calls"}]}`)

toolCall, ok := acc.JustFinishedToolCall()
if !ok {
t.Fatal("Expected finished tool call")
}
if toolCall.ChoiceIndex != 1 {
t.Fatalf("ChoiceIndex: expected 1, got %d", toolCall.ChoiceIndex)
}
if toolCall.Index != 0 {
t.Fatalf("Index: expected 0, got %d", toolCall.Index)
}
if toolCall.ID != "call_123" {
t.Fatalf("ID: expected call_123, got %q", toolCall.ID)
}
if toolCall.Name != "lookup" {
t.Fatalf("Name: expected lookup, got %q", toolCall.Name)
}
if toolCall.Arguments != `{"city":"Paris"}` {
t.Fatalf("Arguments: expected city JSON, got %q", toolCall.Arguments)
}
}

// manually created on 11/3/2024
var mockResponseBody = `data: {"id":"chatcmpl-A3Tguz3LSXTHBTY2NAPBCSyfBltxF","object":"chat.completion.chunk","created":1725392480,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_157b3831f5","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":{"content":[],"refusal":null},"finish_reason":null}],"usage":null}

Expand Down