Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type:
self._apply_proxy_prefix()

@override
def get_config(self) -> LiteLLMConfig:
def get_config(self) -> LiteLLMConfig: # type: ignore[override]
"""Get the LiteLLM model configuration.

Returns:
Expand Down
174 changes: 129 additions & 45 deletions src/strands/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,13 @@ class OpenAIConfig(TypedDict, total=False):
params: Model parameters (e.g., max_tokens).
For a complete list of supported parameters, see
https://platform.openai.com/docs/api-reference/chat/create.
streaming: Optional flag to indicate whether provider streaming should be used.
If omitted, defaults to True (preserves existing behaviour).
"""

model_id: str
params: Optional[dict[str, Any]]
streaming: bool | None

def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None:
"""Initialize provider instance.
Expand Down Expand Up @@ -332,7 +335,7 @@ def format_request(
messages, system_prompt, system_prompt_content=system_prompt_content
),
"model": self.config["model_id"],
"stream": True,
"stream": self.config.get("streaming", True),
"stream_options": {"include_usage": True},
"tools": [
{
Expand Down Expand Up @@ -422,6 +425,73 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
case _:
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")

def _convert_non_streaming_to_streaming(self, response: Any) -> list[StreamEvent]:
"""Convert a provider non-streaming response into streaming-style events.

This helper intentionally does not emit the initial content_start event,
as the caller handles it to ensure parity with the streaming flow.

Args:
response: The non-streaming response object from the provider.

