Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,23 +493,26 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html
if "citationsContent" in content:
citations = content["citationsContent"]
result = {}
result: dict[str, Any] = {}

if "citations" in citations:
result["citations"] = []
for citation in citations["citations"]:
filtered_citation: dict[str, Any] = {}
if "location" in citation:
location = citation["location"]
filtered_location = {}
# Filter location fields to only include Bedrock-supported ones
if "documentIndex" in location:
filtered_location["documentIndex"] = location["documentIndex"]
if "start" in location:
filtered_location["start"] = location["start"]
if "end" in location:
filtered_location["end"] = location["end"]
filtered_citation["location"] = filtered_location
filtered_location: dict[str, Any] = {}
# Handle web-based citations
if "web" in location:
filtered_location["web"] = {
k: v for k, v in location["web"].items() if k in ("url", "domain")
}
# Handle document-based citations
for field in ("documentIndex", "start", "end"):
if field in location:
filtered_location[field] = location[field]
if filtered_location:
filtered_citation["location"] = filtered_location
if "sourceContent" in citation:
filtered_source_content: list[dict[str, Any]] = []
for source_content in citation["sourceContent"]:
Expand Down Expand Up @@ -831,20 +834,21 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera
# For non-streaming citations, emit text and metadata deltas in sequence
# to match streaming behavior where they flow naturally
if "content" in content["citationsContent"]:
text_content = "".join([content["text"] for content in content["citationsContent"]["content"]])
text_content = "".join([c["text"] for c in content["citationsContent"]["content"]])
yield {
"contentBlockDelta": {"delta": {"text": text_content}},
}

for citation in content["citationsContent"]["citations"]:
# Then emit citation metadata (for structure)

citation_metadata: CitationsDelta = {
"title": citation["title"],
"location": citation["location"],
"sourceContent": citation["sourceContent"],
}
yield {"contentBlockDelta": {"delta": {"citation": citation_metadata}}}
# Emit citation metadata with only present fields
citation_metadata: dict[str, Any] = {}
if "title" in citation:
citation_metadata["title"] = citation["title"]
if "location" in citation:
citation_metadata["location"] = citation["location"]
if "sourceContent" in citation:
citation_metadata["sourceContent"] = citation["sourceContent"]
yield {"contentBlockDelta": {"delta": {"citation": cast(CitationsDelta, citation_metadata)}}}

# Yield contentBlockStop event
yield {"contentBlockStop": {}}
Expand Down
16 changes: 15 additions & 1 deletion src/strands/types/citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,22 @@ class DocumentPageLocation(TypedDict, total=False):
end: int


class WebLocation(TypedDict, total=False):
"""Specifies a web-based location for cited content.

Provides location information for content cited from web sources.

Attributes:
url: The URL of the web page containing the cited content.
domain: The domain of the web page containing the cited content.
"""

url: str
domain: str


# Union type for citation locations
CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation]
CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation, WebLocation]


class CitationSourceContent(TypedDict, total=False):
Expand Down
72 changes: 72 additions & 0 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2070,3 +2070,75 @@ async def test_stream_backward_compatibility_system_prompt(bedrock_client, model
"system": [{"text": system_prompt}],
}
bedrock_client.converse_stream.assert_called_once_with(**expected_request)


def test_format_request_message_content_web_citation(model):
"""Test that web citations are correctly filtered to include only url and domain."""
content = {
"citationsContent": {
"citations": [
{
"title": "Web Citation",
"location": {"web": {"url": "https://example.com", "domain": "example.com", "extra": "ignored"}},
"sourceContent": [{"text": "Content"}],
}
],
"content": [{"text": "Generated text"}],
}
}

result = model._format_request_message_content(content)

citation = result["citationsContent"]["citations"][0]
assert citation["location"]["web"] == {"url": "https://example.com", "domain": "example.com"}


def test_format_request_message_content_document_citation(model):
"""Test that document citations preserve documentIndex, start, and end fields."""
content = {
"citationsContent": {
"citations": [
{
"title": "Doc Citation",
"location": {"documentIndex": 0, "start": 100, "end": 200},
"sourceContent": [{"text": "Excerpt"}],
}
],
"content": [{"text": "Generated text"}],
}
}

result = model._format_request_message_content(content)

assert result["citationsContent"]["citations"][0]["location"] == {"documentIndex": 0, "start": 100, "end": 200}


def test_format_request_message_content_citation_optional_fields(model):
"""Test that citations with missing optional fields are handled correctly."""
content = {
"citationsContent": {
"citations": [{"title": "Minimal", "location": {"web": {"url": "https://example.com"}}}],
"content": [{"text": "Text"}],
}
}

result = model._format_request_message_content(content)

citation = result["citationsContent"]["citations"][0]
assert citation["title"] == "Minimal"
assert citation["location"]["web"]["url"] == "https://example.com"
assert "sourceContent" not in citation


def test_format_request_message_content_citation_empty_location(model):
"""Test that citations with invalid locations exclude the location field."""
content = {
"citationsContent": {
"citations": [{"title": "No valid location", "location": {"unknown": "value"}}],
"content": [{"text": "Text"}],
}
}

result = model._format_request_message_content(content)

assert "location" not in result["citationsContent"]["citations"][0]