Skip to content

Commit 83b6157

Browse files
committed
fix: move to complete mapping like other providers
1 parent ae1175e commit 83b6157

File tree

3 files changed

+194
-181
lines changed

3 files changed

+194
-181
lines changed

src/strands/models/bedrock.py

Lines changed: 186 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from ..event_loop import streaming
1919
from ..tools import convert_pydantic_to_tool_spec
20-
from ..types.content import ContentBlock, Message, Messages, _ContentBlockType
20+
from ..types.content import ContentBlock, Messages
2121
from ..types.exceptions import (
2222
ContextWindowOverflowException,
2323
ModelThrottledException,
@@ -43,83 +43,6 @@
4343
"anthropic.claude",
4444
]
4545

46-
# Allowed fields for each Bedrock content block type to prevent validation exceptions
47-
# Bedrock strictly validates content blocks and throws exceptions for unknown fields
48-
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ContentBlock.html
49-
_BEDROCK_CONTENT_BLOCK_FIELDS: dict[_ContentBlockType, set[str]] = {
50-
"image": {
51-
"format",
52-
"source",
53-
}, # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageBlock.html
54-
"toolResult": {
55-
"content",
56-
"toolUseId",
57-
"status",
58-
}, # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html
59-
"toolUse": {
60-
"input",
61-
"name",
62-
"toolUseId",
63-
}, # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolUseBlock.html
64-
"document": {
65-
"name",
66-
"source",
67-
"citations",
68-
"context",
69-
"format",
70-
}, # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_DocumentBlock.html
71-
"video": {
72-
"format",
73-
"source",
74-
}, # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_VideoBlock.html
75-
"reasoningContent": {
76-
"reasoningText",
77-
"redactedContent",
78-
}, # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ReasoningContentBlock.html
79-
"citationsContent": {
80-
"citations",
81-
"content",
82-
}, # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html
83-
"cachePoint": {"type"}, # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CachePointBlock.html
84-
"guardContent": {
85-
"image",
86-
"text",
87-
}, # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_GuardrailConverseContentBlock.html
88-
# Note: text is handled as a primitive (string)
89-
}
90-
_BEDROCK_CONTENT_BLOCK_TYPES: set[_ContentBlockType] = set(_BEDROCK_CONTENT_BLOCK_FIELDS.keys())
91-
92-
# Nested schemas for deep filtering of Bedrock content blocks
93-
_BEDROCK_CONTENT_BLOCK_SCHEMAS: dict[_ContentBlockType, dict[str, Any]] = {
94-
"image": {
95-
"format": True,
96-
"source": {"bytes": True, "s3Location": {"bucket": True, "key": True, "region": True, "version": True}},
97-
},
98-
"toolResult": {"content": True, "toolUseId": True, "status": True},
99-
"toolUse": {"input": True, "name": True, "toolUseId": True},
100-
"document": {
101-
"name": True,
102-
"source": {"bytes": True, "s3Location": {"bucket": True, "key": True, "region": True, "version": True}},
103-
"format": True,
104-
"citations": True,
105-
"context": True,
106-
},
107-
"video": {
108-
"format": True,
109-
"source": {"bytes": True, "s3Location": {"bucket": True, "key": True, "region": True, "version": True}},
110-
},
111-
"reasoningContent": {"reasoningText": {"text": True, "signature": True}, "redactedContent": True},
112-
"citationsContent": {"citations": True, "content": True},
113-
"cachePoint": {"type": True},
114-
"guardContent": {
115-
"image": {
116-
"format": True,
117-
"source": {"bytes": True, "s3Location": {"bucket": True, "key": True, "region": True, "version": True}},
118-
},
119-
"text": {"qualifiers": True, "text": True},
120-
},
121-
}
122-
12346
T = TypeVar("T", bound=BaseModel)
12447

