Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions src/strands/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing_extensions import Unpack, override

from ..types.content import ContentBlock, Messages
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StreamEvent
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
from ._validation import validate_config_keys
Expand Down Expand Up @@ -372,6 +373,10 @@ async def stream(

Yields:
Formatted message chunks from the model.

Raises:
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the request is throttled by OpenAI (rate limits).
"""
logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
Expand All @@ -383,7 +388,20 @@ async def stream(
# client. The asyncio event loop does not allow connections to be shared. For more details, please refer to
# https://github.com/encode/httpx/discussions/2959.
async with openai.AsyncOpenAI(**self.client_args) as client:
response = await client.chat.completions.create(**request)
try:
response = await client.chat.completions.create(**request)
except openai.BadRequestError as e:
# Check if this is a context length exceeded error
if hasattr(e, "code") and e.code == "context_length_exceeded":
logger.warning("OpenAI threw context window overflow error")
raise ContextWindowOverflowException(str(e)) from e
# Re-raise other BadRequestError exceptions
raise
except openai.RateLimitError as e:
# All rate limit errors should be treated as throttling, not context overflow
# Rate limits (including TPM) require waiting/retrying, not context reduction
logger.warning("OpenAI threw rate limit error")
raise ModelThrottledException(str(e)) from e

logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})
Expand Down Expand Up @@ -452,16 +470,33 @@ async def structured_output(

Yields:
Model events with the last being the structured output.

Raises:
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the request is throttled by OpenAI (rate limits).
"""
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx
# client. The asyncio event loop does not allow connections to be shared. For more details, please refer to
# https://github.com/encode/httpx/discussions/2959.
async with openai.AsyncOpenAI(**self.client_args) as client:
response: ParsedChatCompletion = await client.beta.chat.completions.parse(
model=self.get_config()["model_id"],
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
response_format=output_model,
)
try:
response: ParsedChatCompletion = await client.beta.chat.completions.parse(
model=self.get_config()["model_id"],
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
response_format=output_model,
)
except openai.BadRequestError as e:
# Check if this is a context length exceeded error
if hasattr(e, "code") and e.code == "context_length_exceeded":
logger.warning("OpenAI threw context window overflow error")
raise ContextWindowOverflowException(str(e)) from e
# Re-raise other BadRequestError exceptions
raise
except openai.RateLimitError as e:
# All rate limit errors should be treated as throttling, not context overflow
# Rate limits (including TPM) require waiting/retrying, not context reduction
logger.warning("OpenAI threw rate limit error")
raise ModelThrottledException(str(e)) from e

parsed: T | None = None
# Find the first choice with tool_calls
Expand Down
148 changes: 148 additions & 0 deletions tests/strands/models/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import unittest.mock

import openai
import pydantic
import pytest

import strands
from strands.models.openai import OpenAIModel
from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException


@pytest.fixture
Expand Down Expand Up @@ -752,3 +754,149 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings):
model.format_request(messages, tool_choice=None)

assert len(captured_warnings) == 0


@pytest.mark.asyncio
async def test_stream_context_overflow_exception(openai_client, model, messages):
"""Test that OpenAI context overflow errors are properly converted to ContextWindowOverflowException."""
# Create a mock OpenAI BadRequestError with context_length_exceeded code
mock_error = openai.BadRequestError(
message="This model's maximum context length is 4096 tokens. However, your messages resulted in 5000 tokens.",
response=unittest.mock.MagicMock(),
body={"error": {"code": "context_length_exceeded"}},
)
mock_error.code = "context_length_exceeded"

# Configure the mock client to raise the context overflow error
openai_client.chat.completions.create.side_effect = mock_error

# Test that the stream method converts the error properly
with pytest.raises(ContextWindowOverflowException) as exc_info:
async for _ in model.stream(messages):
pass

# Verify the exception message contains the original error
assert "maximum context length" in str(exc_info.value)
assert exc_info.value.__cause__ == mock_error


@pytest.mark.asyncio
async def test_stream_other_bad_request_errors_passthrough(openai_client, model, messages):
"""Test that other BadRequestError exceptions are not converted to ContextWindowOverflowException."""
# Create a mock OpenAI BadRequestError with a different error code
mock_error = openai.BadRequestError(
message="Invalid parameter value",
response=unittest.mock.MagicMock(),
body={"error": {"code": "invalid_parameter"}},
)
mock_error.code = "invalid_parameter"