Returns:
list[StreamEvent]: The converted streaming events.
"""
events: list[StreamEvent] = []

# Extract main text content from first choice if available
if getattr(response, "choices", None):
choice = response.choices[0]
content = None
if hasattr(choice, "message") and hasattr(choice.message, "content"):
content = choice.message.content

# handle str content
if isinstance(content, str):
events.append(self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": content}))
# handle list content (list of blocks/dicts)
elif isinstance(content, list):
for block in content:
if isinstance(block, dict):
# reasoning content
if "reasoningContent" in block and isinstance(block["reasoningContent"], dict):
try:
text = block["reasoningContent"]["reasoningText"]["text"]
events.append(
self.format_chunk(
{"chunk_type": "content_delta", "data_type": "reasoning_content", "data": text}
)
)
except Exception:
logger.warning("block=<%s> | failed to parse reasoning content", block, exc_info=True)
# text block
elif "text" in block:
events.append(
self.format_chunk(
{"chunk_type": "content_delta", "data_type": "text", "data": block["text"]}
)
)
# ignore other block types for now
elif isinstance(block, str):
events.append(
self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": block})
)

# content stop
events.append(self.format_chunk({"chunk_type": "content_stop"}))

# message stop — convert finish reason if available
stop_reason = None
if getattr(response, "choices", None):
stop_reason = getattr(response.choices[0], "finish_reason", None)
events.append(self.format_chunk({"chunk_type": "message_stop", "data": stop_reason or "stop"}))

# metadata (usage) if present
if getattr(response, "usage", None):
events.append(self.format_chunk({"chunk_type": "metadata", "data": response.usage}))

return events

@override
async def stream(
self,
Expand Down Expand Up @@ -480,57 +550,71 @@ async def stream(
finish_reason = None # Store finish_reason for later use
event = None # Initialize for scope safety

async for event in response:
# Defensive: skip events with empty or missing choices
if not getattr(event, "choices", None):
continue
choice = event.choices[0]

if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
chunks, data_type = self._stream_switch_content("reasoning_content", data_type)
for chunk in chunks:
yield chunk
yield self.format_chunk(
{
"chunk_type": "content_delta",
"data_type": data_type,
"data": choice.delta.reasoning_content,
}
)

if choice.delta.content:
chunks, data_type = self._stream_switch_content("text", data_type)
for chunk in chunks:
yield chunk
streaming = self.config.get("streaming", True)

if streaming:
# response is an async iterator when streaming=True
async for event in response:
# skip events with empty or missing choices
if not getattr(event, "choices", None):
continue
choice = event.choices[0]

if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
chunks, data_type = self._stream_switch_content("reasoning_content", data_type)
for chunk in chunks:
yield chunk
yield self.format_chunk(
{
"chunk_type": "content_delta",
"data_type": data_type,
"data": choice.delta.reasoning_content,
}
)

if choice.delta.content:
chunks, data_type = self._stream_switch_content("text", data_type)
for chunk in chunks:
yield chunk
yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": data_type, "data": choice.delta.content}
)

for tool_call in choice.delta.tool_calls or []:
tool_calls.setdefault(tool_call.index, []).append(tool_call)

if choice.finish_reason:
finish_reason = choice.finish_reason # Store for use outside loop
if data_type:
yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type})
break

for tool_deltas in tool_calls.values():
yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": data_type, "data": choice.delta.content}
{"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
)

for tool_call in choice.delta.tool_calls or []:
tool_calls.setdefault(tool_call.index, []).append(tool_call)

if choice.finish_reason:
finish_reason = choice.finish_reason # Store for use outside loop
if data_type:
yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type})
break

for tool_deltas in tool_calls.values():
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})

for tool_delta in tool_deltas:
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
for tool_delta in tool_deltas:
yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}
)

yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})

yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason or "end_turn"})
yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason or "end_turn"})

# Skip remaining events as we don't have use for anything except the final usage payload
async for event in response:
_ = event
# Skip remaining events
async for event in response:
_ = event

if event and hasattr(event, "usage") and event.usage:
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
if event and hasattr(event, "usage") and event.usage:
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
else:
# Non-streaming provider response — convert to streaming-style events
# We manually emit the content_start event here to align with the streaming path
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
for ev in self._convert_non_streaming_to_streaming(response):
yield ev

logger.debug("finished streaming response from model")

Expand Down
46 changes: 46 additions & 0 deletions tests/strands/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,52 @@ async def test_stream(openai_client, model_id, model, agenerator, alist):
openai_client.chat.completions.create.assert_called_once_with(**expected_request)


@pytest.mark.asyncio
async def test_stream_respects_streaming_flag(openai_client, model_id, alist):
# Model configured to NOT stream
model = OpenAIModel(client_args={}, model_id=model_id, params={"max_tokens": 1}, streaming=False)

# Mock a non-streaming response object
mock_choice = unittest.mock.Mock()
mock_choice.finish_reason = "stop"
mock_choice.message = unittest.mock.Mock()
mock_choice.message.content = "non-stream result"
mock_response = unittest.mock.Mock()
mock_response.choices = [mock_choice]
mock_response.usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30)

openai_client.chat.completions.create = unittest.mock.AsyncMock(return_value=mock_response)

# Consume the generator and verify the events
response_gen = model.stream([{"role": "user", "content": [{"text": "hi"}]}])
tru_events = await alist(response_gen)

expected_request = {
"max_tokens": 1,
"model": model_id,
"messages": [{"role": "user", "content": [{"text": "hi", "type": "text"}]}],
"stream": False,
"stream_options": {"include_usage": True},
"tools": [],
}
openai_client.chat.completions.create.assert_called_once_with(**expected_request)

exp_events = [
{"messageStart": {"role": "assistant"}},
{"contentBlockStart": {"start": {}}},
{"contentBlockDelta": {"delta": {"text": "non-stream result"}}},
{"contentBlockStop": {}},
{"messageStop": {"stopReason": "end_turn"}},
{
"metadata": {
"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30},
"metrics": {"latencyMs": 0},
}
},
]
assert tru_events == exp_events


@pytest.mark.asyncio
async def test_stream_empty(openai_client, model_id, model, agenerator, alist):
mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None)
Expand Down
24 changes: 24 additions & 0 deletions tests_integ/models/test_model_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,27 @@ def test_system_prompt_backward_compatibility_integration(model):

# The response should contain our specific system prompt instruction
assert "BACKWARD_COMPAT_TEST" in result.message["content"][0]["text"]


@pytest.mark.asyncio
async def test_openai_non_streaming(alist):
"""Integration test for non-streaming OpenAI responses."""
model = OpenAIModel(
model_id="gpt-4o-mini",
streaming=False,
client_args={"api_key": os.getenv("OPENAI_API_KEY")},
)

response_gen = model.stream([{"role": "user", "content": [{"text": "hi"}]}])
events = await alist(response_gen)

# In non-streaming mode, we expect a consolidated response converted to stream events.
# The exact number of events can vary slightly, but the core structure should be consistent.
assert len(events) >= 5, "Should receive at least 5 events for a non-streaming response"

assert events[0] == {"messageStart": {"role": "assistant"}}, "First event should be messageStart"
assert events[1] == {"contentBlockStart": {"start": {}}}, "Second event should be contentBlockStart"
assert "contentBlockDelta" in events[2], "Third event should be contentBlockDelta"
assert "text" in events[2]["contentBlockDelta"]["delta"], "Delta should contain text"
assert events[3] == {"contentBlockStop": {}}, "Fourth event should be contentBlockStop"
assert "messageStop" in events[4], "Fifth event should be messageStop"