Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ def try_run(self, prompt_stack: PromptStack) -> Message:

logger.debug(response.model_dump())

if response.stop_reason == "max_tokens":
logger.warning(
"Response reached the max_tokens limit (%d). Output may be truncated. "
"Increase the max_tokens parameter to get longer responses.",
self.max_tokens,
)

return Message(
content=[self.__to_prompt_stack_message_content(content) for content in response.content],
role=Message.ASSISTANT_ROLE,
Expand All @@ -126,6 +133,12 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
elif event.type == "message_start":
yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=event.message.usage.input_tokens))
elif event.type == "message_delta":
if event.delta.stop_reason == "max_tokens":
logger.warning(
"Response reached the max_tokens limit (%d). Output may be truncated. "
"Increase the max_tokens parameter to get longer responses.",
self.max_tokens,
)
yield DeltaMessage(usage=DeltaMessage.Usage(output_tokens=event.usage.output_tokens))

def _base_params(self, prompt_stack: PromptStack) -> dict:
Expand Down
60 changes: 60 additions & 0 deletions tests/unit/drivers/prompt/test_anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,63 @@ def test_try_run_omits_sampling_params_for_opus_4_7(self, mock_client, prompt_st
assert "temperature" not in call_kwargs.kwargs
assert "top_p" not in call_kwargs.kwargs
assert "top_k" not in call_kwargs.kwargs

def test_try_run_warns_on_max_tokens(self, mocker, prompt_stack, caplog):
import logging

# Given
mock_client = mocker.patch("anthropic.Anthropic")
mock_client.return_value = Mock(
messages=Mock(
create=Mock(
return_value=Mock(
stop_reason="max_tokens",
usage=Mock(input_tokens=5, output_tokens=10),
content=[Mock(type="text", text="truncated-output")],
)
)
)
)
driver = AnthropicPromptDriver(model="claude-3-haiku", api_key="api-key", max_tokens=10)

# When
with caplog.at_level(logging.WARNING, logger="griptape"):
driver.try_run(prompt_stack)

# Then
assert any("max_tokens" in r.message for r in caplog.records if r.levelno == logging.WARNING)

def test_try_stream_warns_on_max_tokens(self, mocker, prompt_stack, caplog):
import logging

# Given
mock_client = mocker.patch("anthropic.Anthropic")
mock_client.return_value = Mock(
messages=Mock(
create=Mock(
return_value=iter(
[
Mock(type="message_start", message=Mock(usage=Mock(input_tokens=5))),
Mock(
type="content_block_start",
index=0,
content_block=Mock(type="text", text="truncated"),
),
Mock(
type="message_delta",
usage=Mock(output_tokens=10),
delta=Mock(stop_reason="max_tokens"),
),
]
)
)
)
)
driver = AnthropicPromptDriver(model="claude-3-haiku", api_key="api-key", stream=True, max_tokens=10)

# When
with caplog.at_level(logging.WARNING, logger="griptape"):
list(driver.try_stream(prompt_stack))

# Then
assert any("max_tokens" in r.message for r in caplog.records if r.levelno == logging.WARNING)
Loading