Skip to content

Commit 34840ef

Browse files
author
Danilo Poccia
committed
feat(citations): Add support for web-based citations in Bedrock Converse API
Add support for web-based citations in addition to document-based citations: - Added WebLocation TypedDict to citations.py with url and domain fields - Updated CitationLocation union to include WebLocation - Updated bedrock.py to filter web citation fields (url, domain only) - Handle optional citation fields gracefully (title, location, sourceContent) - Added tests for web citations, document citations, and edge cases
1 parent a64a851 commit 34840ef

File tree

3 files changed

+110
-20
lines changed

3 files changed

+110
-20
lines changed

src/strands/models/bedrock.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -493,23 +493,26 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
493493
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html
494494
if "citationsContent" in content:
495495
citations = content["citationsContent"]
496-
result = {}
496+
result: dict[str, Any] = {}
497497

498498
if "citations" in citations:
499499
result["citations"] = []
500500
for citation in citations["citations"]:
501501
filtered_citation: dict[str, Any] = {}
502502
if "location" in citation:
503503
location = citation["location"]
504-
filtered_location = {}
505-
# Filter location fields to only include Bedrock-supported ones
506-
if "documentIndex" in location:
507-
filtered_location["documentIndex"] = location["documentIndex"]
508-
if "start" in location:
509-
filtered_location["start"] = location["start"]
510-
if "end" in location:
511-
filtered_location["end"] = location["end"]
512-
filtered_citation["location"] = filtered_location
504+
filtered_location: dict[str, Any] = {}
505+
# Handle web-based citations
506+
if "web" in location:
507+
filtered_location["web"] = {
508+
k: v for k, v in location["web"].items() if k in ("url", "domain")
509+
}
510+
# Handle document-based citations
511+
for field in ("documentIndex", "start", "end"):
512+
if field in location:
513+
filtered_location[field] = location[field]
514+
if filtered_location:
515+
filtered_citation["location"] = filtered_location
513516
if "sourceContent" in citation:
514517
filtered_source_content: list[dict[str, Any]] = []
515518
for source_content in citation["sourceContent"]:
@@ -831,20 +834,21 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera
831834
# For non-streaming citations, emit text and metadata deltas in sequence
832835
# to match streaming behavior where they flow naturally
833836
if "content" in content["citationsContent"]:
834-
text_content = "".join([content["text"] for content in content["citationsContent"]["content"]])
837+
text_content = "".join([c["text"] for c in content["citationsContent"]["content"]])
835838
yield {
836839
"contentBlockDelta": {"delta": {"text": text_content}},
837840
}
838841

839842
for citation in content["citationsContent"]["citations"]:
840-
# Then emit citation metadata (for structure)
841-
842-
citation_metadata: CitationsDelta = {
843-
"title": citation["title"],
844-
"location": citation["location"],
845-
"sourceContent": citation["sourceContent"],
846-
}
847-
yield {"contentBlockDelta": {"delta": {"citation": citation_metadata}}}
843+
# Emit citation metadata with only present fields
844+
citation_metadata: dict[str, Any] = {}
845+
if "title" in citation:
846+
citation_metadata["title"] = citation["title"]
847+
if "location" in citation:
848+
citation_metadata["location"] = citation["location"]
849+
if "sourceContent" in citation:
850+
citation_metadata["sourceContent"] = citation["sourceContent"]
851+
yield {"contentBlockDelta": {"delta": {"citation": cast(CitationsDelta, citation_metadata)}}}
848852

849853
# Yield contentBlockStop event
850854
yield {"contentBlockStop": {}}

src/strands/types/citations.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,22 @@ class DocumentPageLocation(TypedDict, total=False):
7777
end: int
7878

7979

80+
class WebLocation(TypedDict, total=False):
81+
"""Specifies a web-based location for cited content.
82+
83+
Provides location information for content cited from web sources.
84+
85+
Attributes:
86+
url: The URL of the web page containing the cited content.
87+
domain: The domain of the web page containing the cited content.
88+
"""
89+
90+
url: str
91+
domain: str
92+
93+
8094
# Union type for citation locations
81-
CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation]
95+
CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation, WebLocation]
8296

8397

8498
class CitationSourceContent(TypedDict, total=False):

tests/strands/models/test_bedrock.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2070,3 +2070,75 @@ async def test_stream_backward_compatibility_system_prompt(bedrock_client, model
20702070
"system": [{"text": system_prompt}],
20712071
}
20722072
bedrock_client.converse_stream.assert_called_once_with(**expected_request)
2073+
2074+
2075+
def test_format_request_message_content_web_citation(model):
2076+
"""Test that web citations are correctly filtered to include only url and domain."""
2077+
content = {
2078+
"citationsContent": {
2079+
"citations": [
2080+
{
2081+
"title": "Web Citation",
2082+
"location": {"web": {"url": "https://example.com", "domain": "example.com", "extra": "ignored"}},
2083+
"sourceContent": [{"text": "Content"}],
2084+
}
2085+
],
2086+
"content": [{"text": "Generated text"}],
2087+
}
2088+
}
2089+
2090+
result = model._format_request_message_content(content)
2091+
2092+
citation = result["citationsContent"]["citations"][0]
2093+
assert citation["location"]["web"] == {"url": "https://example.com", "domain": "example.com"}
2094+
2095+
2096+
def test_format_request_message_content_document_citation(model):
2097+
"""Test that document citations preserve documentIndex, start, and end fields."""
2098+
content = {
2099+
"citationsContent": {
2100+
"citations": [
2101+
{
2102+
"title": "Doc Citation",
2103+
"location": {"documentIndex": 0, "start": 100, "end": 200},
2104+
"sourceContent": [{"text": "Excerpt"}],
2105+
}
2106+
],
2107+
"content": [{"text": "Generated text"}],
2108+
}
2109+
}
2110+
2111+
result = model._format_request_message_content(content)
2112+
2113+
assert result["citationsContent"]["citations"][0]["location"] == {"documentIndex": 0, "start": 100, "end": 200}
2114+
2115+
2116+
def test_format_request_message_content_citation_optional_fields(model):
2117+
"""Test that citations with missing optional fields are handled correctly."""
2118+
content = {
2119+
"citationsContent": {
2120+
"citations": [{"title": "Minimal", "location": {"web": {"url": "https://example.com"}}}],
2121+
"content": [{"text": "Text"}],
2122+
}
2123+
}
2124+
2125+
result = model._format_request_message_content(content)
2126+
2127+
citation = result["citationsContent"]["citations"][0]
2128+
assert citation["title"] == "Minimal"
2129+
assert citation["location"]["web"]["url"] == "https://example.com"
2130+
assert "sourceContent" not in citation
2131+
2132+
2133+
def test_format_request_message_content_citation_empty_location(model):
2134+
"""Test that citations with invalid locations exclude the location field."""
2135+
content = {
2136+
"citationsContent": {
2137+
"citations": [{"title": "No valid location", "location": {"unknown": "value"}}],
2138+
"content": [{"text": "Text"}],
2139+
}
2140+
}
2141+
2142+
result = model._format_request_message_content(content)
2143+
2144+
assert "location" not in result["citationsContent"]["citations"][0]

0 commit comments

Comments
 (0)