|
89 | 89 | } |
90 | 90 | _BEDROCK_CONTENT_BLOCK_TYPES: set[_ContentBlockType] = set(_BEDROCK_CONTENT_BLOCK_FIELDS.keys()) |
91 | 91 |
|
| 92 | +# Nested schemas for deep filtering of Bedrock content blocks |
| 93 | +_BEDROCK_CONTENT_BLOCK_SCHEMAS: dict[_ContentBlockType, dict[str, Any]] = { |
| 94 | + "image": { |
| 95 | + "format": True, |
| 96 | + "source": {"bytes": True, "s3Location": {"bucket": True, "key": True, "region": True, "version": True}}, |
| 97 | + }, |
| 98 | + "toolResult": {"content": True, "toolUseId": True, "status": True}, |
| 99 | + "toolUse": {"input": True, "name": True, "toolUseId": True}, |
| 100 | + "document": { |
| 101 | + "name": True, |
| 102 | + "source": {"bytes": True, "s3Location": {"bucket": True, "key": True, "region": True, "version": True}}, |
| 103 | + "format": True, |
| 104 | + "citations": True, |
| 105 | + "context": True, |
| 106 | + }, |
| 107 | + "video": { |
| 108 | + "format": True, |
| 109 | + "source": {"bytes": True, "s3Location": {"bucket": True, "key": True, "region": True, "version": True}}, |
| 110 | + }, |
| 111 | + "reasoningContent": {"reasoningText": {"text": True, "signature": True}, "redactedContent": True}, |
| 112 | + "citationsContent": {"citations": True, "content": True}, |
| 113 | + "cachePoint": {"type": True}, |
| 114 | + "guardContent": { |
| 115 | + "image": { |
| 116 | + "format": True, |
| 117 | + "source": {"bytes": True, "s3Location": {"bucket": True, "key": True, "region": True, "version": True}}, |
| 118 | + }, |
| 119 | + "text": {"qualifiers": True, "text": True}, |
| 120 | + }, |
| 121 | +} |
| 122 | + |
92 | 123 | T = TypeVar("T", bound=BaseModel) |
93 | 124 |
|
94 | 125 | DEFAULT_READ_TIMEOUT = 120 |
@@ -369,12 +400,12 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages: |
369 | 400 | # Should only be one block type per content block since it is a discriminated union |
370 | 401 | block_type = cast(_ContentBlockType, next(iter(filterable_block_types))) |
371 | 402 | block_data = content_block[block_type] |
372 | | - allowed_fields = _BEDROCK_CONTENT_BLOCK_FIELDS[block_type].copy() |
| 403 | + schema = _BEDROCK_CONTENT_BLOCK_SCHEMAS[block_type].copy() |
373 | 404 |
|
374 | 405 | if block_type == "toolResult" and not self._should_include_tool_result_status(): |
375 | | - allowed_fields.discard("status") |
| 406 | + schema.pop("status", None) |
376 | 407 |
|
377 | | - cleaned_data = {k: v for k, v in block_data.items() if k in allowed_fields} |
| 408 | + cleaned_data = _deep_filter(block_data, schema) |
378 | 409 | cleaned_content.append(cast(ContentBlock, {block_type: cleaned_data})) |
379 | 410 | else: |
380 | 411 | # Keep other content blocks as-is |
@@ -805,3 +836,34 @@ async def structured_output( |
805 | 836 | raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") |
806 | 837 |
|
807 | 838 | yield {"output": output_model(**output_response)} |
| 839 | + |
| 840 | + |
| 841 | +def _deep_filter(data: Union[dict[str, Any], Any], schema: dict[str, Any]) -> dict[str, Any]: |
| 842 | + """Fast recursive filtering using nested dict schemas. |
| 843 | +
|
| 844 | + Args: |
| 845 | + data: Input data to filter (content block or nested dict) |
| 846 | + schema: Schema defining allowed fields and nested structure |
| 847 | +
|
| 848 | + Returns: |
| 849 | + Filtered dictionary containing only schema-defined fields |
| 850 | + """ |
| 851 | + if not isinstance(data, dict): |
| 852 | + return {} |
| 853 | + |
| 854 | + result = {} |
| 855 | + for key in data.keys() & schema.keys(): |
| 856 | + value = data[key] |
| 857 | + schema_spec = schema[key] |
| 858 | + |
| 859 | + if schema_spec is True: |
| 860 | + result[key] = value |
| 861 | + elif isinstance(schema_spec, dict) and isinstance(value, dict): |
| 862 | + filtered = _deep_filter(value, schema_spec) |
| 863 | + if filtered: |
| 864 | + result[key] = filtered |
| 865 | + elif isinstance(schema_spec, dict) and isinstance(value, list): |
| 866 | + result[key] = [_deep_filter(item, schema_spec) for item in value if isinstance(item, dict)] |
| 867 | + else: |
| 868 | + result[key] = value |
| 869 | + return result |
0 commit comments