Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,32 @@ def run(self, prompt_input: PromptStack | BaseArtifact) -> Message:
else:
prompt_stack = prompt_input

for attempt in self.retrying():
with attempt:
self.before_run(prompt_stack)

result = self.__process_stream(prompt_stack) if self.stream else self.__process_run(prompt_stack)

self.after_run(result)

return result
raise Exception("prompt driver failed after all retry attempts")
try:
for attempt in self.retrying():
with attempt:
self.before_run(prompt_stack)

result = self.__process_stream(prompt_stack) if self.stream else self.__process_run(prompt_stack)

self.after_run(result)

return result
raise Exception("prompt driver failed after all retry attempts")
except Exception as e:
wrapped = self._wrap_exception(e)
if wrapped is e:
raise
raise wrapped from e

def _wrap_exception(self, exc: Exception) -> Exception:
"""Map a provider or runtime exception to a Griptape-native exception.

Subclasses override this to translate their provider SDK's errors into Griptape
exceptions (see ``OpenAiChatPromptDriver``). This runs OUTSIDE the retry loop, so the
exception type seen by ``ignored_exception_types`` / ``tenacity`` is unchanged. The
default returns ``exc`` unchanged, so drivers that don't override it behave as before.
"""
return exc

def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
"""Converts a Prompt Stack to a string for token counting or model prompt_input.
Expand Down
9 changes: 9 additions & 0 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from griptape.configs.defaults_config import Defaults
from griptape.drivers.prompt import BasePromptDriver
from griptape.exceptions import PromptDriverError
from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property
Expand Down Expand Up @@ -160,6 +161,14 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:

return self._to_delta_message_stream(result)

def _wrap_exception(self, exc: Exception) -> Exception:
openai = import_optional_dependency("openai")
if isinstance(exc, openai.APIStatusError):
return PromptDriverError(str(exc), status_code=getattr(exc, "status_code", None))
if isinstance(exc, openai.OpenAIError):
return PromptDriverError(str(exc))
return exc

def _to_message(self, result: ChatCompletion) -> Message:
if len(result.choices) == 1:
choice_message = result.choices[0].message
Expand Down
4 changes: 3 additions & 1 deletion griptape/exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .dummy_exception import DummyError
from .griptape_error import GriptapeError
from .prompt_driver_error import PromptDriverError

__all__ = ["DummyError"]
__all__ = ["DummyError", "GriptapeError", "PromptDriverError"]
9 changes: 9 additions & 0 deletions griptape/exceptions/griptape_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations


class GriptapeError(Exception):
"""Base class for Griptape's typed exceptions.

New Griptape exception types should inherit from this so callers can catch them
all with a single ``except GriptapeError``.
"""
17 changes: 17 additions & 0 deletions griptape/exceptions/prompt_driver_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

from griptape.exceptions.griptape_error import GriptapeError


class PromptDriverError(GriptapeError):
"""Raised when a Prompt Driver's underlying provider call fails.

The original provider SDK exception is preserved as this error's ``__cause__``
(drivers re-raise with ``raise ... from``), so callers can still inspect it.
``status_code`` is populated for HTTP-style failures (e.g. 401, 429) and is
``None`` for non-HTTP failures such as connection errors.
"""

def __init__(self, message: str, *, status_code: int | None = None) -> None:
super().__init__(message)
self.status_code = status_code
91 changes: 91 additions & 0 deletions tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import base64
from collections.abc import Iterator
from copy import deepcopy
from unittest.mock import ANY, MagicMock, Mock

import httpx
import openai
import pytest
import schema

Expand All @@ -12,6 +15,7 @@
from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction
from griptape.common.prompt_stack.contents.audio_delta_message_content import AudioDeltaMessageContent
from griptape.drivers.prompt.openai import OpenAiChatPromptDriver
from griptape.exceptions import PromptDriverError
from griptape.tokenizers import OpenAiTokenizer
from tests.mocks.mock_tokenizer import MockTokenizer
from tests.mocks.mock_tool.tool import MockTool
Expand Down Expand Up @@ -939,3 +943,90 @@ def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messa
max_tokens=1,
)
assert event.value[0].value == "model-output"


class TestOpenAiChatPromptDriverExceptionWrapping(TestOpenAiChatPromptDriverFixtureMixin):
"""#1946 — OpenAI SDK exceptions are wrapped as Griptape ``PromptDriverError``."""

