Skip to content

Commit b0af868

Browse files
author
Shang Liu
committed
feat: add support for Bedrock/Anthropic ToolChoice to structured_output
1 parent 17ccdd2 commit b0af868

File tree

14 files changed

+368
-32
lines changed

14 files changed

+368
-32
lines changed

src/strands/models/anthropic.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ..types.content import ContentBlock, Messages
1919
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
2020
from ..types.streaming import StreamEvent
21-
from ..types.tools import ToolSpec
21+
from ..types.tools import ToolChoice, ToolSpec
2222
from .model import Model
2323

2424
logger = logging.getLogger(__name__)
@@ -192,14 +192,19 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]:
192192
return formatted_messages
193193

194194
def format_request(
195-
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
195+
self,
196+
messages: Messages,
197+
tool_specs: Optional[list[ToolSpec]] = None,
198+
system_prompt: Optional[str] = None,
199+
tool_choice: Optional[ToolChoice] = None,
196200
) -> dict[str, Any]:
197201
"""Format an Anthropic streaming request.
198202
199203
Args:
200204
messages: List of message objects to be processed by the model.
201205
tool_specs: List of tool specifications to make available to the model.
202206
system_prompt: System prompt to provide context to the model.
207+
tool_choice: Selection strategy for tool invocation.
203208
204209
Returns:
205210
An Anthropic streaming request.
@@ -220,6 +225,7 @@ def format_request(
220225
}
221226
for tool_spec in tool_specs or []
222227
],
228+
**({"tool_choice": tool_choice} if tool_choice else {}),
223229
**({"system": system_prompt} if system_prompt else {}),
224230
**(self.config.get("params") or {}),
225231
}
@@ -347,6 +353,7 @@ async def stream(
347353
messages: Messages,
348354
tool_specs: Optional[list[ToolSpec]] = None,
349355
system_prompt: Optional[str] = None,
356+
tool_choice: Optional[ToolChoice] = None,
350357
**kwargs: Any,
351358
) -> AsyncGenerator[StreamEvent, None]:
352359
"""Stream conversation with the Anthropic model.
@@ -355,6 +362,7 @@ async def stream(
355362
messages: List of message objects to be processed by the model.
356363
tool_specs: List of tool specifications to make available to the model.
357364
system_prompt: System prompt to provide context to the model.
365+
tool_choice: Selection strategy for tool invocation.
358366
**kwargs: Additional keyword arguments for future extensibility.
359367
360368
Yields:
@@ -365,7 +373,7 @@ async def stream(
365373
ModelThrottledException: If the request is throttled by Anthropic.
366374
"""
367375
logger.debug("formatting request")
368-
request = self.format_request(messages, tool_specs, system_prompt)
376+
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
369377
logger.debug("request=<%s>", request)
370378

371379
logger.debug("invoking model")
@@ -407,7 +415,13 @@ async def structured_output(
407415
"""
408416
tool_spec = convert_pydantic_to_tool_spec(output_model)
409417

410-
response = self.stream(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, **kwargs)
418+
response = self.stream(
419+
messages=prompt,
420+
tool_specs=[tool_spec],
421+
system_prompt=system_prompt,
422+
tool_choice=cast(ToolChoice, {"any": {}}),
423+
**kwargs,
424+
)
411425
async for event in process_stream(response):
412426
yield event
413427

src/strands/models/bedrock.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import json
88
import logging
99
import os
10-
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
10+
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, cast
1111

1212
import boto3
1313
from botocore.config import Config as BotocoreConfig
@@ -20,7 +20,7 @@
2020
from ..types.content import ContentBlock, Message, Messages
2121
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
2222
from ..types.streaming import StreamEvent
23-
from ..types.tools import ToolResult, ToolSpec
23+
from ..types.tools import ToolChoice, ToolResult, ToolSpec
2424
from .model import Model
2525

2626
logger = logging.getLogger(__name__)
@@ -168,13 +168,15 @@ def format_request(
168168
messages: Messages,
169169
tool_specs: Optional[list[ToolSpec]] = None,
170170
system_prompt: Optional[str] = None,
171+
tool_choice: Optional[ToolChoice] = None,
171172
) -> dict[str, Any]:
172173
"""Format a Bedrock converse stream request.
173174
174175
Args:
175176
messages: List of message objects to be processed by the model.
176177
tool_specs: List of tool specifications to make available to the model.
177178
system_prompt: System prompt to provide context to the model.
179+
tool_choice: Selection strategy for tool invocation.
178180
179181
Returns:
180182
A Bedrock converse stream request.
@@ -197,7 +199,7 @@ def format_request(
197199
else []
198200
),
199201
],
200-
"toolChoice": {"auto": {}},
202+
**({"toolChoice": tool_choice} if tool_choice else {}),
201203
}
202204
}
203205
if tool_specs
@@ -355,6 +357,7 @@ async def stream(
355357
messages: Messages,
356358
tool_specs: Optional[list[ToolSpec]] = None,
357359
system_prompt: Optional[str] = None,
360+
tool_choice: Optional[ToolChoice] = None,
358361
**kwargs: Any,
359362
) -> AsyncGenerator[StreamEvent, None]:
360363
"""Stream conversation with the Bedrock model.
@@ -366,6 +369,7 @@ async def stream(
366369
messages: List of message objects to be processed by the model.
367370
tool_specs: List of tool specifications to make available to the model.
368371
system_prompt: System prompt to provide context to the model.
372+
tool_choice: Selection strategy for tool invocation.
369373
**kwargs: Additional keyword arguments for future extensibility.
370374
371375
Yields:
@@ -384,7 +388,7 @@ def callback(event: Optional[StreamEvent] = None) -> None:
384388
loop = asyncio.get_event_loop()
385389
queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue()
386390

387-
thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt)
391+
thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt, tool_choice)
388392
task = asyncio.create_task(thread)
389393

390394
while True:
@@ -402,6 +406,7 @@ def _stream(
402406
messages: Messages,
403407
tool_specs: Optional[list[ToolSpec]] = None,
404408
system_prompt: Optional[str] = None,
409+
tool_choice: Optional[ToolChoice] = None,
405410
) -> None:
406411
"""Stream conversation with the Bedrock model.
407412
@@ -413,14 +418,15 @@ def _stream(
413418
messages: List of message objects to be processed by the model.
414419
tool_specs: List of tool specifications to make available to the model.
415420
system_prompt: System prompt to provide context to the model.
421+
tool_choice: Selection strategy for tool invocation.
416422
417423
Raises:
418424
ContextWindowOverflowException: If the input exceeds the model's context window.
419425
ModelThrottledException: If the model service is throttling requests.
420426
"""
421427
try:
422428
logger.debug("formatting request")
423-
request = self.format_request(messages, tool_specs, system_prompt)
429+
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
424430
logger.debug("request=<%s>", request)
425431

426432
logger.debug("invoking model")
@@ -624,7 +630,13 @@ async def structured_output(
624630
"""
625631
tool_spec = convert_pydantic_to_tool_spec(output_model)
626632

627-
response = self.stream(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, **kwargs)
633+
response = self.stream(
634+
messages=prompt,
635+
tool_specs=[tool_spec],
636+
system_prompt=system_prompt,
637+
tool_choice=cast(ToolChoice, {"any": {}}),
638+
**kwargs,
639+
)
628640
async for event in streaming.process_stream(response):
629641
yield event
630642

src/strands/models/litellm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from ..types.content import ContentBlock, Messages
1616
from ..types.streaming import StreamEvent
17-
from ..types.tools import ToolSpec
17+
from ..types.tools import ToolChoice, ToolSpec
1818
from .openai import OpenAIModel
1919

2020
logger = logging.getLogger(__name__)
@@ -109,6 +109,7 @@ async def stream(
109109
messages: Messages,
110110
tool_specs: Optional[list[ToolSpec]] = None,
111111
system_prompt: Optional[str] = None,
112+
tool_choice: Optional[ToolChoice] = None,
112113
**kwargs: Any,
113114
) -> AsyncGenerator[StreamEvent, None]:
114115
"""Stream conversation with the LiteLLM model.
@@ -117,6 +118,8 @@ async def stream(
117118
messages: List of message objects to be processed by the model.
118119
tool_specs: List of tool specifications to make available to the model.
119120
system_prompt: System prompt to provide context to the model.
121+
tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for
122+
interface consistency but is currently ignored for this model provider.**
120123
**kwargs: Additional keyword arguments for future extensibility.
121124
122125
Yields:

src/strands/models/llamaapi.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ..types.content import ContentBlock, Messages
1919
from ..types.exceptions import ModelThrottledException
2020
from ..types.streaming import StreamEvent, Usage
21-
from ..types.tools import ToolResult, ToolSpec, ToolUse
21+
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
2222
from .model import Model
2323

2424
logger = logging.getLogger(__name__)
@@ -327,6 +327,7 @@ async def stream(
327327
messages: Messages,
328328
tool_specs: Optional[list[ToolSpec]] = None,
329329
system_prompt: Optional[str] = None,
330+
tool_choice: Optional[ToolChoice] = None,
330331
**kwargs: Any,
331332
) -> AsyncGenerator[StreamEvent, None]:
332333
"""Stream conversation with the LlamaAPI model.
@@ -335,6 +336,8 @@ async def stream(
335336
messages: List of message objects to be processed by the model.
336337
tool_specs: List of tool specifications to make available to the model.
337338
system_prompt: System prompt to provide context to the model.
339+
tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for
340+
interface consistency but is currently ignored for this model provider.**
338341
**kwargs: Additional keyword arguments for future extensibility.
339342
340343
Yields:

src/strands/models/mistral.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..types.content import ContentBlock, Messages
1616
from ..types.exceptions import ModelThrottledException
1717
from ..types.streaming import StopReason, StreamEvent
18-
from ..types.tools import ToolResult, ToolSpec, ToolUse
18+
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
1919
from .model import Model
2020

2121
logger = logging.getLogger(__name__)
@@ -394,6 +394,7 @@ async def stream(
394394
messages: Messages,
395395
tool_specs: Optional[list[ToolSpec]] = None,
396396
system_prompt: Optional[str] = None,
397+
tool_choice: Optional[ToolChoice] = None,
397398
**kwargs: Any,
398399
) -> AsyncGenerator[StreamEvent, None]:
399400
"""Stream conversation with the Mistral model.
@@ -402,6 +403,8 @@ async def stream(
402403
messages: List of message objects to be processed by the model.
403404
tool_specs: List of tool specifications to make available to the model.
404405
system_prompt: System prompt to provide context to the model.
406+
tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for
407+
interface consistency but is currently ignored for this model provider.**
405408
**kwargs: Additional keyword arguments for future extensibility.
406409
407410
Yields:

src/strands/models/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from ..types.content import Messages
1010
from ..types.streaming import StreamEvent
11-
from ..types.tools import ToolSpec
11+
from ..types.tools import ToolChoice, ToolSpec
1212

1313
logger = logging.getLogger(__name__)
1414

@@ -70,6 +70,7 @@ def stream(
7070
messages: Messages,
7171
tool_specs: Optional[list[ToolSpec]] = None,
7272
system_prompt: Optional[str] = None,
73+
tool_choice: Optional[ToolChoice] = None,
7374
**kwargs: Any,
7475
) -> AsyncIterable[StreamEvent]:
7576
"""Stream conversation with the model.
@@ -84,6 +85,7 @@ def stream(
8485
messages: List of message objects to be processed by the model.
8586
tool_specs: List of tool specifications to make available to the model.
8687
system_prompt: System prompt to provide context to the model.
88+
tool_choice: Selection strategy for tool invocation.
8789
**kwargs: Additional keyword arguments for future extensibility.
8890
8991
Yields:

src/strands/models/ollama.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from ..types.content import ContentBlock, Messages
1515
from ..types.streaming import StopReason, StreamEvent
16-
from ..types.tools import ToolSpec
16+
from ..types.tools import ToolChoice, ToolSpec
1717
from .model import Model
1818

1919
logger = logging.getLogger(__name__)
@@ -284,6 +284,7 @@ async def stream(
284284
messages: Messages,
285285
tool_specs: Optional[list[ToolSpec]] = None,
286286
system_prompt: Optional[str] = None,
287+
tool_choice: Optional[ToolChoice] = None,
287288
**kwargs: Any,
288289
) -> AsyncGenerator[StreamEvent, None]:
289290
"""Stream conversation with the Ollama model.
@@ -292,6 +293,8 @@ async def stream(
292293
messages: List of message objects to be processed by the model.
293294
tool_specs: List of tool specifications to make available to the model.
294295
system_prompt: System prompt to provide context to the model.
296+
tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for
297+
interface consistency but is currently ignored for this model provider.**
295298
**kwargs: Additional keyword arguments for future extensibility.
296299
297300
Yields:

0 commit comments

Comments
 (0)