diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 17f1bbb94..ac95492ff 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -72,7 +72,7 @@ def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type: self._apply_proxy_prefix() @override - def get_config(self) -> LiteLLMConfig: + def get_config(self) -> LiteLLMConfig: # type: ignore[override] """Get the LiteLLM model configuration. Returns: diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 435c82cab..8c4622653 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -50,10 +50,13 @@ class OpenAIConfig(TypedDict, total=False): params: Model parameters (e.g., max_tokens). For a complete list of supported parameters, see https://platform.openai.com/docs/api-reference/chat/create. + streaming: Optional flag to indicate whether provider streaming should be used. + If omitted, defaults to True (preserves existing behaviour). """ model_id: str params: Optional[dict[str, Any]] + streaming: bool | None def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None: """Initialize provider instance. @@ -332,7 +335,7 @@ def format_request( messages, system_prompt, system_prompt_content=system_prompt_content ), "model": self.config["model_id"], - "stream": True, + "stream": self.config.get("streaming", True), "stream_options": {"include_usage": True}, "tools": [ { @@ -422,6 +425,73 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: case _: raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + def _convert_non_streaming_to_streaming(self, response: Any) -> list[StreamEvent]: + """Convert a provider non-streaming response into streaming-style events. + + This helper intentionally does not emit the initial content_start event, + as the caller handles it to ensure parity with the streaming flow. + + Args: + response: The non-streaming response object from the provider. + + Returns: + list[StreamEvent]: The converted streaming events. + """ + events: list[StreamEvent] = [] + + # Extract main text content from first choice if available + if getattr(response, "choices", None): + choice = response.choices[0] + content = None + if hasattr(choice, "message") and hasattr(choice.message, "content"): + content = choice.message.content + + # handle str content + if isinstance(content, str): + events.append(self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": content})) + # handle list content (list of blocks/dicts) + elif isinstance(content, list): + for block in content: + if isinstance(block, dict): + # reasoning content + if "reasoningContent" in block and isinstance(block["reasoningContent"], dict): + try: + text = block["reasoningContent"]["reasoningText"]["text"] + events.append( + self.format_chunk( + {"chunk_type": "content_delta", "data_type": "reasoning_content", "data": text} + ) + ) + except Exception: + logger.warning("block=<%s> | failed to parse reasoning content", block, exc_info=True) + # text block + elif "text" in block: + events.append( + self.format_chunk( + {"chunk_type": "content_delta", "data_type": "text", "data": block["text"]} + ) + ) + # ignore other block types for now + elif isinstance(block, str): + events.append( + self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": block}) + ) + + # content stop + events.append(self.format_chunk({"chunk_type": "content_stop"})) + + # message stop — convert finish reason if available + stop_reason = None + if getattr(response, "choices", None): + stop_reason = getattr(response.choices[0], "finish_reason", None) + events.append(self.format_chunk({"chunk_type": "message_stop", "data": stop_reason or "stop"})) + + # metadata (usage) if present + if getattr(response, "usage", None): + events.append(self.format_chunk({"chunk_type": "metadata", "data": response.usage})) + + return events + @override async def stream( self, @@ -480,57 +550,71 @@ async def stream( finish_reason = None # Store finish_reason for later use event = None # Initialize for scope safety - async for event in response: - # Defensive: skip events with empty or missing choices - if not getattr(event, "choices", None): - continue - choice = event.choices[0] - - if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: - chunks, data_type = self._stream_switch_content("reasoning_content", data_type) - for chunk in chunks: - yield chunk - yield self.format_chunk( - { - "chunk_type": "content_delta", - "data_type": data_type, - "data": choice.delta.reasoning_content, - } - ) - - if choice.delta.content: - chunks, data_type = self._stream_switch_content("text", data_type) - for chunk in chunks: - yield chunk + streaming = self.config.get("streaming", True) + + if streaming: + # response is an async iterator when streaming=True + async for event in response: + # skip events with empty or missing choices + if not getattr(event, "choices", None): + continue + choice = event.choices[0] + + if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: + chunks, data_type = self._stream_switch_content("reasoning_content", data_type) + for chunk in chunks: + yield chunk + yield self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": data_type, + "data": choice.delta.reasoning_content, + } + ) + + if choice.delta.content: + chunks, data_type = self._stream_switch_content("text", data_type) + for chunk in chunks: + yield chunk + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": data_type, "data": choice.delta.content} + ) + + for tool_call in choice.delta.tool_calls or []: + tool_calls.setdefault(tool_call.index, []).append(tool_call) + + if choice.finish_reason: + finish_reason = choice.finish_reason # Store for use outside loop + if data_type: + yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type}) + break + + for tool_deltas in tool_calls.values(): yield self.format_chunk( - {"chunk_type": "content_delta", "data_type": data_type, "data": choice.delta.content} + {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]} ) - for tool_call in choice.delta.tool_calls or []: - tool_calls.setdefault(tool_call.index, []).append(tool_call) - - if choice.finish_reason: - finish_reason = choice.finish_reason # Store for use outside loop - if data_type: - yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type}) - break - - for tool_deltas in tool_calls.values(): - yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) - - for tool_delta in tool_deltas: - yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) + for tool_delta in tool_deltas: + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta} + ) - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) - yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason or "end_turn"}) + yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason or "end_turn"}) - # Skip remaining events as we don't have use for anything except the final usage payload - async for event in response: - _ = event + # Skip remaining events + async for event in response: + _ = event - if event and hasattr(event, "usage") and event.usage: - yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) + if event and hasattr(event, "usage") and event.usage: + yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) + else: + # Non-streaming provider response — convert to streaming-style events + # We manually emit the content_start event here to align with the streaming path + yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) + for ev in self._convert_non_streaming_to_streaming(response): + yield ev logger.debug("finished streaming response from model") diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 0de0c4ebc..fbcb28d13 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -614,6 +614,52 @@ async def test_stream(openai_client, model_id, model, agenerator, alist): openai_client.chat.completions.create.assert_called_once_with(**expected_request) +@pytest.mark.asyncio +async def test_stream_respects_streaming_flag(openai_client, model_id, alist): + # Model configured to NOT stream + model = OpenAIModel(client_args={}, model_id=model_id, params={"max_tokens": 1}, streaming=False) + + # Mock a non-streaming response object + mock_choice = unittest.mock.Mock() + mock_choice.finish_reason = "stop" + mock_choice.message = unittest.mock.Mock() + mock_choice.message.content = "non-stream result" + mock_response = unittest.mock.Mock() + mock_response.choices = [mock_choice] + mock_response.usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) + + openai_client.chat.completions.create = unittest.mock.AsyncMock(return_value=mock_response) + + # Consume the generator and verify the events + response_gen = model.stream([{"role": "user", "content": [{"text": "hi"}]}]) + tru_events = await alist(response_gen) + + expected_request = { + "max_tokens": 1, + "model": model_id, + "messages": [{"role": "user", "content": [{"text": "hi", "type": "text"}]}], + "stream": False, + "stream_options": {"include_usage": True}, + "tools": [], + } + openai_client.chat.completions.create.assert_called_once_with(**expected_request) + + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "non-stream result"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, + "metrics": {"latencyMs": 0}, + } + }, + ] + assert tru_events == exp_events + + @pytest.mark.asyncio async def test_stream_empty(openai_client, model_id, model, agenerator, alist): mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None) diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index feb591d1a..74a9a6bd5 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -257,3 +257,27 @@ def test_system_prompt_backward_compatibility_integration(model): # The response should contain our specific system prompt instruction assert "BACKWARD_COMPAT_TEST" in result.message["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_openai_non_streaming(alist): + """Integration test for non-streaming OpenAI responses.""" + model = OpenAIModel( + model_id="gpt-4o-mini", + streaming=False, + client_args={"api_key": os.getenv("OPENAI_API_KEY")}, + ) + + response_gen = model.stream([{"role": "user", "content": [{"text": "hi"}]}]) + events = await alist(response_gen) + + # In non-streaming mode, we expect a consolidated response converted to stream events. + # The exact number of events can vary slightly, but the core structure should be consistent. + assert len(events) >= 5, "Should receive at least 5 events for a non-streaming response" + + assert events[0] == {"messageStart": {"role": "assistant"}}, "First event should be messageStart" + assert events[1] == {"contentBlockStart": {"start": {}}}, "Second event should be contentBlockStart" + assert "contentBlockDelta" in events[2], "Third event should be contentBlockDelta" + assert "text" in events[2]["contentBlockDelta"]["delta"], "Delta should contain text" + assert events[3] == {"contentBlockStop": {}}, "Fourth event should be contentBlockStop" + assert "messageStop" in events[4], "Fifth event should be messageStop"