Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 70 additions & 25 deletions kirolink.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"strings"
"time"

"crypto/sha256"
"github.com/alexandeism/kirolink/protocol"
)

Expand Down Expand Up @@ -65,9 +66,18 @@ type CodeWhispererTool struct {
// HistoryUserMessage defines a user message in history
type HistoryUserMessage struct {
UserInputMessage struct {
Content string `json:"content"`
ModelId string `json:"modelId"`
Origin string `json:"origin"`
Content string `json:"content"`
ModelId string `json:"modelId"`
Origin string `json:"origin"`
UserInputMessageContext struct {
ToolResults []struct {
Content []struct {
Text string `json:"text"`
} `json:"content"`
Status string `json:"status"`
ToolUseId string `json:"toolUseId"`
} `json:"toolResults,omitempty"`
} `json:"userInputMessageContext,omitempty"`
} `json:"userInputMessage"`
}

Expand All @@ -89,6 +99,8 @@ type AnthropicRequest struct {
Stream bool `json:"stream"`
Temperature *float64 `json:"temperature,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
// KiroLink extensions
ConversationId *string `json:"conversation_id,omitempty"`
}

// AnthropicStreamResponse defines the Anthropic streaming response structure
Expand Down Expand Up @@ -331,6 +343,15 @@ func generateUUID() string {
return fmt.Sprintf("%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:])
}

// generateDeterministicUUID generates a stable UUID based on input hash
func generateDeterministicUUID(seed string) string {
hash := sha256.Sum256([]byte(seed))
b := hash[:16]
b[6] = (b[6] & 0x0f) | 0x40 // Version 4
b[8] = (b[8] & 0x3f) | 0x80 // Variant bits
return fmt.Sprintf("%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:])
}

// truncateString truncates a string to maxLen, appending "..." if truncated.
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
Expand Down Expand Up @@ -695,7 +716,19 @@ func buildCodeWhispererRequest(anthropicReq AnthropicRequest) CodeWhispererReque
}
resolvedModel := resolveModelID(anthropicReq.Model)
cwReq.ConversationState.ChatTriggerType = "MANUAL"
cwReq.ConversationState.ConversationId = generateUUID()

// Session continuity: use client-provided ID or a deterministic one based on the first message
if anthropicReq.ConversationId != nil && *anthropicReq.ConversationId != "" {
cwReq.ConversationState.ConversationId = *anthropicReq.ConversationId
} else if len(anthropicReq.Messages) > 0 {
// Heuristic: Use the first user message as a stable seed for the conversation.
// Note: We skip potential system prompts or earlier turns to keep it stable.
firstMsg := anthropicReq.Messages[0]
cwReq.ConversationState.ConversationId = generateDeterministicUUID(getMessageContent(firstMsg.Content))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid deriving session ID solely from first message

Generating conversationId from only getMessageContent(firstMsg.Content) causes unrelated chats to collide whenever they open with the same text (for example, many users start with "Hello") if the client does not send conversation_id. Because CodeWhisperer state is keyed by this ID, those collisions can merge tool/history context across distinct sessions, which is a functional and privacy regression compared with the previous random UUID behavior.

Useful? React with 👍 / 👎.

} else {
cwReq.ConversationState.ConversationId = generateUUID()
}

cwReq.ConversationState.CurrentMessage.UserInputMessage.Content = buildCurrentMessageContent(anthropicReq)
cwReq.ConversationState.CurrentMessage.UserInputMessage.ModelId = resolvedModel
cwReq.ConversationState.CurrentMessage.UserInputMessage.Origin = "AI_EDITOR"
Expand All @@ -705,24 +738,26 @@ func buildCodeWhispererRequest(anthropicReq AnthropicRequest) CodeWhispererReque
cwReq.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext.Tools = buildCodeWhispererTools(anthropicReq.Tools)
}

// NOTE: We do NOT map tool_result to CodeWhisperer's toolResults format.
// Instead, tool results are included as text content via getMessageContent,
// which the model can understand from context. This avoids format mismatches
// with CodeWhisperer's undocumented schema.
// Extract tool results for the current message if they exist
if lastMsg := anthropicReq.Messages[len(anthropicReq.Messages)-1]; lastMsg.Role == "user" {
if results := extractToolResults(lastMsg.Content); len(results) > 0 {
cwReq.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext.ToolResults = results
}
}

// Build history from caller-provided conversation only.
history := make([]any, 0, len(anthropicReq.Messages)-1)
for i := 0; i < len(anthropicReq.Messages)-1; i++ {
msg := anthropicReq.Messages[i]
content := getMessageContent(msg.Content)
if strings.TrimSpace(content) == "" {
if strings.TrimSpace(content) == "" && !hasToolResults(msg.Content) {
continue
}

if msg.Role == "assistant" {
assistantMsg := HistoryAssistantMessage{}
assistantMsg.AssistantResponseMessage.Content = content
assistantMsg.AssistantResponseMessage.ToolUses = make([]any, 0)
assistantMsg.AssistantResponseMessage.ToolUses = extractToolUses(msg.Content)
history = append(history, assistantMsg)
continue
}
Expand All @@ -731,6 +766,12 @@ func buildCodeWhispererRequest(anthropicReq AnthropicRequest) CodeWhispererReque
userMsg.UserInputMessage.Content = content
userMsg.UserInputMessage.ModelId = resolvedModel
userMsg.UserInputMessage.Origin = "AI_EDITOR"

// Extract tool results for history if they exist
if results := extractToolResults(msg.Content); len(results) > 0 {
userMsg.UserInputMessage.UserInputMessageContext.ToolResults = results
}

history = append(history, userMsg)
}
cwReq.ConversationState.History = history
Expand Down Expand Up @@ -1306,7 +1347,7 @@ func eventIndex(value any) int {
}
}

func buildAnthropicResponsePayload(model string, inputTokens int, translated translatedAnthropicResponse) map[string]any {
func buildAnthropicResponsePayload(conversationId, model string, inputTokens int, translated translatedAnthropicResponse) map[string]any {
content := make([]map[string]any, 0, len(translated.Blocks))
for _, block := range translated.Blocks {
switch block.Type {
Expand All @@ -1326,32 +1367,34 @@ func buildAnthropicResponsePayload(model string, inputTokens int, translated tra
}

return map[string]any{
"content": content,
"model": model,
"role": "assistant",
"stop_reason": translated.StopReason,
"stop_sequence": nil,
"type": "message",
"content": content,
"model": model,
"role": "assistant",
"stop_reason": translated.StopReason,
"stop_sequence": nil,
"type": "message",
"conversation_id": conversationId,
"usage": map[string]any{
"input_tokens": inputTokens,
"output_tokens": translated.OutputTokens,
},
}
}

func buildAnthropicStreamEvents(messageId, model string, inputTokens int, translated translatedAnthropicResponse) []protocol.SSEEvent {
func buildAnthropicStreamEvents(conversationId, messageId, model string, inputTokens int, translated translatedAnthropicResponse) []protocol.SSEEvent {
events := []protocol.SSEEvent{{
Event: "message_start",
Data: map[string]any{
"type": "message_start",
"message": map[string]any{
"id": messageId,
"type": "message",
"role": "assistant",
"content": []any{},
"model": model,
"stop_reason": nil,
"stop_sequence": nil,
"id": messageId,
"type": "message",
"role": "assistant",
"content": []any{},
"model": model,
"stop_reason": nil,
"stop_sequence": nil,
"conversation_id": conversationId,
"usage": map[string]any{
"input_tokens": inputTokens,
"output_tokens": 1,
Expand Down Expand Up @@ -1586,6 +1629,7 @@ func handleStreamRequest(w http.ResponseWriter, anthropicReq AnthropicRequest, a
if len(parsedEvents) > 0 {
translated := assembleAnthropicResponse(parsedEvents)
streamEvents := buildAnthropicStreamEvents(
cwReq.ConversationState.ConversationId,
messageId,
responseModelID(cwReq, anthropicReq),
len(cwReq.ConversationState.CurrentMessage.UserInputMessage.Content),
Expand Down Expand Up @@ -1663,6 +1707,7 @@ func handleNonStreamRequest(w http.ResponseWriter, anthropicReq AnthropicRequest

// Build Anthropic response
anthropicResp := buildAnthropicResponsePayload(
cwReq.ConversationState.ConversationId,
responseModelID(cwReq, anthropicReq),
len(cwReq.ConversationState.CurrentMessage.UserInputMessage.Content),
translated,
Expand Down
3 changes: 2 additions & 1 deletion response_translation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func TestBuildAnthropicResponsePayloadUsesResolvedModel(t *testing.T) {
cwReq.ConversationState.CurrentMessage.UserInputMessage.ModelId = "CLAUDE_SONNET_4_5"

payload := buildAnthropicResponsePayload(
"conv-123",
responseModelID(cwReq, AnthropicRequest{Model: "claude-sonnet-4-5-20250929"}),
11,
translatedAnthropicResponse{
Expand Down Expand Up @@ -72,7 +73,7 @@ func TestBuildAnthropicStreamEventsUsesTranslatedBlocks(t *testing.T) {
OutputTokens: 2,
}

events := buildAnthropicStreamEvents("msg_123", "CLAUDE_SONNET_4_5", 11, translated)
events := buildAnthropicStreamEvents("conv-123", "msg_123", "CLAUDE_SONNET_4_5", 11, translated)
if len(events) != 10 {
t.Fatalf("expected 10 stream events, got %d", len(events))
}
Expand Down
94 changes: 94 additions & 0 deletions tool_mapping_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package main

import (
"testing"
)

func TestBuildCodeWhispererRequest_SessionContinuity(t *testing.T) {
req1 := AnthropicRequest{
Messages: []AnthropicRequestMessage{
{Role: "user", Content: "Hello"},
},
}
req2 := AnthropicRequest{
Messages: []AnthropicRequestMessage{
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there!"},
{Role: "user", Content: "How are you?"},
},
}

cwReq1 := buildCodeWhispererRequest(req1)
cwReq2 := buildCodeWhispererRequest(req2)

if cwReq1.ConversationState.ConversationId != cwReq2.ConversationState.ConversationId {
t.Errorf("ConversationId mismatch: %s != %s", cwReq1.ConversationState.ConversationId, cwReq2.ConversationState.ConversationId)
}
}

func TestBuildCodeWhispererRequest_ToolMapping(t *testing.T) {
req := AnthropicRequest{
Messages: []AnthropicRequestMessage{
{Role: "user", Content: "Call tool"},
{Role: "assistant", Content: []interface{}{
map[string]interface{}{
"type": "tool_use",
"id": "tool-1",
"name": "my_tool",
"input": map[string]interface{}{
"arg": "val",
},
},
}},
{Role: "user", Content: []interface{}{
map[string]interface{}{
"type": "tool_result",
"tool_use_id": "tool-1",
"content": "Result content",
"is_error": false,
},
}},
},
}

cwReq := buildCodeWhispererRequest(req)

// Check History items
history := cwReq.ConversationState.History
if len(history) != 2 {
t.Fatalf("Expected history length 2, got %d", len(history))
}

userMsg0, ok := history[0].(HistoryUserMessage)
if !ok {
t.Fatalf("Expected first history item to be HistoryUserMessage, got %T", history[0])
}
if userMsg0.UserInputMessage.Content != "Call tool" {
t.Errorf("Expected first history item content 'Call tool', got %s", userMsg0.UserInputMessage.Content)
}

assistantMsg, ok := history[1].(HistoryAssistantMessage)
if !ok {
t.Fatalf("Expected second history item to be HistoryAssistantMessage, got %T", history[1])
}
if len(assistantMsg.AssistantResponseMessage.ToolUses) != 1 {
t.Fatalf("Expected 1 tool use in history, got %d", len(assistantMsg.AssistantResponseMessage.ToolUses))
}
toolUse := assistantMsg.AssistantResponseMessage.ToolUses[0].(map[string]interface{})
if toolUse["name"] != "my_tool" {
t.Errorf("Expected tool name my_tool, got %v", toolUse["name"])
}

// Check Current Message ToolResults
currentContext := cwReq.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
if len(currentContext.ToolResults) != 1 {
t.Fatalf("Expected 1 tool result in current message context, got %d", len(currentContext.ToolResults))
}
result := currentContext.ToolResults[0]
if result.ToolUseId != "tool-1" {
t.Errorf("Expected tool use id tool-1, got %s", result.ToolUseId)
}
if result.Content[0].Text != "Result content" {
t.Errorf("Expected result content 'Result content', got %s", result.Content[0].Text)
}
}