@pytest.fixture()
def simple_prompt_stack(self):
prompt_stack = PromptStack()
prompt_stack.add_user_message("hello")
return prompt_stack

@staticmethod
def _openai_status_error(error_cls: type[openai.APIStatusError], status_code: int) -> openai.APIStatusError:
request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions")
response = httpx.Response(status_code, request=request)
return error_cls("boom", response=response, body=None)

def _driver(self, **kwargs):
return OpenAiChatPromptDriver(
model="gpt-4o", api_key="x", tokenizer=MockTokenizer(model="test-model"), **kwargs
)

def test_run_wraps_openai_status_error(self, mock_chat_completion_create, simple_prompt_stack):
mock_chat_completion_create.side_effect = self._openai_status_error(openai.AuthenticationError, 401)

with pytest.raises(PromptDriverError) as exc_info:
self._driver().run(simple_prompt_stack)

assert exc_info.value.status_code == 401
assert isinstance(exc_info.value.__cause__, openai.AuthenticationError)

def test_run_does_not_retry_4xx(self, mock_chat_completion_create, simple_prompt_stack):
mock_chat_completion_create.side_effect = self._openai_status_error(openai.BadRequestError, 400)

with pytest.raises(PromptDriverError):
self._driver(max_attempts=2, min_retry_delay=0, max_retry_delay=0).run(simple_prompt_stack)

assert mock_chat_completion_create.call_count == 1 # fast-fail: not retried

def test_run_retries_then_wraps_429(self, mock_chat_completion_create, simple_prompt_stack):
mock_chat_completion_create.side_effect = self._openai_status_error(openai.RateLimitError, 429)

with pytest.raises(PromptDriverError) as exc_info:
self._driver(max_attempts=2, min_retry_delay=0, max_retry_delay=0).run(simple_prompt_stack)

assert exc_info.value.status_code == 429
assert mock_chat_completion_create.call_count == 2 # retried up to max_attempts

def test_run_does_not_wrap_non_openai_exception(self, mock_chat_completion_create, simple_prompt_stack):
mock_chat_completion_create.side_effect = ValueError("internal griptape error")

with pytest.raises(ValueError):
self._driver(max_attempts=1, min_retry_delay=0, max_retry_delay=0).run(simple_prompt_stack)

def test_stream_wraps_openai_status_error(self, mock_chat_completion_create, simple_prompt_stack):
mock_chat_completion_create.side_effect = self._openai_status_error(openai.AuthenticationError, 401)

with pytest.raises(PromptDriverError) as exc_info:
self._driver(stream=True).run(simple_prompt_stack)

assert exc_info.value.status_code == 401
assert isinstance(exc_info.value.__cause__, openai.AuthenticationError)

def test_run_wraps_connection_error_with_status_none(self, mock_chat_completion_create, simple_prompt_stack):
request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions")
mock_chat_completion_create.side_effect = openai.APIConnectionError(request=request)

with pytest.raises(PromptDriverError) as exc_info:
self._driver(max_attempts=1, min_retry_delay=0, max_retry_delay=0).run(simple_prompt_stack)

assert exc_info.value.status_code is None
assert isinstance(exc_info.value.__cause__, openai.APIConnectionError)

def test_stream_wraps_error_raised_mid_iteration(self, mock_chat_completion_create, simple_prompt_stack):
error = self._openai_status_error(openai.InternalServerError, 500)

def exploding_stream() -> Iterator[object]:
yield from ()
raise error

mock_chat_completion_create.return_value = exploding_stream()

with pytest.raises(PromptDriverError) as exc_info:
self._driver(stream=True, max_attempts=1, min_retry_delay=0, max_retry_delay=0).run(simple_prompt_stack)

assert exc_info.value.status_code == 500
assert isinstance(exc_info.value.__cause__, openai.InternalServerError)