77import json
88import logging
99import os
10+ import warnings
1011from typing import Any , AsyncGenerator , Callable , Iterable , Literal , Optional , Type , TypeVar , Union , cast
1112
1213import boto3
2324 ModelThrottledException ,
2425)
2526from ..types .streaming import CitationsDelta , StreamEvent
26- from ..types .tools import ToolResult , ToolSpec
27- from ._config_validation import validate_config_keys
27+ from ..types .tools import ToolChoice , ToolResult , ToolSpec
28+ from ._validation import validate_config_keys
2829from .model import Model
2930
3031logger = logging .getLogger (__name__ )
3132
33+ # See: `BedrockModel._get_default_model_with_warning` for why we need both
3234DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0"
35+ _DEFAULT_BEDROCK_MODEL_ID = "{}.anthropic.claude-sonnet-4-20250514-v1:0"
3336DEFAULT_BEDROCK_REGION = "us-west-2"
3437
3538BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [
@@ -130,13 +133,16 @@ def __init__(
130133 if region_name and boto_session :
131134 raise ValueError ("Cannot specify both `region_name` and `boto_session`." )
132135
133- self .config = BedrockModel .BedrockConfig (model_id = DEFAULT_BEDROCK_MODEL_ID , include_tool_result_status = "auto" )
136+ session = boto_session or boto3 .Session ()
137+ resolved_region = region_name or session .region_name or os .environ .get ("AWS_REGION" ) or DEFAULT_BEDROCK_REGION
138+ self .config = BedrockModel .BedrockConfig (
139+ model_id = BedrockModel ._get_default_model_with_warning (resolved_region , model_config ),
140+ include_tool_result_status = "auto" ,
141+ )
134142 self .update_config (** model_config )
135143
136144 logger .debug ("config=<%s> | initializing" , self .config )
137145
138- session = boto_session or boto3 .Session ()
139-
140146 # Add strands-agents to the request user agent
141147 if boto_client_config :
142148 existing_user_agent = getattr (boto_client_config , "user_agent_extra" , None )
@@ -151,8 +157,6 @@ def __init__(
151157 else :
152158 client_config = BotocoreConfig (user_agent_extra = "strands-agents" , read_timeout = DEFAULT_READ_TIMEOUT )
153159
154- resolved_region = region_name or session .region_name or os .environ .get ("AWS_REGION" ) or DEFAULT_BEDROCK_REGION
155-
156160 self .client = session .client (
157161 service_name = "bedrock-runtime" ,
158162 config = client_config ,
@@ -197,13 +201,15 @@ def format_request(
197201 messages : Messages ,
198202 tool_specs : Optional [list [ToolSpec ]] = None ,
199203 system_prompt : Optional [str ] = None ,
204+ tool_choice : ToolChoice | None = None ,
200205 ) -> dict [str , Any ]:
201206 """Format a Bedrock converse stream request.
202207
203208 Args:
204209 messages: List of message objects to be processed by the model.
205210 tool_specs: List of tool specifications to make available to the model.
206211 system_prompt: System prompt to provide context to the model.
212+ tool_choice: Selection strategy for tool invocation.
207213
208214 Returns:
209215 A Bedrock converse stream request.
@@ -226,7 +232,7 @@ def format_request(
226232 else []
227233 ),
228234 ],
229- "toolChoice" : {"auto" : {}},
235+ ** ({ "toolChoice" : tool_choice if tool_choice else {"auto" : {}}}) ,
230236 }
231237 }
232238 if tool_specs
@@ -418,6 +424,7 @@ async def stream(
418424 messages : Messages ,
419425 tool_specs : Optional [list [ToolSpec ]] = None ,
420426 system_prompt : Optional [str ] = None ,
427+ tool_choice : ToolChoice | None = None ,
421428 ** kwargs : Any ,
422429 ) -> AsyncGenerator [StreamEvent , None ]:
423430 """Stream conversation with the Bedrock model.
@@ -429,6 +436,7 @@ async def stream(
429436 messages: List of message objects to be processed by the model.
430437 tool_specs: List of tool specifications to make available to the model.
431438 system_prompt: System prompt to provide context to the model.
439+ tool_choice: Selection strategy for tool invocation.
432440 **kwargs: Additional keyword arguments for future extensibility.
433441
434442 Yields:
@@ -447,7 +455,7 @@ def callback(event: Optional[StreamEvent] = None) -> None:
447455 loop = asyncio .get_event_loop ()
448456 queue : asyncio .Queue [Optional [StreamEvent ]] = asyncio .Queue ()
449457
450- thread = asyncio .to_thread (self ._stream , callback , messages , tool_specs , system_prompt )
458+ thread = asyncio .to_thread (self ._stream , callback , messages , tool_specs , system_prompt , tool_choice )
451459 task = asyncio .create_task (thread )
452460
453461 while True :
@@ -465,6 +473,7 @@ def _stream(
465473 messages : Messages ,
466474 tool_specs : Optional [list [ToolSpec ]] = None ,
467475 system_prompt : Optional [str ] = None ,
476+ tool_choice : ToolChoice | None = None ,
468477 ) -> None :
469478 """Stream conversation with the Bedrock model.
470479
@@ -476,14 +485,15 @@ def _stream(
476485 messages: List of message objects to be processed by the model.
477486 tool_specs: List of tool specifications to make available to the model.
478487 system_prompt: System prompt to provide context to the model.
488+ tool_choice: Selection strategy for tool invocation.
479489
480490 Raises:
481491 ContextWindowOverflowException: If the input exceeds the model's context window.
482492 ModelThrottledException: If the model service is throttling requests.
483493 """
484494 try :
485495 logger .debug ("formatting request" )
486- request = self .format_request (messages , tool_specs , system_prompt )
496+ request = self .format_request (messages , tool_specs , system_prompt , tool_choice )
487497 logger .debug ("request=<%s>" , request )
488498
489499 logger .debug ("invoking model" )
@@ -740,6 +750,7 @@ async def structured_output(
740750 messages = prompt ,
741751 tool_specs = [tool_spec ],
742752 system_prompt = system_prompt ,
753+ tool_choice = cast (ToolChoice , {"any" : {}}),
743754 ** kwargs ,
744755 )
745756 async for event in streaming .process_stream (response ):
@@ -764,3 +775,46 @@ async def structured_output(
764775 raise ValueError ("No valid tool use or tool use input was found in the Bedrock response." )
765776
766777 yield {"output" : output_model (** output_response )}
778+
779+ @staticmethod
780+ def _get_default_model_with_warning (region_name : str , model_config : Optional [BedrockConfig ] = None ) -> str :
781+ """Get the default Bedrock modelId based on region.
782+
783+ If the region is not **known** to support inference then we show a helpful warning
784+ that compliments the exception that Bedrock will throw.
785+ If the customer provided a model_id in their config or they overrode the `DEFAULT_BEDROCK_MODEL_ID`
786+ then we should not process further.
787+
788+ Args:
789+ region_name (str): region for bedrock model
790+ model_config (Optional[dict[str, Any]]): Model Config that caller passes in on init
791+ """
792+ if DEFAULT_BEDROCK_MODEL_ID != _DEFAULT_BEDROCK_MODEL_ID .format ("us" ):
793+ return DEFAULT_BEDROCK_MODEL_ID
794+
795+ model_config = model_config or {}
796+ if model_config .get ("model_id" ):
797+ return model_config ["model_id" ]
798+
799+ prefix_inference_map = {"ap" : "apac" } # some inference endpoints can be a bit different than the region prefix
800+
801+ prefix = "-" .join (region_name .split ("-" )[:- 2 ]).lower () # handles `us-east-1` or `us-gov-east-1`
802+ if prefix not in {"us" , "eu" , "ap" , "us-gov" }:
803+ warnings .warn (
804+ f"""
805+ ================== WARNING ==================
806+
807+ This region { region_name } does not support
808+ our default inference endpoint: { _DEFAULT_BEDROCK_MODEL_ID .format (prefix )} .
809+ Update the agent to pass in a 'model_id' like so:
810+ ```
811+ Agent(..., model='valid_model_id', ...)
812+ ````
813+ Documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
814+
815+ ==================================================
816+ """ ,
817+ stacklevel = 2 ,
818+ )
819+
820+ return _DEFAULT_BEDROCK_MODEL_ID .format (prefix_inference_map .get (prefix , prefix ))
0 commit comments