77import json
88import logging
99import os
10- from typing import Any , AsyncGenerator , Callable , Iterable , Literal , Optional , Type , TypeVar , Union
10+ from typing import Any , AsyncGenerator , Callable , Iterable , Literal , Optional , Type , TypeVar , Union , cast
1111
1212import boto3
1313from botocore .config import Config as BotocoreConfig
2020from ..types .content import ContentBlock , Message , Messages
2121from ..types .exceptions import ContextWindowOverflowException , ModelThrottledException
2222from ..types .streaming import StreamEvent
23- from ..types .tools import ToolResult , ToolSpec
23+ from ..types .tools import ToolChoice , ToolResult , ToolSpec
2424from .model import Model
2525
2626logger = logging .getLogger (__name__ )
@@ -168,13 +168,15 @@ def format_request(
168168 messages : Messages ,
169169 tool_specs : Optional [list [ToolSpec ]] = None ,
170170 system_prompt : Optional [str ] = None ,
171+ tool_choice : Optional [ToolChoice ] = None ,
171172 ) -> dict [str , Any ]:
172173 """Format a Bedrock converse stream request.
173174
174175 Args:
175176 messages: List of message objects to be processed by the model.
176177 tool_specs: List of tool specifications to make available to the model.
177178 system_prompt: System prompt to provide context to the model.
179+ tool_choice: Selection strategy for tool invocation.
178180
179181 Returns:
180182 A Bedrock converse stream request.
@@ -197,7 +199,7 @@ def format_request(
197199 else []
198200 ),
199201 ],
200- "toolChoice" : { "auto" : {}} ,
202+ ** ({ "toolChoice" : tool_choice } if tool_choice else {}) ,
201203 }
202204 }
203205 if tool_specs
@@ -355,6 +357,7 @@ async def stream(
355357 messages : Messages ,
356358 tool_specs : Optional [list [ToolSpec ]] = None ,
357359 system_prompt : Optional [str ] = None ,
360+ tool_choice : Optional [ToolChoice ] = None ,
358361 ** kwargs : Any ,
359362 ) -> AsyncGenerator [StreamEvent , None ]:
360363 """Stream conversation with the Bedrock model.
@@ -366,6 +369,7 @@ async def stream(
366369 messages: List of message objects to be processed by the model.
367370 tool_specs: List of tool specifications to make available to the model.
368371 system_prompt: System prompt to provide context to the model.
372+ tool_choice: Selection strategy for tool invocation.
369373 **kwargs: Additional keyword arguments for future extensibility.
370374
371375 Yields:
@@ -384,7 +388,7 @@ def callback(event: Optional[StreamEvent] = None) -> None:
384388 loop = asyncio .get_event_loop ()
385389 queue : asyncio .Queue [Optional [StreamEvent ]] = asyncio .Queue ()
386390
387- thread = asyncio .to_thread (self ._stream , callback , messages , tool_specs , system_prompt )
391+ thread = asyncio .to_thread (self ._stream , callback , messages , tool_specs , system_prompt , tool_choice )
388392 task = asyncio .create_task (thread )
389393
390394 while True :
@@ -402,6 +406,7 @@ def _stream(
402406 messages : Messages ,
403407 tool_specs : Optional [list [ToolSpec ]] = None ,
404408 system_prompt : Optional [str ] = None ,
409+ tool_choice : Optional [ToolChoice ] = None ,
405410 ) -> None :
406411 """Stream conversation with the Bedrock model.
407412
@@ -413,14 +418,15 @@ def _stream(
413418 messages: List of message objects to be processed by the model.
414419 tool_specs: List of tool specifications to make available to the model.
415420 system_prompt: System prompt to provide context to the model.
421+ tool_choice: Selection strategy for tool invocation.
416422
417423 Raises:
418424 ContextWindowOverflowException: If the input exceeds the model's context window.
419425 ModelThrottledException: If the model service is throttling requests.
420426 """
421427 try :
422428 logger .debug ("formatting request" )
423- request = self .format_request (messages , tool_specs , system_prompt )
429+ request = self .format_request (messages , tool_specs , system_prompt , tool_choice )
424430 logger .debug ("request=<%s>" , request )
425431
426432 logger .debug ("invoking model" )
@@ -624,7 +630,13 @@ async def structured_output(
624630 """
625631 tool_spec = convert_pydantic_to_tool_spec (output_model )
626632
627- response = self .stream (messages = prompt , tool_specs = [tool_spec ], system_prompt = system_prompt , ** kwargs )
633+ response = self .stream (
634+ messages = prompt ,
635+ tool_specs = [tool_spec ],
636+ system_prompt = system_prompt ,
637+ tool_choice = cast (ToolChoice , {"any" : {}}),
638+ ** kwargs ,
639+ )
628640 async for event in streaming .process_stream (response ):
629641 yield event
630642
0 commit comments