Skip to content

Commit 03e6f00

Browse files
committed
Merge branch 'main' of https://github.com/strands-agents/sdk-python into gitikavj/add-gemini-model-provider
2 parents df6dc67 + fe7a700 commit 03e6f00

39 files changed

+2998
-92
lines changed

.github/workflows/integration-test.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
approval-env: ${{ steps.collab-check.outputs.result }}
1313
steps:
1414
- name: Collaborator Check
15-
uses: actions/github-script@v7
15+
uses: actions/github-script@v8
1616
id: collab-check
1717
with:
1818
result-encoding: string
@@ -46,7 +46,7 @@ jobs:
4646
contents: read
4747
steps:
4848
- name: Configure Credentials
49-
uses: aws-actions/configure-aws-credentials@v4
49+
uses: aws-actions/configure-aws-credentials@v5
5050
with:
5151
role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }}
5252
aws-region: us-east-1
@@ -57,7 +57,7 @@ jobs:
5757
ref: ${{ github.event.pull_request.head.sha }} # Pull the commit from the forked repo
5858
persist-credentials: false # Don't persist credentials for subsequent actions
5959
- name: Set up Python
60-
uses: actions/setup-python@v5
60+
uses: actions/setup-python@v6
6161
with:
6262
python-version: '3.10'
6363
- name: Install dependencies

.github/workflows/pypi-publish-on-release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
persist-credentials: false
2828

2929
- name: Set up Python
30-
uses: actions/setup-python@v5
30+
uses: actions/setup-python@v6
3131
with:
3232
python-version: '3.10'
3333

.github/workflows/test-lint.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ jobs:
5656
ref: ${{ inputs.ref }} # Explicitly define which commit to check out
5757
persist-credentials: false # Don't persist credentials for subsequent actions
5858
- name: Set up Python
59-
uses: actions/setup-python@v5
59+
uses: actions/setup-python@v6
6060
with:
6161
python-version: ${{ matrix.python-version }}
6262
- name: Install dependencies
@@ -79,7 +79,7 @@ jobs:
7979
persist-credentials: false
8080

8181
- name: Set up Python
82-
uses: actions/setup-python@v5
82+
uses: actions/setup-python@v6
8383
with:
8484
python-version: '3.10'
8585
cache: 'pip'

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ from strands.models import BedrockModel
130130
from strands.models.ollama import OllamaModel
131131
from strands.models.llamaapi import LlamaAPIModel
132132
from strands.models.gemini import GeminiModel
133+
from strands.models.llamacpp import LlamaCppModel
133134