12548
DEFAULT_READ_TIMEOUT = 120
@@ -258,17 +181,6 @@ def get_config(self) -> BedrockConfig:
258181
"""
259182
return self.config
260183

261-
def _should_include_tool_result_status(self) -> bool:
262-
"""Determine whether to include tool result status based on current config."""
263-
include_status = self.config.get("include_tool_result_status", "auto")
264-
265-
if include_status is True:
266-
return True
267-
elif include_status is False:
268-
return False
269-
else: # "auto"
270-
return any(model in self.config["model_id"] for model in _MODELS_INCLUDE_STATUS)
271-
272184
def format_request(
273185
self,
274186
messages: Messages,
@@ -352,7 +264,7 @@ def format_request(
352264
),
353265
}
354266

355-
def _format_bedrock_messages(self, messages: Messages) -> Messages:
267+
def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]:
356268
"""Format messages for Bedrock API compatibility.
357269
358270
This function ensures messages conform to Bedrock's expected format by:
@@ -373,13 +285,13 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages:
373285
content blocks to remove any additional fields before sending to Bedrock.
374286
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ContentBlock.html
375287
"""
376-
cleaned_messages = []
288+
cleaned_messages: list[dict[str, Any]] = []
377289

378290
filtered_unknown_members = False
379291
dropped_deepseek_reasoning_content = False
380292

381293
for message in messages:
382-
cleaned_content: list[ContentBlock] = []
294+
cleaned_content: list[dict[str, Any]] = []
383295

384296
for content_block in message["content"]:
385297
# Filter out SDK_UNKNOWN_MEMBER content blocks
@@ -393,28 +305,13 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages:
393305
dropped_deepseek_reasoning_content = True
394306
continue
395307

396-
# Clean content blocks that need field filtering for Bedrock API compatibility
397-
filterable_block_types = set(content_block.keys()) & _BEDROCK_CONTENT_BLOCK_TYPES
398-
399-
if filterable_block_types:
400-
# Should only be one block type per content block since it is a discriminated union
401-
block_type = cast(_ContentBlockType, next(iter(filterable_block_types)))
402-
block_data = content_block[block_type]
403-
schema = _BEDROCK_CONTENT_BLOCK_SCHEMAS[block_type].copy()
404-
405-
if block_type == "toolResult" and not self._should_include_tool_result_status():
406-
schema.pop("status", None)
407-
408-
cleaned_data = _deep_filter(block_data, schema)
409-
cleaned_content.append(cast(ContentBlock, {block_type: cleaned_data}))
410-
else:
411-
# Keep other content blocks as-is
412-
cleaned_content.append(content_block)
308+
# Format content blocks for Bedrock API compatibility
309+
formatted_content = self._format_request_message_content(content_block)
310+
cleaned_content.append(formatted_content)
413311

414312
# Create new message with cleaned content (skip if empty)
415313
if cleaned_content:
416-
cleaned_message: Message = Message(content=cleaned_content, role=message["role"])
417-
cleaned_messages.append(cleaned_message)
314+
cleaned_messages.append({"content": cleaned_content, "role": message["role"]})
418315

419316
if filtered_unknown_members:
420317
logger.warning(
@@ -427,6 +324,184 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages:
427324

428325
return cleaned_messages
429326

327+
def _should_include_tool_result_status(self) -> bool:
328+
"""Determine whether to include tool result status based on current config."""
329+
include_status = self.config.get("include_tool_result_status", "auto")
330+
331+
if include_status is True:
332+
return True
333+
elif include_status is False:
334+
return False
335+
else: # "auto"
336+
return any(model in self.config["model_id"] for model in _MODELS_INCLUDE_STATUS)
337+
338+
def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]:
339+
"""Format a Bedrock content block.
340+
341+
Bedrock strictly validates content blocks and throws exceptions for unknown fields.
342+
This function extracts only the fields that Bedrock supports for each content type.
343+
344+
Args:
345+
content: Content block to format.
346+
347+
Returns:
348+
Bedrock formatted content block.
349+
350+
Raises:
351+
TypeError: If the content block type is not supported by Bedrock.
352+
"""
353+
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CachePointBlock.html
354+
if "cachePoint" in content:
355+
return {"cachePoint": {"type": content["cachePoint"]["type"]}}
356+
357+
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_DocumentBlock.html
358+
if "document" in content:
359+
document = content["document"]
360+
result: dict[str, Any] = {}
361+
362+
# Handle required fields (all optional due to total=False)
363+
if "name" in document:
364+
result["name"] = document["name"]
365+
if "format" in document:
366+
result["format"] = document["format"]
367+
368+
# Handle source
369+
if "source" in document:
370+
result["source"] = {"bytes": document["source"]["bytes"]}
371+
372+
# Handle optional fields
373+
if "citations" in document and document["citations"] is not None:
374+
result["citations"] = {"enabled": document["citations"]["enabled"]}
375+
if "context" in document:
376+
result["context"] = document["context"]
377+
378+
return {"document": result}
379+
380+
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_GuardrailConverseContentBlock.html
381+
if "guardContent" in content:
382+
guard = content["guardContent"]
383+
guard_text = guard["text"]
384+
result = {"text": {"text": guard_text["text"], "qualifiers": guard_text["qualifiers"]}}
385+
return {"guardContent": result}
386+
387+
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageBlock.html
388+
if "image" in content:
389+
image = content["image"]
390+
source = image["source"]
391+
formatted_source = {}
392+
if "bytes" in source:
393+
formatted_source = {"bytes": source["bytes"]}
394+
result = {"format": image["format"], "source": formatted_source}
395+
return {"image": result}
396+
397+
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ReasoningContentBlock.html
398+
if "reasoningContent" in content:
399+
reasoning = content["reasoningContent"]
400+
result = {}
401+
402+
if "reasoningText" in reasoning:
403+
reasoning_text = reasoning["reasoningText"]
404+
result["reasoningText"] = {}
405+
if "text" in reasoning_text:
406+
result["reasoningText"]["text"] = reasoning_text["text"]
407+
# Only include signature if truthy (avoid empty strings)
408+
if reasoning_text.get("signature"):
409+
result["reasoningText"]["signature"] = reasoning_text["signature"]
410+
411+
if "redactedContent" in reasoning:
412+
result["redactedContent"] = reasoning["redactedContent"]
413+
414+
return {"reasoningContent": result}
415+
416+
# Pass through text and other simple content types
417+
if "text" in content:
418+
return {"text": content["text"]}
419+
420+
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html
421+
if "toolResult" in content:
422+
tool_result = content["toolResult"]
423+
formatted_content: list[dict[str, Any]] = []
424+
for tool_result_content in tool_result["content"]:
425+
if "json" in tool_result_content:
426+
# Handle json field since not in ContentBlock but valid in ToolResultContent
427+
formatted_content.append({"json": tool_result_content["json"]})
428+
else:
429+
formatted_content.append(
430+
self._format_request_message_content(cast(ContentBlock, tool_result_content))
431+
)
432+
433+
result = {
434+
"content": formatted_content,
435+
"toolUseId": tool_result["toolUseId"],
436+
}
437+
if "status" in tool_result and self._should_include_tool_result_status():
438+
result["status"] = tool_result["status"]
439+
return {"toolResult": result}
440+
441+
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolUseBlock.html
442+
if "toolUse" in content:
443+
tool_use = content["toolUse"]
444+
return {
445+
"toolUse": {
446+
"input": tool_use["input"],
447+
"name": tool_use["name"],
448+
"toolUseId": tool_use["toolUseId"],
449+
}
450+
}
451+
452+
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_VideoBlock.html
453+
if "video" in content:
454+
video = content["video"]
455+
source = video["source"]
456+
formatted_source = {}
457+
if "bytes" in source:
458+
formatted_source = {"bytes": source["bytes"]}
459+
result = {"format": video["format"], "source": formatted_source}
460+
return {"video": result}
461+
462+
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html
463+
if "citationsContent" in content:
464+
citations = content["citationsContent"]
465+
result = {}
466+
467+
if "citations" in citations:
468+
result["citations"] = []
469+
for citation in citations["citations"]:
470+
filtered_citation: dict[str, Any] = {}
471+
if "location" in citation:
472+
location = citation["location"]
473+
filtered_location = {}
474+
# Filter location fields to only include Bedrock-supported ones
475+
if "documentIndex" in location:
476+
filtered_location["documentIndex"] = location["documentIndex"]
477+
if "start" in location:
478+
filtered_location["start"] = location["start"]
479+
if "end" in location:
480+
filtered_location["end"] = location["end"]
481+
filtered_citation["location"] = filtered_location
482+
if "sourceContent" in citation:
483+
filtered_source_content: list[dict[str, Any]] = []
484+
for source_content in citation["sourceContent"]:
485+
if "text" in source_content:
486+
filtered_source_content.append({"text": source_content["text"]})
487+
if filtered_source_content:
488+
filtered_citation["sourceContent"] = filtered_source_content
489+
if "title" in citation:
490+
filtered_citation["title"] = citation["title"]
491+
result["citations"].append(filtered_citation)
492+
493+
if "content" in citations:
494+
filtered_content: list[dict[str, Any]] = []
495+
for generated_content in citations["content"]:
496+
if "text" in generated_content:
497+
filtered_content.append({"text": generated_content["text"]})
498+
if filtered_content:
499+
result["content"] = filtered_content
500+
501+
return {"citationsContent": result}
502+
503+
raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
504+
430505
def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
431506
"""Check if guardrail data contains any blocked policies.
432507
@@ -836,34 +911,3 @@ async def structured_output(
836911
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")
837912

838913
yield {"output": output_model(**output_response)}
839-
840-
841-
def _deep_filter(data: Union[dict[str, Any], Any], schema: dict[str, Any]) -> dict[str, Any]:
842-
"""Fast recursive filtering using nested dict schemas.
843-
844-
Args:
845-
data: Input data to filter (content block or nested dict)
846-
schema: Schema defining allowed fields and nested structure
847-
848-
Returns:
849-
Filtered dictionary containing only schema-defined fields
850-
"""
851-
if not isinstance(data, dict):
852-
return {}
853-
854-
result = {}
855-
for key in data.keys() & schema.keys():
856-
value = data[key]
857-
schema_spec = schema[key]
858-
859-
if schema_spec is True:
860-
result[key] = value
861-
elif isinstance(schema_spec, dict) and isinstance(value, dict):
862-
filtered = _deep_filter(value, schema_spec)
863-
if filtered:
864-
result[key] = filtered
865-
elif isinstance(schema_spec, dict) and isinstance(value, list):
866-
result[key] = [_deep_filter(item, schema_spec) for item in value if isinstance(item, dict)]
867-
else:
868-
result[key] = value
869-
return result

0 commit comments

Comments
 (0)