# Configure the mock client to raise the non-context error
openai_client.chat.completions.create.side_effect = mock_error

# Test that other BadRequestError exceptions pass through unchanged
with pytest.raises(openai.BadRequestError) as exc_info:
async for _ in model.stream(messages):
pass

# Verify the original exception is raised, not ContextWindowOverflowException
assert exc_info.value == mock_error


@pytest.mark.asyncio
async def test_structured_output_context_overflow_exception(openai_client, model, messages, test_output_model_cls):
"""Test that structured output also handles context overflow properly."""
# Create a mock OpenAI BadRequestError with context_length_exceeded code
mock_error = openai.BadRequestError(
message="This model's maximum context length is 4096 tokens. However, your messages resulted in 5000 tokens.",
response=unittest.mock.MagicMock(),
body={"error": {"code": "context_length_exceeded"}},
)
mock_error.code = "context_length_exceeded"

# Configure the mock client to raise the context overflow error
openai_client.beta.chat.completions.parse.side_effect = mock_error

# Test that the structured_output method converts the error properly
with pytest.raises(ContextWindowOverflowException) as exc_info:
async for _ in model.structured_output(test_output_model_cls, messages):
pass

# Verify the exception message contains the original error
assert "maximum context length" in str(exc_info.value)
assert exc_info.value.__cause__ == mock_error


@pytest.mark.asyncio
async def test_stream_rate_limit_as_throttle(openai_client, model, messages):
"""Test that all rate limit errors are converted to ModelThrottledException."""

# Create a mock OpenAI RateLimitError (any type of rate limit)
mock_error = openai.RateLimitError(
message="Request too large for gpt-4o on tokens per min (TPM): Limit 30000, Requested 117505.",
response=unittest.mock.MagicMock(),
body={"error": {"code": "rate_limit_exceeded"}},
)
mock_error.code = "rate_limit_exceeded"

# Configure the mock client to raise the rate limit error
openai_client.chat.completions.create.side_effect = mock_error

# Test that the stream method converts the error properly
with pytest.raises(ModelThrottledException) as exc_info:
async for _ in model.stream(messages):
pass

# Verify the exception message contains the original error
assert "tokens per min" in str(exc_info.value)
assert exc_info.value.__cause__ == mock_error


@pytest.mark.asyncio
async def test_stream_request_rate_limit_as_throttle(openai_client, model, messages):
"""Test that request-based rate limit errors are converted to ModelThrottledException."""

# Create a mock OpenAI RateLimitError for request-based rate limiting
mock_error = openai.RateLimitError(
message="Rate limit reached for requests per minute.",
response=unittest.mock.MagicMock(),
body={"error": {"code": "rate_limit_exceeded"}},
)
mock_error.code = "rate_limit_exceeded"

# Configure the mock client to raise the request rate limit error
openai_client.chat.completions.create.side_effect = mock_error

# Test that the stream method converts the error properly
with pytest.raises(ModelThrottledException) as exc_info:
async for _ in model.stream(messages):
pass

# Verify the exception message contains the original error
assert "Rate limit reached" in str(exc_info.value)
assert exc_info.value.__cause__ == mock_error


@pytest.mark.asyncio
async def test_structured_output_rate_limit_as_throttle(openai_client, model, messages, test_output_model_cls):
"""Test that structured output handles rate limit errors properly."""

# Create a mock OpenAI RateLimitError
mock_error = openai.RateLimitError(
message="Request too large for gpt-4o on tokens per min (TPM): Limit 30000, Requested 117505.",
response=unittest.mock.MagicMock(),
body={"error": {"code": "rate_limit_exceeded"}},
)
mock_error.code = "rate_limit_exceeded"

# Configure the mock client to raise the rate limit error
openai_client.beta.chat.completions.parse.side_effect = mock_error

# Test that the structured_output method converts the error properly
with pytest.raises(ModelThrottledException) as exc_info:
async for _ in model.structured_output(test_output_model_cls, messages):
pass

# Verify the exception message contains the original error
assert "tokens per min" in str(exc_info.value)
assert exc_info.value.__cause__ == mock_error
Loading