134135
# Bedrock
135136
bedrock_model = BedrockModel(
@@ -170,6 +171,7 @@ Built-in providers:
170171
- [Anthropic](https://strandsagents.com/latest/user-guide/concepts/model-providers/anthropic/)
171172
- [Gemini](https://strandsagents.com/latest/user-guide/concepts/model-providers/gemini/)
172173
- [LiteLLM](https://strandsagents.com/latest/user-guide/concepts/model-providers/litellm/)
174+
- [llama.cpp](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamacpp/)
173175
- [LlamaAPI](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamaapi/)
174176
- [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/)
175177
- [OpenAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/)

src/strands/agent/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentR
425425
**kwargs: Additional parameters to pass through the event loop.
426426
427427
Returns:
428-
Result object containing:
428+
Result: object containing:
429429
430430
- stop_reason: Why the event loop stopped (e.g., "end_turn", "max_tokens")
431431
- message: The final message from the model
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from typing_extensions import get_type_hints
77

8+
from ..types.tools import ToolChoice
9+
810

911
def validate_config_keys(config_dict: Mapping[str, Any], config_class: Type) -> None:
1012
"""Validate that config keys match the TypedDict fields.
@@ -25,3 +27,16 @@ def validate_config_keys(config_dict: Mapping[str, Any], config_class: Type) ->
2527
f"\nSee https://github.com/strands-agents/sdk-python/issues/815",
2628
stacklevel=4,
2729
)
30+
31+
32+
def warn_on_tool_choice_not_supported(tool_choice: ToolChoice | None) -> None:
33+
"""Emits a warning if a tool choice is provided but not supported by the provider.
34+
35+
Args:
36+
tool_choice: the tool_choice provided to the provider
37+
"""
38+
if tool_choice:
39+
warnings.warn(
40+
"A ToolChoice was provided to this provider but is not supported and will be ignored",
41+
stacklevel=4,
42+
)

src/strands/models/anthropic.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from ..types.content import ContentBlock, Messages
1919
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
2020
from ..types.streaming import StreamEvent
21-
from ..types.tools import ToolSpec
22-
from ._config_validation import validate_config_keys
21+
from ..types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec
22+
from ._validation import validate_config_keys
2323
from .model import Model
2424

2525
logger = logging.getLogger(__name__)
@@ -195,14 +195,19 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]:
195195
return formatted_messages
196196

197197
def format_request(
198-
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
198+
self,
199+
messages: Messages,
200+
tool_specs: Optional[list[ToolSpec]] = None,
201+
system_prompt: Optional[str] = None,
202+
tool_choice: ToolChoice | None = None,
199203
) -> dict[str, Any]:
200204
"""Format an Anthropic streaming request.
201205
202206
Args:
203207
messages: List of message objects to be processed by the model.
204208
tool_specs: List of tool specifications to make available to the model.
205209
system_prompt: System prompt to provide context to the model.
210+
tool_choice: Selection strategy for tool invocation.
206211
207212
Returns:
208213
An Anthropic streaming request.
@@ -223,10 +228,25 @@ def format_request(
223228
}
224229
for tool_spec in tool_specs or []
225230
],
231+
**(self._format_tool_choice(tool_choice)),
226232
**({"system": system_prompt} if system_prompt else {}),
227233
**(self.config.get("params") or {}),
228234
}
229235

236+
@staticmethod
237+
def _format_tool_choice(tool_choice: ToolChoice | None) -> dict:
238+
if tool_choice is None:
239+
return {}
240+
241+
if "any" in tool_choice:
242+
return {"tool_choice": {"type": "any"}}
243+
elif "auto" in tool_choice:
244+
return {"tool_choice": {"type": "auto"}}
245+
elif "tool" in tool_choice:
246+
return {"tool_choice": {"type": "tool", "name": cast(ToolChoiceToolDict, tool_choice)["tool"]["name"]}}
247+
else:
248+
return {}
249+
230250
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
231251
"""Format the Anthropic response events into standardized message chunks.
232252
@@ -350,6 +370,7 @@ async def stream(
350370
messages: Messages,
351371
tool_specs: Optional[list[ToolSpec]] = None,
352372
system_prompt: Optional[str] = None,
373+
tool_choice: ToolChoice | None = None,
353374
**kwargs: Any,
354375
) -> AsyncGenerator[StreamEvent, None]:
355376
"""Stream conversation with the Anthropic model.
@@ -358,6 +379,7 @@ async def stream(
358379
messages: List of message objects to be processed by the model.
359380
tool_specs: List of tool specifications to make available to the model.
360381
system_prompt: System prompt to provide context to the model.
382+
tool_choice: Selection strategy for tool invocation.
361383
**kwargs: Additional keyword arguments for future extensibility.
362384
363385
Yields:
@@ -368,7 +390,7 @@ async def stream(
368390
ModelThrottledException: If the request is throttled by Anthropic.
369391
"""
370392
logger.debug("formatting request")
371-
request = self.format_request(messages, tool_specs, system_prompt)
393+
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
372394
logger.debug("request=<%s>", request)
373395

374396
logger.debug("invoking model")
@@ -410,7 +432,13 @@ async def structured_output(
410432
"""
411433
tool_spec = convert_pydantic_to_tool_spec(output_model)
412434

413-
response = self.stream(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, **kwargs)
435+
response = self.stream(
436+
messages=prompt,
437+
tool_specs=[tool_spec],
438+
system_prompt=system_prompt,
439+
tool_choice=cast(ToolChoice, {"any": {}}),
440+
**kwargs,
441+
)
414442
async for event in process_stream(response):
415443
yield event
416444

src/strands/models/bedrock.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import logging
99
import os
10+
import warnings
1011
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, cast
1112

1213
import boto3
@@ -23,13 +24,15 @@
2324
ModelThrottledException,
2425
)
2526
from ..types.streaming import CitationsDelta, StreamEvent
26-
from ..types.tools import ToolResult, ToolSpec
27-
from ._config_validation import validate_config_keys
27+
from ..types.tools import ToolChoice, ToolResult, ToolSpec
28+
from ._validation import validate_config_keys
2829
from .model import Model
2930

3031
logger = logging.getLogger(__name__)
3132

33+
# See: `BedrockModel._get_default_model_with_warning` for why we need both
3234
DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0"
35+
_DEFAULT_BEDROCK_MODEL_ID = "{}.anthropic.claude-sonnet-4-20250514-v1:0"
3336
DEFAULT_BEDROCK_REGION = "us-west-2"
3437

3538
BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [
@@ -130,13 +133,16 @@ def __init__(
130133
if region_name and boto_session:
131134
raise ValueError("Cannot specify both `region_name` and `boto_session`.")
132135

133-
self.config = BedrockModel.BedrockConfig(model_id=DEFAULT_BEDROCK_MODEL_ID, include_tool_result_status="auto")
136+
session = boto_session or boto3.Session()
137+
resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION
138+
self.config = BedrockModel.BedrockConfig(
139+
model_id=BedrockModel._get_default_model_with_warning(resolved_region, model_config),
140+
include_tool_result_status="auto",
141+
)
134142
self.update_config(**model_config)
135143

136144
logger.debug("config=<%s> | initializing", self.config)
137145

138-
session = boto_session or boto3.Session()
139-
140146
# Add strands-agents to the request user agent
141147
if boto_client_config:
142148
existing_user_agent = getattr(boto_client_config, "user_agent_extra", None)
@@ -151,8 +157,6 @@ def __init__(
151157
else:
152158
client_config = BotocoreConfig(user_agent_extra="strands-agents", read_timeout=DEFAULT_READ_TIMEOUT)
153159

154-
resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION
155-
156160
self.client = session.client(
157161
service_name="bedrock-runtime",
158162
config=client_config,
@@ -197,13 +201,15 @@ def format_request(
197201
messages: Messages,
198202
tool_specs: Optional[list[ToolSpec]] = None,
199203
system_prompt: Optional[str] = None,
204+
tool_choice: ToolChoice | None = None,
200205
) -> dict[str, Any]:
201206
"""Format a Bedrock converse stream request.
202207
203208
Args:
204209
messages: List of message objects to be processed by the model.
205210
tool_specs: List of tool specifications to make available to the model.
206211
system_prompt: System prompt to provide context to the model.
212+
tool_choice: Selection strategy for tool invocation.
207213
208214
Returns:
209215
A Bedrock converse stream request.
@@ -226,7 +232,7 @@ def format_request(
226232
else []
227233
),
228234
],
229-
"toolChoice": {"auto": {}},
235+
**({"toolChoice": tool_choice if tool_choice else {"auto": {}}}),
230236
}
231237
}
232238
if tool_specs
@@ -418,6 +424,7 @@ async def stream(
418424
messages: Messages,
419425
tool_specs: Optional[list[ToolSpec]] = None,
420426
system_prompt: Optional[str] = None,
427+
tool_choice: ToolChoice | None = None,
421428
**kwargs: Any,
422429
) -> AsyncGenerator[StreamEvent, None]:
423430
"""Stream conversation with the Bedrock model.
@@ -429,6 +436,7 @@ async def stream(
429436
messages: List of message objects to be processed by the model.
430437
tool_specs: List of tool specifications to make available to the model.
431438
system_prompt: System prompt to provide context to the model.
439+
tool_choice: Selection strategy for tool invocation.
432440
**kwargs: Additional keyword arguments for future extensibility.
433441
434442
Yields:
@@ -447,7 +455,7 @@ def callback(event: Optional[StreamEvent] = None) -> None:
447455
loop = asyncio.get_event_loop()
448456
queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue()
449457

450-
thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt)
458+
thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt, tool_choice)
451459
task = asyncio.create_task(thread)
452460

453461
while True:
@@ -465,6 +473,7 @@ def _stream(
465473
messages: Messages,
466474
tool_specs: Optional[list[ToolSpec]] = None,
467475
system_prompt: Optional[str] = None,
476+
tool_choice: ToolChoice | None = None,
468477
) -> None:
469478
"""Stream conversation with the Bedrock model.
470479
@@ -476,14 +485,15 @@ def _stream(
476485
messages: List of message objects to be processed by the model.
477486
tool_specs: List of tool specifications to make available to the model.
478487
system_prompt: System prompt to provide context to the model.
488+
tool_choice: Selection strategy for tool invocation.
479489
480490
Raises:
481491
ContextWindowOverflowException: If the input exceeds the model's context window.
482492
ModelThrottledException: If the model service is throttling requests.
483493
"""
484494
try:
485495
logger.debug("formatting request")
486-
request = self.format_request(messages, tool_specs, system_prompt)
496+
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
487497
logger.debug("request=<%s>", request)
488498

489499
logger.debug("invoking model")
@@ -740,6 +750,7 @@ async def structured_output(
740750
messages=prompt,
741751
tool_specs=[tool_spec],
742752
system_prompt=system_prompt,
753+
tool_choice=cast(ToolChoice, {"any": {}}),
743754
**kwargs,
744755
)
745756
async for event in streaming.process_stream(response):
@@ -764,3 +775,46 @@ async def structured_output(
764775
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")
765776

766777
yield {"output": output_model(**output_response)}
778+
779+
@staticmethod
780+
def _get_default_model_with_warning(region_name: str, model_config: Optional[BedrockConfig] = None) -> str:
781+
"""Get the default Bedrock modelId based on region.
782+
783+
If the region is not **known** to support inference then we show a helpful warning
784+
that compliments the exception that Bedrock will throw.
785+
If the customer provided a model_id in their config or they overrode the `DEFAULT_BEDROCK_MODEL_ID`
786+
then we should not process further.
787+
788+
Args:
789+
region_name (str): region for bedrock model
790+
model_config (Optional[dict[str, Any]]): Model Config that caller passes in on init
791+
"""
792+
if DEFAULT_BEDROCK_MODEL_ID != _DEFAULT_BEDROCK_MODEL_ID.format("us"):
793+
return DEFAULT_BEDROCK_MODEL_ID
794+
795+
model_config = model_config or {}
796+
if model_config.get("model_id"):
797+
return model_config["model_id"]
798+
799+
prefix_inference_map = {"ap": "apac"} # some inference endpoints can be a bit different than the region prefix
800+
801+
prefix = "-".join(region_name.split("-")[:-2]).lower() # handles `us-east-1` or `us-gov-east-1`
802+
if prefix not in {"us", "eu", "ap", "us-gov"}:
803+
warnings.warn(
804+
f"""
805+
================== WARNING ==================
806+
807+
This region {region_name} does not support
808+
our default inference endpoint: {_DEFAULT_BEDROCK_MODEL_ID.format(prefix)}.
809+
Update the agent to pass in a 'model_id' like so:
810+
```
811+
Agent(..., model='valid_model_id', ...)
812+
````
813+
Documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
814+
815+
==================================================
816+
""",
817+
stacklevel=2,
818+
)
819+
820+
return _DEFAULT_BEDROCK_MODEL_ID.format(prefix_inference_map.get(prefix, prefix))

0 commit comments

Comments
 (0)